├── .clang-format ├── .github ├── dependabot.yml └── workflows │ └── build-and-test.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── VERSION_NUMBER ├── cmake ├── ONNXOptimizerConfig.cmake.in ├── ONNXOptimizerConfigVersion.cmake.in └── utils.cmake ├── examples └── onnx_optimizer_exec.cpp ├── onnxoptimizer ├── __init__.py ├── __main__.py ├── c_api │ ├── onnxoptimizer_c_api.cc │ └── onnxoptimizer_c_api.h ├── cpp2py_export.cc ├── model_util.cc ├── model_util.h ├── onnxoptimizer_main.py ├── optimize.cc ├── optimize.h ├── pass.cc ├── pass.h ├── pass_manager.cc ├── pass_manager.h ├── pass_registry.cc ├── pass_registry.h ├── passes │ ├── adjust_add.h │ ├── adjust_slice_and_matmul.h │ ├── bitscast.h │ ├── cse_util.h │ ├── data_type.h │ ├── eliminate_common_subexpression.h │ ├── eliminate_consecutive_idempotent_ops.h │ ├── eliminate_deadend.h │ ├── eliminate_duplicate_initializer.h │ ├── eliminate_identity.h │ ├── eliminate_if_with_const_cond.h │ ├── eliminate_nop_cast.h │ ├── eliminate_nop_concat.h │ ├── eliminate_nop_dropout.h │ ├── eliminate_nop_expand.h │ ├── eliminate_nop_flatten.h │ ├── eliminate_nop_monotone_argmax.h │ ├── eliminate_nop_pad.h │ ├── eliminate_nop_reshape.h │ ├── eliminate_nop_split.h │ ├── eliminate_nop_transpose.h │ ├── eliminate_nop_with_unit.h │ ├── eliminate_shape_gather.h │ ├── eliminate_shape_op.h │ ├── eliminate_slice_after_shape.h │ ├── eliminate_unused_initializer.h │ ├── extract_constant_to_initializer.h │ ├── fuse_add_bias_into_conv.h │ ├── fuse_bn_into_conv.h │ ├── fuse_concat_into_reshape.h │ ├── fuse_consecutive_concats.h │ ├── fuse_consecutive_log_softmax.h │ ├── fuse_consecutive_reduce_unsqueeze.h │ ├── fuse_consecutive_slices.h │ ├── fuse_consecutive_squeezes.h │ ├── fuse_consecutive_transposes.h │ ├── fuse_consecutive_unsqueezes.h │ ├── fuse_matmul_add_bias_into_gemm.h │ ├── fuse_pad_into_conv.h │ ├── fuse_pad_into_pool.h │ ├── fuse_qkv.h │ ├── fuse_transpose_into_gemm.h │ ├── lift_lexical_references.h │ ├── logging.h │ ├── nop.h │ ├── pass_util.cc │ ├── pass_util.h │ ├── rename_input_output.h │ ├── replace_einsum_with_matmul.h │ ├── rewrite_input_dtype.h │ ├── set_unique_name_for_nodes.h │ ├── split.h │ ├── string_utils.h │ ├── tensor_util.cc │ └── tensor_util.h └── test │ └── optimizer_test.py ├── setup.cfg ├── setup.py └── tools └── mypy-onnx.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | AllowShortBlocksOnASingleLine: false 3 | AllowShortCaseLabelsOnASingleLine: false 4 | AllowShortFunctionsOnASingleLine: Empty 5 | AllowShortLoopsOnASingleLine: false 6 | AllowShortIfStatementsOnASingleLine: false 7 | 8 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) ONNX Project Contributors 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # To get started with Dependabot version updates, you'll need to specify which 6 | # package ecosystems to update and where the package manifests are located. 7 | # Please see the documentation for all configuration options: 8 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 9 | 10 | version: 2 11 | updates: 12 | - package-ecosystem: "pip" # See documentation for possible values 13 | directory: "/" # Location of package manifests 14 | schedule: 15 | interval: "monthly" 16 | ignore: 17 | # Only update them manually since updating them might break compatibility 18 | - dependency-name: "numpy" 19 | - dependency-name: "protobuf" 20 | open-pull-requests-limit: 10 21 | 22 | - package-ecosystem: "github-actions" 23 | # Workflow files stored in the 24 | # default location of `.github/workflows` 25 | directory: "/" 26 | schedule: 27 | interval: "monthly" 28 | open-pull-requests-limit: 20 29 | 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build_wheels: 7 | env: 8 | CIBW_ARCHS_MACOS: x86_64 universal2 9 | MACOSX_DEPLOYMENT_TARGET: "10.15" 10 | CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28 11 | CIBW_BEFORE_ALL_LINUX: WD=`pwd` && /opt/python/cp38-cp38/bin/python -m pip install --target tmp_cmake cmake && cp tmp_cmake/bin/cmake /usr/local/bin/cmake && rm -rf tmp_cmake && /opt/python/cp38-cp38/bin/python -m pip install cmake && cmake --version && whereis cmake 12 | CIBW_BEFORE_ALL_MACOS: WD=`pwd` && pip install cmake 13 | CIBW_BEFORE_BUILD_LINUX: pip install protobuf 14 | CIBW_BEFORE_BUILD_WINDOWS: python -m pip install protobuf 15 | CIBW_BEFORE_BUILD_MACOS: pip install protobuf 16 | CIBW_TEST_REQUIRES_LINUX: pytest pytest-xdist flake8 mypy onnxruntime==1.19.2 17 | CIBW_TEST_REQUIRES_MACOS: pytest pytest-xdist 18 | CIBW_TEST_REQUIRES_WINDOWS: pytest pytest-xdist 19 | CIBW_BEFORE_TEST_LINUX: pip install torch==2.2.0+cpu torchvision==0.17.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 20 | CIBW_TEST_COMMAND: pytest {project}/onnxoptimizer/test 21 | CIBW_TEST_COMMAND_LINUX: cd {project} && flake8 && pytest 22 | # Python3.11 doesn't have torchvision prebuilt wheel 23 | CIBW_TEST_SKIP: "cp311-* *_arm64 *_universal2:arm64" 24 | CIBW_ENVIRONMENT: CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5" 25 | CIBW_ENVIRONMENT_WINDOWS: USE_MSVC_STATIC_RUNTIME=0 CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON" 26 | CIBW_BUILD: "${{ matrix.python }}-*" 27 | CIBW_SKIP: "*-win32 *-manylinux_i686 *-musllinux_*" 28 | name: Build wheels on ${{ matrix.os }} 29 | runs-on: ${{ matrix.os }} 30 | strategy: 31 | matrix: 32 | os: [ubuntu-22.04, windows-2019, macos-15] 33 | python: ["cp39", "cp310", "cp311", "cp312"] 34 | steps: 35 | - uses: actions/checkout@v4 36 | with: 37 | submodules: recursive 38 | - name: Build wheels 39 | uses: pypa/cibuildwheel@v2.23.3 40 | - uses: actions/upload-artifact@v4 41 | with: 42 | name: artifact-${{ matrix.os }}-${{ matrix.python }} 43 | path: ./wheelhouse/*.whl 44 | 45 | build_sdist: 46 | name: Build source distribution 47 | runs-on: ubuntu-latest 48 | steps: 49 | - uses: actions/checkout@v4 50 | with: 51 | submodules: recursive 52 | 53 | - run: python3 -m pip install protobuf 54 | 55 | - name: Build sdist 56 | run: pipx run build --sdist 57 | 58 | - name: Install and test sdist 59 | run: | 60 | # It's important to leave the project directory where a 'onnxoptimizer' subdirectory exists 61 | cd dist 62 | python3 -m pip install *.tar.gz 63 | python3 -c "import onnxoptimizer; print(onnxoptimizer.get_fuse_and_elimination_passes())" 64 | 65 | - uses: actions/upload-artifact@v4 66 | with: 67 | path: dist/*.tar.gz 68 | 69 | upload_pypi: 70 | name: Upload to PyPI 71 | needs: [build_wheels, build_sdist] 72 | runs-on: ubuntu-latest 73 | steps: 74 | - uses: actions/download-artifact@v4 75 | with: 76 | name: artifact 77 | path: dist 78 | 79 | - name: Publish distribution 📦 to Test PyPI 80 | uses: pypa/gh-action-pypi-publish@release/v1 81 | with: 82 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 83 | repository-url: https://test.pypi.org/legacy/ 84 | skip-existing: true 85 | 86 | - name: Publish distribution 📦 to PyPI 87 | if: startsWith(github.ref, 'refs/tags/v') 88 | uses: pypa/gh-action-pypi-publish@release/v1 89 | with: 90 | password: ${{ secrets.PYPI_API_TOKEN }} 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## General 2 | 3 | # Compiled Object files 4 | *.slo 5 | *.lo 6 | *.o 7 | *.cuo 8 | 9 | # Compiled Dynamic libraries 10 | *.so 11 | *.dylib 12 | *.pyd 13 | 14 | # Compiled Static libraries 15 | *.lai 16 | *.la 17 | *.a 18 | 19 | # Compiled python 20 | *.pyc 21 | 22 | # Compiled MATLAB 23 | *.mex* 24 | 25 | # IPython notebook checkpoints 26 | .ipynb_checkpoints 27 | 28 | # Editor temporaries 29 | *.swn 30 | *.swo 31 | *.swp 32 | *~ 33 | 34 | # Sublime Text settings 35 | *.sublime-workspace 36 | *.sublime-project 37 | 38 | # Eclipse Project settings 39 | *.*project 40 | .settings 41 | 42 | # QtCreator files 43 | *.user 44 | 45 | # PyCharm files 46 | .idea 47 | 48 | # Visual Studio Code files 49 | .vscode 50 | 51 | # OSX dir files 52 | .DS_Store 53 | 54 | ## ONNX 55 | 56 | # build, distribute, and bins (+ python proto bindings) 57 | build/ 58 | build_*/ 59 | .build_debug/* 60 | .build_release/* 61 | .setuptools-cmake-build*/* 62 | 63 | # setup.py intermediates 64 | .eggs 65 | dist 66 | *.egg-info 67 | *.ninja 68 | .ninja_deps 69 | .ninja_log 70 | compile_commands.json 71 | 72 | # generated files 73 | onnxoptimizer/version.py 74 | compile_commands.json 75 | 76 | # autocomplete 77 | .ycm_extra_conf.py 78 | 79 | # test coverage data files 80 | *.gcov 81 | 82 | .mypy_cache 83 | virtualenv 84 | venv 85 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/onnx"] 2 | path = third_party/onnx 3 | url = https://github.com/daquexian/onnx.git 4 | [submodule "third_party/protobuf"] 5 | path = third_party/protobuf 6 | url = https://github.com/protocolbuffers/protobuf.git 7 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | 3 | # For std::filesystem 4 | # Must be a cache variable and be set before project() 5 | # Reference: https://cmake.org/cmake/help/latest/variable/CMAKE_OSX_DEPLOYMENT_TARGET.html 6 | # Maybe it can be a normal variable if policy CMP0126 is set to NEW? 7 | set(CMAKE_OSX_DEPLOYMENT_TARGET 10.15 CACHE STRING "Minimum OS X deployment version") 8 | 9 | project(onnx_optimizer C CXX) 10 | 11 | set(CMAKE_CXX_STANDARD 17) 12 | 13 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 14 | 15 | include(cmake/utils.cmake) 16 | 17 | # For integration with onnxruntime_webassembly etc. 18 | if (NOT DEFINED ONNX_TARGET_NAME) 19 | set(ONNX_TARGET_NAME onnx) 20 | endif() 21 | 22 | option(ONNX_OPT_USE_SYSTEM_PROTOBUF "" OFF) 23 | if(NOT ONNX_OPT_USE_SYSTEM_PROTOBUF) 24 | option(protobuf_BUILD_TESTS "" OFF) 25 | option(protobuf_MSVC_STATIC_RUNTIME "" ${ONNX_USE_MSVC_STATIC_RUNTIME}) 26 | add_subdirectory_if_no_target(${PROJECT_SOURCE_DIR}/third_party/protobuf/cmake libprotobuf) 27 | endif() 28 | 29 | 30 | set(ONNX_ROOT ${PROJECT_SOURCE_DIR}/third_party/onnx) 31 | add_subdirectory_if_no_target(${ONNX_ROOT} ${ONNX_TARGET_NAME}) 32 | 33 | file(READ "${PROJECT_SOURCE_DIR}/VERSION_NUMBER" ONNX_OPTIMIZER_VERSION) 34 | string(STRIP "${ONNX_OPTIMIZER_VERSION}" ONNX_OPTIMIZER_VERSION) 35 | 36 | file(GLOB onnx_opt_srcs "onnxoptimizer/*.cc" 37 | "onnxoptimizer/*.h" 38 | "onnxoptimizer/passes/*.cc" 39 | "onnxoptimizer/passes/*.h" 40 | ) 41 | list(REMOVE_ITEM onnx_opt_srcs "${PROJECT_SOURCE_DIR}/onnxoptimizer/cpp2py_export.cc") 42 | 43 | onnxopt_add_library(onnx_optimizer ${onnx_opt_srcs}) 44 | target_link_libraries(onnx_optimizer PUBLIC ${ONNX_TARGET_NAME}) 45 | target_include_directories(onnx_optimizer PUBLIC 46 | $ 47 | $ 48 | ) 49 | 50 | onnxopt_add_executable(onnx_optimizer_exec examples/onnx_optimizer_exec.cpp) 51 | target_link_libraries(onnx_optimizer_exec onnx_optimizer) 52 | 53 | 54 | file(GLOB onnx_opt_c_api_srcs "onnxoptimizer/c_api/*.cc" 55 | "onnxoptimizer/c_api/*.h" 56 | ) 57 | 58 | onnxopt_add_library(onnx_optimizer_c_api ${onnx_opt_c_api_srcs}) 59 | target_link_libraries(onnx_optimizer_c_api PRIVATE onnx_optimizer) 60 | target_include_directories(onnx_optimizer_c_api PUBLIC 61 | $ 62 | $ 63 | ) 64 | 65 | if(BUILD_ONNX_PYTHON) 66 | if("${PY_EXT_SUFFIX}" STREQUAL "") 67 | if(MSVC) 68 | set(PY_EXT_SUFFIX ".pyd") 69 | else() 70 | set(PY_EXT_SUFFIX ".so") 71 | endif() 72 | endif() 73 | find_package(Python COMPONENTS Interpreter REQUIRED) 74 | 75 | onnxopt_add_library(onnx_opt_cpp2py_export MODULE "onnxoptimizer/cpp2py_export.cc") 76 | set_target_properties(onnx_opt_cpp2py_export PROPERTIES PREFIX "") 77 | set_target_properties(onnx_opt_cpp2py_export 78 | PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") 79 | set_target_properties(onnx_opt_cpp2py_export PROPERTIES SUFFIX ${PY_EXT_SUFFIX}) 80 | set_target_properties(onnx_opt_cpp2py_export 81 | PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) 82 | target_include_directories(onnx_opt_cpp2py_export PRIVATE 83 | $ 84 | $ 85 | ${Python_INCLUDE_DIR}) 86 | # pybind11 is a header only lib 87 | find_package(pybind11 2.2) 88 | if(pybind11_FOUND) 89 | target_include_directories(onnx_opt_cpp2py_export PUBLIC 90 | ${pybind11_INCLUDE_DIRS}) 91 | else() 92 | if(EXISTS ${ONNX_ROOT}/third_party/pybind11/include/pybind11/pybind11.h) 93 | target_include_directories(onnx_opt_cpp2py_export PUBLIC 94 | ${ONNX_ROOT}/third_party/pybind11/include) 95 | else() 96 | message(FATAL_ERROR "cannot find pybind") 97 | endif() 98 | endif() 99 | 100 | if(APPLE) 101 | set_target_properties(onnx_opt_cpp2py_export 102 | PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") 103 | target_link_libraries(onnx_opt_cpp2py_export 104 | PRIVATE -Wl,-force_load,$) 105 | elseif(MSVC) 106 | # In MSVC, we will add whole archive in default 107 | target_link_libraries(onnx_opt_cpp2py_export 108 | PRIVATE -WHOLEARCHIVE:$) 109 | elseif(CMAKE_SYSTEM_NAME STREQUAL "AIX") 110 | # whole-archive linker option not available on AIX 111 | target_sources(onnx_opt_cpp2py_export 112 | PRIVATE $) 113 | else() 114 | # Assume everything else is like gcc 115 | target_link_libraries(onnx_opt_cpp2py_export 116 | PRIVATE "-Wl,--whole-archive" $ 117 | "-Wl,--no-whole-archive") 118 | set_target_properties(onnx_opt_cpp2py_export 119 | PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL") 120 | endif() 121 | 122 | target_link_libraries(onnx_opt_cpp2py_export PRIVATE onnx_optimizer) 123 | 124 | if(MSVC) 125 | find_package(Python COMPONENTS Interpreter Development REQUIRED) 126 | target_link_libraries(onnx_opt_cpp2py_export PRIVATE ${Python_LIBRARIES}) 127 | target_compile_options(onnx_opt_cpp2py_export 128 | PRIVATE /MP 129 | /WX 130 | /wd4800 # disable warning type' : forcing 131 | # value to bool 'true' or 'false' 132 | # (performance warning) 133 | /wd4503 # identifier' : decorated name length 134 | # exceeded, name was truncated 135 | /wd4146 # unary minus operator applied to 136 | # unsigned type, result still 137 | # unsigned from include\google\protob 138 | # uf\wire_format_lite.h 139 | /wd4244 # 'argument': conversion from 'google:: 140 | # protobuf::uint64' to 'int', possible 141 | # loss of data 142 | /wd4267 # Conversion from 'size_t' to 'int', 143 | # possible loss of data 144 | /wd4996 # The second parameter is ignored. 145 | ${EXTRA_FLAGS}) 146 | if(ONNX_USE_PROTOBUF_SHARED_LIBS) 147 | target_compile_options(onnx_opt_cpp2py_export 148 | PRIVATE /wd4251 # 'identifier' : class 'type1' needs to 149 | # have dll-interface to be used by 150 | # clients of class 'type2' 151 | ) 152 | endif() 153 | endif() 154 | endif() 155 | 156 | include(GNUInstallDirs) 157 | 158 | install(DIRECTORY ${PROJECT_SOURCE_DIR}/onnxoptimizer 159 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} 160 | FILES_MATCHING 161 | PATTERN "*.h") 162 | 163 | configure_file( 164 | ${PROJECT_SOURCE_DIR}/cmake/ONNXOptimizerConfigVersion.cmake.in 165 | ${PROJECT_BINARY_DIR}/ONNXOptimizerConfigVersion.cmake 166 | @ONLY) 167 | configure_file( 168 | ${PROJECT_SOURCE_DIR}/cmake/ONNXOptimizerConfig.cmake.in 169 | ${PROJECT_BINARY_DIR}/ONNXOptimizerConfig.cmake 170 | @ONLY) 171 | install(FILES 172 | ${PROJECT_BINARY_DIR}/ONNXOptimizerConfigVersion.cmake 173 | ${PROJECT_BINARY_DIR}/ONNXOptimizerConfig.cmake 174 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ONNXOptimizer 175 | COMPONENT dev) 176 | install(EXPORT ONNXOptimizerTargets DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/ONNXOptimizer") 177 | install(TARGETS 178 | onnx_optimizer onnx_optimizer_c_api 179 | EXPORT ONNXOptimizerTargets DESTINATION ${CMAKE_INSTALL_LIBDIR}) 180 | 181 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include onnxoptimizer *.h *.c *.cc *.proto 2 | recursive-include examples * 3 | recursive-include cmake * 4 | recursive-include tools * 5 | recursive-include third_party * 6 | include VERSION_NUMBER 7 | include CMakeLists.txt 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # ONNX Optimizer 4 | 5 | [![PyPI version](https://img.shields.io/pypi/v/onnxoptimizer.svg)](https://pypi.python.org/pypi/onnxoptimizer/) 6 | [![PyPI license](https://img.shields.io/pypi/l/onnxoptimizer.svg)](https://pypi.python.org/pypi/onnxoptimizer/) 7 | [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/onnx/optimizer/pulls) 8 | 9 | ## 🛠 Maintainer Wanted 10 | 11 | We are currently **looking for a new maintainer** to help support and evolve the `onnxoptimizer` project. 12 | 13 | If you're passionate about ONNX, graph optimizations, or contributing to the open source machine learning ecosystem, we'd love to hear from you! This is a great opportunity to contribute to a widely used project and collaborate with the ONNX community. 14 | 15 | **To express interest:** 16 | Please open an issue or comment on [this thread](https://github.com/onnx/optimizer/issues) and let us know about your interest and background. 17 | 18 | ## Introduction 19 | 20 | ONNX provides a C++ library for performing arbitrary optimizations on ONNX models, as well as a growing list of prepackaged optimization passes. 21 | 22 | The primary motivation is to share work between the many ONNX backend implementations. Not all possible optimizations can be directly implemented on ONNX graphs - some will need additional backend-specific information - but many can, and our aim is to provide all such passes along with ONNX so that they can be re-used with a single function call. 23 | 24 | You may be interested in invoking the provided passes, or in implementing new ones (or both). 25 | 26 | ## Installation 27 | 28 | You can install onnxoptimizer from PyPI: 29 | 30 | ```bash 31 | pip3 install onnxoptimizer 32 | ``` 33 | 34 | Note that you may need to upgrade your pip first if you have trouble: 35 | 36 | ```bash 37 | pip3 install -U pip 38 | ``` 39 | 40 | If you want to build from source: 41 | 42 | ```bash 43 | git clone --recursive https://github.com/onnx/optimizer onnxoptimizer 44 | cd onnxoptimizer 45 | pip3 install -e . 46 | ``` 47 | 48 | Note that you need to install protobuf before building from source. 49 | 50 | 51 | ## Command-line API 52 | Now you can use command-line api in terminal instead of python script. 53 | 54 | ``` 55 | python -m onnxoptimizer input_model.onnx output_model.onnx 56 | ``` 57 | 58 | Arguments list is following: 59 | ``` 60 | # python3 -m onnxoptimizer -h 61 | usage: python -m onnxoptimizer input_model.onnx output_model.onnx 62 | 63 | onnxoptimizer command-line api 64 | 65 | optional arguments: 66 | -h, --help show this help message and exit 67 | --print_all_passes print all available passes 68 | --print_fuse_elimination_passes 69 | print all fuse and elimination passes 70 | -p [PASSES ...], --passes [PASSES ...] 71 | list of optimization passes name, if no set, fuse_and_elimination_passes will be used 72 | --fixed_point fixed point 73 | ``` 74 | ## Roadmap 75 | 76 | * More built-in pass 77 | * Separate graph rewriting and constant folding (or a pure graph rewriting mode, see [issue #9](https://github.com/onnx/optimizer/issues/9) for the details) 78 | 79 | ## Relevant tools 80 | 81 | * [onnx-simplifier](https://github.com/daquexian/onnx-simplifier): A handy and popular tool based on onnxoptimizer 82 | 83 | ## Code of Conduct 84 | 85 | [ONNX Open Source Code of Conduct](https://onnx.ai/codeofconduct.html) 86 | -------------------------------------------------------------------------------- /VERSION_NUMBER: -------------------------------------------------------------------------------- 1 | 0.3.19 2 | -------------------------------------------------------------------------------- /cmake/ONNXOptimizerConfig.cmake.in: -------------------------------------------------------------------------------- 1 | # - Config file for the ONNX Optimizer package 2 | # It defines the following variable(s) 3 | # ONNX_OPTIMIZER_INCLUDE_DIRS - include directories for onnx optimizer 4 | # as well as ONNX Optimizer targets for other cmake libraries to use. 5 | 6 | # library version information 7 | set(ONNX_OPTIMIZER_VERSION "@ONNX_OPTIMIZER_VERSION@") 8 | 9 | # import targets 10 | include ("${CMAKE_CURRENT_LIST_DIR}/ONNXOptimizerTargets.cmake") 11 | 12 | # include directory. 13 | # 14 | # Newer versions of CMake set the INTERFACE_INCLUDE_DIRECTORIES property 15 | # of the imported targets. It is hence not necessary to add this path 16 | # manually to the include search path for targets which link to gflags. 17 | # The following lines are here for backward compatibility, in case one 18 | # would like to use the old-style include path. 19 | get_filename_component( 20 | CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) 21 | get_filename_component( 22 | _INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) 23 | set(ONNX_OPTIMIZER_INCLUDE_DIRS "${_INSTALL_PREFIX}/include") 24 | 25 | -------------------------------------------------------------------------------- /cmake/ONNXOptimizerConfigVersion.cmake.in: -------------------------------------------------------------------------------- 1 | set(PACKAGE_VERSION "@ONNX_OPTIMIZER_VERSION@") 2 | 3 | # Check whether the requested PACKAGE_FIND_VERSION is compatible 4 | if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") 5 | set(PACKAGE_VERSION_COMPATIBLE FALSE) 6 | else() 7 | set(PACKAGE_VERSION_COMPATIBLE TRUE) 8 | if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") 9 | set(PACKAGE_VERSION_EXACT TRUE) 10 | endif() 11 | endif() 12 | 13 | -------------------------------------------------------------------------------- /cmake/utils.cmake: -------------------------------------------------------------------------------- 1 | include(${PROJECT_SOURCE_DIR}/third_party/onnx/cmake/Utils.cmake) 2 | 3 | # Poor man's FetchContent 4 | function(add_subdirectory_if_no_target dir target) 5 | if (NOT TARGET ${target}) 6 | add_subdirectory(${dir}) 7 | endif() 8 | endfunction() 9 | 10 | function(onnxopt_add_library) 11 | add_library(${ARGV}) 12 | if (MSVC) 13 | add_msvc_runtime_flag(${ARGV0}) 14 | endif() 15 | endfunction() 16 | 17 | function(onnxopt_add_executable) 18 | add_executable(${ARGV}) 19 | if (MSVC) 20 | add_msvc_runtime_flag(${ARGV0}) 21 | endif() 22 | endfunction() 23 | -------------------------------------------------------------------------------- /examples/onnx_optimizer_exec.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | void printUsage() { 14 | std::string usage = 15 | R"(Usage: onnx_optimizer_exec [model.onnx] [model_out.onnx] [optional: model_data_out.data])"; 16 | std::cout << usage << std::endl; 17 | } 18 | 19 | int main(int argc, char** argv) { 20 | if (argc != 3 && argc != 4) { 21 | printUsage(); 22 | return -1; 23 | } 24 | std::string model_in_path(argv[1]); 25 | std::string model_out_path(argv[2]); 26 | std::string model_data_path{}; 27 | if (argc == 4) { 28 | model_data_path = std::filesystem::relative( 29 | std::string(argv[3]), 30 | std::filesystem::path(model_out_path).parent_path()).string(); 31 | } 32 | 33 | try { 34 | ONNX_NAMESPACE::ModelProto model; 35 | onnx::optimization::loadModel(&model, model_in_path, true); 36 | onnx::checker::check_model(model); 37 | auto new_model = onnx::optimization::Optimize( 38 | model, onnx::optimization::GetFuseAndEliminationPass()); 39 | onnx::checker::check_model(new_model); 40 | bool save_external_data = !model_data_path.empty(); 41 | onnx::optimization::saveModel(&new_model, model_out_path, 42 | save_external_data, model_data_path); 43 | 44 | } catch (std::exception& e) { 45 | std::cout << e.what() << std::endl; 46 | return -1; 47 | } 48 | return 0; 49 | } 50 | -------------------------------------------------------------------------------- /onnxoptimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # ATTENTION: The code in this file is highly EXPERIMENTAL. 4 | # Adventurous users should note that the APIs will probably change. 5 | 6 | """onnx optimizer 7 | 8 | This enables users to optimize their models. 9 | """ 10 | 11 | import onnx 12 | import onnxoptimizer.onnx_opt_cpp2py_export as C 13 | from .version import version as __version__ # noqa 14 | from onnx import ModelProto 15 | from typing import Text, Sequence, Optional 16 | from onnxoptimizer.onnxoptimizer_main import main 17 | import tempfile 18 | import os 19 | 20 | get_available_passes = C.get_available_passes 21 | 22 | get_fuse_and_elimination_passes = C.get_fuse_and_elimination_passes 23 | 24 | 25 | def optimize(model, passes=None, fixed_point=False): # type: (ModelProto, Optional[Sequence[Text]], bool) -> ModelProto 26 | """Apply the optimization on the serialized ModelProto. 27 | 28 | Arguments: 29 | model (ModelProto): model 30 | passes (list of string): list of optimization names 31 | 32 | Return: 33 | return (ModelProto) optimized model 34 | """ 35 | 36 | if passes is None: 37 | passes = get_fuse_and_elimination_passes() 38 | if not isinstance(model, ModelProto): 39 | raise ValueError( 40 | 'Optimizer only accepts ModelProto, incorrect type: {}'.format(type(model))) 41 | try: 42 | model_str = model.SerializeToString() 43 | if fixed_point: 44 | optimized_model_str = C.optimize_fixedpoint(model_str, passes) 45 | else: 46 | optimized_model_str = C.optimize(model_str, passes) 47 | 48 | return onnx.load_from_string(optimized_model_str) 49 | except ValueError: 50 | file_src = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) 51 | file_dest = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) 52 | data_file_src = tempfile.NamedTemporaryFile(delete=False) 53 | data_file_dest = tempfile.NamedTemporaryFile(delete=False) 54 | data_src_rel_filename = os.path.relpath(data_file_src.name, os.path.dirname(file_src.name)) 55 | data_dest_rel_filename = os.path.relpath(data_file_dest.name, os.path.dirname(file_dest.name)) 56 | try: 57 | onnx.save(model, file_src.name, save_as_external_data=True, location=data_src_rel_filename, convert_attribute=True,) 58 | if fixed_point: 59 | C.optimize_fixedpoint_from_path(file_src.name, file_dest.name, passes, data_dest_rel_filename) 60 | else: 61 | C.optimize_from_path(file_src.name, file_dest.name, passes, data_dest_rel_filename) 62 | return onnx.load(file_dest, load_external_data=True) 63 | finally: 64 | os.remove(file_src.name) 65 | os.remove(file_dest.name) 66 | os.remove(data_file_src.name) 67 | os.remove(data_file_dest.name) 68 | 69 | 70 | __all__ = ['optimize', 'get_available_passes', 'get_fuse_and_elimination_passes', 'main'] 71 | -------------------------------------------------------------------------------- /onnxoptimizer/__main__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # ATTENTION: The code in this file is highly EXPERIMENTAL. 4 | # Adventurous users should note that the APIs will probably change. 5 | 6 | """onnx optimizer 7 | 8 | This enables users to optimize their models. 9 | """ 10 | 11 | from . import main 12 | 13 | 14 | if __name__ == '__main__': 15 | main() 16 | -------------------------------------------------------------------------------- /onnxoptimizer/c_api/onnxoptimizer_c_api.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "onnx/onnx_pb.h" 6 | #include "onnx/proto_utils.h" 7 | #include "onnxoptimizer/model_util.h" 8 | #include "onnxoptimizer/optimize.h" 9 | #include "onnxoptimizer_c_api.h" 10 | 11 | static const char** CopyPasses(const std::vector& passes) { 12 | size_t n = passes.size(); 13 | char** res_passes = static_cast(malloc(sizeof(char*) * (n + 1))); 14 | if (!res_passes) { 15 | return NULL; 16 | } 17 | int valid_count = 0; 18 | for (const auto& pass : passes) { 19 | const auto* from = pass.c_str(); 20 | const auto n = strlen(from); 21 | char* to = static_cast(malloc(n + 1)); 22 | if (!to) { 23 | continue; 24 | } else { 25 | memcpy(to, from, n); 26 | to[n] = '\0'; 27 | res_passes[valid_count++] = to; 28 | } 29 | } 30 | while (valid_count <= n) { 31 | res_passes[valid_count++] = NULL; 32 | } 33 | return const_cast(res_passes); 34 | } 35 | 36 | const char** C_API_GetAvailablePasses() { 37 | return CopyPasses(ONNX_NAMESPACE::optimization::GetAvailablePasses()); 38 | } 39 | 40 | const char** C_API_GetFuseAndEliminationPass() { 41 | return CopyPasses(ONNX_NAMESPACE::optimization::GetFuseAndEliminationPass()); 42 | } 43 | 44 | void C_API_ReleasePasses(const char*** passes) { 45 | if (!passes) { 46 | return; 47 | } 48 | const char* p = passes[0][0]; 49 | while (p) { 50 | void* cur = reinterpret_cast(const_cast(p)); 51 | p++; 52 | free(cur); 53 | } 54 | free(passes[0]); 55 | passes[0] = NULL; 56 | } 57 | 58 | static bool SerializeProtoAndCopy(const ONNX_NAMESPACE::ModelProto& p, 59 | void** buffer, size_t* size) { 60 | std::string out; 61 | p.SerializeToString(&out); 62 | void* buf = malloc(sizeof(char) * out.size()); 63 | if (!buf) { 64 | return false; 65 | } 66 | memcpy(buf, out.c_str(), out.size()); 67 | *size = out.size(); 68 | *buffer = buf; 69 | return true; 70 | } 71 | 72 | static std::pair Optimize( 73 | const ONNX_NAMESPACE::ModelProto& proto, const char** passes, 74 | const bool fix_point) { 75 | std::vector names; 76 | const char* p = passes[0]; 77 | while (p) { 78 | names.push_back(std::string(p)); 79 | p++; 80 | } 81 | if (names.empty()) { 82 | return std::make_pair(false, ONNX_NAMESPACE::ModelProto()); 83 | } 84 | try { 85 | if (fix_point) { 86 | auto result = ONNX_NAMESPACE::optimization::OptimizeFixed(proto, names); 87 | return std::make_pair(true, result); 88 | } else { 89 | auto result = ONNX_NAMESPACE::optimization::Optimize(proto, names); 90 | return std::make_pair(true, result); 91 | } 92 | } catch (std::exception& e) { 93 | std::cerr << e.what(); 94 | return std::make_pair(false, ONNX_NAMESPACE::ModelProto()); 95 | } 96 | } 97 | 98 | bool C_API_Optimize(const char* mp_in_buffer, const size_t mp_in_size, 99 | const char** passes, const bool fix_point, 100 | void** mp_out_buffer, size_t* mp_out_size) { 101 | if (!mp_in_buffer || mp_in_size == 0 || !passes || !mp_out_buffer || 102 | !mp_out_size) { 103 | return false; 104 | } 105 | 106 | ONNX_NAMESPACE::ModelProto proto{}; 107 | if (!ONNX_NAMESPACE::ParseProtoFromBytes(&proto, mp_in_buffer, mp_in_size)) { 108 | return false; 109 | } 110 | bool ok = false; 111 | ONNX_NAMESPACE::ModelProto result{}; 112 | std::tie(ok, result) = Optimize(proto, passes, fix_point); 113 | if (!ok) { 114 | return false; 115 | } 116 | return SerializeProtoAndCopy(result, mp_out_buffer, mp_out_size); 117 | } 118 | 119 | bool C_API_OtimizeFromFile(const char* import_model_path, 120 | const char* export_model_path, const char** passes, 121 | const bool fix_point, const bool save_external_data, 122 | const char* data_file_name) { 123 | if (!import_model_path || !export_model_path || !passes || 124 | (save_external_data && !data_file_name)) { 125 | return false; 126 | } 127 | try { 128 | ONNX_NAMESPACE::ModelProto proto{}; 129 | ONNX_NAMESPACE::optimization::loadModel( 130 | &proto, std::string(import_model_path), true); 131 | bool ok = false; 132 | ONNX_NAMESPACE::ModelProto result{}; 133 | std::tie(ok, result) = Optimize(proto, passes, fix_point); 134 | if (!ok) { 135 | return false; 136 | } 137 | ONNX_NAMESPACE::optimization::saveModel( 138 | &result, std::string(export_model_path), save_external_data, 139 | std::string(data_file_name)); 140 | return true; 141 | } catch (std::exception& e) { 142 | std::cerr << e.what(); 143 | return false; 144 | } 145 | } -------------------------------------------------------------------------------- /onnxoptimizer/c_api/onnxoptimizer_c_api.h: -------------------------------------------------------------------------------- 1 | #ifndef ONNXOPTIMIZER_C_API_H 2 | #define ONNXOPTIMIZER_C_API_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | /// caller must call C_API_ReleasePasses to free memory 9 | const char** C_API_GetAvailablePasses(); 10 | 11 | /// caller must call C_API_ReleasePasses to free memory 12 | const char** C_API_GetFuseAndEliminationPass(); 13 | 14 | void C_API_ReleasePasses(const char*** passes); 15 | 16 | // caller must call free to release mp_out buffer 17 | bool C_API_Optimize(const char* mp_in, const size_t mp_in_size, 18 | const char** passes, const bool fix_point, void** mp_out, 19 | size_t* mp_out_size); 20 | 21 | bool C_API_OtimizeFromFile(const char* import_model_path, 22 | const char* export_model_path, const char** passes, 23 | const bool fix_point, const bool save_external_data, 24 | const char* data_file_name); 25 | 26 | #ifdef __cplusplus 27 | } 28 | #endif 29 | 30 | #endif -------------------------------------------------------------------------------- /onnxoptimizer/cpp2py_export.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include 6 | #include 7 | 8 | #include "onnx/py_utils.h" 9 | #include "onnxoptimizer/model_util.h" 10 | #include "onnxoptimizer/optimize.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace py = pybind11; 14 | using namespace pybind11::literals; 15 | PYBIND11_MODULE(onnx_opt_cpp2py_export, onnx_opt_cpp2py_export) { 16 | onnx_opt_cpp2py_export.doc() = "ONNX Optimizer"; 17 | 18 | onnx_opt_cpp2py_export.def( 19 | "optimize", 20 | [](const py::bytes& bytes, const std::vector& names) { 21 | ModelProto proto{}; 22 | ParseProtoFromPyBytes(&proto, bytes); 23 | auto const result = optimization::Optimize(proto, names); 24 | std::string out; 25 | result.SerializeToString(&out); 26 | return py::bytes(out); 27 | }); 28 | 29 | onnx_opt_cpp2py_export.def( 30 | "optimize_fixedpoint", 31 | [](const py::bytes& bytes, const std::vector& names) { 32 | ModelProto proto{}; 33 | ParseProtoFromPyBytes(&proto, bytes); 34 | auto const result = optimization::OptimizeFixed(proto, names); 35 | std::string out; 36 | result.SerializeToString(&out); 37 | return py::bytes(out); 38 | }); 39 | 40 | onnx_opt_cpp2py_export.def( 41 | "optimize_from_path", [](const std::string& import_model_path, 42 | const std::string& export_model_path, 43 | const std::vector& names, 44 | const std::string& export_data_file_name) { 45 | ModelProto proto{}; 46 | optimization::loadModel(&proto, import_model_path, true); 47 | auto result = optimization::Optimize(proto, names); 48 | optimization::saveModel(&result, export_model_path, true, 49 | export_data_file_name); 50 | }); 51 | 52 | onnx_opt_cpp2py_export.def( 53 | "optimize_fixedpoint_from_path", 54 | [](const std::string& import_model_path, 55 | const std::string& export_model_path, 56 | const std::vector& names, 57 | const std::string& export_data_file_name) { 58 | ModelProto proto{}; 59 | optimization::loadModel(&proto, import_model_path, true); 60 | auto result = optimization::OptimizeFixed(proto, names); 61 | optimization::saveModel(&result, export_model_path, true, 62 | export_data_file_name); 63 | }); 64 | onnx_opt_cpp2py_export.def("get_available_passes", 65 | &optimization::GetAvailablePasses); 66 | onnx_opt_cpp2py_export.def("get_fuse_and_elimination_passes", 67 | &optimization::GetFuseAndEliminationPass); 68 | } 69 | } // namespace ONNX_NAMESPACE 70 | -------------------------------------------------------------------------------- /onnxoptimizer/model_util.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/onnx_pb.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | void loadModel(ModelProto* m, const std::string& model_path, 16 | const bool load_external_data = false); 17 | 18 | void saveModel(ModelProto* m, const std::string& model_path, 19 | const bool save_external_data = false, 20 | const std::string& data_file_name = {}); 21 | 22 | } // namespace optimization 23 | } // namespace ONNX_NAMESPACE -------------------------------------------------------------------------------- /onnxoptimizer/onnxoptimizer_main.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # ATTENTION: The code in this file is highly EXPERIMENTAL. 4 | # Adventurous users should note that the APIs will probably change. 5 | 6 | """onnx optimizer 7 | 8 | This enables users to optimize their models. 9 | """ 10 | 11 | import onnx 12 | import onnx.checker 13 | import argparse 14 | import sys 15 | import onnxoptimizer 16 | import pathlib 17 | 18 | usage = 'python -m onnxoptimizer input_model.onnx output_model.onnx ' 19 | 20 | 21 | def format_argv(argv): 22 | argv_ = argv[1:] 23 | if len(argv_) == 1: 24 | return argv_ 25 | elif len(argv_) >= 2: 26 | return argv_[2:] 27 | else: 28 | print('please check arguments!') 29 | sys.exit(1) 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser( 34 | prog='onnxoptimizer', 35 | usage=usage, 36 | description='onnxoptimizer command-line api') 37 | parser.add_argument('--print_all_passes', action='store_true', default=False, help='print all available passes') 38 | parser.add_argument('--print_fuse_elimination_passes', action='store_true', default=False, help='print all fuse and elimination passes') 39 | parser.add_argument('-p', '--passes', nargs='*', default=None, help='list of optimization passes name, if no set, fuse_and_elimination_passes will be used') 40 | parser.add_argument('--fixed_point', action='store_true', default=False, help='fixed point') 41 | argv = sys.argv.copy() 42 | args = parser.parse_args(format_argv(sys.argv)) 43 | 44 | all_available_passes = onnxoptimizer.get_available_passes() 45 | fuse_and_elimination_passes = onnxoptimizer.get_fuse_and_elimination_passes() 46 | 47 | if args.print_all_passes: 48 | print(*all_available_passes) 49 | sys.exit(0) 50 | 51 | if args.print_fuse_elimination_passes: 52 | print(*fuse_and_elimination_passes) 53 | sys.exit(0) 54 | 55 | passes = args.passes 56 | if args.passes is None: 57 | passes = fuse_and_elimination_passes 58 | 59 | if len(argv[1:]) < 2: 60 | print('usage:{}'.format(usage)) 61 | print('please check arguments!') 62 | sys.exit(1) 63 | 64 | input_file = argv[1] 65 | output_file = argv[2] 66 | 67 | if not pathlib.Path(input_file).exists(): 68 | print("input file: {0} no exist!".format(input_file)) 69 | sys.exit(1) 70 | 71 | model = onnx.load(input_file) 72 | 73 | # when model size large than 2G bytes, onnx.checker.check_model(model) will fail. 74 | # we use onnx.check.check_model(input_file) as workaround 75 | onnx.checker.check_model(input_file) 76 | model = onnxoptimizer.optimize(model=model, passes=passes, fixed_point=args.fixed_point) 77 | if model is None: 78 | print('onnxoptimizer failed') 79 | sys.exit(1) 80 | try: 81 | onnx.save(proto=model, f=output_file) 82 | except: 83 | onnx.save(proto=model, f=output_file, save_as_external_data=True) 84 | onnx.checker.check_model(output_file) 85 | -------------------------------------------------------------------------------- /onnxoptimizer/optimize.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #include "onnxoptimizer/optimize.h" 9 | 10 | namespace ONNX_NAMESPACE { 11 | namespace optimization { 12 | 13 | GlobalPassRegistry Optimizer::passes; 14 | 15 | Optimizer::Optimizer( 16 | const std::vector& names, 17 | const bool fixed_point) { 18 | if (fixed_point) { 19 | this->pass_manager = 20 | std::shared_ptr(new FixedPointPassManager()); 21 | } else { 22 | this->pass_manager = 23 | std::shared_ptr(new GeneralPassManager()); 24 | } 25 | for (const auto& name : names) { 26 | auto pass = passes.find(name); 27 | this->pass_manager->add(pass); 28 | } 29 | } 30 | Optimizer::~Optimizer() {} 31 | 32 | ModelProto Optimize( 33 | const ModelProto& mp_in, 34 | const std::vector& names) { 35 | Optimizer current_opt(names, false); 36 | return current_opt.optimize(mp_in); 37 | } 38 | ModelProto OptimizeFixed( 39 | const ModelProto& mp_in, 40 | const std::vector& names) { 41 | Optimizer current_opt(names, true); 42 | return current_opt.optimize(mp_in); 43 | } 44 | const std::vector GetAvailablePasses() { 45 | return Optimizer::passes.GetAvailablePasses(); 46 | } 47 | const std::vector GetFuseAndEliminationPass() { 48 | return Optimizer::passes.GetFuseAndEliminationPass(); 49 | } 50 | 51 | } // namespace optimization 52 | } // namespace ONNX_NAMESPACE 53 | -------------------------------------------------------------------------------- /onnxoptimizer/optimize.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/common/ir.h" 11 | #include "onnx/common/ir_pb_converter.h" 12 | #include "onnx/proto_utils.h" 13 | 14 | #include "onnxoptimizer/pass_manager.h" 15 | #include "onnxoptimizer/pass_registry.h" 16 | 17 | #include "vector" 18 | 19 | namespace ONNX_NAMESPACE { 20 | namespace optimization { 21 | 22 | struct Optimizer { 23 | static GlobalPassRegistry passes; 24 | 25 | public: 26 | Optimizer(const std::vector &names, const bool fixed_point); 27 | ~Optimizer(); 28 | 29 | ModelProto optimize(const ModelProto &_mp_in) { 30 | ModelProto mp_in = _mp_in; 31 | if (mp_in.ir_version() == 3) { 32 | // Upgrade ir_version to 4 so that initializer can be not in input 33 | mp_in.set_ir_version(4); 34 | } 35 | std::shared_ptr g(ImportModelProto(mp_in)); 36 | 37 | if (g.get() == nullptr) { 38 | std::cerr << "Warning: onnx optimizer is unable to parse input model. " 39 | << "(The IR version of the ONNX model may be too old.)" 40 | << std::endl; 41 | // If we can't parse the file, just return the input. 42 | return mp_in; 43 | } 44 | 45 | ModelProto mp_out = PrepareOutput(mp_in); 46 | this->pass_manager->run(*g); 47 | ExportModelProto(&mp_out, g); 48 | return mp_out; 49 | } 50 | 51 | private: 52 | std::shared_ptr pass_manager; 53 | 54 | ModelProto AddInitializerToInput(const ModelProto &original_model) { 55 | ModelProto model = original_model; 56 | std::vector input_names; 57 | for (const auto &x : model.graph().input()) { 58 | input_names.push_back(x.name()); 59 | } 60 | for (const auto &x : model.graph().initializer()) { 61 | if (std::find(input_names.begin(), input_names.end(), x.name()) == 62 | input_names.end()) { 63 | auto *value_info = model.mutable_graph()->add_input(); 64 | value_info->set_name(x.name()); 65 | TypeProto *type = value_info->mutable_type(); 66 | auto *tensor = type->mutable_tensor_type(); 67 | tensor->set_elem_type(x.data_type()); 68 | auto *shape = tensor->mutable_shape(); 69 | for (const auto &dim : x.dims()) { 70 | TensorShapeProto::Dimension *new_dim = shape->add_dim(); 71 | new_dim->set_dim_value(dim); 72 | } 73 | } 74 | } 75 | return model; 76 | } 77 | }; 78 | 79 | const std::vector GetAvailablePasses(); 80 | 81 | const std::vector GetFuseAndEliminationPass(); 82 | 83 | ModelProto Optimize(const ModelProto &mp_in, 84 | const std::vector &names); 85 | 86 | ModelProto OptimizeFixed(const ModelProto &mp_in, 87 | const std::vector &names); 88 | } // namespace optimization 89 | } // namespace ONNX_NAMESPACE 90 | -------------------------------------------------------------------------------- /onnxoptimizer/pass.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "onnx/common/assertions.h" 6 | 7 | #include "onnxoptimizer/pass.h" 8 | 9 | namespace ONNX_NAMESPACE { 10 | namespace optimization { 11 | 12 | Pass::Pass( 13 | PassType pass_type, 14 | PassEfficiency pass_efficiency, 15 | PassOptimizationType pass_optimization_type) { 16 | this->pass_type = pass_type; 17 | this->pass_efficiency = pass_efficiency; 18 | this->pass_optimization_type = pass_optimization_type; 19 | } 20 | 21 | Pass::~Pass() {} 22 | 23 | unsigned int Pass::DescendOnGraphAttributesAndCount( 24 | Node* n, 25 | std::function fn) { 26 | unsigned int num_changes = 0; 27 | for (auto name : n->attributeNames()) { 28 | auto kind = n->kindOf(name); 29 | if (kind == AttributeKind::g) { 30 | num_changes += fn(*n->g(name)); 31 | } 32 | if (kind == AttributeKind::gs) { 33 | for (auto& g : n->gs(name)) { 34 | num_changes += fn(*g); 35 | } 36 | } 37 | } 38 | return num_changes; 39 | } 40 | 41 | void Pass::DescendOnGraphAttributesUnconstrained( 42 | Node* n, 43 | std::function fn) { 44 | for (auto name : n->attributeNames()) { 45 | auto kind = n->kindOf(name); 46 | if (kind == AttributeKind::g) { 47 | fn(*n->g(name)); 48 | } 49 | if (kind == AttributeKind::gs) { 50 | for (auto& g : n->gs(name)) { 51 | fn(*g); 52 | } 53 | } 54 | } 55 | } 56 | 57 | PredicateBasedPass::~PredicateBasedPass() {} 58 | 59 | unsigned int PredicateBasedPass::_runPassInternal(Graph& graph) { 60 | unsigned int num_changes = false; 61 | for (auto it = graph.begin(); it != graph.end(); ++it) { 62 | auto* n = *it; 63 | num_changes += this->DescendOnGraphAttributesAndCount( 64 | n, [this](Graph& g) { return _runPassInternal(g); }); 65 | if (this->patternMatchPredicate(n)) { 66 | NodeDestroyType destroy_type = NodeDestroyType::DestroyZero; 67 | num_changes += this->runTransform(n, graph, destroy_type); 68 | 69 | if (destroy_type == NodeDestroyType::DestroyOne) { 70 | it.destroyCurrent(); 71 | } 72 | } 73 | } 74 | return num_changes; 75 | } 76 | 77 | PassAnalysisType PredicateBasedPass::getPassAnalysisType() const { 78 | return PassAnalysisType::CountBased; 79 | } 80 | 81 | std::shared_ptr PredicateBasedPass::runPass(Graph& graph) { 82 | bool initialized_pass = this->initializePass(graph); 83 | unsigned int touched_optimizations = this->_runPassInternal(graph); 84 | bool finalized_pass = this->finalizePass(graph); 85 | 86 | return std::shared_ptr(new CountBasedPassAnalysis( 87 | this, touched_optimizations, initialized_pass, finalized_pass)); 88 | } 89 | 90 | CountBasedPassAnalysis::CountBasedPassAnalysis( 91 | Pass* pass, 92 | unsigned int num_positive_transforms, 93 | bool initialization_done, 94 | bool finalization_done) { 95 | this->pass = pass; 96 | this->num_positive_transforms = num_positive_transforms; 97 | this->initialization_done = initialization_done; 98 | this->finalization_done = finalization_done; 99 | } 100 | 101 | FullGraphBasedPass::~FullGraphBasedPass() {} 102 | 103 | } // namespace optimization 104 | } // namespace ONNX_NAMESPACE 105 | -------------------------------------------------------------------------------- /onnxoptimizer/pass_manager.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "onnxoptimizer/pass_manager.h" 6 | 7 | namespace ONNX_NAMESPACE { 8 | namespace optimization { 9 | 10 | PassManager::PassManager() {} 11 | PassManager::~PassManager() {} 12 | 13 | GeneralPassManager::~GeneralPassManager() { 14 | this->passes.clear(); 15 | } 16 | void GeneralPassManager::add(std::shared_ptr pass) { 17 | this->passes.push_back(std::move(pass)); 18 | } 19 | 20 | std::shared_ptr GeneralPassManager::run(Graph& graph) { 21 | for (const std::shared_ptr& pass : this->passes) { 22 | auto pass_analysis = pass->runPass(graph); 23 | } 24 | return std::shared_ptr(new EmptyPassManagerAnalysis()); 25 | } 26 | 27 | std::shared_ptr FixedPointPassManager::run(Graph& graph) { 28 | bool fixed_point_optimization_done; 29 | 30 | do { 31 | fixed_point_optimization_done = false; 32 | for (const std::shared_ptr& pass : this->passes) { 33 | std::shared_ptr analysis = pass->runPass(graph); 34 | if (pass->getPassAnalysisType() == PassAnalysisType::Empty) { 35 | continue; 36 | } 37 | std::shared_ptr count_analysis = 38 | std::static_pointer_cast(analysis); 39 | 40 | while (count_analysis->fixedPointOptimizationNeeded()) { 41 | count_analysis = std::static_pointer_cast( 42 | pass->runPass(graph)); 43 | fixed_point_optimization_done = true; 44 | } 45 | } 46 | } while (fixed_point_optimization_done); 47 | 48 | return std::shared_ptr(new EmptyPassManagerAnalysis()); 49 | } 50 | } // namespace optimization 51 | } // namespace ONNX_NAMESPACE 52 | -------------------------------------------------------------------------------- /onnxoptimizer/pass_manager.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 7 | // Adventurous users should note that the APIs will probably change. 8 | 9 | #include 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | // An analysis returned from the run done by a manager 16 | struct PassManagerAnalysis {}; 17 | struct EmptyPassManagerAnalysis : PassManagerAnalysis {}; 18 | 19 | // Base class of all PassManager's. The class should be able to add new passes 20 | // as well as run the passes given a graph. 21 | class PassManager { 22 | public: 23 | PassManager(); 24 | virtual ~PassManager(); 25 | 26 | virtual void add(std::shared_ptr P) = 0; 27 | virtual std::shared_ptr run(Graph& graph) = 0; 28 | }; 29 | 30 | // The GeneralPassManager has no restriction on type of Pass and runs the passes 31 | // once in a linear fashion. 32 | class GeneralPassManager : public PassManager { 33 | public: 34 | GeneralPassManager() {} 35 | ~GeneralPassManager() override; 36 | 37 | void add(std::shared_ptr pass) override; 38 | std::shared_ptr run(Graph& graph) override; 39 | 40 | protected: 41 | // use vector here to ensure the order of the passes 42 | // for some pass, order is critical, for example, 43 | // split_init and split_predict should be the last in the list 44 | std::vector> passes; 45 | }; 46 | 47 | // Exhibits the same behavior as GeneralPassManager but will instead check 48 | // whether or not fixed point optimization is needed. 49 | class FixedPointPassManager : public GeneralPassManager { 50 | std::shared_ptr run(Graph& graph) override; 51 | }; 52 | 53 | } // namespace optimization 54 | } // namespace ONNX_NAMESPACE 55 | -------------------------------------------------------------------------------- /onnxoptimizer/pass_registry.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #include "onnxoptimizer/pass_registry.h" 9 | 10 | namespace ONNX_NAMESPACE { 11 | namespace optimization { 12 | 13 | const std::vector GlobalPassRegistry::GetFuseAndEliminationPass() { 14 | std::vector names; 15 | for (const auto& name : this->pass_names) { 16 | const auto pass_type = this->passes.at(name)->getPassType(); 17 | if (pass_type == PassType::Fuse || pass_type == PassType::Nop) { 18 | names.push_back(name); 19 | } 20 | } 21 | return names; 22 | } 23 | 24 | } // namespace optimization 25 | } // namespace ONNX_NAMESPACE 26 | -------------------------------------------------------------------------------- /onnxoptimizer/pass_registry.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | #include "onnx/common/ir.h" 14 | #include "onnx/common/ir_pb_converter.h" 15 | #include "onnx/proto_utils.h" 16 | #include "onnxoptimizer/passes/adjust_add.h" 17 | #include "onnxoptimizer/passes/adjust_slice_and_matmul.h" 18 | #include "onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h" 19 | #include "onnxoptimizer/passes/eliminate_deadend.h" 20 | #include "onnxoptimizer/passes/eliminate_duplicate_initializer.h" 21 | #include "onnxoptimizer/passes/eliminate_identity.h" 22 | #include "onnxoptimizer/passes/eliminate_if_with_const_cond.h" 23 | #include "onnxoptimizer/passes/eliminate_nop_cast.h" 24 | #include "onnxoptimizer/passes/eliminate_nop_concat.h" 25 | #include "onnxoptimizer/passes/eliminate_nop_dropout.h" 26 | #include "onnxoptimizer/passes/eliminate_nop_expand.h" 27 | #include "onnxoptimizer/passes/eliminate_nop_flatten.h" 28 | #include "onnxoptimizer/passes/eliminate_nop_monotone_argmax.h" 29 | #include "onnxoptimizer/passes/eliminate_nop_pad.h" 30 | #include "onnxoptimizer/passes/eliminate_nop_reshape.h" 31 | #include "onnxoptimizer/passes/eliminate_nop_split.h" 32 | #include "onnxoptimizer/passes/eliminate_nop_transpose.h" 33 | #include "onnxoptimizer/passes/eliminate_shape_gather.h" 34 | #include "onnxoptimizer/passes/eliminate_shape_op.h" 35 | #include "onnxoptimizer/passes/eliminate_slice_after_shape.h" 36 | #include "onnxoptimizer/passes/eliminate_unused_initializer.h" 37 | #include "onnxoptimizer/passes/extract_constant_to_initializer.h" 38 | #include "onnxoptimizer/passes/fuse_add_bias_into_conv.h" 39 | #include "onnxoptimizer/passes/fuse_bn_into_conv.h" 40 | #include "onnxoptimizer/passes/fuse_concat_into_reshape.h" 41 | #include "onnxoptimizer/passes/fuse_consecutive_concats.h" 42 | #include "onnxoptimizer/passes/fuse_consecutive_log_softmax.h" 43 | #include "onnxoptimizer/passes/fuse_consecutive_reduce_unsqueeze.h" 44 | #include "onnxoptimizer/passes/fuse_consecutive_squeezes.h" 45 | #include "onnxoptimizer/passes/fuse_consecutive_transposes.h" 46 | #include "onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h" 47 | #include "onnxoptimizer/passes/fuse_pad_into_conv.h" 48 | #include "onnxoptimizer/passes/fuse_pad_into_pool.h" 49 | #include "onnxoptimizer/passes/fuse_transpose_into_gemm.h" 50 | #include "onnxoptimizer/passes/lift_lexical_references.h" 51 | #include "onnxoptimizer/passes/nop.h" 52 | #include "onnxoptimizer/passes/rename_input_output.h" 53 | #include "onnxoptimizer/passes/replace_einsum_with_matmul.h" 54 | #include "onnxoptimizer/passes/set_unique_name_for_nodes.h" 55 | #include "onnxoptimizer/passes/split.h" 56 | #include "onnxoptimizer/passes/fuse_consecutive_slices.h" 57 | #include "onnxoptimizer/passes/eliminate_common_subexpression.h" 58 | #include "onnxoptimizer/passes/fuse_qkv.h" 59 | #include "onnxoptimizer/passes/fuse_consecutive_unsqueezes.h" 60 | #include "onnxoptimizer/passes/eliminate_nop_with_unit.h" 61 | #include "onnxoptimizer/passes/rewrite_input_dtype.h" 62 | 63 | namespace ONNX_NAMESPACE { 64 | namespace optimization { 65 | 66 | // Registry containing all passes available in ONNX. 67 | struct GlobalPassRegistry { 68 | std::map> passes; 69 | std::vector pass_names; 70 | 71 | GlobalPassRegistry() { 72 | // Register the optimization passes to the optimizer. 73 | registerPass(); 74 | registerPass(); 75 | registerPass(); 76 | registerPass(); 77 | registerPass(); 78 | registerPass(); 79 | registerPass(); 80 | registerPass(); 81 | registerPass(); 82 | registerPass(); 83 | registerPass(); 84 | registerPass(); 85 | registerPass(); 86 | registerPass(); 87 | registerPass(); 88 | registerPass(); 89 | registerPass(); 90 | registerPass(); 91 | registerPass(); 92 | registerPass(); 93 | registerPass(); 94 | registerPass(); 95 | registerPass(); 96 | registerPass(); 97 | registerPass(); 98 | registerPass(); 99 | registerPass(); 100 | registerPass(); 101 | registerPass(); 102 | registerPass(); 103 | registerPass(); 104 | registerPass(); 105 | registerPass(); 106 | registerPass(); 107 | registerPass(); 108 | registerPass(); 109 | registerPass(); 110 | registerPass(); 111 | registerPass(); 112 | registerPass(); 113 | registerPass(); 114 | registerPass(); 115 | registerPass(); 116 | registerPass(); 117 | registerPass(); 118 | registerPass(); 119 | registerPass(); 120 | } 121 | 122 | ~GlobalPassRegistry() { 123 | this->passes.clear(); 124 | } 125 | 126 | std::shared_ptr find(std::string pass_name) { 127 | auto it = this->passes.find(pass_name); 128 | ONNX_ASSERTM(it != this->passes.end(), "pass %s is unknown.", 129 | pass_name.c_str()); 130 | return it->second; 131 | } 132 | const std::vector GetAvailablePasses() { 133 | return pass_names; 134 | } 135 | 136 | const std::vector GetFuseAndEliminationPass(); 137 | 138 | template 139 | void registerPass() { 140 | static_assert(std::is_base_of::value, "T must inherit from Pass"); 141 | std::shared_ptr pass(new T()); 142 | passes[pass->getPassName()] = pass; 143 | pass_names.emplace_back(pass->getPassName()); 144 | } 145 | }; 146 | } // namespace optimization 147 | } // namespace ONNX_NAMESPACE 148 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/adjust_add.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | /* 11 | 12 | Before: 13 | bias MatMul(X,weights) 14 | | | 15 | | | 16 | | | 17 | Add 18 | 19 | This case could not be fused in TensorRT, but adjust bias to the second input, 20 | it will be fused Gemm+biasAdd. 21 | 22 | After: 23 | 24 | MatMul(X,weights) bias 25 | | | 26 | | | 27 | | | 28 | Add 29 | 30 | */ 31 | 32 | #include 33 | 34 | #include "onnx/defs/tensor_util.h" 35 | #include "onnxoptimizer/pass.h" 36 | #include "onnxoptimizer/passes/pass_util.h" 37 | namespace ONNX_NAMESPACE { 38 | namespace optimization { 39 | 40 | struct AdjustAdd final : public PredicateBasedPass { 41 | explicit AdjustAdd() 42 | : PredicateBasedPass(PassType::Immutable, PassEfficiency::Complete, 43 | PassOptimizationType::Compute) {} 44 | std::string getPassName() const override { 45 | return "adjust_add"; 46 | } 47 | 48 | bool patternMatchPredicate(Node* node) override { 49 | return CheckKind(node, kAdd) && IsConstantTensor(node, 0) && 50 | !IsConstantTensor(node, 1); 51 | } 52 | 53 | bool runTransform(Node* n, Graph& graph, 54 | NodeDestroyType& destroy_current) override { 55 | auto* old = n->replaceInput(0, n->inputs()[1]); 56 | n->replaceInput(1, old); 57 | destroy_current = NodeDestroyType::DestroyZero; 58 | return true; 59 | } 60 | }; 61 | 62 | } // namespace optimization 63 | } // namespace ONNX_NAMESPACE 64 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/adjust_slice_and_matmul.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | /* 11 | Before 12 | Y = Matmul(Slice(data, start, end, axes) ,rhs) , where data and rhs are 13 | constant tensor, axes of slice should be constant tensor as well as the value of 14 | axes should not represent the last shape. 15 | 16 | After 17 | Y = Slice(Matmul(data, rhs), start, end, axes) , where Matmul can be folded. 18 | */ 19 | 20 | #include 21 | 22 | #include "onnx/defs/tensor_util.h" 23 | #include "onnxoptimizer/pass.h" 24 | #include "onnxoptimizer/passes/pass_util.h" 25 | 26 | namespace ONNX_NAMESPACE { 27 | namespace optimization { 28 | 29 | struct AdjustSliceAndMatmul final : public PredicateBasedPass { 30 | explicit AdjustSliceAndMatmul() 31 | : PredicateBasedPass(PassType::Replace, PassEfficiency::Complete, 32 | PassOptimizationType::Compute) {} 33 | std::string getPassName() const override { 34 | return "adjust_slice_and_matmul"; 35 | } 36 | 37 | bool patternMatchPredicate(Node* node) override { 38 | int64_t slice_axis; 39 | const bool result = CheckKind(node, kMatMul, 0, kSlice) && 40 | // rhs should be constant tensor 41 | IsConstantTensor(node, 1) && 42 | // lhs should be constant tensor 43 | IsConstantTensor(node, 0, 0) && 44 | // slice should have explicit axes 45 | GetInputsOfPreNode(node, 0).size() >= 4 && 46 | // axes of slice should be constant tensor 47 | IsConstantTensor(node, 0, 3) && 48 | node->inputs()[0]->uses().size() == 1; 49 | if (!result) { 50 | return false; 51 | } 52 | Node* slice = PrevNode(node, 0); 53 | const int64_t rank = slice->input(0)->sizes().size(); 54 | std::vector axes = GetIntsFromValue(slice->input(3)); 55 | return std::none_of(axes.cbegin(), axes.cend(), [&rank](int64_t d) { 56 | return AddYIfNegative(d, rank) == rank - 1; 57 | }); 58 | } 59 | 60 | bool runTransform(Node* n, Graph& graph, 61 | NodeDestroyType& destroy_current) override { 62 | Value* slice_value = n->inputs()[0]; 63 | Value* mat_y = n->inputs()[1]; 64 | 65 | Node* slice = slice_value->node(); 66 | Value* mat_x = slice->inputs()[0]; 67 | 68 | Node* new_matmul = graph.create(kMatMul, 1); 69 | new_matmul->addInput(mat_x); 70 | new_matmul->addInput(mat_y); 71 | 72 | Node* new_slice = graph.create(kSlice, 1); 73 | new_slice->addInput(new_matmul->output()); 74 | for (int i = 1; i < slice->inputs().size(); ++i) { 75 | new_slice->addInput(slice->inputs()[i]); 76 | } 77 | 78 | new_slice->insertBefore(n); 79 | new_matmul->insertBefore(new_slice); 80 | 81 | const bool replacing_success = tryReplacingAllUsesWith(n, new_slice); 82 | if (!replacing_success) { 83 | return false; 84 | } 85 | destroy_current = NodeDestroyType::DestroyOne; 86 | return true; 87 | } 88 | }; 89 | 90 | } // namespace optimization 91 | } // namespace ONNX_NAMESPACE -------------------------------------------------------------------------------- /onnxoptimizer/passes/bitscast.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | #include 10 | #include 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | inline float FP32FromBits(uint32_t bits) { 16 | union { 17 | uint32_t as_bits; 18 | float as_value; 19 | } fp32{bits}; 20 | return fp32.as_value; 21 | } 22 | 23 | inline uint32_t FP32ToBits(float value) { 24 | union { 25 | float as_value; 26 | uint32_t as_bits; 27 | } fp32{value}; 28 | return fp32.as_bits; 29 | } 30 | 31 | } // namespace optimization 32 | } // namespace ONNX_NAMESPACE 33 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/data_type.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/passes/bitscast.h" 11 | #include "onnxoptimizer/passes/logging.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct Complex64 { 17 | using base_type = float; 18 | Complex64() : real_part(0.f), imaginary_part(0.f) {} 19 | Complex64(float r, float i) : real_part(r), imaginary_part(i) {} 20 | 21 | bool operator==(const Complex64& rhs) const { 22 | return real_part == rhs.real_part && imaginary_part == rhs.imaginary_part; 23 | } 24 | 25 | bool operator!=(const Complex64& rhs) const { 26 | return !(*this == rhs); 27 | } 28 | 29 | float real_part; 30 | float imaginary_part; 31 | }; 32 | 33 | struct Complex128 { 34 | using base_type = double; 35 | Complex128() : real_part(0.0), imaginary_part(0.0) {} 36 | Complex128(double r, double i) : real_part(r), imaginary_part(i) {} 37 | 38 | bool operator==(const Complex128& rhs) const { 39 | return real_part == rhs.real_part && imaginary_part == rhs.imaginary_part; 40 | } 41 | bool operator!=(const Complex128& rhs) const { 42 | return !(*this == rhs); 43 | } 44 | double real_part; 45 | double imaginary_part; 46 | }; 47 | 48 | /// The IEEE 754 specifies a half-precision as having format: 1 bit for sign, 5 49 | /// bits for the exponet and 11 bits for the mantissa. 50 | 51 | struct Float16 { 52 | Float16() : bits(0) {} 53 | // bit-wise convert 54 | Float16(int32_t v) : bits(static_cast(v)) {} 55 | Float16(uint16_t v) : bits(v) {} 56 | bool operator==(const Float16& rhs) const { 57 | return bits == rhs.bits; 58 | } 59 | bool operator!=(const Float16& rhs) const { 60 | return !(*this == rhs); 61 | } 62 | uint16_t bits; 63 | }; 64 | 65 | /// Bfloat16 representation uses 1 bit for the sign, 8 bits for the exponent 66 | /// and 7 bits for the mantissa. It is assumed that floats are in IEEE 754 67 | /// format so the representation is just bits 16-31 of a single precision float. 68 | struct BFloat16 { 69 | BFloat16() : bits(0) {} 70 | // bit-wise convert 71 | BFloat16(int32_t v) : bits(static_cast(v)) {} 72 | BFloat16(uint16_t v) : bits(v) {} 73 | BFloat16(float v) : bits(static_cast(FP32ToBits(v) >> 16)) {} 74 | 75 | /// to float 76 | operator float() { 77 | return FP32FromBits(static_cast(bits) << 16); 78 | } 79 | /// from float 80 | BFloat16& operator=(float val) { 81 | bits = static_cast(FP32ToBits(val) >> 16); 82 | return *this; 83 | } 84 | 85 | bool operator==(const BFloat16& rhs) const { 86 | return bits == rhs.bits; 87 | } 88 | bool operator!=(const BFloat16& rhs) const { 89 | return !(*this == rhs); 90 | } 91 | uint16_t bits; 92 | }; 93 | 94 | template 95 | std::size_t ComplexHashHelper(const Complex& complex) { 96 | auto hasher = std::hash(); 97 | auto r = hasher(complex.real_part); 98 | auto i = hasher(complex.imaginary_part); 99 | return r ^= i + 0x9e3779b9 + (r << 6) + (r >> 2); 100 | } 101 | 102 | } // namespace optimization 103 | } // namespace ONNX_NAMESPACE 104 | 105 | namespace std { 106 | 107 | #define DEFINE_COMPLEX_HASH(type) \ 108 | template <> \ 109 | struct hash { \ 110 | std::size_t operator()(const type& complex) const { \ 111 | return ONNX_NAMESPACE::optimization::ComplexHashHelper(complex); \ 112 | } \ 113 | }; 114 | 115 | DEFINE_COMPLEX_HASH(ONNX_NAMESPACE::optimization::Complex128) 116 | DEFINE_COMPLEX_HASH(ONNX_NAMESPACE::optimization::Complex64) 117 | #undef DEFINE_COMPLEX_HASH 118 | 119 | template <> 120 | struct hash { 121 | std::size_t operator()(const ONNX_NAMESPACE::optimization::Float16& v) const { 122 | return hash{}(v.bits); 123 | }; 124 | }; 125 | 126 | template <> 127 | struct hash { 128 | std::size_t operator()( 129 | const ONNX_NAMESPACE::optimization::BFloat16& v) const { 130 | return hash{}(v.bits); 131 | }; 132 | }; 133 | 134 | } // namespace std 135 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_common_subexpression.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "onnx/defs/tensor_util.h" 12 | #include "onnxoptimizer/pass.h" 13 | #include "onnxoptimizer/passes/cse_util.h" 14 | #include "onnxoptimizer/passes/logging.h" 15 | #include "onnxoptimizer/passes/pass_util.h" 16 | #include "onnxoptimizer/passes/string_utils.h" 17 | 18 | namespace ONNX_NAMESPACE { 19 | namespace optimization { 20 | 21 | struct EliminateCommonSubexpression final : public FullGraphBasedPass { 22 | explicit EliminateCommonSubexpression() 23 | : FullGraphBasedPass(PassType::Nop, PassEfficiency::Complete, 24 | PassOptimizationType::Compute) {} 25 | std::string getPassName() const override { 26 | return "eliminate_common_subexpression"; 27 | } 28 | PassAnalysisType getPassAnalysisType() const override { 29 | return PassAnalysisType::CountBased; 30 | } 31 | 32 | unsigned int EliminateCommonSubexpressions(Graph &graph) { 33 | auto node_list = graph.nodes(); 34 | unsigned int cse_removed = 0; 35 | std::unordered_map hash_map; 36 | for (auto it = node_list.begin(); it != node_list.end(); ++it) { 37 | auto node = *it; 38 | auto kind = node->kind(); 39 | if (!node->hasUses() || !IsSupportedByCSE(node)) { 40 | continue; 41 | } 42 | VLOG(2) << Str("kind: ", kind.toString(), ", ", node->name(), 43 | " is processing"); 44 | if (hash_map.find(node) == hash_map.end()) { 45 | hash_map[node] = node; 46 | } else { 47 | auto other = hash_map.at(node); 48 | auto outputs = other->outputs(); 49 | auto replaced_outputs = node->outputs(); 50 | for (int i = 0; i < outputs.size(); ++i) { 51 | if (tryReplacingAllUsesWith(replaced_outputs[i], outputs[i])) { 52 | VLOG(1) << Str("kind: ", kind.toString(), ", ", node->name(), " [", 53 | i, "] output has been replaced by ", other->name()); 54 | cse_removed++; 55 | } 56 | } 57 | } 58 | } 59 | return cse_removed; 60 | } 61 | 62 | std::shared_ptr runPass(Graph &graph) override { 63 | auto cse_removed = this->EliminateCommonSubexpressions(graph); 64 | VLOG(1) << Str("cse_removed count: ", cse_removed); 65 | return std::shared_ptr( 66 | new CountBasedPassAnalysis(this, cse_removed, false, false)); 67 | } 68 | }; 69 | } // namespace optimization 70 | } // namespace ONNX_NAMESPACE 71 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "onnxoptimizer/passes/pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct EliminateConsecutiveIdempotentOps final : public PredicateBasedPass { 17 | explicit EliminateConsecutiveIdempotentOps() 18 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 19 | PassOptimizationType::Compute) {} 20 | 21 | std::string getPassName() const override { 22 | return "eliminate_consecutive_idempotent_ops"; 23 | } 24 | 25 | bool patternMatchPredicate(Node* node) override { 26 | static const std::unordered_set idempotent_ops = { 27 | "Ceil", "Floor", "Round", "Relu", "Reshape"}; 28 | for (const auto& op : idempotent_ops) { 29 | // TODO: support uses().size() > 1 for ops except Reshape 30 | if (CheckKind(node, Symbol(op), 0, Symbol(op)) && 31 | node->input(0)->uses().size() == 1) { 32 | return true; 33 | } 34 | } 35 | return false; 36 | } 37 | bool runTransform(Node* node, Graph& graph, 38 | NodeDestroyType& destroy_current) override { 39 | Node* previous_node = node->input(0)->node(); 40 | std::vector sizes = previous_node->input(0)->sizes(); 41 | bool replacing_success = 42 | tryReplacingAllUsesWith(node->input(0), previous_node->input(0)); 43 | if (replacing_success) { 44 | if (node->kind() == kReshape) { 45 | // restore the correct sizes 46 | previous_node->input(0)->setSizes(sizes); 47 | } 48 | return true; 49 | } 50 | return false; 51 | } 52 | }; 53 | 54 | } // namespace optimization 55 | } // namespace ONNX_NAMESPACE 56 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_deadend.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | #pragma once 8 | #include "onnxoptimizer/pass.h" 9 | namespace ONNX_NAMESPACE { 10 | namespace optimization { 11 | struct EliminateDeadEnd final : public FullGraphBasedPass { 12 | explicit EliminateDeadEnd() 13 | : FullGraphBasedPass(PassType::Nop, PassEfficiency::Complete, 14 | PassOptimizationType::Compute) {} 15 | std::string getPassName() const override { 16 | return "eliminate_deadend"; 17 | } 18 | PassAnalysisType getPassAnalysisType() const override { 19 | return PassAnalysisType::CountBased; 20 | } 21 | unsigned int EliminateDead(Graph& graph) { 22 | unsigned int nodes_removed = 0; 23 | auto nodes = graph.nodes().reverse(); 24 | for (auto it = nodes.begin(); it != nodes.end(); it++) { 25 | auto node = *it; 26 | if (!node->hasUses()) { 27 | nodes_removed++; 28 | it.destroyCurrent(); 29 | } 30 | } 31 | return nodes_removed; 32 | } 33 | std::shared_ptr runPass(Graph& graph) override { 34 | auto nodes_removed = this->EliminateDead(graph); 35 | return std::shared_ptr( 36 | new CountBasedPassAnalysis(this, nodes_removed, false, false)); 37 | } 38 | }; 39 | } // namespace optimization 40 | } // namespace ONNX_NAMESPACE 41 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_duplicate_initializer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | #pragma once 8 | 9 | // Before: 10 | // A, B are in the initializer list, and A is equal to B 11 | // E = Add(D, A) 12 | // F = Add(F, B) 13 | // G = Add(E, F) 14 | // After: 15 | // A is in the initializer list 16 | // E = Add(D, A) 17 | // F = Add(F, A) 18 | // G = Add(E, F) 19 | // 20 | // NOTE: ONNX IR has an bug that an initializer must also 21 | // be an graph input. Currently we are using a workaround 22 | // that adds initializers to inputs before optimization 23 | // and removes the added initializers from inputs after 24 | // optimization. That makes us cannot distinguish the 25 | // initializers really in the inputs and the initializers 26 | // not in the inputs. While only the latter can be eliminated, 27 | // we eliminate all duplicated initializers instead. That 28 | // may cause unexpected behavior in some rare cases. 29 | 30 | #include 31 | #include 32 | 33 | #include "onnx/defs/tensor_util.h" 34 | #include "onnxoptimizer/pass.h" 35 | #include "onnxoptimizer/passes/cse_util.h" 36 | 37 | namespace ONNX_NAMESPACE { 38 | namespace optimization { 39 | 40 | struct EliminateDuplicateInitializer final : public FullGraphBasedPass { 41 | explicit EliminateDuplicateInitializer() 42 | : FullGraphBasedPass(PassType::Nop, PassEfficiency::Complete, 43 | PassOptimizationType::Memory) {} 44 | std::string getPassName() const override { 45 | return "eliminate_duplicate_initializer"; 46 | } 47 | PassAnalysisType getPassAnalysisType() const override { 48 | return PassAnalysisType::CountBased; 49 | } 50 | 51 | Value *findInitializerValueByName(Node *initializer_node, 52 | const std::string &name) { 53 | for (size_t i = 0; i < initializer_node->outputs().size(); i++) { 54 | if (initializer_node->outputs()[i]->uniqueName() == name) { 55 | return initializer_node->outputs()[i]; 56 | } 57 | } 58 | return nullptr; 59 | } 60 | 61 | unsigned int EliminateInitializer(Graph &graph) { 62 | unsigned int initializers_removed = 0; 63 | const std::vector &initializers = graph.initializers(); 64 | 65 | // Make {name : Value} map 66 | std::unordered_set input_set; 67 | for (auto inp : graph.inputs()) { 68 | if (inp->has_unique_name()) { 69 | input_set.emplace(inp->uniqueName()); 70 | } 71 | } 72 | 73 | std::unordered_set output_set; 74 | for (auto out : graph.outputs()) { 75 | if (out->has_unique_name()) { 76 | output_set.emplace(out->uniqueName()); 77 | } 78 | } 79 | std::unordered_map 80 | initializer_map; 81 | std::vector> replaced_table; 82 | for (const auto& initializer : initializers) { 83 | if (!initializer.hasName()) { 84 | continue; 85 | } 86 | const auto &name = initializer.name(); 87 | // Ignore initializer which is an input 88 | if (input_set.find(name) != input_set.end()) { 89 | continue; 90 | } 91 | // Ignore initializer which is output 92 | if (output_set.find(name) != output_set.end()) { 93 | continue; 94 | } 95 | if (initializer_map.count(&initializer) == 0) { 96 | initializer_map[&initializer] = name; 97 | } else { 98 | replaced_table.emplace_back( 99 | std::make_pair(name, initializer_map.at(&initializer))); 100 | } 101 | } 102 | if (replaced_table.empty()) { 103 | return initializers_removed; 104 | } 105 | // workaround to fetch initializer_node_ pointer in graph 106 | Tensor dummy_tensor; 107 | dummy_tensor.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 108 | Node *initializer_node = 109 | graph.addInitializerAndCreateValue(dummy_tensor)->node(); 110 | VLOG(1) << Str("====== Graph: ", graph.name(), "====="); 111 | for (const auto &p : replaced_table) { 112 | VLOG(1) << Str("<", p.first, ",", p.second, ">"); 113 | Value *old_value = findInitializerValueByName(initializer_node, p.first); 114 | Value *new_value = findInitializerValueByName(initializer_node, p.second); 115 | if (!old_value || !new_value) { 116 | continue; 117 | } 118 | old_value->replaceAllUsesWith(new_value); 119 | graph.eraseInitializerAndInput(old_value); 120 | initializers_removed++; 121 | } 122 | VLOG(1) << Str("====== Graph: ", graph.name(), 123 | "=====, removed: ", initializers_removed); 124 | graph.eraseInitializer(dummy_tensor.name()); 125 | return initializers_removed; 126 | } 127 | std::shared_ptr runPass(Graph &graph) override { 128 | auto initializers_removed = this->EliminateInitializer(graph); 129 | return std::shared_ptr( 130 | new CountBasedPassAnalysis(this, initializers_removed, false, false)); 131 | } 132 | }; 133 | } // namespace optimization 134 | } // namespace ONNX_NAMESPACE 135 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_identity.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct EliminateIdentity final : public PredicateBasedPass { 16 | explicit EliminateIdentity() 17 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 18 | PassOptimizationType::Compute) {} 19 | 20 | std::string getPassName() const override { 21 | return "eliminate_identity"; 22 | } 23 | 24 | bool patternMatchPredicate(Node* node) override { 25 | return node->kind() == kIdentity; 26 | } 27 | bool runTransform(Node* node, Graph& graph, 28 | NodeDestroyType& destroy_current) override { 29 | const bool replacing_success = 30 | tryReplacingAllUsesWith(node->output(), node->input()); 31 | if (!replacing_success) { 32 | return false; 33 | } 34 | destroy_current = NodeDestroyType::DestroyOne; 35 | return true; 36 | } 37 | }; 38 | 39 | } // namespace optimization 40 | } // namespace ONNX_NAMESPACE 41 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_if_with_const_cond.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "onnxoptimizer/passes/pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | // This optimization works well especially when used together with 17 | // constant folding (onnx-simplifier), for example, the if node 18 | // introduced by PyTorch squeeze op will be eliminated when the input 19 | // shape is known. 20 | // Ideally eliminate_if_with_const_cond + eliminate_deadend + constant 21 | // folding can be replaced by the more powerful sparse conditional 22 | // constant propagation, which obviously cannot be implemented in 23 | // the current optimizer framework. 24 | 25 | struct EliminateIfWithConstCond final : public PredicateBasedPass { 26 | explicit EliminateIfWithConstCond() 27 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 28 | PassOptimizationType::Compute) {} 29 | 30 | std::string getPassName() const override { 31 | return "eliminate_if_with_const_cond"; 32 | } 33 | 34 | // step 1: find "if" node with constant cond (i.e. const true or false) 35 | bool patternMatchPredicate(Node *node) override { 36 | if (node->kind() == kIf) { 37 | const auto cond_value = node->input(); 38 | if ((cond_value->node()->kind() == kConstant || 39 | cond_value->owningGraph()->is_constant_initializer(cond_value))) { 40 | return true; 41 | } 42 | } 43 | return false; 44 | } 45 | 46 | // step 2: inline the subgraph (for example, inline then_branch when cond === 47 | // true) 48 | // by re-creating all subgraph nodes in parent graph 49 | // note: handle captured value 50 | // step 3: Delete "if" node itself 51 | bool runTransform(Node *if_node, Graph &graph, 52 | NodeDestroyType &destroy_current) override { 53 | const auto cond_value = if_node->input(); 54 | const Tensor *cond_tensor = FetchConstantTensor(cond_value); 55 | const bool cond = ParseTensorData(cond_tensor)[0]; 56 | auto &parent_graph = graph; 57 | const auto subgraph = if_node->g(cond ? kthen_branch : kelse_branch); 58 | 59 | std::unordered_map unique_name_to_value_in_parent; 60 | 61 | for (auto *x : parent_graph.nodes()) { 62 | for (auto *x_output : x->outputs()) { 63 | unique_name_to_value_in_parent[x_output->uniqueName()] = x_output; 64 | } 65 | } 66 | std::unordered_map value_dict; 67 | for (auto *node : subgraph->nodes()) { 68 | auto *new_node = 69 | parent_graph.create(node->kind(), node->outputs().size()); 70 | new_node->insertBefore(if_node); 71 | new_node->copyAttributes(*node); 72 | for (const auto *input : node->inputs()) { 73 | const auto &unique_name = input->uniqueName(); 74 | if (value_dict.find(unique_name) == value_dict.end()) { 75 | if (input->node()->kind() == kCaptured) { 76 | auto it = unique_name_to_value_in_parent.find(unique_name); 77 | if (it == unique_name_to_value_in_parent.end()) { 78 | // a value from the parent graph of parent_graph 79 | auto *captured_node = parent_graph.create(kCaptured, 1); 80 | captured_node->output()->setUniqueName(unique_name); 81 | new_node->addInput(captured_node->output()); 82 | } else { 83 | new_node->addInput(it->second); 84 | } 85 | } else if (input->node()->kind() == kParam) { 86 | ONNX_ASSERT(subgraph->is_constant_initializer(input)); 87 | const Tensor &initializer_subgraph = 88 | *subgraph->getInitializer(input->uniqueName()); 89 | // copy a new tensor 90 | Tensor initializer_parent_graph = initializer_subgraph; 91 | new_node->addInput(parent_graph.addInitializerAndCreateValue( 92 | initializer_parent_graph)); 93 | } else { 94 | ONNX_ASSERTM( 95 | false, 96 | "input node not in value_dict can only be captured or param"); 97 | } 98 | } else { 99 | new_node->addInput(value_dict[unique_name]); 100 | } 101 | } 102 | for (int i = 0; i < node->outputs().size(); i++) { 103 | const auto *output_in_subgraph = node->outputs()[i]; 104 | auto *output_in_parent_graph = new_node->outputs()[i]; 105 | value_dict[output_in_subgraph->uniqueName()] = output_in_parent_graph; 106 | } 107 | } 108 | const auto &subgraph_outputs = subgraph->outputs(); 109 | for (int i = 0; i < subgraph_outputs.size(); i++) { 110 | auto *new_output = value_dict[subgraph_outputs[i]->uniqueName()]; 111 | auto *if_output = if_node->outputs()[i]; 112 | if_output->replaceAllUsesWith(new_output); 113 | } 114 | destroy_current = DestroyOne; 115 | return true; 116 | } 117 | }; 118 | 119 | } // namespace optimization 120 | } // namespace ONNX_NAMESPACE 121 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_cast.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct EliminateNopCast final : public PredicateBasedPass { 16 | explicit EliminateNopCast() 17 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 18 | PassOptimizationType::Compute) {} 19 | 20 | std::string getPassName() const override { 21 | return "eliminate_nop_cast"; 22 | } 23 | 24 | bool patternMatchPredicate(Node* node) override { 25 | return (node->kind() == kCast && node->hasAttribute(kto) && 26 | node->input()->elemType() == node->i(kto)); 27 | } 28 | 29 | bool runTransform(Node* node, Graph& graph, 30 | NodeDestroyType& destroy_current) override { 31 | const bool replacing_success = 32 | tryReplacingAllUsesWith(node->output(), node->input()); 33 | if (!replacing_success) { 34 | return false; 35 | } 36 | destroy_current = NodeDestroyType::DestroyOne; 37 | return true; 38 | } 39 | }; 40 | 41 | } // namespace optimization 42 | } // namespace ONNX_NAMESPACE 43 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_concat.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct EliminateNopConcat final : public PredicateBasedPass { 16 | explicit EliminateNopConcat() 17 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 18 | PassOptimizationType::Memory) {} 19 | 20 | std::string getPassName() const override { 21 | return "eliminate_nop_concat"; 22 | } 23 | 24 | bool patternMatchPredicate(Node* node) override { 25 | return node->kind() == kConcat && node->inputs().size() == 1; 26 | } 27 | 28 | bool runTransform(Node* node, Graph& graph, 29 | NodeDestroyType& destroy_current) override { 30 | const bool replacing_success = 31 | tryReplacingAllUsesWith(node->output(), node->input()); 32 | if (!replacing_success) { 33 | return false; 34 | } 35 | destroy_current = NodeDestroyType::DestroyOne; 36 | return true; 37 | } 38 | }; 39 | 40 | } // namespace optimization 41 | } // namespace ONNX_NAMESPACE 42 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_dropout.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct EliminateNopDropout final : public PredicateBasedPass { 16 | explicit EliminateNopDropout() 17 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 18 | PassOptimizationType::Compute) {} 19 | 20 | std::string getPassName() const override { 21 | return "eliminate_nop_dropout"; 22 | } 23 | 24 | bool patternMatchPredicate(Node* node) override { 25 | // in opset 12, ratio is an input of Dropout rather than an attribute, 26 | // however we don't want to to remove Dropout fro opset 12+, since it 27 | // supports training-friendly models, for which the Dropout ops are required 28 | return (node->kind() == kDropout && node->hasAttribute(kratio)) && 29 | node->f(kratio) == 0.0; 30 | } 31 | 32 | bool runTransform(Node* node, Graph& graph, 33 | NodeDestroyType& destroy_current) override { 34 | // Don't assume that theres only one output. 35 | for (size_t i = 0; i < node->outputs().size(); ++i) { 36 | const bool replacing_success = 37 | tryReplacingAllUsesWith(node->outputs()[i], node->input()); 38 | if (!replacing_success) { 39 | return false; 40 | } 41 | } 42 | destroy_current = NodeDestroyType::DestroyOne; 43 | return true; 44 | } 45 | }; 46 | 47 | } // namespace optimization 48 | } // namespace ONNX_NAMESPACE 49 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_expand.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct EliminateNopExpand final : public PredicateBasedPass { 17 | explicit EliminateNopExpand() 18 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 19 | PassOptimizationType::Compute) {} 20 | 21 | std::string getPassName() const override { 22 | return "eliminate_nop_expand"; 23 | } 24 | 25 | bool patternMatchPredicate(Node* node) override { 26 | return node->kind() == kExpand && IsConstantTensor(node, 1); 27 | } 28 | 29 | bool runTransform(Node* node, Graph& graph, 30 | NodeDestroyType& destroy_current) override { 31 | auto& input_value = node->inputs()[0]; 32 | const auto* shape_tensor = FetchConstantTensor(node->input(1)); 33 | 34 | if (!shape_tensor || 35 | !isABroadcastToB(ParseTensorData(shape_tensor), 36 | input_value->sizes()) || 37 | !tryReplacingAllUsesWith(node->output(), input_value)) { 38 | return false; 39 | } 40 | 41 | destroy_current = NodeDestroyType::DestroyOne; 42 | return true; 43 | } 44 | }; 45 | 46 | } // namespace optimization 47 | } // namespace ONNX_NAMESPACE 48 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_flatten.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "onnxoptimizer/passes/pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct EliminateNopFlatten final : public PredicateBasedPass { 17 | explicit EliminateNopFlatten() 18 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 19 | PassOptimizationType::Compute) {} 20 | 21 | std::string getPassName() const override { 22 | return "eliminate_nop_flatten"; 23 | } 24 | 25 | bool patternMatchPredicate(Node *node) override { 26 | if (!CheckKind(node, "Flatten")) { 27 | return false; 28 | } 29 | const Value *input = node->input(); 30 | if (!input->has_sizes()) { 31 | return false; 32 | } 33 | const auto input_shape = input->sizes(); 34 | const int64_t axis = GetValueFromAttrWithDefault(node, kaxis, 1); 35 | if (input_shape.size() == 2) { 36 | if (axis == 1 || axis == -1) { 37 | return true; 38 | } 39 | if (input_shape[0].is_int && input_shape[0].dim == 1 && axis == 0) { 40 | return true; 41 | } 42 | } 43 | 44 | return false; 45 | } 46 | 47 | bool runTransform(Node *node, Graph &graph, 48 | NodeDestroyType &destroy_current) override { 49 | const bool replacing_success = 50 | tryReplacingAllUsesWith(node->output(), node->input()); 51 | if (!replacing_success) { 52 | return false; 53 | } 54 | destroy_current = NodeDestroyType::DestroyOne; 55 | return true; 56 | } 57 | }; 58 | 59 | } // namespace optimization 60 | } // namespace ONNX_NAMESPACE 61 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_monotone_argmax.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | #pragma once 8 | 9 | #include "onnxoptimizer/pass.h" 10 | 11 | namespace ONNX_NAMESPACE { 12 | namespace optimization { 13 | 14 | // Note for log and sqrt this optimization is not always right, 15 | // because it is a undefined behavior when the input is negative 16 | const std::unordered_set monotone_node_no_axis_kind{kLog, kExp, 17 | kSqrt}; 18 | 19 | const std::unordered_set monotone_node_axis_kind{kSoftmax, 20 | kLogSoftmax}; 21 | 22 | struct EliminateNopMonotoneArgmax final : public PredicateBasedPass { 23 | explicit EliminateNopMonotoneArgmax() 24 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Partial, 25 | PassOptimizationType::Compute) {} 26 | 27 | std::string getPassName() const override { 28 | return "eliminate_nop_monotone_argmax"; 29 | } 30 | 31 | static inline bool satisfies_monotone_condition(int64_t axis, Node* node) { 32 | if (monotone_node_no_axis_kind.find(node->kind()) != 33 | monotone_node_no_axis_kind.end()) { 34 | return true; 35 | } 36 | if (monotone_node_axis_kind.find(node->kind()) != 37 | monotone_node_axis_kind.end()) { 38 | if (node->hasAttribute(kaxis)) { 39 | return axis == node->i(kaxis); 40 | } 41 | } 42 | return false; 43 | } 44 | 45 | bool patternMatchPredicate(Node* node) override { 46 | if (node->kind() == kArgMax) { 47 | if (node->hasAttribute(kaxis)) { 48 | auto node_axis = node->i(kaxis); 49 | return node->inputs().size() == 1 && 50 | satisfies_monotone_condition(node_axis, node->input()->node()); 51 | } 52 | } 53 | return false; 54 | } 55 | 56 | bool runTransform(Node* node, Graph&, NodeDestroyType&) override { 57 | Node* monotone_node = node->input()->node(); 58 | if (monotone_node->output()->uses().size() == 1) { 59 | const bool replacing_success = tryReplacingAllUsesWith( 60 | monotone_node->output(), monotone_node->input()); 61 | if (!replacing_success) { 62 | return false; 63 | } 64 | monotone_node->destroy(); 65 | return true; 66 | } 67 | return false; 68 | } 69 | }; 70 | 71 | } // namespace optimization 72 | } // namespace ONNX_NAMESPACE 73 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_pad.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/defs/tensor_util.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/pass_util.h" 13 | #include "onnxoptimizer/passes/string_utils.h" 14 | #include "onnxoptimizer/passes/logging.h" 15 | 16 | 17 | namespace ONNX_NAMESPACE { 18 | namespace optimization { 19 | 20 | struct EliminateNopPad final : public PredicateBasedPass { 21 | explicit EliminateNopPad() 22 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 23 | PassOptimizationType::Compute) {} 24 | 25 | std::string getPassName() const override { 26 | return "eliminate_nop_pad"; 27 | } 28 | 29 | static bool is_nop_pad(Node* node, Graph& graph) { 30 | std::vector pads; 31 | if (!GetValueFromAttrOrInput(node, kpads, 1, pads) || pads.empty()) { 32 | return false; 33 | } 34 | VLOG(1) << Str("pads",pads); 35 | for (const auto& p : pads) { 36 | if (p != 0) { 37 | return false; 38 | } 39 | } 40 | return true; 41 | } 42 | 43 | bool patternMatchPredicate(Node* node) override { 44 | return node->kind() == kPad; 45 | } 46 | 47 | bool runTransform(Node* node, Graph& graph, 48 | NodeDestroyType& destroy_current) override { 49 | if (!is_nop_pad(node, graph)) 50 | return false; 51 | const bool replacing_success = 52 | tryReplacingAllUsesWith(node->output(), node->inputs()[0]); 53 | if (!replacing_success) { 54 | return false; 55 | } 56 | destroy_current = NodeDestroyType::DestroyOne; 57 | return true; 58 | } 59 | }; 60 | 61 | } // namespace optimization 62 | } // namespace ONNX_NAMESPACE 63 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_reshape.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct EliminateNopReshape final : public PredicateBasedPass { 17 | explicit EliminateNopReshape() 18 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 19 | PassOptimizationType::Compute) {} 20 | 21 | std::string getPassName() const override { 22 | return "eliminate_nop_reshape"; 23 | } 24 | 25 | bool patternMatchPredicate(Node *node) override { 26 | return node->kind() == kReshape && !node->inputs()[0]->sizes().empty() && 27 | IsConstantTensor(node, 1); 28 | } 29 | 30 | bool runTransform(Node *node, Graph &graph, 31 | NodeDestroyType &destroy_current) override { 32 | const auto &old_shape = node->inputs()[0]->sizes(); 33 | const auto *new_shape_input = node->inputs()[1]; 34 | 35 | const Tensor *new_shape_tensor = FetchConstantTensor(new_shape_input); 36 | if (!new_shape_tensor) { 37 | return false; 38 | } 39 | 40 | if (new_shape_tensor->elem_type() != 41 | ONNX_NAMESPACE::TensorProto_DataType_INT64) { 42 | return false; 43 | } 44 | const auto new_shape = ParseTensorData(new_shape_tensor); 45 | 46 | if (new_shape.size() != old_shape.size()) { 47 | return false; 48 | } 49 | 50 | int unknown_dim_count = 0; 51 | for (int i = 0; i < new_shape.size(); ++i) { 52 | const auto new_dim = new_shape[i]; 53 | // the dim can be copied from the input only when allowzero == 0 54 | if (new_dim == 0 && !(node->hasAttribute(Symbol("allowzero")) && 55 | node->i(Symbol("allowzero")) == 1)) { 56 | continue; 57 | } 58 | if (old_shape[i].is_int) { 59 | if (new_dim == -1) { 60 | unknown_dim_count++; 61 | } else if (old_shape[i].dim != new_dim) { 62 | return false; 63 | } 64 | } else { 65 | unknown_dim_count++; 66 | } 67 | } 68 | if (unknown_dim_count > 1) { 69 | return false; 70 | } 71 | 72 | const bool replacing_success = 73 | tryReplacingAllUsesWith(node->output(), node->inputs()[0]); 74 | if (!replacing_success) { 75 | return false; 76 | } 77 | destroy_current = NodeDestroyType::DestroyOne; 78 | return true; 79 | } 80 | }; 81 | 82 | } // namespace optimization 83 | } // namespace ONNX_NAMESPACE 84 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_split.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct EliminateNopSplit final : public PredicateBasedPass { 17 | explicit EliminateNopSplit() 18 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 19 | PassOptimizationType::Memory) {} 20 | 21 | std::string getPassName() const override { 22 | return "eliminate_nop_split"; 23 | } 24 | 25 | bool patternMatchPredicate(Node* node) override { 26 | return CheckKind(node, "Split") && node->inputs()[0]->has_sizes() && 27 | node->outputs().size() == 1; 28 | } 29 | 30 | bool runTransform(Node* node, Graph& graph, 31 | NodeDestroyType& destroy_current) override { 32 | auto* input = node->inputs()[0]; 33 | const auto& sizes = input->sizes(); 34 | int64_t axis = GetValueFromAttrWithDefault(node, kaxis, int64_t{0}); 35 | axis = AddYIfNegative(axis, static_cast(sizes.size())); 36 | std::vector split; 37 | if (GetValueFromAttrOrInput(node, ksplit, 1, split) && !split.empty() && 38 | (!sizes[axis].is_int || sizes[axis].dim != split[0])) { 39 | return false; 40 | } 41 | 42 | const bool replacing_success = 43 | tryReplacingAllUsesWith(node->output(), input); 44 | if (!replacing_success) { 45 | return false; 46 | } 47 | destroy_current = NodeDestroyType::DestroyOne; 48 | return true; 49 | } 50 | }; 51 | 52 | } // namespace optimization 53 | } // namespace ONNX_NAMESPACE 54 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_transpose.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct EliminateNopTranspose final : public PredicateBasedPass { 16 | explicit EliminateNopTranspose() 17 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 18 | PassOptimizationType::Compute) {} 19 | 20 | std::string getPassName() const override { 21 | return "eliminate_nop_transpose"; 22 | } 23 | 24 | static bool is_nop_transpose(const std::vector& perm) { 25 | for (size_t i = 0; i < perm.size(); i++) 26 | if (perm[i] != (int)i) 27 | return false; 28 | return true; 29 | } 30 | 31 | bool patternMatchPredicate(Node* node) override { 32 | return (node->kind() == kTranspose && node->hasAttribute(kperm)) && 33 | is_nop_transpose(node->is(kperm)); 34 | } 35 | 36 | bool runTransform(Node* node, Graph& graph, 37 | NodeDestroyType& destroy_current) override { 38 | const bool replacing_success = 39 | tryReplacingAllUsesWith(node->output(), node->input()); 40 | if (!replacing_success) { 41 | return false; 42 | } 43 | destroy_current = NodeDestroyType::DestroyOne; 44 | return true; 45 | } 46 | }; 47 | 48 | } // namespace optimization 49 | } // namespace ONNX_NAMESPACE 50 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_nop_with_unit.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/defs/tensor_util.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/pass_util.h" 13 | 14 | namespace ONNX_NAMESPACE { 15 | namespace optimization { 16 | 17 | struct EliminateOpWithUnit final : public PredicateBasedPass { 18 | explicit EliminateOpWithUnit() 19 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 20 | PassOptimizationType::Compute) {} 21 | 22 | std::string getPassName() const override { 23 | return "eliminate_nop_with_unit"; 24 | } 25 | 26 | #define PROTO_DTYPE_LIST(_) \ 27 | _(TensorProto_DataType_BFLOAT16) \ 28 | _(TensorProto_DataType_FLOAT16) \ 29 | _(TensorProto_DataType_FLOAT) \ 30 | _(TensorProto_DataType_DOUBLE) \ 31 | _(TensorProto_DataType_UINT8) \ 32 | _(TensorProto_DataType_INT8) \ 33 | _(TensorProto_DataType_UINT16) \ 34 | _(TensorProto_DataType_INT16) \ 35 | _(TensorProto_DataType_UINT32) \ 36 | _(TensorProto_DataType_INT32) \ 37 | _(TensorProto_DataType_UINT64) \ 38 | _(TensorProto_DataType_INT64) \ 39 | _(TensorProto_DataType_BOOL) 40 | 41 | bool isAllOf(const Tensor& tensor, int value) { 42 | int elem_type = tensor.elem_type(); 43 | #define CASE_BRANCH_CONTENT(pb_dtype) \ 44 | case pb_dtype: { \ 45 | using cpp_dtype = ToCppType::type; \ 46 | const std::vector data = ParseTensorData(&tensor); \ 47 | return std::all_of( \ 48 | data.cbegin(), data.cend(), \ 49 | [value](const cpp_dtype& v) { return v == cpp_dtype(value); }); \ 50 | } 51 | 52 | switch (elem_type) { PROTO_DTYPE_LIST(CASE_BRANCH_CONTENT) } 53 | #undef CASE_BRANCH_CONTENT 54 | #undef PROTO_DTYPE_LIST 55 | return false; 56 | } 57 | 58 | bool isAllOne(const Tensor& tensor) { 59 | if (tensor.elem_type() == TensorProto_DataType_FLOAT16) { 60 | return isAllOf(tensor, 0x3c00); 61 | } 62 | if (tensor.elem_type() == TensorProto_DataType_BFLOAT16) { 63 | return isAllOf(tensor, BFloat16(1.f)); 64 | } 65 | return isAllOf(tensor, 1); 66 | } 67 | 68 | bool patternMatchPredicate(Node* node) override { 69 | return true; 70 | }; 71 | 72 | bool isUnit(const Tensor& tensor, NodeKind kind, int index) { 73 | if (kind == Symbol("And") || kind == kMul) { 74 | return isAllOne(tensor); 75 | } 76 | if (kind == Symbol("Or") || kind == kAdd) { 77 | return isAllOf(tensor, 0); 78 | } 79 | if (kind == kSub) { 80 | return index == 1 && isAllOf(tensor, 0); 81 | } 82 | if (kind == kDiv || kind == kPow) { 83 | return index == 1 && isAllOne(tensor); 84 | } 85 | if (kind == kConcat) { 86 | return ElemCntOfTensor(tensor) == 0; 87 | } 88 | return false; 89 | } 90 | bool isBroadcastBinaryOp(NodeKind kind) { 91 | return kind == kAdd || kind == kMul || kind == kDiv || kind == kSub || 92 | kind == kPow || kind == Symbol("And") || kind == Symbol("Or"); 93 | } 94 | 95 | bool runTransform(Node* node, Graph& graph, 96 | NodeDestroyType& destroy_current) override { 97 | for (int i = 0; i < node->inputs().size(); i++) { 98 | auto* input = node->inputs()[i]; 99 | if (auto* tensor = FetchConstantTensor(input)) { 100 | NodeKind kind = node->kind(); 101 | if (isUnit(*tensor, kind, i)) { 102 | if (isBroadcastBinaryOp(kind)) { 103 | // replace the node with the other input 104 | auto* other_input = node->inputs()[1 - i]; 105 | if (isABroadcastToB(tensor->sizes(), other_input->sizes())) { 106 | return tryReplacingAllUsesWith(node->output(), other_input); 107 | } 108 | } 109 | if (kind == kConcat) { 110 | node->removeInput(i); 111 | return true; 112 | } 113 | } 114 | } 115 | } 116 | 117 | return false; 118 | } 119 | }; 120 | 121 | } // namespace optimization 122 | } // namespace ONNX_NAMESPACE 123 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_shape_gather.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/defs/tensor_util.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/pass_util.h" 13 | 14 | namespace ONNX_NAMESPACE { 15 | namespace optimization { 16 | 17 | struct EliminateShapeGather final : public PredicateBasedPass { 18 | explicit EliminateShapeGather() 19 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 20 | PassOptimizationType::Compute) {} 21 | 22 | std::string getPassName() const override { 23 | return "eliminate_shape_gather"; 24 | } 25 | 26 | bool patternMatchPredicate(Node *node) override { 27 | return CheckKind(node, "Gather", 0, "Shape") && IsConstantTensor(node, 1) && 28 | HasDimsOfInputOfNode(PrevNode(node, 0), 0); 29 | } 30 | 31 | bool runTransform(Node *node, Graph &graph, 32 | NodeDestroyType &destroy_current) override { 33 | auto *x = node->inputs()[0]; 34 | auto *indices = node->inputs()[1]; 35 | Node *shape = x->node(); 36 | const auto &dims = shape->input()->sizes(); 37 | 38 | int64_t indices_val; 39 | if (!FetchSoleIntValueOfTensor(indices, indices_val)) { 40 | return false; 41 | } 42 | 43 | const auto [start, end] = FetchStartAndEndAttrOfShape(shape); 44 | 45 | indices_val = AddYIfNegative(indices_val, end - start); 46 | indices_val += start; 47 | 48 | ONNX_ASSERT(indices_val < dims.size()); 49 | 50 | if (!dims[indices_val].is_int || dims[indices_val].dim == -1) { 51 | return false; 52 | } 53 | 54 | Tensor tensor; 55 | if (indices->sizes().size() == 1) { 56 | tensor.sizes().push_back(1); 57 | } 58 | tensor.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 59 | tensor.int64s().push_back(dims[indices_val].dim); 60 | Value *value = graph.addInitializerAndCreateValue(tensor); 61 | 62 | const bool replacing_success = 63 | tryReplacingAllUsesWith(node->output(), value); 64 | if (!replacing_success) { 65 | return false; 66 | } 67 | destroy_current = NodeDestroyType::DestroyOne; 68 | return true; 69 | } 70 | }; 71 | 72 | } // namespace optimization 73 | } // namespace ONNX_NAMESPACE 74 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_shape_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/common/tensor.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/pass_util.h" 13 | 14 | namespace ONNX_NAMESPACE { 15 | namespace optimization { 16 | 17 | struct EliminateShapeOp final : public PredicateBasedPass { 18 | explicit EliminateShapeOp() 19 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 20 | PassOptimizationType::Compute) {} 21 | 22 | std::string getPassName() const override { 23 | return "eliminate_shape_op"; 24 | } 25 | 26 | bool patternMatchPredicate(Node *node) override { 27 | if (!CheckKind(node, "Shape") || !HasDimsOfInputOfNode(node, 0)) { 28 | return false; 29 | } 30 | const Value *input = node->input(); 31 | 32 | const auto [start, end] = FetchStartAndEndAttrOfShape(node); 33 | 34 | return std::all_of( 35 | input->sizes().cbegin() + start, input->sizes().cbegin() + end, 36 | [](const auto &dim) { return dim.is_int && dim.dim >= 0; }); 37 | } 38 | 39 | bool runTransform(Node *node, Graph &graph, 40 | NodeDestroyType &destroy_current) override { 41 | const Value *input = node->input(); 42 | const auto [start, end] = FetchStartAndEndAttrOfShape(node); 43 | 44 | Tensor tensor; 45 | tensor.sizes().push_back(end - start); 46 | tensor.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 47 | std::transform(input->sizes().begin() + start, input->sizes().begin() + end, 48 | std::back_inserter(tensor.int64s()), 49 | [](const auto &dim) { return dim.dim; }); 50 | 51 | Value *value = graph.addInitializerAndCreateValue(tensor); 52 | 53 | const bool replacing_success = 54 | tryReplacingAllUsesWith(node->output(), value); 55 | if (!replacing_success) { 56 | return false; 57 | } 58 | destroy_current = NodeDestroyType::DestroyOne; 59 | return true; 60 | } 61 | }; 62 | 63 | } // namespace optimization 64 | } // namespace ONNX_NAMESPACE 65 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_slice_after_shape.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/defs/tensor_util.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/pass_util.h" 13 | 14 | namespace ONNX_NAMESPACE { 15 | namespace optimization { 16 | 17 | struct EliminateSliceAfterShape final : public PredicateBasedPass { 18 | explicit EliminateSliceAfterShape() 19 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 20 | PassOptimizationType::Compute) {} 21 | 22 | std::string getPassName() const override { 23 | return "eliminate_slice_after_shape"; 24 | } 25 | 26 | bool patternMatchPredicate(Node *node) override { 27 | return CheckKind(node, kSlice, 0, "Shape") && 28 | HasDimsOfInputOfNode(PrevNode(node, 0), 0); 29 | } 30 | 31 | bool runTransform(Node *node, Graph &graph, 32 | NodeDestroyType &destroy_current) override { 33 | Node *shape_node = PrevNode(node, 0); 34 | const auto &dims_of_shape_node_input = shape_node->input()->sizes(); 35 | std::vector result_of_shape_op; 36 | 37 | { 38 | const auto [shape_start, shape_end] = 39 | FetchStartAndEndAttrOfShape(shape_node); 40 | 41 | for (int i = shape_start; i < shape_end; ++i) { 42 | result_of_shape_op.push_back(dims_of_shape_node_input[i]); 43 | } 44 | } 45 | 46 | int64_t slice_start = 0, slice_end = result_of_shape_op.size(), 47 | slice_step = 1; 48 | const int opset_version = getOpsetVersion(graph); 49 | if (opset_version < 10) { 50 | if (!FetchSoleIntValueOfAttr(node, kstarts, slice_start) || 51 | !FetchSoleIntValueOfAttr(node, kends, slice_end)) { 52 | return false; 53 | } 54 | } else { 55 | if (!FetchSoleIntValueOfTensor(node->inputs()[1], slice_start) || 56 | !FetchSoleIntValueOfTensor(node->inputs()[2], slice_end) || 57 | (node->inputs().size() == 5 && 58 | !FetchSoleIntValueOfTensor(node->inputs()[4], slice_step)) || 59 | slice_step == 0) { 60 | return false; 61 | } 62 | } 63 | 64 | slice_start = 65 | AddYIfNegative(slice_start, result_of_shape_op.size()); 66 | slice_end = AddYIfNegative(slice_end, result_of_shape_op.size()); 67 | 68 | std::vector result_of_slice_op; 69 | if (slice_step > 0) { 70 | slice_start = 71 | std::clamp(slice_start, 0, result_of_shape_op.size()); 72 | slice_end = std::clamp(slice_end, 0, result_of_shape_op.size()); 73 | for (; slice_start < slice_end; slice_start += slice_step) { 74 | assert(slice_start < result_of_shape_op.size()); 75 | const auto &d = dims_of_shape_node_input[slice_start]; 76 | if (!d.is_int) { 77 | return false; 78 | } 79 | result_of_slice_op.push_back(d.dim); 80 | } 81 | } else { 82 | slice_start = 83 | std::clamp(slice_start, 0, result_of_shape_op.size() - 1); 84 | slice_end = std::clamp(slice_end, -1, result_of_shape_op.size()); 85 | for (; slice_start > slice_end; slice_start += slice_step) { 86 | assert(slice_start < result_of_shape_op.size() && slice_start >= 0); 87 | const auto &d = dims_of_shape_node_input[slice_start]; 88 | if (!d.is_int) { 89 | return false; 90 | } 91 | result_of_slice_op.push_back(d.dim); 92 | } 93 | } 94 | Tensor tensor; 95 | tensor.sizes().push_back(result_of_slice_op.size()); 96 | tensor.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 97 | tensor.int64s().swap(result_of_slice_op); 98 | Value *value = graph.addInitializerAndCreateValue(tensor); 99 | 100 | const bool replacing_success = 101 | tryReplacingAllUsesWith(node->output(), value); 102 | if (!replacing_success) { 103 | return false; 104 | } 105 | destroy_current = NodeDestroyType::DestroyOne; 106 | return true; 107 | } 108 | }; 109 | 110 | } // namespace optimization 111 | } // namespace ONNX_NAMESPACE 112 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/eliminate_unused_initializer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // A, B, C are in the initializer list 12 | // D = Add(B, C) 13 | // After: 14 | // B, C are in the initializer list and A is removed 15 | // D = Add(B, C) 16 | // 17 | // this pass can handle the case satisfy all following conditions: 18 | // condition 1: A is not used as any node's input 19 | // condition 2: A is not an output 20 | 21 | #include "onnxoptimizer/pass.h" 22 | 23 | namespace ONNX_NAMESPACE { 24 | namespace optimization { 25 | 26 | struct EliminateUnusedInitializer final : public FullGraphBasedPass { 27 | explicit EliminateUnusedInitializer() 28 | : FullGraphBasedPass(PassType::Nop, PassEfficiency::Complete, 29 | PassOptimizationType::Memory) {} 30 | 31 | std::string getPassName() const override { 32 | return "eliminate_unused_initializer"; 33 | } 34 | 35 | PassAnalysisType getPassAnalysisType() const override { 36 | return PassAnalysisType::Empty; 37 | } 38 | 39 | void erase_used_initializers( 40 | Graph& g, std::unordered_set* initializer_names) { 41 | for (auto output : g.outputs()) { 42 | initializer_names->erase(output->uniqueName()); 43 | } 44 | for (auto it = g.begin(); it != g.end(); ++it) { 45 | auto* n = *it; 46 | DescendOnGraphAttributesUnconstrained( 47 | n, [this, initializer_names](Graph& graph) { 48 | erase_used_initializers(graph, initializer_names); 49 | }); 50 | for (auto* input : n->inputs()) { 51 | initializer_names->erase(input->uniqueName()); 52 | } 53 | } 54 | } 55 | 56 | void eliminate_unused_initializer(Graph& graph) { 57 | std::unordered_set initializer_names( 58 | graph.initializer_names().begin(), graph.initializer_names().end()); 59 | erase_used_initializers(graph, &initializer_names); 60 | 61 | // remove initializer and input if need 62 | for (std::string name : initializer_names) { 63 | graph.eraseInitializer(name); 64 | auto iter = std::find_if( 65 | graph.inputs().begin(), graph.inputs().end(), 66 | [&name](Value* input) { return input->uniqueName() == name; }); 67 | if (iter != graph.inputs().end()) { 68 | graph.eraseInput(std::distance(graph.inputs().begin(), iter)); 69 | } 70 | } 71 | } 72 | 73 | std::shared_ptr runPass(Graph& graph) override { 74 | eliminate_unused_initializer(graph); 75 | return std::shared_ptr(new PostPassAnalysis()); 76 | } 77 | }; 78 | 79 | } // namespace optimization 80 | } // namespace ONNX_NAMESPACE 81 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/extract_constant_to_initializer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // A = Constant() 12 | // After: 13 | // A is in the initializer list 14 | // 15 | // this pass can handle the case satisfy all following conditions: 16 | // condition 1: A is the output of a Constant node 17 | #include "onnx/common/assertions.h" 18 | #include "onnxoptimizer/pass.h" 19 | 20 | namespace ONNX_NAMESPACE { 21 | namespace optimization { 22 | 23 | struct ExtractConstantToInitializer final : public PredicateBasedPass { 24 | explicit ExtractConstantToInitializer() 25 | : PredicateBasedPass(PassType::Nop, PassEfficiency::Complete, 26 | PassOptimizationType::Memory) {} 27 | 28 | std::string getPassName() const override { 29 | return "extract_constant_to_initializer"; 30 | } 31 | 32 | bool patternMatchPredicate(Node* node) override { 33 | return node->kind() == kConstant; 34 | } 35 | 36 | bool runTransform(Node* node, Graph& graph, 37 | NodeDestroyType& destroy_current) override { 38 | Tensor t = node->t(kvalue); 39 | Value* new_init; 40 | if (node->output()->has_unique_name() && 41 | std::find(graph.outputs().rbegin(), graph.outputs().rend(), 42 | node->output()) == graph.outputs().rend()) { 43 | t.setName(node->output()->uniqueName()); 44 | new_init = graph.addInitializerAndCreateValue(t); 45 | node->output()->setUniqueName( 46 | ONNX_NAMESPACE::to_string(graph.getNextUnique()), false); 47 | } else { 48 | new_init = graph.addInitializerAndCreateValue(t); 49 | } 50 | const bool replacing_success = 51 | tryReplacingAllUsesWith(node->output(), new_init); 52 | if (!replacing_success) { 53 | return false; 54 | } 55 | destroy_current = NodeDestroyType::DestroyOne; 56 | return true; 57 | } 58 | }; 59 | 60 | } // namespace optimization 61 | } // namespace ONNX_NAMESPACE 62 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_add_bias_into_conv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // Z = Conv(X, Y) 12 | // B = Z + A 13 | // After: 14 | // B = Conv(X, Y, A) 15 | // 16 | // the pass can handle the following cases: 17 | // case 1: A is 1D tensor and A.dim[0] == Z.dim[1] 18 | // case 2: A is 1-element 1D tensor 19 | 20 | #include 21 | 22 | #include "onnx/common/assertions.h" 23 | #include "onnxoptimizer/pass.h" 24 | #include "onnxoptimizer/passes/pass_util.h" 25 | 26 | namespace ONNX_NAMESPACE { 27 | namespace optimization { 28 | 29 | struct FuseAddBiasIntoConv final : public PredicateBasedPass { 30 | explicit FuseAddBiasIntoConv() 31 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 32 | PassOptimizationType::Compute) {} 33 | std::string getPassName() const override { 34 | return "fuse_add_bias_into_conv"; 35 | } 36 | bool patternMatchPredicate(Node *node) override { 37 | return CheckKind(node, kAdd, 0, kConv) && 38 | GetInputsOfPreNode(node, 0).size() == 2; 39 | } 40 | static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector &axes, 41 | Value *input, Node *target_node, 42 | BuiltinSymbol k) { 43 | assert(k == kSqueeze || k == kUnsqueeze); 44 | Node *squeeze = graph.create(k, 1); 45 | int opset_version = getOpsetVersion(graph); 46 | squeeze->addInput(input); 47 | int version_threshold = 13; 48 | if (opset_version < version_threshold && opset_version != 0) { 49 | squeeze->is_(kaxes, std::move(axes)); 50 | } else { 51 | Tensor t; 52 | t.sizes().push_back(axes.size()); 53 | t.int64s() = axes; 54 | t.elem_type() = TensorProto_DataType_INT64; 55 | Value *tv = graph.addInitializerAndCreateValue(t); 56 | squeeze->addInput(tv); 57 | } 58 | squeeze->insertBefore(target_node); 59 | return squeeze; 60 | } 61 | bool runTransform(Node *n, Graph &graph, 62 | NodeDestroyType &destroy_current) override { 63 | // due to current broadcasting's constraint, Conv has to be the first 64 | // operand 65 | destroy_current = NodeDestroyType::DestroyZero; 66 | auto orig_conv = n->inputs()[0]; 67 | auto orig_bias = n->inputs()[1]; 68 | // check if bias is Const or in graph's initializers 69 | if (orig_bias->node()->kind() != kConstant && 70 | orig_bias->node()->kind() != kParam) { 71 | return false; 72 | } 73 | // check if conv is only used by Add 74 | if (orig_conv->uses().size() > 1) { 75 | return false; 76 | } 77 | auto conv_shape = orig_conv->sizes(); 78 | auto bias_shape = orig_bias->sizes(); 79 | auto weight_shape = orig_conv->node()->inputs()[1]->sizes(); 80 | int64_t M = -1; 81 | int64_t rank = -1; 82 | // try to get feature M and rank from conv_shape 83 | if (conv_shape.size() > 1 && conv_shape[1].is_int) { 84 | M = conv_shape[1].dim; 85 | rank = conv_shape.size(); 86 | } 87 | // try to get feature M and rank from weight_shape 88 | if (weight_shape.size() > 0 && weight_shape[0].is_int) { 89 | ONNX_ASSERT(M == -1 || M == weight_shape[0].dim); 90 | M = weight_shape[0].dim; 91 | ONNX_ASSERT(rank == -1 || 92 | rank == static_cast(weight_shape.size())); 93 | rank = weight_shape.size(); 94 | } 95 | int64_t num_el = 1; 96 | for (int i = 0; i < static_cast(bias_shape.size()); ++i) { 97 | if (bias_shape[i].is_int) { 98 | num_el *= bias_shape[i].dim; 99 | } else { 100 | num_el = -1; 101 | return false; 102 | } 103 | } 104 | if (M == -1 || num_el == -1) { 105 | // No enough information, bail out 106 | return false; 107 | } 108 | if (rank < static_cast(bias_shape.size())) { 109 | return false; 110 | } 111 | if (num_el == 1) { 112 | if (orig_bias->node()->kind() != kParam && 113 | orig_conv->node()->isBefore(orig_bias->node())) { 114 | orig_bias->node()->moveBefore(orig_conv->node()); 115 | } 116 | Value *conv_3rd_input = orig_bias; 117 | if (bias_shape.size() > 1) { 118 | std::vector axes(bias_shape.size() - 1); 119 | std::iota(axes.begin(), axes.end(), 0); 120 | Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input, 121 | orig_conv->node(), kSqueeze); 122 | conv_3rd_input = squeeze->output(); 123 | } else if (bias_shape.size() == 0) { 124 | std::vector axes = {0}; 125 | Node *unsqueeze = makeSqueezeOrUnsqueeze(graph, axes, conv_3rd_input, 126 | orig_conv->node(), kUnsqueeze); 127 | conv_3rd_input = unsqueeze->output(); 128 | } 129 | if (M > 1) { 130 | Node *constant = graph.create(kConstant, 1); 131 | Tensor t; 132 | t.sizes().push_back(static_cast(1)); 133 | t.int64s().push_back(M); 134 | t.elem_type() = TensorProto_DataType_INT64; 135 | Symbol sym = Symbol("value"); 136 | constant->t_(sym, t); 137 | std::vector s = {1}; 138 | constant->output()->setSizes(s); 139 | constant->output()->setElemType(TensorProto_DataType_INT64); 140 | constant->insertBefore(orig_conv->node()); 141 | Node *tile = graph.create(kTile, 1); 142 | tile->addInput(conv_3rd_input); 143 | tile->addInput(constant->output()); 144 | conv_3rd_input = tile->output(); 145 | tile->insertBefore(orig_conv->node()); 146 | } 147 | orig_conv->node()->addInput(conv_3rd_input); 148 | } else if (rank > static_cast(bias_shape.size()) + 1) { 149 | return false; 150 | } else if (num_el == M && 151 | bias_shape[1 + bias_shape.size() - static_cast(rank)] 152 | .dim == M) { 153 | ONNX_ASSERT(bias_shape.size() > 1); 154 | if (orig_bias->node()->kind() != kParam && 155 | orig_conv->node()->isBefore(orig_bias->node())) { 156 | orig_bias->node()->moveBefore(orig_conv->node()); 157 | } 158 | std::vector axes(bias_shape.size()); 159 | std::iota(axes.begin(), axes.end(), static_cast(0)); 160 | axes.erase(axes.begin() + 161 | (1 + bias_shape.size() - static_cast(rank))); 162 | Node *squeeze = makeSqueezeOrUnsqueeze(graph, axes, orig_bias, 163 | orig_conv->node(), kSqueeze); 164 | orig_conv->node()->addInput(squeeze->output()); 165 | } else { 166 | return false; 167 | } 168 | if (orig_conv->sizes().size() == 0 && n->output()->sizes().size() > 0) { 169 | orig_conv->setSizes(n->output()->sizes()); 170 | } 171 | if (n->output()->elemType() != TensorProto_DataType_UNDEFINED) { 172 | orig_conv->setElemType(n->output()->elemType()); 173 | } 174 | const bool replacing_success = 175 | tryReplacingAllUsesWith(n, orig_conv->node()); 176 | if (!replacing_success) { 177 | return false; 178 | } 179 | destroy_current = NodeDestroyType::DestroyOne; 180 | return true; 181 | } 182 | }; 183 | 184 | } // namespace optimization 185 | } // namespace ONNX_NAMESPACE 186 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_bn_into_conv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // conv = Conv() 12 | // bn = BatchNormalization() 13 | // 14 | // After: 15 | // bn is deleted 16 | // new inputs/initializers to conv are added to graph 17 | // any no longer used inputs/initializers are erased from graph 18 | // 19 | // this pass can handle the case satisfy all following conditions: 20 | // condition 1: Run in testing mode 21 | // condition 2: Inputs 1 - 4 of bn are all initializer_size 22 | // condition 3: Output of initial conv has no other uses 23 | // condition 3: Currently works for only DOUBLE, FLOAT32 tensor types 24 | // 25 | // Formula for transformation 26 | // $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ 27 | // $$ X_{conv} = X * W + b_{conv} $$ 28 | // thus, substituting $X$ with $X_{conv}$ in the BN equation we get: 29 | // $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - 30 | // m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ or 31 | // $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$ 32 | // $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ 33 | 34 | #include "onnx/common/assertions.h" 35 | #include "onnxoptimizer/pass.h" 36 | #include "onnxoptimizer/passes/pass_util.h" 37 | 38 | namespace ONNX_NAMESPACE { 39 | namespace optimization { 40 | // TODO: Currently broken for complex values and float16 41 | struct FuseBNIntoConv final : public PredicateBasedPass { 42 | explicit FuseBNIntoConv() 43 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 44 | PassOptimizationType::Compute) {} 45 | 46 | std::string getPassName() const override { 47 | return "fuse_bn_into_conv"; 48 | } 49 | 50 | bool modify_conv(Node* conv, Node* bn, Graph& graph) { 51 | const auto& bn_inputs = bn->inputs(); 52 | const auto& conv_inputs = conv->inputs(); 53 | 54 | auto bn_scale = *FetchConstantTensor(bn_inputs[1]); 55 | auto bn_bais = *FetchConstantTensor(bn_inputs[2]); 56 | auto bn_mean = *FetchConstantTensor(bn_inputs[3]); 57 | auto bn_var = *FetchConstantTensor(bn_inputs[4]); 58 | auto conv_W = *FetchConstantTensor(conv_inputs[1]); 59 | bn_scale.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 60 | bn_bais.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 61 | bn_mean.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 62 | bn_var.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 63 | conv_W.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 64 | 65 | /// scale bais mean var must be the same shape (C) 66 | ONNX_ASSERT(bn_scale.sizes() == bn_bais.sizes()); 67 | ONNX_ASSERT(bn_scale.sizes() == bn_mean.sizes()); 68 | ONNX_ASSERT(bn_scale.sizes() == bn_var.sizes()); 69 | ONNX_ASSERT(bn_scale.sizes().size() == 1); 70 | int64_t C = bn_scale.sizes()[0]; 71 | ONNX_ASSERT(conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C); 72 | if (bn_scale.elem_type() != bn_bais.elem_type() || 73 | bn_scale.elem_type() != bn_mean.elem_type() || 74 | bn_scale.elem_type() != bn_var.elem_type() || 75 | bn_scale.elem_type() != conv_W.elem_type()) { 76 | return false; 77 | } 78 | 79 | Value* conv_bias = nullptr; 80 | if (conv_inputs.size() == 3) { 81 | if (!IsConstantTensor(conv_inputs[2])) { 82 | return false; 83 | } 84 | auto bc_t = *FetchConstantTensor(conv_inputs[2]); 85 | bc_t.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 86 | ONNX_ASSERT(bc_t.sizes() == bn_scale.sizes()); 87 | conv_bias = graph.addInitializerAndCreateValue(bc_t); 88 | } else { 89 | Tensor bc_t; 90 | bc_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; 91 | bc_t.sizes().push_back(C); 92 | for (int i = 0; i < C; ++i) { 93 | bc_t.floats().push_back(float{0}); 94 | } 95 | conv_bias = graph.addInitializerAndCreateValue(bc_t); 96 | } 97 | 98 | /// scalar 99 | Tensor eps_t; 100 | eps_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; 101 | eps_t.floats().push_back(GetValueFromAttrWithDefault(bn, kepsilon, 1e-5f)); 102 | Value* eps = graph.addInitializerAndCreateValue(eps_t); 103 | 104 | Node* cast = graph.create(kCast, 1); 105 | cast->addInput(eps); 106 | cast->i_(kto, bn_var.elem_type()); 107 | cast->insertBefore(conv); 108 | 109 | Node* var_add = graph.create(kAdd, 1); 110 | var_add->insertAfter(cast); 111 | var_add->addInput(graph.addInitializerAndCreateValue(bn_var)); 112 | var_add->addInput(cast->output()); 113 | 114 | Node* sqrt = graph.create(kSqrt, 1); 115 | sqrt->insertAfter(var_add); 116 | sqrt->addInput(var_add->output()); 117 | 118 | Node* scale = graph.create(kDiv, 1); 119 | scale->insertAfter(sqrt); 120 | scale->addInput(graph.addInitializerAndCreateValue(bn_scale)); 121 | scale->addInput(sqrt->output()); 122 | 123 | Node* unsqueeze = graph.create(kUnsqueeze, 1); 124 | unsqueeze->insertAfter(scale); 125 | unsqueeze->addInput(scale->output()); 126 | std::vector insert_dims; 127 | for (int i = 1; i < conv_W.sizes().size(); ++i) { 128 | insert_dims.push_back(i); 129 | } 130 | if (getOpsetVersion(graph) >= 13) { 131 | Tensor shape_s_t; 132 | shape_s_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 133 | shape_s_t.sizes().push_back(insert_dims.size()); 134 | shape_s_t.int64s() = insert_dims; 135 | unsqueeze->addInput(graph.addInitializerAndCreateValue(shape_s_t)); 136 | } else { 137 | unsqueeze->is_(kaxes, std::move(insert_dims)); 138 | } 139 | 140 | Node* mul_w = graph.create(kMul, 1); 141 | mul_w->insertAfter(unsqueeze); 142 | mul_w->addInput(graph.addInitializerAndCreateValue(conv_W)); 143 | mul_w->addInput(unsqueeze->output()); 144 | 145 | Node* cast1 = graph.create(kCast, 1); 146 | cast1->insertAfter(mul_w); 147 | cast1->addInput(conv_bias); 148 | cast1->i_(kto, bn_mean.elem_type()); 149 | 150 | Node* sub = graph.create(kSub, 1); 151 | sub->insertAfter(cast1); 152 | sub->addInput(cast1->output()); 153 | sub->addInput(graph.addInitializerAndCreateValue(bn_mean)); 154 | 155 | Node* mul = graph.create(kMul, 1); 156 | mul->insertAfter(sub); 157 | mul->addInput(sub->output()); 158 | mul->addInput(scale->output()); 159 | 160 | Node* bias_add = graph.create(kAdd, 1); 161 | bias_add->insertAfter(mul); 162 | bias_add->addInput(mul->output()); 163 | bias_add->addInput(graph.addInitializerAndCreateValue(bn_bais)); 164 | 165 | Value* old_w_value = conv_inputs[1]; 166 | conv->replaceInput(1, mul_w->output()); 167 | if (old_w_value->uses().size() == 0) { 168 | graph.eraseInitializerAndInput(old_w_value); 169 | } 170 | 171 | if (conv_inputs.size() == 3) { 172 | Value* old_b_value = conv_inputs[2]; 173 | conv->replaceInput(2, bias_add->output()); 174 | if (old_b_value->uses().size() == 0) { 175 | graph.eraseInitializerAndInput(old_b_value); 176 | } 177 | } else { 178 | conv->addInput(bias_add->output()); 179 | } 180 | return true; 181 | } 182 | 183 | bool patternMatchPredicate(Node* n) override { 184 | return CheckKind(n, kBatchNormalization, 0, kConv) && 185 | GetValueFromAttrWithDefault(n, "training_mode", (int64_t)0) == 0 && 186 | n->input(0)->uses().size() == 1 && n->outputs().size() == 1 && 187 | IsConstantTensor(n, 1) && IsConstantTensor(n, 2) && 188 | IsConstantTensor(n, 3) && IsConstantTensor(n, 4) && 189 | IsConstantTensor(PrevNode(n, 0), 1); 190 | } 191 | bool runTransform(Node* n, Graph& graph, 192 | NodeDestroyType& destroy_current) override { 193 | Node* bn = n; 194 | Node* conv = PrevNode(n, 0); 195 | auto origInput = bn->inputs()[0]; 196 | if (!modify_conv(conv, bn, graph)) { 197 | destroy_current = NodeDestroyType::DestroyZero; 198 | return false; 199 | } 200 | // clean 201 | for (int i = 4; i >= 1; --i) { 202 | if (bn->inputs()[i]->uses().size() == 1) { 203 | auto input = bn->inputs()[i]; 204 | bn->removeInput(i); 205 | graph.eraseInitializerAndInput(input); 206 | } 207 | } 208 | const bool replacing_success = 209 | tryReplacingAllUsesWith(bn->output(), origInput); 210 | if (!replacing_success) { 211 | return false; 212 | } 213 | destroy_current = NodeDestroyType::DestroyOne; 214 | return true; 215 | } 216 | }; 217 | 218 | } // namespace optimization 219 | } // namespace ONNX_NAMESPACE 220 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_concat_into_reshape.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/defs/tensor_util.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/pass_util.h" 13 | 14 | namespace ONNX_NAMESPACE { 15 | namespace optimization { 16 | 17 | // before 18 | // Z = Reshape(X, Concat(...)) or Z = Reshape(X, Cast(Concat(...), to=INT64 )) 19 | // after 20 | // Z = Reshape(X, Y) , Y is a constant tensor 21 | 22 | // this pass can handle the case when: 23 | // 1. the number of unknown dims in the result of Concat should be not more 24 | // than 1 25 | 26 | struct FuseConcatIntoReshape final : public PredicateBasedPass { 27 | explicit FuseConcatIntoReshape() 28 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 29 | PassOptimizationType::Compute) {} 30 | 31 | std::string getPassName() const override { 32 | return "fuse_concat_into_reshape"; 33 | } 34 | 35 | inline bool matchConcatReshape(Node *node) { 36 | return CheckKind(node, kReshape, 1, kConcat) && 37 | node->input(1)->node()->i(kaxis) == 0; 38 | } 39 | 40 | inline bool matchConcatCastReshape(Node *node) { 41 | return CheckKind(node, kReshape, 1, kCast, 0, kConcat) && 42 | node->inputs()[1]->node()->i(kto) == 43 | ONNX_NAMESPACE::TensorProto_DataType_INT64 && 44 | PrevNode(node, 1, 0)->i(kaxis) == 0; 45 | } 46 | 47 | bool patternMatchPredicate(Node *node) override { 48 | return matchConcatReshape(node) || matchConcatCastReshape(node); 49 | } 50 | 51 | bool runTransform(Node *node, Graph &graph, 52 | NodeDestroyType &destroy_current) override { 53 | const bool has_cast = matchConcatCastReshape(node); 54 | 55 | Node *concat = nullptr; 56 | if (has_cast) { 57 | concat = PrevNode(node, 1, 0); 58 | } else { 59 | concat = PrevNode(node, 1); 60 | } 61 | 62 | std::vector shapes; 63 | for (const auto *v : concat->inputs()) { 64 | const Tensor *tensor = FetchConstantTensor(v); 65 | 66 | if (tensor == nullptr) { 67 | // If the value of v is unknown, and v has only one element, we can 68 | // represent it with -1 in Reshape op. 69 | // TODO: 70 | // support the case that v is a shape and only one of the dims is 71 | // unknown. Example: Concat [?, 2, 3] and [4] into [-1, 2, 3, 4] 72 | if (v->sizes().size() != 1 || !v->sizes()[0].is_int || 73 | v->sizes()[0].dim != 1) { 74 | return false; 75 | } 76 | shapes.push_back(-1); 77 | continue; 78 | } 79 | // only support INT64 when the pattern is concat->reshape 80 | if (!has_cast && 81 | tensor->elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { 82 | return false; 83 | } 84 | #define DO_CASE(pb_type, cpp_type) \ 85 | case ONNX_NAMESPACE::TensorProto_DataType_##pb_type: { \ 86 | const auto data = ParseTensorData(tensor); \ 87 | std::transform(data.cbegin(), data.cend(), std::back_inserter(shapes), \ 88 | [](const auto &v) { return static_cast(v); }); \ 89 | break; \ 90 | } 91 | /// support cast 92 | switch (tensor->elem_type()) { 93 | DO_CASE(FLOAT, float) 94 | DO_CASE(INT32, int32_t) 95 | DO_CASE(INT64, int64_t) 96 | DO_CASE(DOUBLE, double) 97 | DO_CASE(UINT8, uint8_t) 98 | DO_CASE(INT8, int8_t) 99 | DO_CASE(UINT16, uint16_t) 100 | DO_CASE(INT16, int16_t) 101 | DO_CASE(UINT32, uint32_t) 102 | DO_CASE(UINT64, uint64_t) 103 | default: 104 | return false; 105 | } 106 | } 107 | int unknown_dim_count = 0; 108 | for (auto dim : shapes) { 109 | unknown_dim_count += int64_t(dim == -1); 110 | } 111 | if (unknown_dim_count > 1) { 112 | return false; 113 | } 114 | 115 | Tensor t; 116 | t.sizes().push_back(shapes.size()); 117 | t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 118 | t.int64s().swap(shapes); 119 | Value *value = graph.addInitializerAndCreateValue(t); 120 | 121 | node->replaceInput(1, value); 122 | destroy_current = NodeDestroyType::DestroyZero; 123 | return true; 124 | } 125 | }; 126 | 127 | } // namespace optimization 128 | } // namespace ONNX_NAMESPACE 129 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_concats.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct FuseConsecutiveConcats final : public PredicateBasedPass { 16 | explicit FuseConsecutiveConcats() 17 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Partial, 18 | PassOptimizationType::Compute) {} 19 | 20 | std::string getPassName() const override { 21 | return "fuse_consecutive_concats"; 22 | } 23 | 24 | void insertInput(Node* node, size_t i, Value* value) { 25 | const auto input_size = node->inputs().size(); 26 | if (i == input_size) { 27 | node->addInput(value); 28 | } else { 29 | for (size_t j = input_size - 1; j >= i; j--) { 30 | Value* cur_input = node->input(j); 31 | if (j == input_size - 1) { 32 | node->addInput(cur_input); 33 | } else { 34 | node->replaceInput(j + 1, cur_input); 35 | } 36 | } 37 | node->replaceInput(i, value); 38 | } 39 | } 40 | 41 | bool patternMatchPredicate(Node* node) override { 42 | // we don't check if our concat node has inputs which are also concat nodes 43 | // because this requires a for loop through the inputs. If it turns out 44 | // there is then we still have to do a for loop in the runTransform portion 45 | // of the code. In order not to waste a loop we don't check the real pattern 46 | // match condition. 47 | return node->kind() == kConcat && node->hasAttribute(kaxis); 48 | } 49 | bool runTransform(Node* concat_node, Graph&, 50 | NodeDestroyType& destroy_current) override { 51 | destroy_current = NodeDestroyType::DestroyZero; 52 | bool transform_ran = false; 53 | for (size_t i = 0; i < concat_node->inputs().size(); i++) { 54 | Value* cur_input_value = concat_node->inputs()[i]; 55 | Node* cur_input_node = cur_input_value->node(); 56 | if (cur_input_node->kind() == kConcat && 57 | cur_input_value->uses().size() == 1 && 58 | cur_input_node->hasAttribute(kaxis) && 59 | cur_input_node->i(kaxis) == concat_node->i(kaxis)) { 60 | transform_ran = true; 61 | // Inserts n inputs of cur_input_node at index i+1~i+1+(n-1), 62 | // and remove cur_input_node at index i. 63 | // As a result, cur_input_node is replaced by its inputs inplace, 64 | // instead of always appending its inputs at the end. 65 | for (size_t j = 0; j < cur_input_node->inputs().size(); j++) { 66 | Value* value = cur_input_node->input(j); 67 | insertInput(concat_node, i + 1 + j, value); 68 | } 69 | concat_node->removeInput(i); 70 | cur_input_node->destroy(); 71 | } 72 | } 73 | return transform_ran; 74 | } 75 | }; 76 | 77 | } // namespace optimization 78 | } // namespace ONNX_NAMESPACE 79 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_log_softmax.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "onnxoptimizer/passes/pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct FuseConsecutiveLogSoftmax final : public PredicateBasedPass { 17 | explicit FuseConsecutiveLogSoftmax() 18 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 19 | PassOptimizationType::Compute) {} 20 | 21 | std::string getPassName() const override { 22 | return "fuse_consecutive_log_softmax"; 23 | } 24 | 25 | bool patternMatchPredicate(Node* node) override { 26 | return CheckKind(node, kLog, 0, kSoftmax) && 27 | node->input()->uses().size() == 1; 28 | } 29 | bool runTransform(Node* log_node, Graph& graph, 30 | NodeDestroyType& destroy_current) override { 31 | Value* log_node_output = log_node->output(); 32 | Node* softmax_node = PrevNode(log_node, 0); 33 | Node* log_softmax_node = graph.create(kLogSoftmax, 1); 34 | 35 | // log_softmax_node construction 36 | log_softmax_node->i_(kaxis, softmax_node->i(kaxis)); 37 | log_softmax_node->addInput(softmax_node->input()); 38 | log_softmax_node->insertBefore(softmax_node); 39 | log_softmax_node->output()->setSizes(log_node_output->sizes()); 40 | log_softmax_node->output()->setElemType(log_node_output->elemType()); 41 | 42 | const bool replacing_success = 43 | tryReplacingAllUsesWith(log_node, log_softmax_node); 44 | if (!replacing_success) { 45 | return false; 46 | } 47 | log_node->removeAllInputs(); 48 | destroy_current = NodeDestroyType::DestroyOne; 49 | return true; 50 | } 51 | }; 52 | 53 | } // namespace optimization 54 | } // namespace ONNX_NAMESPACE 55 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_reduce_unsqueeze.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnx/defs/tensor_proto_util.h" 11 | #include "onnxoptimizer/pass.h" 12 | #include "onnxoptimizer/passes/logging.h" 13 | #include "onnxoptimizer/passes/pass_util.h" 14 | 15 | namespace ONNX_NAMESPACE { 16 | namespace optimization { 17 | 18 | const std::unordered_set reduction_operators{ 19 | kReduceL1, kReduceL2, kReduceLogSum, kReduceLogSumExp, kReduceMax, 20 | kReduceMean, kReduceMin, kReduceProd, kReduceSum, kReduceSumSquare}; 21 | 22 | struct FuseConsecutiveReduceUnsqueeze final : public PredicateBasedPass { 23 | explicit FuseConsecutiveReduceUnsqueeze() 24 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 25 | PassOptimizationType::Compute) {} 26 | 27 | std::string getPassName() const override { 28 | return "fuse_consecutive_reduce_unsqueeze"; 29 | } 30 | 31 | bool patternMatchPredicate(Node *node) override { 32 | // check that the current node is of type Unsqueeze and has defined axes 33 | bool cur_node_check = node->kind() == kUnsqueeze; 34 | if (cur_node_check) { 35 | Node *prev_node = node->inputs()[0]->node(); 36 | // check that the previous node a reduction operator and has defined 37 | // axes/keepdims 38 | bool reduction_node_check = reduction_operators.find(prev_node->kind()) != 39 | reduction_operators.end() && 40 | prev_node->hasAttribute(kkeepdims); 41 | if (reduction_node_check) { 42 | // insure that keepdims is set to false currently 43 | return prev_node->i(kkeepdims) == 0; 44 | } 45 | } 46 | return false; 47 | } 48 | bool runTransform(Node *node, Graph &graph, 49 | NodeDestroyType &destroy_current) override { 50 | Node *reduction_op = node->inputs()[0]->node(); 51 | // This pass will modify reduction_op, so it must have only one user. 52 | if (reduction_op->output()->uses().size() != 1) { 53 | return false; 54 | } 55 | std::vector axes; 56 | std::vector prev_axes; 57 | if (!GetValueFromAttrOrInput(node, kaxes, 1, axes) || 58 | !GetValueFromAttrOrInput(reduction_op, kaxes, 1, prev_axes) || 59 | axes != prev_axes) { 60 | return false; 61 | } 62 | 63 | const bool replacing_success = 64 | tryReplacingAllUsesWith(node->output(), node->inputs()[0]); 65 | if (replacing_success) { 66 | // set keepdims flag to be true 67 | reduction_op->i_(kkeepdims, 1); 68 | // remove unnecessary unsqueeze 69 | reduction_op->output()->setSizes(node->output()->sizes()); 70 | reduction_op->output()->setElemType(node->output()->elemType()); 71 | destroy_current = NodeDestroyType::DestroyOne; 72 | return true; 73 | } else { 74 | return false; 75 | } 76 | } 77 | }; 78 | 79 | } // namespace optimization 80 | } // namespace ONNX_NAMESPACE 81 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_slices.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | /// this pass can simplify focus network 9 | 10 | #pragma once 11 | 12 | #include "onnxoptimizer/pass.h" 13 | #include "onnxoptimizer/passes/pass_util.h" 14 | 15 | namespace ONNX_NAMESPACE { 16 | namespace optimization { 17 | 18 | struct FuseConsecutiveSlices final : public PredicateBasedPass { 19 | explicit FuseConsecutiveSlices() 20 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 21 | PassOptimizationType::Memory) {} 22 | 23 | std::string getPassName() const override { 24 | return "fuse_consecutive_slices"; 25 | } 26 | 27 | bool patternMatchPredicate(Node *node) override { 28 | std::vector slice1_axes, slice2_axes; 29 | if (CheckKind(node, kSlice, 0, kSlice) && node->inputs().size() == 5 && 30 | GetInputsOfPreNode(node, 0).size() == 5 && 31 | GetValueFromInput(node, 3, slice1_axes) && 32 | GetValueFromInput(PrevNode(node, 0), 3, slice2_axes)) { 33 | if (!node->input(0)->has_sizes()) { 34 | return false; 35 | } 36 | for (auto& axis : slice1_axes) { 37 | axis = AddYIfNegative(axis, node->inputs()[0]->sizes().size()); 38 | } 39 | for (auto& axis : slice2_axes) { 40 | axis = AddYIfNegative(axis, node->inputs()[0]->sizes().size()); 41 | } 42 | bool has_intersection = HasIntersection(slice1_axes, slice2_axes); 43 | return !has_intersection; 44 | } 45 | return false; 46 | } 47 | bool runTransform(Node *n, Graph &graph, 48 | NodeDestroyType &destroy_current) override { 49 | /* 50 | X 51 | | 52 | Slice2 53 | | 54 | Slice1 55 | | 56 | Y 57 | */ 58 | Node *slice1 = n; 59 | Node *slice2 = PrevNode(n, 0); 60 | 61 | std::vector new_nodes; 62 | for (int i = 1; i < 5; ++i) { 63 | Node *node = graph.create(kConcat, 1); 64 | node->addInput(slice2->input(i)); 65 | node->addInput(slice1->input(i)); 66 | node->i_(kaxis, 0); 67 | new_nodes.push_back(node); 68 | } 69 | 70 | Node *new_slice = graph.create(kSlice, 1); 71 | new_slice->insertBefore(slice1); 72 | new_slice->addInput(slice2->input(0)); 73 | for (auto *node : new_nodes) { 74 | new_slice->addInput(node->output()); 75 | node->insertBefore(new_slice); 76 | } 77 | 78 | const bool replacing_success = 79 | tryReplacingAllUsesWith(slice1->output(), new_slice->output()); 80 | if (!replacing_success) { 81 | return false; 82 | } 83 | destroy_current = NodeDestroyType::DestroyOne; 84 | return true; 85 | } 86 | }; 87 | 88 | } // namespace optimization 89 | } // namespace ONNX_NAMESPACE 90 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_squeezes.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // X is a tensor with shape=[1, 1, 2, 3, 1, 5, 1] 12 | // Y = Squeeze(X, axes=[1, 4]) -> shape=[1, 2, 3, 5, 1] 13 | // Z = Squeeze(Y, axes=[0, 4]) -> shape=[2, 3, 5] 14 | // After: 15 | // Z = Squeeze(X, axes=[0, 1, 4, 6]) 16 | #include "onnxoptimizer/pass.h" 17 | #include "onnxoptimizer/passes/logging.h" 18 | #include "onnxoptimizer/passes/pass_util.h" 19 | 20 | namespace ONNX_NAMESPACE { 21 | namespace optimization { 22 | 23 | struct FuseConsecutiveSqueezes final : public PredicateBasedPass { 24 | explicit FuseConsecutiveSqueezes() 25 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 26 | PassOptimizationType::Compute) {} 27 | 28 | std::string getPassName() const override { 29 | return "fuse_consecutive_squeezes"; 30 | } 31 | static bool IsAxesAnAttr(const Graph &graph) { 32 | const int opset_version = getOpsetVersion(graph); 33 | return opset_version <= 12 && opset_version != 0; 34 | } 35 | 36 | // modify the vector `composed_axes` such that squeeze by it is equivalent 37 | // to squeeze by `axes_1` and then by `axes_2` 38 | static bool compose_squeezes(const Node *input_n, const Node *n, 39 | const Graph &graph, 40 | std::vector &composed_axes) { 41 | std::vector axes_1; 42 | std::vector axes_2; 43 | if (!GetValueFromAttrOrInput(input_n, kaxes, 1, axes_1) || 44 | !GetValueFromAttrOrInput(n, kaxes, 1, axes_2)) { 45 | return false; 46 | } 47 | 48 | std::vector &ret = composed_axes; 49 | ret.clear(); 50 | ret.reserve(axes_1.size() + axes_2.size()); 51 | std::vector sorted_axes_1(axes_1.begin(), axes_1.end()); 52 | std::sort(sorted_axes_1.begin(), sorted_axes_1.end()); 53 | std::copy(sorted_axes_1.begin(), sorted_axes_1.end(), 54 | std::back_inserter(ret)); 55 | 56 | for (int64_t i : axes_2) { 57 | for (auto iter = sorted_axes_1.begin(); iter != sorted_axes_1.end(); 58 | ++iter) { 59 | // if current axis 1 - prev_num is bigger than axis 2 60 | // put axis 2 + prev_num as new axis 61 | int64_t prev_num = std::distance(sorted_axes_1.begin(), iter); 62 | if (*iter - prev_num > i) { 63 | ret.push_back(i + prev_num); 64 | break; 65 | } 66 | // if no current axis 1 - prev_num is bigger than axis 2 67 | // put axis 2 + prev_num + 1 as new axis 68 | if (std::next(iter) == sorted_axes_1.end()) { 69 | ret.push_back(i + prev_num + 1); 70 | } 71 | } 72 | } 73 | std::sort(ret.begin(), ret.end()); 74 | return true; 75 | } 76 | 77 | bool patternMatchPredicate(Node *node) override { 78 | return node->kind() == kSqueeze && 79 | node->inputs()[0]->node()->kind() == kSqueeze; 80 | } 81 | bool runTransform(Node *n, Graph &graph, 82 | NodeDestroyType &destroy_current) override { 83 | auto orig_input = n->inputs()[0]; 84 | std::vector rs; 85 | bool success = compose_squeezes(orig_input->node(), n, graph, rs); 86 | if (!success) { 87 | return false; 88 | } 89 | n->replaceInput(0, orig_input->node()->inputs()[0]); 90 | if (orig_input->uses().size() == 0) { 91 | orig_input->node()->destroy(); 92 | } 93 | if (IsAxesAnAttr(graph)) { 94 | n->is_(kaxes, std::move(rs)); 95 | } else { 96 | Tensor t; 97 | t.sizes().push_back(rs.size()); 98 | t.int64s() = rs; 99 | t.elem_type() = TensorProto_DataType_INT64; 100 | auto axes_v = n->inputs()[1]; 101 | Value *tv = graph.addInitializerAndCreateValue(t); 102 | n->replaceInput(1, tv); 103 | if (axes_v->uses().size() == 0) { 104 | if (axes_v->node()->kind() == kConstant) { 105 | axes_v->node()->destroy(); 106 | } else { 107 | graph.eraseInitializerAndInput(axes_v); 108 | } 109 | } 110 | } 111 | destroy_current = NodeDestroyType::DestroyZero; 112 | return true; 113 | } 114 | }; 115 | 116 | } // namespace optimization 117 | } // namespace ONNX_NAMESPACE 118 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_transposes.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct FuseConsecutiveTransposes final : public PredicateBasedPass { 16 | explicit FuseConsecutiveTransposes() 17 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 18 | PassOptimizationType::Compute) {} 19 | 20 | std::string getPassName() const override { 21 | return "fuse_consecutive_transposes"; 22 | } 23 | 24 | // returns a vector `ret` such that transposing by `ret` is equivalent 25 | // to transposing by `t1` and then by `t2` 26 | std::vector compose_transposes(const std::vector& t1, 27 | const std::vector& t2) { 28 | ONNX_ASSERT(t1.size() == t2.size()); 29 | std::vector ret; 30 | ret.reserve(t1.size()); 31 | for (size_t i = 0; i < t1.size(); i++) { 32 | ONNX_ASSERT(t2[i] < static_cast(t1.size())); 33 | ONNX_ASSERT(t1[static_cast(t2[i])] < 34 | static_cast(t1.size())); 35 | ret.push_back(t1[static_cast(t2[i])]); 36 | } 37 | return ret; 38 | } 39 | 40 | bool patternMatchPredicate(Node* node) override { 41 | return node->kind() == kTranspose && 42 | node->input()->node()->kind() == kTranspose; 43 | } 44 | 45 | bool runTransform(Node* n, Graph&, 46 | NodeDestroyType& destroy_current) override { 47 | auto origInput = n->input(); 48 | if (!n->hasAttribute(kperm) && !origInput->node()->hasAttribute(kperm)) { 49 | // One special case (two consecutive transposes with no perm, 50 | // since we do not have the shape information here, we have 51 | // to eliminate two transpose together. 52 | if (n->output()->has_sizes()) { 53 | origInput->node()->input()->setSizes(n->output()->sizes()); 54 | } 55 | const bool replacing_success = 56 | tryReplacingAllUsesWith(n, origInput->node()->input()->node()); 57 | if (!replacing_success) { 58 | return false; 59 | } 60 | // only destroy the first transpose and DCE will take care of the another 61 | destroy_current = NodeDestroyType::DestroyOne; 62 | return true; 63 | } 64 | if (!n->hasAttribute(kperm) || !origInput->node()->hasAttribute(kperm)) { 65 | destroy_current = NodeDestroyType::DestroyZero; 66 | return false; 67 | } 68 | n->is_(kperm, 69 | compose_transposes(origInput->node()->is(kperm), n->is(kperm))); 70 | n->replaceInput(0, origInput->node()->input()); 71 | if (origInput->uses().size() == 0) { 72 | origInput->node()->destroy(); 73 | } 74 | destroy_current = NodeDestroyType::DestroyZero; 75 | return false; 76 | } 77 | }; 78 | 79 | } // namespace optimization 80 | } // namespace ONNX_NAMESPACE 81 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_consecutive_unsqueezes.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | #include "onnx/defs/tensor_util.h" 13 | #include "onnxoptimizer/pass.h" 14 | #include "onnxoptimizer/passes/pass_util.h" 15 | 16 | namespace ONNX_NAMESPACE { 17 | namespace optimization { 18 | 19 | struct FuseConsecutiveUnsqueezes final : public PredicateBasedPass { 20 | explicit FuseConsecutiveUnsqueezes() 21 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 22 | PassOptimizationType::Compute) {} 23 | std::string getPassName() const override { 24 | return "fuse_consecutive_unsqueezes"; 25 | } 26 | 27 | bool patternMatchPredicate(Node* node) override { 28 | return CheckKind(node, kUnsqueeze, 0, kUnsqueeze) && 29 | GetInputsOfPreNode(node, 0)[0]->has_sizes(); 30 | } 31 | 32 | bool runTransform(Node* n, Graph& graph, 33 | NodeDestroyType& destroy_current) override { 34 | destroy_current = NodeDestroyType::DestroyZero; 35 | Node* prev = PrevNode(n, 0); 36 | bool axes_is_attr = n->hasAttribute(kaxes); 37 | std::vector axes_of_prev, axes; 38 | if (!GetValueFromAttrOrInput(n, kaxes, 1, axes) || 39 | !GetValueFromAttrOrInput(prev, kaxes, 1, axes_of_prev)) { 40 | return false; 41 | } 42 | const auto dims = prev->input(0)->sizes(); 43 | for (auto& axis : axes_of_prev) { 44 | axis = AddYIfNegative( 45 | axis, static_cast(dims.size() + axes_of_prev.size())); 46 | } 47 | VLOG(1) << "axes of prev node: " << axes_of_prev; 48 | for (auto& axis : axes) { 49 | axis = AddYIfNegative( 50 | axis, static_cast(dims.size() + axes_of_prev.size() + 51 | axes.size())); 52 | } 53 | VLOG(1) << "axes : " << axes; 54 | std::sort(axes_of_prev.begin(), axes_of_prev.end()); 55 | std::sort(axes.begin(), axes.end()); 56 | 57 | std::vector fused_axes; 58 | for (auto& n : axes_of_prev) { 59 | for (const auto& m : axes) { 60 | if (m <= n) { 61 | n++; 62 | } 63 | } 64 | } 65 | fused_axes = axes_of_prev; 66 | std::transform(axes.cbegin(), axes.cend(), std::back_inserter(fused_axes), 67 | [](const auto& d) { return d; }); 68 | std::sort(fused_axes.begin(), fused_axes.end()); 69 | VLOG(1) << "fused axes: " << fused_axes; 70 | n->replaceInput(0, prev->input(0)); 71 | 72 | if (axes_is_attr) { 73 | n->is_(kaxes, std::move(fused_axes)); 74 | } else { 75 | Tensor axes_t; 76 | axes_t.sizes().push_back(fused_axes.size()); 77 | axes_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 78 | axes_t.int64s().swap(fused_axes); 79 | n->replaceInput(1, graph.addInitializerAndCreateValue(axes_t)); 80 | } 81 | return true; 82 | } 83 | }; 84 | 85 | } // namespace optimization 86 | } // namespace ONNX_NAMESPACE -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // Z = MatMul(X, Y) 12 | // A = Z + Bias 13 | // After: 14 | // A = Gemm(X, Y, Bias) 15 | // 16 | // the pass can handle the case when: 17 | // case 1: Bias is 1D tensor and Bias.dim[0] == Z.dim[1] 18 | // case 2: Bias is 2D tensor and Bias.dim[0] == Z.dim[0] or 1 19 | // and Bias.dim[1] = Z.dim[1] 20 | 21 | #include 22 | 23 | #include "onnx/common/assertions.h" 24 | #include "onnxoptimizer/pass.h" 25 | #include "onnxoptimizer/passes/pass_util.h" 26 | 27 | namespace ONNX_NAMESPACE { 28 | namespace optimization { 29 | 30 | struct FuseMatMulAddBiasIntoGemm final : public PredicateBasedPass { 31 | explicit FuseMatMulAddBiasIntoGemm() 32 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 33 | PassOptimizationType::Compute) {} 34 | std::string getPassName() const override { 35 | return "fuse_matmul_add_bias_into_gemm"; 36 | } 37 | bool patternMatchPredicate(Node* node) override { 38 | return CheckKind(node, kAdd, 0, kMatMul); 39 | } 40 | bool runTransform(Node* n, Graph& graph, 41 | NodeDestroyType& destroy_current) override { 42 | // due to current broadcasting's constraint, MatMul has to be the first 43 | // operand 44 | destroy_current = NodeDestroyType::DestroyZero; 45 | auto orig_matmul = n->inputs()[0]; 46 | auto orig_bias = n->inputs()[1]; 47 | 48 | // check if MatMul is only used by Add 49 | if (orig_matmul->uses().size() > 1) { 50 | return false; 51 | } 52 | auto x_shape = orig_matmul->node()->inputs()[0]->sizes(); 53 | auto y_shape = orig_matmul->node()->inputs()[1]->sizes(); 54 | int64_t z_N = -1; 55 | int64_t z_M = -1; 56 | // try to get feature N from x_shape 57 | if (static_cast(x_shape.size()) == 2 && x_shape[0].is_int) { 58 | z_N = x_shape[0].dim; 59 | } else { 60 | return false; 61 | } 62 | // try to get feature M from y_shape 63 | if (static_cast(y_shape.size()) == 2 && y_shape[1].is_int) { 64 | z_M = y_shape[1].dim; 65 | } else { 66 | return false; 67 | } 68 | // check if bias_shape is compatible 69 | auto bias_shape = orig_bias->sizes(); 70 | auto bias_dim = static_cast(bias_shape.size()); 71 | int64_t bias_N = -1; 72 | int64_t bias_M = -1; 73 | if (bias_dim == 1 && bias_shape[0].is_int) { 74 | bias_N = 1; 75 | bias_M = bias_shape[0].dim; 76 | } else if (bias_dim == 2 && bias_shape[0].is_int && bias_shape[1].is_int) { 77 | bias_N = bias_shape[0].dim; 78 | bias_M = bias_shape[1].dim; 79 | } else { 80 | return false; 81 | } 82 | if ((bias_N != z_N && bias_N != 1) || bias_M != z_M) { 83 | return false; 84 | } 85 | // proceed to fuse MatMul and Add into Gemm 86 | Node* gemm = 87 | graph.create(kGemm, orig_matmul->node()->inputs(), n->outputs().size()); 88 | gemm->addInput(n->inputs()[1]); 89 | for (int i = 0; i < static_cast(gemm->outputs().size()); ++i) { 90 | gemm->outputs()[i]->copyMetadata(n->outputs()[i]); 91 | } 92 | gemm->f_(kalpha, 1.0); 93 | gemm->f_(kbeta, 1.0); 94 | gemm->i_(ktransA, 0); 95 | gemm->i_(ktransB, 0); 96 | gemm->insertBefore(orig_matmul->node()); 97 | const bool replacing_success = tryReplacingAllUsesWith(n, gemm); 98 | if (!replacing_success) { 99 | return false; 100 | } 101 | // only destroy MatMul here and DCE will take care of the Add 102 | destroy_current = NodeDestroyType::DestroyOne; 103 | return true; 104 | } 105 | }; 106 | 107 | } // namespace optimization 108 | } // namespace ONNX_NAMESPACE 109 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_pad_into_conv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // P = Pad(X) - opset 10 and below (or) Pad(X, Pads, [Constant_value]) - opset 12 | // 11 and above Z = Conv(P, Y) 13 | // After: 14 | // Z = Conv(X, Y) with "pads" attribute set 15 | // 16 | // the pass handles the case when Pad is zero-padding the input 17 | // (i.e. mode=constant and Constant_value=0) 18 | 19 | #include 20 | 21 | #include "onnx/defs/tensor_util.h" 22 | #include "onnxoptimizer/pass.h" 23 | #include "onnxoptimizer/passes/pass_util.h" 24 | 25 | namespace ONNX_NAMESPACE { 26 | namespace optimization { 27 | 28 | struct FusePadIntoConv final : public PredicateBasedPass { 29 | explicit FusePadIntoConv() 30 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 31 | PassOptimizationType::Compute) {} 32 | std::string getPassName() const override { 33 | return "fuse_pad_into_conv"; 34 | } 35 | bool patternMatchPredicate(Node* node) override { 36 | return CheckKind(node, kConv, 0, kPad); 37 | } 38 | bool runTransform(Node* n, Graph& graph, 39 | NodeDestroyType& destroy_current) override { 40 | destroy_current = NodeDestroyType::DestroyZero; 41 | 42 | // check if Pad is only used by Conv 43 | if (n->inputs()[0]->uses().size() > 1) { 44 | return false; 45 | } 46 | 47 | Node* conv = n; 48 | Node* pad = PrevNode(n, 0); 49 | 50 | // Process 'pads' data 51 | std::vector pads; 52 | if (!GetValueFromAttrOrInput(pad, kpads, 1, pads)) { 53 | return false; 54 | } 55 | 56 | // Process 'mode' 57 | std::string default_pad_mode{"constant"}; 58 | 59 | // cannot fuse if the pad mode is not "Constant" 60 | if (GetValueFromAttrWithDefault(pad, kmode, default_pad_mode) != 61 | default_pad_mode) { 62 | return false; 63 | } 64 | 65 | // Process 'Constant_value' 66 | { 67 | union ConstantValueType { 68 | int32_t i32; 69 | int64_t i64; 70 | float f32; 71 | double f64; 72 | uint8_t ui8; 73 | int8_t i8; 74 | uint16_t ui16; 75 | int16_t i16; 76 | } cv; 77 | 78 | #define Define_GetConstantValueFromInput(token) \ 79 | (GetValueFromInput(pad, 2, cv.token) && cv.token == decltype(cv.token)(0)) 80 | 81 | do { 82 | if (GetValueFromAttr(pad, kvalue, cv.f64) && cv.f64 == double(0)) { 83 | break; 84 | } 85 | if (pad->inputs().size() >= 3) { 86 | if (pad->input(2)->uniqueName().empty()) { 87 | break; 88 | } 89 | if (Define_GetConstantValueFromInput(i32) || 90 | Define_GetConstantValueFromInput(i64) || 91 | Define_GetConstantValueFromInput(f32) || 92 | Define_GetConstantValueFromInput(f64) || 93 | Define_GetConstantValueFromInput(ui8) || 94 | Define_GetConstantValueFromInput(i8) || 95 | Define_GetConstantValueFromInput(ui16) || 96 | Define_GetConstantValueFromInput(i16)) { 97 | break; 98 | } 99 | return false; 100 | } 101 | } while (0); 102 | 103 | #undef Define_GetConstantValueFromInput 104 | } 105 | 106 | // check if some values in 'pads' prevents us from fusing it into 'Conv' 107 | // node 108 | int pads_size = static_cast(pads.size()); 109 | 110 | // check if padding is applied only on feature dims 111 | if (pads[0] != 0 || pads[1] != 0 || pads[pads_size / 2] != 0 || 112 | pads[pads_size / 2 + 1] != 0) { 113 | return false; 114 | } 115 | 116 | // check if padding is only positive 117 | if (std::any_of(pads.begin(), pads.end(), 118 | [](int64_t local_value) { return local_value < 0; })) { 119 | return false; 120 | } 121 | 122 | int conv_pads_size = pads_size - 4; 123 | std::vector conv_pads(conv_pads_size, 0); 124 | // Fuse into existing padding, if available 125 | if (conv->hasAttribute(kpads)) { 126 | conv_pads = conv->is(kpads); 127 | } 128 | 129 | for (int i = 2, j = 0; i < pads_size / 2; ++i, ++j) { 130 | conv_pads[j] += pads[i]; 131 | conv_pads[conv_pads_size / 2 + j] += pads[pads_size / 2 + i]; 132 | } 133 | 134 | conv->is_(kpads, std::move(conv_pads)); 135 | conv->replaceInput(0, pad->inputs()[0]); 136 | pad->destroy(); 137 | 138 | return true; 139 | } 140 | }; 141 | 142 | } // namespace optimization 143 | } // namespace ONNX_NAMESPACE 144 | 145 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_pad_into_pool.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // P = Pad(X) - opset 10 and below (or) Pad(X, Pads, [constant_value]) - opset 12 | // 11 and above Z = pool(P, Y) 13 | // After: 14 | // Z = pool(X, Y) with "pads" attribute set 15 | // 16 | // the pass handles the case when Pad is zero-padding the input 17 | // (i.e. mode=constant and constant_value=0) 18 | 19 | #include 20 | 21 | #include "onnx/defs/tensor_util.h" 22 | #include "onnxoptimizer/pass.h" 23 | #include "onnxoptimizer/passes/pass_util.h" 24 | 25 | namespace ONNX_NAMESPACE { 26 | namespace optimization { 27 | 28 | struct FusePadIntoPool final : public PredicateBasedPass { 29 | explicit FusePadIntoPool() 30 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 31 | PassOptimizationType::Compute) {} 32 | std::string getPassName() const override { 33 | return "fuse_pad_into_pool"; 34 | } 35 | 36 | bool patternMatchPredicate(Node* node) override { 37 | return CheckKind(node, "AveragePool", 0, kPad) || 38 | CheckKind(node, "MaxPool", 0, kPad); 39 | } 40 | 41 | bool runTransform(Node* n, Graph& graph, 42 | NodeDestroyType& destroy_current) override { 43 | destroy_current = NodeDestroyType::DestroyZero; 44 | 45 | // check if Pad is only used by pool 46 | if (n->inputs()[0]->uses().size() > 1) { 47 | return false; 48 | } 49 | 50 | Node* pool = n; 51 | Node* pad = n->inputs()[0]->node(); 52 | 53 | // Process 'pads' data 54 | std::vector pads; 55 | if (!GetValueFromAttrOrInput(pad, kpads, 1, pads)) { 56 | return false; 57 | } 58 | 59 | // Process 'mode' 60 | std::string default_pad_mode{"constant"}; 61 | std::string pad_mode = 62 | GetValueFromAttrWithDefault(pad, kmode, default_pad_mode); 63 | 64 | // cannot fuse if the pad mode is not "Constant" 65 | if (pad_mode != default_pad_mode) { 66 | return false; 67 | } 68 | 69 | // Process 'Constant_value' 70 | { 71 | union ConstantValueType { 72 | int32_t i32; 73 | int64_t i64; 74 | float f32; 75 | double f64; 76 | uint8_t ui8; 77 | int8_t i8; 78 | uint16_t ui16; 79 | int16_t i16; 80 | } cv; 81 | 82 | #define Define_GetConstantValueFromInput(token) \ 83 | (GetValueFromInput(pad, 2, cv.token) && cv.token == decltype(cv.token)(0)) 84 | 85 | do { 86 | if (GetValueFromAttr(pad, kvalue, cv.f64) && cv.f64 == double(0)) { 87 | break; 88 | } 89 | if (pad->inputs().size() >= 3) { 90 | if (pad->input(2)->uniqueName().empty()) { 91 | break; 92 | } 93 | if (Define_GetConstantValueFromInput(i32) || 94 | Define_GetConstantValueFromInput(i64) || 95 | Define_GetConstantValueFromInput(f32) || 96 | Define_GetConstantValueFromInput(f64) || 97 | Define_GetConstantValueFromInput(ui8) || 98 | Define_GetConstantValueFromInput(i8) || 99 | Define_GetConstantValueFromInput(ui16) || 100 | Define_GetConstantValueFromInput(i16)) { 101 | break; 102 | } 103 | return false; 104 | } 105 | } while (0); 106 | 107 | #undef Define_GetConstantValueFromInput 108 | } 109 | 110 | // check if some values in 'pads' prevents us from fusing it into 'Conv' 111 | // node 112 | int pads_size = static_cast(pads.size()); 113 | 114 | // check if padding is applied only on feature dims 115 | if (pads[0] != 0 || pads[1] != 0 || pads[pads_size / 2] != 0 || 116 | pads[pads_size / 2 + 1] != 0) { 117 | return false; 118 | } 119 | 120 | // check if padding is only positive 121 | if (std::any_of(pads.begin(), pads.end(), 122 | [](int64_t local_value) { return local_value < 0; })) { 123 | return false; 124 | } 125 | 126 | int pool_pads_size = pads_size - 4; 127 | std::vector pool_pads(pool_pads_size, 0); 128 | // Fuse into existing padding, if available 129 | if (pool->hasAttribute(kpads)) { 130 | pool_pads = pool->is(kpads); 131 | } 132 | 133 | for (int i = 2, j = 0; i < pads_size / 2; ++i, ++j) { 134 | pool_pads[j] += pads[i]; 135 | pool_pads[pool_pads_size / 2 + j] += pads[pads_size / 2 + i]; 136 | } 137 | 138 | if (pool->kind() == Symbol("AveragePool")) { 139 | int64_t count_include_pad = 1; 140 | pool->i_(kcount_include_pad, count_include_pad); 141 | } 142 | 143 | pool->is_(kpads, std::move(pool_pads)); 144 | pool->replaceInput(0, pad->inputs()[0]); 145 | pad->destroy(); 146 | 147 | return true; 148 | } 149 | }; 150 | 151 | } // namespace optimization 152 | } // namespace ONNX_NAMESPACE 153 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_qkv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | #include "onnx/defs/tensor_util.h" 13 | #include "onnxoptimizer/pass.h" 14 | #include "onnxoptimizer/passes/pass_util.h" 15 | 16 | namespace ONNX_NAMESPACE { 17 | namespace optimization { 18 | 19 | struct FuseQKV final : public PredicateBasedPass { 20 | explicit FuseQKV() 21 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 22 | PassOptimizationType::Compute) {} 23 | std::string getPassName() const override { 24 | return "fuse_qkv"; 25 | } 26 | 27 | bool patternMatchPredicate(Node* node) override { 28 | return CheckKind(node, kMatMul) && node->input(0)->uses().size() == 3; 29 | } 30 | 31 | bool runTransform(Node* n, Graph& graph, 32 | NodeDestroyType& destroy_current) override { 33 | destroy_current = NodeDestroyType::DestroyZero; 34 | const auto uses = n->input(0)->uses(); 35 | for (const auto& use : uses) { 36 | if (use.offset != 0 || !CheckKind(use.user, kMatMul) || 37 | use.user->output()->uses().size() != 1 || 38 | !IsConstantTensor(use.user, 1)) { 39 | return false; 40 | } 41 | } 42 | /// q k v not a one-to-one correspondence actually 43 | Node* q = uses[0].user; 44 | Node* k = uses[1].user; 45 | Node* v = uses[2].user; 46 | 47 | const Tensor* q_t = FetchConstantTensor(q->input(1)); 48 | const Tensor* k_t = FetchConstantTensor(k->input(1)); 49 | const Tensor* v_t = FetchConstantTensor(v->input(1)); 50 | if (q_t->sizes() != k_t->sizes() || q_t->sizes() != v_t->sizes()) { 51 | return false; 52 | } 53 | Node* prev = PrevNode(n, 0); 54 | Node* cat = graph.create(kConcat, 1); 55 | cat->insertAfter(prev); 56 | cat->addInput(q->input(1)); 57 | cat->addInput(k->input(1)); 58 | cat->addInput(v->input(1)); 59 | cat->i_(kaxis, q_t->sizes().size() - 1); 60 | Node* matmul = graph.create(kMatMul, 1); 61 | matmul->insertAfter(cat); 62 | matmul->addInput(q->input(0)); 63 | matmul->addInput(cat->output()); 64 | 65 | Node* split = graph.create("Split"_sym, 3); 66 | split->i_(kaxis, -1); 67 | split->insertAfter(matmul); 68 | split->addInput(matmul->output()); 69 | const auto opset_version = getOpsetVersion(graph); 70 | if (opset_version >= 13) { 71 | Tensor split_t; 72 | split_t.sizes().push_back(3); 73 | split_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; 74 | split_t.int64s().push_back(q_t->sizes().back()); 75 | split_t.int64s().push_back(k_t->sizes().back()); 76 | split_t.int64s().push_back(v_t->sizes().back()); 77 | split->addInput(graph.addInitializerAndCreateValue(split_t)); 78 | } else { 79 | split->is_(ksplit, 80 | std::vector{q_t->sizes().back(), k_t->sizes().back(), 81 | v_t->sizes().back()}); 82 | } 83 | if (!tryReplacingAllUsesWith(q->output(), split->outputs()[0])) { 84 | return false; 85 | } 86 | if (!tryReplacingAllUsesWith(k->output(), split->outputs()[1])) { 87 | return false; 88 | } 89 | if (!tryReplacingAllUsesWith(v->output(), split->outputs()[2])) { 90 | return false; 91 | } 92 | destroy_current = NodeDestroyType::DestroyOne; 93 | return true; 94 | } 95 | }; 96 | 97 | } // namespace optimization 98 | } // namespace ONNX_NAMESPACE 99 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/fuse_transpose_into_gemm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct FuseTransposeIntoGemm final : public PredicateBasedPass { 16 | explicit FuseTransposeIntoGemm() 17 | : PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete, 18 | PassOptimizationType::Compute) {} 19 | std::string getPassName() const override { 20 | return "fuse_transpose_into_gemm"; 21 | } 22 | bool patternMatchPredicate(Node* node) override { 23 | return node->kind() == kGemm; 24 | } 25 | bool runTransform(Node* n, Graph&, 26 | NodeDestroyType& destroy_current) override { 27 | const std::vector simple_trans_perm({1, 0}); 28 | destroy_current = NodeDestroyType::DestroyZero; 29 | bool ret_val = false; 30 | for (size_t i : {0, 1}) { 31 | auto inp = n->inputs()[i]; 32 | auto trans = i == 0 ? ktransA : ktransB; 33 | if (inp->node()->kind() == kTranspose && 34 | inp->node()->is(kperm) == simple_trans_perm) { 35 | n->replaceInput(i, inp->node()->input()); 36 | n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1); 37 | if (inp->uses().size() == 0) { 38 | inp->node()->destroy(); 39 | ret_val = true; 40 | } 41 | } 42 | } 43 | return ret_val; 44 | } 45 | }; 46 | 47 | } // namespace optimization 48 | } // namespace ONNX_NAMESPACE 49 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/lift_lexical_references.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include "onnxoptimizer/pass.h" 9 | 10 | namespace ONNX_NAMESPACE { 11 | namespace optimization { 12 | 13 | // Lift lexically-scoped references within control operators to be inputs of the 14 | // ops themselves. This transformation yields a graph that does not conform to 15 | // the ONNX spec. 16 | // 17 | // The purpose of this pass is to expose the data dependencies within control 18 | // blocks for frameworks that use those dependencies to schedule parallel 19 | // execution. e.g. caffe2 graph execution. 20 | // 21 | // Example: 22 | // ******************************** Before ************************************* 23 | // graph test (%X[FLOAT, 5]) { 24 | // %Y = Identity(%X) 25 | // %trip_count = Constant[value = ]() 26 | // %condition = Constant[value = ]() 27 | // %Y2, %Y3 = Loop[body = ](%trip_count, %condition, %) 28 | // return %Y, %Y2 29 | // } 30 | // 31 | // graph body_graph (%i[INT32, scalar], %cond[BOOL, scalar]) { 32 | // %_Y2 = Identity(%X) 33 | // %_Y3 = Identity(%Y) 34 | // return %cond, %_Y2, %_Y3 35 | // } 36 | // 37 | // ******************************** After ************************************** 38 | // graph test (%X[FLOAT, 5]) { 39 | // %Y = Identity(%X) 40 | // %trip_count = Constant[value = ]() 41 | // %condition = Constant[value = ]() 42 | // %Y2, %Y3 = Loop[__control_inputs = ['X', 'Y'], body = ](%trip_count, %condition, %) 44 | // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 45 | // return %Y, %Y2 46 | // } 47 | // 48 | // graph body_graph (%i[INT32, scalar], %cond[BOOL, scalar]) { 49 | // %_Y2 = Identity(%X) 50 | // %_Y3 = Identity(%Y) 51 | // return %cond, %_Y2, %_Y3 52 | // } 53 | // 54 | // ******************************** Continue Docs******************************* 55 | // 56 | // The algorithm is roughly: 57 | // symbol_table_stack = empty stack of symbol tables 58 | // 59 | // liftreferences(graph) 60 | // -> a set of unresolved reference strings: 61 | // unresolved_references = {} 62 | // 63 | // symbol_table_stack.push(new symbol table containing inputs for this 64 | // sub-graph) for each node in the graph: 65 | // for input in node.inputs: 66 | // if input is not in this frame: 67 | // unresolved_references.insert(input) 68 | // if node is a control flow operator: 69 | // for each sub-graph g: 70 | // for each output in g's body: 71 | // if output is defined in current scope: 72 | // control_inputs.insert(output) 73 | // refs = liftreferences(g) 74 | // for each ref in refs: 75 | // if ref is in this frame or any parent frame (control_inputs): 76 | // control_inputs.insert(ref) 77 | // else: 78 | // unresolved_references.insert(ref) 79 | // set the control inputs attribute to the node 80 | // for output in node.outputs: 81 | // symbol_table_stack.top()[output] = Value* 82 | // return unresolved_references 83 | struct LiftLexicalReferences : public FullGraphBasedPass { 84 | explicit LiftLexicalReferences() 85 | : FullGraphBasedPass(PassType::Separate, PassEfficiency::Complete, 86 | PassOptimizationType::Memory) {} 87 | 88 | std::string getPassName() const override { 89 | return "lift_lexical_references"; 90 | } 91 | PassAnalysisType getPassAnalysisType() const override { 92 | return PassAnalysisType::Empty; 93 | } 94 | 95 | using ValueTable = std::unordered_map; 96 | 97 | // Environment stack, please to store value table and 98 | // controlled inputs 99 | struct Environment { 100 | Environment(std::shared_ptr next = nullptr) : next(next) {} 101 | 102 | std::shared_ptr next; 103 | 104 | Value* findInThisFrame(const std::string& name) { 105 | auto it = value_table.find(name); 106 | if (it != value_table.end()) { 107 | return it->second; 108 | } 109 | return nullptr; 110 | } 111 | 112 | Value* findInParentFrame(const std::string& name) { 113 | return next ? next->findInAnyFrame(name) : nullptr; 114 | } 115 | 116 | Value* findInAnyFrame(const std::string& name) { 117 | for (auto runner = this; runner; runner = runner->next.get()) { 118 | if (auto r = runner->findInThisFrame(name)) { 119 | return r; 120 | } 121 | } 122 | return nullptr; 123 | } 124 | 125 | void setVar(const std::string& name, Value* value) { 126 | value_table[name] = value; 127 | } 128 | 129 | private: 130 | ValueTable value_table; 131 | }; 132 | 133 | std::shared_ptr environment_stack; 134 | 135 | // environment stack helper 136 | void pushFrame() { 137 | environment_stack = std::make_shared(environment_stack); 138 | } 139 | 140 | std::shared_ptr popFrame() { 141 | auto old_frame = environment_stack; 142 | environment_stack = environment_stack->next; 143 | return old_frame; 144 | } 145 | 146 | std::set liftReferences(Graph* g) { 147 | std::set unresolved_references; 148 | pushFrame(); 149 | for (auto& inp : g->inputs()) { 150 | environment_stack->setVar(inp->uniqueName(), inp); 151 | } 152 | 153 | for (auto* n : g->nodes()) { 154 | // Skip optional input/captured value node. 155 | if (n->kind() == ONNX_NAMESPACE::kUndefined || 156 | n->kind() == ONNX_NAMESPACE::kCaptured) { 157 | continue; 158 | } 159 | for (auto* inp : n->inputs()) { 160 | // Empty string is 0-input variadic argument. Skip that one. 161 | if (!inp->uniqueName().empty() && 162 | !environment_stack->findInThisFrame(inp->uniqueName())) { 163 | unresolved_references.insert(inp->uniqueName()); 164 | } 165 | } 166 | 167 | std::set local_unresolved; 168 | 169 | // if a graph body output has already already been emitted outside of the 170 | // subgraph scope, then it must be added as an input to the subgraph 171 | auto add_subgraph_outputs = [&](Graph* body_graph) { 172 | for (auto* out : body_graph->outputs()) { 173 | if (environment_stack->findInAnyFrame(out->uniqueName())) { 174 | local_unresolved.insert(out->uniqueName()); 175 | } 176 | } 177 | }; 178 | 179 | if (n->kind() == ONNX_NAMESPACE::kLoop) { 180 | auto* body_graph = n->g(ONNX_NAMESPACE::kbody).get(); 181 | local_unresolved = liftReferences(body_graph); 182 | add_subgraph_outputs(body_graph); 183 | } else if (n->kind() == ONNX_NAMESPACE::kIf) { 184 | auto* then_graph = n->g(ONNX_NAMESPACE::kthen_branch).get(); 185 | add_subgraph_outputs(then_graph); 186 | auto then_unresolved = liftReferences(then_graph); 187 | local_unresolved.insert(then_unresolved.begin(), then_unresolved.end()); 188 | auto* else_graph = n->g(ONNX_NAMESPACE::kelse_branch).get(); 189 | add_subgraph_outputs(else_graph); 190 | auto else_unresolved = liftReferences(else_graph); 191 | local_unresolved.insert(else_unresolved.begin(), else_unresolved.end()); 192 | } 193 | 194 | std::vector control_inputs; 195 | for (auto& unresolved : local_unresolved) { 196 | if (environment_stack->findInAnyFrame(unresolved)) { 197 | control_inputs.push_back(unresolved); 198 | } else { 199 | unresolved_references.insert(unresolved); 200 | } 201 | } 202 | 203 | // Create this attribute so the backend knows how many of these inputs 204 | // are simply there for control dependencies 205 | if (!control_inputs.empty()) { 206 | n->ss_(ONNX_NAMESPACE::k__control_inputs, std::move(control_inputs)); 207 | } 208 | 209 | for (auto* out : n->outputs()) { 210 | environment_stack->setVar(out->uniqueName(), out); 211 | } 212 | } 213 | 214 | popFrame(); 215 | return unresolved_references; 216 | } 217 | 218 | std::shared_ptr runPass(Graph& graph) override { 219 | auto unresolved = liftReferences(&graph); 220 | 221 | if (unresolved.size()) { 222 | std::string errmsg = "Unresolved value references: "; 223 | for (auto& ref : unresolved) { 224 | errmsg += ref + ","; 225 | } 226 | throw std::runtime_error(errmsg); 227 | } 228 | return std::shared_ptr(new PostPassAnalysis()); 229 | } 230 | }; 231 | 232 | } // namespace optimization 233 | } // namespace ONNX_NAMESPACE 234 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | #include // min 10 | #include // abort getenv 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "onnxoptimizer/passes/string_utils.h" 17 | 18 | namespace ONNX_NAMESPACE { 19 | namespace optimization { 20 | 21 | namespace details { 22 | 23 | static constexpr char logging_prefix[] = {'F', 'E', 'W', 'I', 'V'}; 24 | 25 | constexpr int LOG_INFO = 0; 26 | constexpr int LOG_WARNING = 1; 27 | constexpr int LOG_ERROR = 2; 28 | constexpr int LOG_FATAL = 3; 29 | 30 | static std::once_flag read_log_threshold_flag; 31 | static int log_threshold = LOG_INFO; 32 | 33 | static void ReadLogThresholdFromEnv() { 34 | char* threshold = std::getenv("LOG_THRESHOLD"); 35 | if (!threshold) { 36 | return; 37 | } 38 | std::stringstream ss; 39 | ss << threshold; 40 | ss >> log_threshold; 41 | } 42 | 43 | class MessageControl { 44 | public: 45 | MessageControl(const char* file, const char* function, int line, int severity) 46 | : severity_(severity) { 47 | std::call_once(read_log_threshold_flag, ReadLogThresholdFromEnv); 48 | stream_ << "[" << logging_prefix[std::min(4, LOG_FATAL - severity_)] 49 | << " " << StripFilename(file) << ":" << line << " " << function 50 | << "]: "; 51 | } 52 | 53 | ~MessageControl() { 54 | if (severity_ < log_threshold) { 55 | return; 56 | } 57 | std::cout << stream_.rdbuf() << std::endl; 58 | if (severity_ == LOG_FATAL) { 59 | std::abort(); 60 | } 61 | } 62 | 63 | std::ostream& Stream() { 64 | return stream_; 65 | } 66 | 67 | private: 68 | int severity_; 69 | std::stringstream stream_; 70 | }; 71 | 72 | } // namespace details 73 | 74 | #define LOG(n) \ 75 | details::MessageControl(__FILE__, __FUNCTION__, __LINE__, details::LOG_##n) \ 76 | .Stream() 77 | 78 | #define VLOG(n) \ 79 | details::MessageControl(__FILE__, __FUNCTION__, __LINE__, -n).Stream() 80 | 81 | #define LOG_IF(n, condition) \ 82 | if (condition) \ 83 | LOG(n) 84 | 85 | #define VLOG_IF(n, condition) \ 86 | if (condition) \ 87 | VLOG(n) 88 | 89 | } // namespace optimization 90 | } // namespace ONNX_NAMESPACE -------------------------------------------------------------------------------- /onnxoptimizer/passes/nop.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "onnxoptimizer/pass.h" 8 | 9 | namespace ONNX_NAMESPACE { 10 | namespace optimization { 11 | 12 | struct NopEmptyPass final : public FullGraphBasedPass { 13 | explicit NopEmptyPass() 14 | : FullGraphBasedPass(PassType::Nop, PassEfficiency::Complete, 15 | PassOptimizationType::None) {} 16 | 17 | std::string getPassName() const override { 18 | return "nop"; 19 | } 20 | PassAnalysisType getPassAnalysisType() const override { 21 | return PassAnalysisType::Empty; 22 | } 23 | std::shared_ptr runPass(Graph&) override { 24 | return std::make_shared(); 25 | } 26 | }; 27 | } // namespace optimization 28 | } // namespace ONNX_NAMESPACE 29 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/pass_util.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #include "onnx/defs/tensor_util.h" 9 | #include "pass_util.h" 10 | 11 | namespace ONNX_NAMESPACE { 12 | namespace optimization { 13 | 14 | bool FetchSoleIntValueOfTensor(const Value* t, int64_t& val) { 15 | int32_t i32_val; 16 | const bool r1 = FetchSoleValueOfTensor(t, val); 17 | const bool r2 = FetchSoleValueOfTensor(t, i32_val); 18 | if (r2) { 19 | val = i32_val; 20 | } 21 | return r1 || r2; 22 | } 23 | 24 | } // namespace optimization 25 | } // namespace ONNX_NAMESPACE 26 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/rename_input_output.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // fetch input/output pattern from environment variable 11 | // OPTIMIZER_RENAME_INPUT_PATTERN(default: input_%d) and 12 | // OPTIMIZER_RENAME_OUTPUT_PATTERN(default: output_%d). 13 | 14 | #include "onnxoptimizer/pass.h" 15 | 16 | namespace ONNX_NAMESPACE { 17 | namespace optimization { 18 | 19 | struct RenameInputOutput final : public FullGraphBasedPass { 20 | explicit RenameInputOutput() 21 | : FullGraphBasedPass(PassType::Other, PassEfficiency::Complete, 22 | PassOptimizationType::None) {} 23 | 24 | std::string getPassName() const override { 25 | return "rename_input_output"; 26 | } 27 | 28 | PassAnalysisType getPassAnalysisType() const override { 29 | return PassAnalysisType::Empty; 30 | } 31 | 32 | std::vector fetchPatternFromEnv() { 33 | auto fetch_env = [](const std::string& env) -> std::string { 34 | const auto* res = std::getenv(env.c_str()); 35 | return res ? std::string(res) : std::string{}; 36 | }; 37 | auto split_pattern = 38 | [](const std::string& s, const std::string& pattern, 39 | const std::string& default_s) -> std::vector { 40 | std::vector res(2); 41 | const std::string* tmp = &s; 42 | auto n = tmp->find(pattern); 43 | if (n == std::string::npos) { 44 | tmp = &default_s; 45 | n = tmp->find(pattern); 46 | } 47 | 48 | res[0] = tmp->substr(0, n); 49 | res[1] = tmp->substr(n + pattern.size()); 50 | return res; 51 | }; 52 | const std::string pattern{"%d"}; 53 | const std::vector envs{"OPTIMIZER_RENAME_INPUT_PATTERN", 54 | "OPTIMIZER_RENAME_OUTPUT_PATTERN"}; 55 | const std::vector default_str{"input_%d", "output_%d"}; 56 | std::vector result; 57 | for (int i = 0; i < 2; ++i) { 58 | auto env_str = fetch_env(envs[i]); 59 | auto split_str = split_pattern(env_str, pattern, default_str[i]); 60 | std::copy(split_str.begin(), split_str.end(), std::back_inserter(result)); 61 | } 62 | return result; 63 | } 64 | 65 | void rename_input_output(Graph& graph) { 66 | std::unordered_set initializer_names( 67 | graph.initializer_names().begin(), graph.initializer_names().end()); 68 | 69 | const auto rename_patterns = fetchPatternFromEnv(); 70 | 71 | for (int i = 0; i < graph.inputs().size(); ++i) { 72 | auto& value = graph.inputs()[i]; 73 | // ignore when input also in initializer 74 | if (initializer_names.count(value->uniqueName()) > 0) { 75 | continue; 76 | } 77 | const std::string current_name = 78 | rename_patterns[0] + std::to_string(i) + rename_patterns[1]; 79 | 80 | value->setUniqueName(current_name); 81 | } 82 | 83 | for (int i = 0; i < graph.outputs().size(); ++i) { 84 | auto& value = graph.outputs()[i]; 85 | const std::string current_name = 86 | rename_patterns[2] + std::to_string(i) + rename_patterns[3]; 87 | value->setUniqueName(current_name); 88 | } 89 | } 90 | 91 | std::shared_ptr runPass(Graph& graph) override { 92 | rename_input_output(graph); 93 | return std::shared_ptr(new PostPassAnalysis()); 94 | } 95 | }; 96 | 97 | } // namespace optimization 98 | } // namespace ONNX_NAMESPACE 99 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/replace_einsum_with_matmul.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | // Before: 11 | // Z = einsum(X,Y) 12 | // After: 13 | // Z = matmul(X,Y) 14 | // or Y1 = transpose(Y), Z = matmul(X,Y1) 15 | // the pass can handle the case when: 16 | // case 1: equation represents matmul, e.g: "bhij,bhjd->bhid" 17 | // case 2: equation represents transpose matmul, e.g: "bhid,bhjd->bhij" 18 | 19 | #include 20 | 21 | #include "onnx/defs/tensor_util.h" 22 | #include "onnxoptimizer/pass.h" 23 | 24 | namespace ONNX_NAMESPACE { 25 | namespace optimization { 26 | 27 | struct ReplaceEinsumWithMatmul final : public PredicateBasedPass { 28 | explicit ReplaceEinsumWithMatmul() 29 | : PredicateBasedPass(PassType::Replace, PassEfficiency::Complete, 30 | PassOptimizationType::Compute) {} 31 | std::string getPassName() const override { 32 | return "replace_einsum_with_matmul"; 33 | } 34 | 35 | bool patternMatchPredicate(Node* node) override { 36 | return CheckKind(node, "Einsum") && node->inputs().size() == 2 && 37 | std::all_of(node->inputs().begin(), node->inputs().end(), 38 | [](const Value* v) { 39 | switch (v->elemType()) { 40 | // matmul support these dtype 41 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: 42 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: 43 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: 44 | case ONNX_NAMESPACE::TensorProto_DataType_INT32: 45 | case ONNX_NAMESPACE::TensorProto_DataType_UINT32: 46 | case ONNX_NAMESPACE::TensorProto_DataType_INT64: 47 | case ONNX_NAMESPACE::TensorProto_DataType_UINT64: 48 | return true; 49 | } 50 | return false; 51 | }); 52 | } 53 | 54 | template 55 | bool isEqual(const T& a, const T& b) { 56 | return a == b; 57 | } 58 | 59 | template 60 | bool isEqual(const T& a, const T& b, const T& c) { 61 | return isEqual(a, b) && isEqual(a, c); 62 | } 63 | 64 | bool runTransform(Node* n, Graph& graph, 65 | NodeDestroyType& destroy_current) override { 66 | ONNX_ASSERT(n->hasAttribute(Symbol("equation"))); 67 | std::string equation = n->s(Symbol("equation")); 68 | // remove space 69 | equation.erase(std::remove(equation.begin(), equation.end(), ' '), 70 | equation.end()); 71 | auto mid_index = equation.find("->"); 72 | if (mid_index == std::string::npos) { 73 | return false; 74 | } 75 | const auto left_equation = equation.substr(0, mid_index); 76 | const auto right_equation = equation.substr(mid_index + 2); 77 | mid_index = left_equation.find(","); 78 | if (mid_index == std::string::npos) { 79 | return false; 80 | } 81 | // reference https://github.com/onnx/onnx/blob/main/docs/Operators.md#Einsum 82 | const auto term1 = left_equation.substr(0, mid_index); 83 | const auto term2 = left_equation.substr(mid_index + 1); 84 | 85 | if (term1.size() < 2 || 86 | !isEqual(term1.size(), term2.size(), right_equation.size())) { 87 | return false; 88 | } 89 | 90 | auto is_lower_letter = [](char c) { 91 | return std::isalpha(c) && std::islower(c); 92 | }; 93 | 94 | const int shape_size = term1.size(); 95 | for (int i = 0; i < shape_size; ++i) { 96 | // the term should only contain lower case letters 97 | if (!is_lower_letter(term1[i]) || !is_lower_letter(term2[i]) || 98 | !is_lower_letter(right_equation[i])) { 99 | return false; 100 | } 101 | // the batch dim should be equal 102 | if ((i < shape_size - 2) && 103 | !isEqual(term1[i], term2[i], right_equation[i])) { 104 | return false; 105 | } 106 | } 107 | char term1_m = term1[shape_size - 2]; 108 | char term2_k = term2[shape_size - 2]; 109 | char result_m = right_equation[shape_size - 2]; 110 | 111 | char term1_k = term1[shape_size - 1]; 112 | char term2_n = term2[shape_size - 1]; 113 | char result_n = right_equation[shape_size - 1]; 114 | 115 | bool need_transpose = false; 116 | if (isEqual(term1_m, result_m) && isEqual(term1_k, term2_k) && 117 | isEqual(term2_n, result_n)) { // "ij,jd->id" 118 | need_transpose = false; 119 | } else if (isEqual(term1_m, result_m) && isEqual(term1_k, term2_n) && 120 | isEqual(term2_k, result_n)) { // "id,jd->ij" 121 | need_transpose = true; 122 | } else { 123 | return false; 124 | } 125 | 126 | Node* matmul_node = graph.create(kMatMul, 1); 127 | matmul_node->addInput(n->inputs()[0]); 128 | if (need_transpose) { 129 | Node* transpose_node = graph.create(kTranspose, 1); 130 | transpose_node->addInput(n->inputs()[1]); 131 | transpose_node->output()->setUniqueName( 132 | ONNX_NAMESPACE::to_string(graph.getNextUnique())); 133 | std::vector perm(shape_size, 0); 134 | for (int i = 0; i < shape_size - 2; ++i) { 135 | perm[i] = i; 136 | } 137 | perm[shape_size - 1] = shape_size - 2; 138 | perm[shape_size - 2] = shape_size - 1; 139 | transpose_node->is_(kperm, std::move(perm)); 140 | matmul_node->addInput(transpose_node->output()); 141 | transpose_node->insertBefore(n); 142 | } else { 143 | matmul_node->addInput(n->inputs()[1]); 144 | } 145 | matmul_node->insertBefore(n); 146 | 147 | const bool replacing_success = tryReplacingAllUsesWith(n, matmul_node); 148 | if (!replacing_success) { 149 | return false; 150 | } 151 | destroy_current = NodeDestroyType::DestroyOne; 152 | return true; 153 | } 154 | }; 155 | 156 | } // namespace optimization 157 | } // namespace ONNX_NAMESPACE 158 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/rewrite_input_dtype.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | #include "pass_util.h" 12 | 13 | namespace ONNX_NAMESPACE { 14 | namespace optimization { 15 | 16 | struct RewriteInputDtype final : public FullGraphBasedPass { 17 | explicit RewriteInputDtype() 18 | : FullGraphBasedPass(PassType::Other, PassEfficiency::Complete, 19 | PassOptimizationType::None) {} 20 | 21 | std::string getPassName() const override { 22 | return "rewrite_input_dtype"; 23 | } 24 | 25 | PassAnalysisType getPassAnalysisType() const override { 26 | return PassAnalysisType::Empty; 27 | } 28 | 29 | void rewrite_input_dtype(Graph& graph) { 30 | std::unordered_set initializer_names( 31 | graph.initializer_names().begin(), graph.initializer_names().end()); 32 | 33 | for (auto& value : graph.inputs()) { 34 | // ignore when input also in initializer 35 | if (initializer_names.count(value->uniqueName()) > 0 || 36 | value->elemType() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { 37 | continue; 38 | } 39 | auto use_list = value->uses(); 40 | 41 | Node* cast = graph.create(kCast, 1); 42 | cast = graph.appendNode(cast); 43 | cast->i_(kto, static_cast( 44 | ONNX_NAMESPACE::TensorProto_DataType_INT64)); 45 | cast->addInput(value); 46 | cast->output()->setUniqueName( 47 | ONNX_NAMESPACE::to_string(graph.getNextUnique())); 48 | 49 | for (auto& use : use_list) { 50 | if (!cast->isBefore(use.user)) { 51 | cast->moveBefore(use.user); 52 | } 53 | use.user->replaceInput(use.offset, cast->output()); 54 | } 55 | value->setElemType(ONNX_NAMESPACE::TensorProto_DataType_INT32); 56 | } 57 | } 58 | 59 | std::shared_ptr runPass(Graph& graph) override { 60 | rewrite_input_dtype(graph); 61 | return std::shared_ptr(new PostPassAnalysis()); 62 | } 63 | }; 64 | 65 | } // namespace optimization 66 | } // namespace ONNX_NAMESPACE -------------------------------------------------------------------------------- /onnxoptimizer/passes/set_unique_name_for_nodes.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | struct SetUniqueNameForNodes final : public PredicateBasedPass { 16 | explicit SetUniqueNameForNodes() 17 | : PredicateBasedPass(PassType::Other, PassEfficiency::Complete, 18 | PassOptimizationType::None) {} 19 | 20 | std::string getPassName() const override { 21 | return "set_unique_name_for_nodes"; 22 | } 23 | 24 | bool patternMatchPredicate(Node* node) override { 25 | return !node->has_name(); 26 | } 27 | 28 | bool runTransform(Node* node, Graph& graph, 29 | NodeDestroyType& destroy_current) override { 30 | node->setName(ONNX_NAMESPACE::to_string(graph.getNextUnique())); 31 | destroy_current = NodeDestroyType::DestroyZero; 32 | return true; 33 | } 34 | }; 35 | 36 | } // namespace optimization 37 | } // namespace ONNX_NAMESPACE 38 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/split.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include "onnxoptimizer/pass.h" 11 | 12 | namespace ONNX_NAMESPACE { 13 | namespace optimization { 14 | 15 | static constexpr const char* impure_operators[] = { 16 | "RandomNormal", "RandomNormalLike", 17 | "RandomUniform", "RandomUniformLike", 18 | "Loop", "If", 19 | "Scan", 20 | }; 21 | 22 | static bool is_pure_operator(Node* n) { 23 | for (auto x : impure_operators) { 24 | if (n->kind() == Symbol(x)) { 25 | return false; 26 | } 27 | } 28 | return true; 29 | } 30 | 31 | static bool has_uses(Node* n){ 32 | 33 | for(auto output: n->outputs()){ 34 | if (!output->uses().empty()) 35 | return true; 36 | } 37 | 38 | return false; 39 | 40 | } 41 | // Split the graph into 'init' and 'predict' nets. This is kind of 42 | // like constant folding, except that rather than actually execute the 43 | // constant computations, we simply split them out into a separate 44 | // graph. Nodes that have any transitive dependency on the 45 | // initializers, or on impure operators, must remain in the predict 46 | // net. All others may be moved to the init net. 47 | // 48 | // This function destructively mutates the graph into either the init 49 | // or the predict net. If you want both, which you probably do, 50 | // arrange to call it twice. 51 | // 52 | // NOTE POTENTIAL BREAKAGE: 53 | // 54 | // The ONNX spec provides no guarantees about "staging", i.e. which 55 | // inputs change on every invocation vs which generally stay the same. 56 | // Here we make the assumption that inputs which have an initializer 57 | // value provided for them vary only between invocations of the init 58 | // net, and are constant across runs of the predict net. 59 | // 60 | static void split_init_and_predict(Graph& graph, bool init, bool predict) { 61 | // The first step is to identify which Values are reachable from 62 | // either of 63 | // - inputs without corresponding initializers 64 | // - impure operators 65 | // Any such Values belong to the predict net. Nodes belong to the 66 | // predict net if they are impure or if any of their inputs do. 67 | 68 | std::unordered_set predict_net_values; 69 | 70 | auto value_belongs_to_predict_net = [&](Value* v) { 71 | return predict_net_values.count(v) > 0; 72 | }; 73 | auto node_belongs_to_predict_net = [&](Node* n) { 74 | return !is_pure_operator(n) || 75 | std::any_of(n->inputs().begin(), n->inputs().end(), 76 | value_belongs_to_predict_net); 77 | }; 78 | 79 | { 80 | std::unordered_set initializer_names( 81 | graph.initializer_names().begin(), graph.initializer_names().end()); 82 | 83 | for (Value* v : graph.inputs()) { 84 | if (initializer_names.count(v->uniqueName()) == 0) { 85 | predict_net_values.insert(v); 86 | } 87 | } 88 | } 89 | 90 | for (Node* n : graph.nodes()) { 91 | if (node_belongs_to_predict_net(n)) { 92 | for (Value* v : n->outputs()) { 93 | predict_net_values.insert(v); 94 | } 95 | } 96 | } 97 | 98 | // Any Value which is not itself in the predict net, but which 99 | // is used by a Node which is, becomes an output of the init 100 | // graph and an input of the predict net 101 | std::unordered_set new_interface; 102 | for (Node* n : graph.nodes()) { 103 | if (node_belongs_to_predict_net(n)) { 104 | for (Value* v : n->inputs()) { 105 | if (!value_belongs_to_predict_net(v)) { 106 | new_interface.insert(v); 107 | } 108 | } 109 | } 110 | } 111 | 112 | for (Value* v : graph.outputs()) { 113 | if (!value_belongs_to_predict_net(v)) { 114 | new_interface.insert(v); 115 | } 116 | } 117 | 118 | if (init) { 119 | // Add new outputs corresponding to the boundary between init and 120 | // predict nets, ensuring that we don't duplicate outputs. 121 | for (Value* v : graph.outputs()) { 122 | new_interface.erase(v); 123 | } 124 | for (Value* v : new_interface) { 125 | if (v->node()->kind() == kUndefined) { 126 | continue; 127 | } 128 | graph.registerOutput(v); 129 | } 130 | 131 | // Remove outputs that belong to the predict net. 132 | for (auto i = graph.outputs().size(); i--;) { 133 | if (value_belongs_to_predict_net(graph.outputs()[i])) { 134 | graph.return_node()->removeInput(i); 135 | } 136 | } 137 | 138 | // Delete nodes that belong to the predict net, in reverse 139 | // topological order. 140 | for (auto it = graph.nodes().rbegin(); it != graph.nodes().rend(); it++) { 141 | if (node_belongs_to_predict_net(*it)) { 142 | it.destroyCurrent(); 143 | } 144 | } 145 | 146 | // Remove inputs that belong to the predict net. 147 | for (auto i = graph.inputs().size(); i--;) { 148 | if (value_belongs_to_predict_net(graph.inputs()[i])) { 149 | graph.eraseInput(i); 150 | } 151 | } 152 | } else if (predict) { 153 | // When creating the predict net, 'undefined' nodes will 154 | // naturally go into the init net. We need to have a place to 155 | // copy the ones we want to keep in the predict net. 156 | auto* optionalInputDummyNode = graph.create(kUndefined, 1); 157 | graph.appendNode(optionalInputDummyNode); 158 | optionalInputDummyNode->outputs()[0]->setUniqueName(""); 159 | 160 | // Add new inputs, ensuring that we don't introduce duplicates. 161 | // Also cut the boundary between init and predict net by replacing 162 | // the Values along the boundary with replaceAllUsesWith. 163 | for (Value* v : graph.inputs()) { 164 | new_interface.erase(v); 165 | } 166 | for (Value* v : new_interface) { 167 | if (v->node()->kind() == kUndefined) { 168 | v->replaceAllUsesWith(optionalInputDummyNode->outputs()[0]); 169 | } else { 170 | Value* newv = graph.addInput()->copyMetadata(v); 171 | v->replaceAllUsesWith(newv); 172 | } 173 | } 174 | 175 | // Delete nodes that aren't in the predict net, in reverse 176 | // topological order. 177 | for (auto it = graph.nodes().rbegin(); it != graph.nodes().rend(); it++) { 178 | if (*it == optionalInputDummyNode) { 179 | continue; 180 | } 181 | if (node_belongs_to_predict_net(*it)) { 182 | continue; 183 | } 184 | 185 | if (!has_uses(*it)) 186 | it.destroyCurrent(); 187 | } 188 | 189 | // Remove inputs that aren't used by the predict net. 190 | for (auto i = graph.inputs().size(); i--;) { 191 | if (graph.inputs()[i]->uses().empty()) { 192 | graph.eraseInput(i); 193 | } 194 | } 195 | 196 | // Remove all initializers, they are already in the init net. 197 | graph.clearInitializers(); 198 | } 199 | } 200 | 201 | struct SplitInit final : public FullGraphBasedPass { 202 | explicit SplitInit() 203 | : FullGraphBasedPass(PassType::Separate, PassEfficiency::Complete, 204 | PassOptimizationType::Memory) {} 205 | 206 | std::string getPassName() const override { 207 | return "split_init"; 208 | } 209 | PassAnalysisType getPassAnalysisType() const override { 210 | return PassAnalysisType::Empty; 211 | } 212 | std::shared_ptr runPass(Graph& graph) override { 213 | split_init_and_predict(graph, true, false); 214 | return std::shared_ptr(new PostPassAnalysis()); 215 | } 216 | }; 217 | 218 | struct SplitPredict final : public FullGraphBasedPass { 219 | explicit SplitPredict() 220 | : FullGraphBasedPass(PassType::Separate, PassEfficiency::Complete, 221 | PassOptimizationType::Memory) {} 222 | std::string getPassName() const override { 223 | return "split_predict"; 224 | } 225 | PassAnalysisType getPassAnalysisType() const override { 226 | return PassAnalysisType::Empty; 227 | } 228 | std::shared_ptr runPass(Graph& graph) override { 229 | split_init_and_predict(graph, false, true); 230 | return std::shared_ptr(new PostPassAnalysis()); 231 | } 232 | }; 233 | 234 | } // namespace optimization 235 | } // namespace ONNX_NAMESPACE 236 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/string_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace ONNX_NAMESPACE { 15 | namespace optimization { 16 | 17 | template 18 | std::ostream& operator<<(std::ostream& os, const std::vector& datas) { 19 | os << "["; 20 | bool first = true; 21 | for (const auto& d : datas) { 22 | if (!first) { 23 | os << delimiter; 24 | } else { 25 | first = false; 26 | } 27 | os << d; 28 | } 29 | os << "]"; 30 | return os; 31 | } 32 | 33 | inline std::string StripFilename(const std::string& filepath) { 34 | size_t pos = filepath.find_last_of('/'); 35 | return (pos == std::string::npos) ? filepath : filepath.substr(pos + 1); 36 | } 37 | 38 | inline std::string StripFilename(const char* filepath) { 39 | return StripFilename(std::string(filepath)); 40 | } 41 | 42 | namespace { 43 | template 44 | struct StrTypeTrait { 45 | using type = const T&; 46 | }; 47 | 48 | template 49 | struct StrTypeTrait { 50 | using type = const char*; 51 | }; 52 | 53 | inline std::ostream& _str(std::ostream& os) { 54 | return os; 55 | } 56 | 57 | template 58 | std::ostream& _str(std::ostream& os, const T& s, const Args&... args) { 59 | os << s; 60 | return _str(os, args...); 61 | } 62 | 63 | template 64 | struct StrWrapper { 65 | static std::string Call(const Args&... args) { 66 | std::ostringstream os; 67 | _str(os, args...); 68 | return os.str(); 69 | } 70 | }; 71 | 72 | template <> 73 | struct StrWrapper { 74 | static std::string Call(const std::string& s) { 75 | return s; 76 | } 77 | }; 78 | 79 | template <> 80 | struct StrWrapper { 81 | static const char* Call(const char* s) { 82 | return s; 83 | } 84 | }; 85 | 86 | } // namespace 87 | 88 | template 89 | decltype(auto) Str(const Args&... args) { 90 | return StrWrapper::type...>::Call(args...); 91 | } 92 | 93 | } // namespace optimization 94 | } // namespace ONNX_NAMESPACE 95 | -------------------------------------------------------------------------------- /onnxoptimizer/passes/tensor_util.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. 6 | // Adventurous users should note that the APIs will probably change. 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | #include "onnx/common/tensor.h" 13 | #include "onnxoptimizer/passes/data_type.h" 14 | 15 | namespace ONNX_NAMESPACE { 16 | namespace optimization { 17 | 18 | int64_t ElemCntOfTensor(const Tensor* tensor); 19 | int64_t ElemCntOfTensor(const Tensor& tensor); 20 | 21 | template 22 | const std::vector ParseTensorData(const Tensor* tensor); 23 | 24 | } // namespace optimization 25 | } // namespace ONNX_NAMESPACE -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | [aliases] 4 | test=pytest 5 | 6 | [tool:pytest] 7 | # addopts = --nbval --current-env 8 | addopts = -n auto 9 | testpaths = onnxoptimizer/test/ 10 | 11 | [metadata] 12 | license_file = LICENSE 13 | 14 | [flake8] 15 | select = B,C,E,F,P,T4,W,B9 16 | max-line-length = 80 17 | ### DEFAULT IGNORES FOR 4-space INDENTED PROJECTS ### 18 | # E127, E128 are hard to silence in certain nested formatting situations. 19 | # E265, E266 talk about comment formatting which is too opinionated. 20 | # E402 warns on imports coming after statements. There are important use cases 21 | # like demandimport (https://fburl.com/demandimport) that require statements 22 | # before imports. 23 | # E501 is not flexible enough, we're using B950 instead. 24 | # E722 is a duplicate of B001. 25 | # F405 is hard to silence since we indeed do star import 26 | # P207 is a duplicate of B003. 27 | # P208 is a duplicate of C403. 28 | # W503 talks about operator formatting which is too opinionated. 29 | # F401 clashes with PEP484 requiring us to import types that are only used in 30 | # type comments. 31 | ignore = E127, E128, E265, E266, E402, E501, E722, F405, P207, P208, W503, F401 32 | exclude = 33 | .git, 34 | __pycache__, 35 | build/*, 36 | third_party/* 37 | *_pb2.py, 38 | .cache/* 39 | .eggs 40 | .setuptools-cmake-build*/* 41 | 42 | [mypy] 43 | # follow-imports = silent # TODO remove this 44 | mypy_path = stubs:third_party/onnx/third_party/pybind11 45 | strict_optional = True 46 | warn_return_any = True 47 | warn_no_return = True 48 | # TODO warn_unused_ignores = True 49 | warn_redundant_casts = True 50 | warn_incomplete_stub = True 51 | # TODO disallow_untyped_calls = True 52 | check_untyped_defs = True 53 | disallow_any_generics = True 54 | no_implicit_optional = True 55 | # TODO disallow_incomplete_defs = True 56 | # TODO disallow_subclassing_any = True 57 | disallow_untyped_decorators = True 58 | warn_unused_configs = True 59 | 60 | [mypy-onnxoptimizer.*] 61 | disallow_untyped_defs = True 62 | 63 | [mypy-onnxoptimizer.onnx_opt_cpp2py_export] 64 | ignore_missing_imports = True 65 | 66 | [mypy-onnx.*] 67 | disallow_untyped_defs = True 68 | ignore_missing_imports = True 69 | 70 | [mypy-tools.*] 71 | disallow_untyped_defs = True 72 | 73 | # Ignore errors in setup.py 74 | [mypy-setup] 75 | ignore_errors = True 76 | 77 | -------------------------------------------------------------------------------- /tools/mypy-onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | import subprocess 7 | import os 8 | 9 | 10 | def main(): # type: () -> None 11 | try: 12 | root_folder = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | os.chdir(root_folder) 14 | 15 | subprocess.check_call(["mypy", "."]) 16 | subprocess.check_call(["mypy", "--py2", "."]) 17 | 18 | exit(0) 19 | except subprocess.CalledProcessError: 20 | # Catch this exception because we don't want it to output a backtrace that would clutter the mypy output 21 | exit(1) 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | --------------------------------------------------------------------------------