├── VERSION ├── onnxsim ├── .clang-format ├── __main__.py ├── __init__.py ├── bin │ ├── onnxsim_option.h │ ├── onnxsim_bin.cpp │ └── onnxsim_option.cpp ├── onnxsim.h ├── model_info.py ├── cpp2py_export.cc ├── model_checking.py ├── onnxsim.cpp ├── onnx_simplifier.py └── cxxopts.hpp ├── imgs ├── comparison.png ├── simple_reshape.png └── complicated_reshape.png ├── requirements.txt ├── MANIFEST.in ├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── build-and-test.yml ├── .gitmodules ├── cmake └── build_ort.cmake ├── .gitignore ├── CMakeLists.txt ├── README.md ├── tests └── test_python_api.py ├── LICENSE └── setup.py /VERSION: -------------------------------------------------------------------------------- 1 | 0.4.36-post1 2 | -------------------------------------------------------------------------------- /onnxsim/.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | -------------------------------------------------------------------------------- /onnxsim/__main__.py: -------------------------------------------------------------------------------- 1 | from . import main 2 | 3 | 4 | if __name__ == '__main__': 5 | main() 6 | -------------------------------------------------------------------------------- /imgs/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsukumijima/onnx-simplifier-prebuilt/master/imgs/comparison.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | onnx 2 | onnxoptimizer >= 0.2.5 3 | onnxruntime >= 1.6.0 4 | protobuf >= 3.7.0 5 | rich != 12.1.0 6 | -------------------------------------------------------------------------------- /imgs/simple_reshape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsukumijima/onnx-simplifier-prebuilt/master/imgs/simple_reshape.png -------------------------------------------------------------------------------- /imgs/complicated_reshape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsukumijima/onnx-simplifier-prebuilt/master/imgs/complicated_reshape.png -------------------------------------------------------------------------------- /onnxsim/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxsim.onnx_simplifier import simplify, main 2 | 3 | # register python executor 4 | import onnxsim.onnx_simplifier 5 | import onnxsim.onnxsim_cpp2py_export 6 | x = onnxsim.onnx_simplifier.PyModelExecutor() 7 | onnxsim.onnxsim_cpp2py_export._set_model_executor(x) 8 | 9 | from .version import version as __version__ # noqa 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include onnxsim *.h *.hpp *.c *.cc *.cpp *.proto 2 | recursive-include cmake * 3 | recursive-include third_party * 4 | recursive-exclude third_party/onnxruntime * 5 | recursive-exclude third_party/onnx-optimizer/build * 6 | recursive-exclude third_party/onnx/build * 7 | recursive-exclude third_party/onnx/onnx/backend * 8 | include CMakeLists.txt 9 | include VERSION 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] " 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Model** 14 | To reproduce the problem, please post download link of your model here, or send your model to daquexian566@gmail.com 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "onnxsim/third_party/onnxruntime"] 2 | path = third_party/onnxruntime 3 | url = git@github.com:microsoft/onnxruntime.git 4 | [submodule "onnxsim/third_party/onnx-optimizer"] 5 | path = third_party/onnx-optimizer 6 | url = git@github.com:onnx/optimizer.git 7 | [submodule "third_party/pybind11"] 8 | path = third_party/pybind11 9 | url = git@github.com:pybind/pybind11.git 10 | -------------------------------------------------------------------------------- /onnxsim/bin/onnxsim_option.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cxxopts.hpp" 4 | 5 | class OnnxsimOption { 6 | public: 7 | OnnxsimOption() = default; 8 | OnnxsimOption(int argc, char** argv) { Parse(argc, argv); } 9 | ~OnnxsimOption() = default; 10 | 11 | void Parse(int argc, char** argv); 12 | 13 | template 14 | T Get(const std::string& key) const { 15 | T value = options_[key].as(); 16 | return value; 17 | } 18 | 19 | private: 20 | cxxopts::ParseResult options_; 21 | }; -------------------------------------------------------------------------------- /onnxsim/bin/onnxsim_bin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "onnx/common/file_utils.h" 4 | #include "onnxsim.h" 5 | #include "onnxsim_option.h" 6 | 7 | int main(int argc, char** argv) { 8 | // force env initialization to register opset 9 | InitEnv(); 10 | OnnxsimOption option(argc, argv); 11 | bool no_opt = option.Get("no-opt"); 12 | bool no_sim = option.Get("no-sim"); 13 | bool no_shape_inference = option.Get("no-shape-inference"); 14 | auto input_model_filename = option.Get("input-model"); 15 | auto output_model_filename = option.Get("output-model"); 16 | 17 | onnx::ModelProto model; 18 | onnx::LoadProtoFromPath(input_model_filename, model); 19 | 20 | model = Simplify( 21 | model, 22 | no_opt ? std::nullopt : std::make_optional>({}), 23 | !no_sim, !no_shape_inference, SIZE_MAX); 24 | 25 | std::ofstream ofs(output_model_filename, 26 | std::ios::out | std::ios::trunc | std::ios::binary); 27 | if (!model.SerializeToOstream(&ofs)) { 28 | throw std::invalid_argument("save model error"); 29 | } 30 | return 0; 31 | } 32 | -------------------------------------------------------------------------------- /cmake/build_ort.cmake: -------------------------------------------------------------------------------- 1 | # For MessageDifferencer::Equals 2 | option(onnxruntime_USE_FULL_PROTOBUF "" ON) 3 | if (EMSCRIPTEN) 4 | if (NOT DEFINED ONNX_CUSTOM_PROTOC_EXECUTABLE) 5 | message(FATAL_ERROR "ONNX_CUSTOM_PROTOC_EXECUTABLE must be set for emscripten") 6 | endif() 7 | 8 | option(onnxruntime_BUILD_WEBASSEMBLY "" ON) 9 | option(onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB "" ON) 10 | option(onnxruntime_ENABLE_WEBASSEMBLY_SIMD "" OFF) 11 | option(onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING "" ON) 12 | option(onnxruntime_ENABLE_WEBASSEMBLY_THREADS "" OFF) 13 | option(onnxruntime_BUILD_UNIT_TESTS "" OFF) 14 | set(onnxruntime_EMSCRIPTEN_SETTINGS "MALLOC=dlmalloc") 15 | set(onnxruntime_ENABLE_WEBASSEMBLY_THREADS ON) 16 | 17 | # For custom onnx target in onnx optimizer 18 | set(ONNX_TARGET_NAME onnxruntime_webassembly) 19 | else() 20 | # For native build, only shared libs is ok. Otherwise libonnx.a will be linked twice (in onnxruntime and in onnxsim) 21 | # For emscripten build, since the libonnxruntime_webassembly.a is bundled by `bundle_static_library`, onnxsim can link 22 | # to the single libonnxruntime_webassembly.a 23 | option(onnxruntime_BUILD_SHARED_LIB "" ON) 24 | endif() 25 | add_subdirectory(third_party/onnxruntime/cmake) 26 | 27 | if (NOT EMSCRIPTEN) 28 | set(BUILD_SHARED_LIBS ON) 29 | endif() 30 | -------------------------------------------------------------------------------- /onnxsim/onnxsim.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | struct ModelExecutor { 10 | virtual ~ModelExecutor() = default; 11 | static void set_instance(std::shared_ptr instance) { 12 | instance_ = std::move(instance); 13 | } 14 | static std::vector Run( 15 | const onnx::ModelProto& model, 16 | const std::vector& inputs) { 17 | if (instance_ == nullptr) { 18 | throw std::runtime_error("empty instance"); 19 | } 20 | return instance_->_Run(model, inputs); 21 | } 22 | 23 | // public it for pybind11 24 | virtual std::vector _Run( 25 | const onnx::ModelProto& model, 26 | const std::vector& inputs) const = 0; 27 | 28 | private: 29 | static std::shared_ptr instance_; 30 | }; 31 | 32 | void InitEnv(); 33 | 34 | onnx::ModelProto Simplify( 35 | const onnx::ModelProto& model, 36 | std::optional> skip_optimizers, 37 | bool constant_folding, bool shape_inference, size_t tensor_size_threshold); 38 | 39 | void SimplifyPath(const std::string& in_path, const std::string& out_path, 40 | std::optional> skip_optimizers, 41 | bool constant_folding, bool shape_inference, 42 | size_t tensor_size_threshold); 43 | -------------------------------------------------------------------------------- /onnxsim/bin/onnxsim_option.cpp: -------------------------------------------------------------------------------- 1 | #include "onnxsim_option.h" 2 | #include 3 | 4 | void OnnxsimOption::Parse(int argc, char** argv) { 5 | cxxopts::Options cxx_options("onnxsim", "Simplify your ONNX model"); 6 | 7 | // clang-format off 8 | cxx_options.add_options() 9 | ("h,help", "Print help") 10 | ("i,input-model", "Input onnx model filename. This argument is required.", cxxopts::value()) 11 | ("o,output-model", "Output onnx model filename. This argument is required.", cxxopts::value()) 12 | ("no-opt", "No optimization", cxxopts::value()->default_value("false")) 13 | ("no-sim", "No simplification", cxxopts::value()->default_value("false")) 14 | ("no-shape-inference", "No shape inference", cxxopts::value()->default_value("false")) 15 | ; 16 | // clang-format on 17 | 18 | try { 19 | options_ = cxx_options.parse(argc, argv); 20 | } catch (cxxopts::OptionParseException cxxopts_exception) { 21 | std::cout << "[Error] Can not parse your options" << std::endl; 22 | std::cout << cxx_options.help() << std::endl; 23 | exit(1); 24 | } 25 | 26 | if (options_.count("help")) { 27 | std::cout << cxx_options.help() << std::endl; 28 | exit(0); 29 | } 30 | if (!options_.count("input-model") || !options_.count("output-model")) { 31 | std::cout << cxx_options.help() << std::endl; 32 | exit(1); 33 | } 34 | 35 | return; 36 | } 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | .idea/ 126 | *.onnx 127 | *.onnx.data 128 | 129 | compile_commands.json 130 | temp*.py 131 | test*.py 132 | 133 | .setuptools-cmake-build/ 134 | onnxsim/version.py 135 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | 3 | # For std::filesystem in onnx optimizer 4 | # Must be a cache variable and be set before project() 5 | # Reference: https://cmake.org/cmake/help/latest/variable/CMAKE_OSX_DEPLOYMENT_TARGET.html 6 | # It can be a normal variable if policy CMP0126 is set to NEW. 7 | set(CMAKE_OSX_DEPLOYMENT_TARGET 10.15 CACHE STRING "Minimum OS X deployment version") 8 | 9 | project(onnxsim CXX) 10 | 11 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 12 | set(CMAKE_CXX_STANDARD 17) 13 | 14 | option(ONNXSIM_PYTHON "" OFF) 15 | option(ONNXSIM_BUILTIN_ORT "" ON) 16 | option(ONNXSIM_WASM_NODE "For node (enable NODERAWFS etc.)" OFF) 17 | 18 | if (ONNXSIM_PYTHON AND EMSCRIPTEN) 19 | message(STATUS "python and emscripten cannot be built at the same time") 20 | endif() 21 | 22 | if (NOT ONNXSIM_BUILTIN_ORT AND EMSCRIPTEN) 23 | message(STATUS "emscripten needs builtin ort") 24 | endif() 25 | 26 | add_compile_options( 27 | $<$:$<$:-fdiagnostics-color=always>> 28 | $<$:$<$:-fcolor-diagnostics>> 29 | $<$:$<$:-fcolor-diagnostics>>) 30 | if (WIN32) 31 | add_compile_definitions(NOMINMAX) 32 | endif() 33 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 34 | 35 | if (ONNXSIM_BUILTIN_ORT) 36 | include(cmake/build_ort.cmake) 37 | if (EMSCRIPTEN) 38 | set(ORT_NAME onnxruntime_webassembly) 39 | else() 40 | set(ORT_NAME onnxruntime) 41 | endif() 42 | endif() 43 | 44 | # configure onnx-optimizer after onnxruntime, because they both depend on onnx and onnxruntime has its own flags for onnx 45 | add_subdirectory(third_party/onnx-optimizer) 46 | 47 | add_library(onnxsim onnxsim/onnxsim.cpp) 48 | if (ONNXSIM_BUILTIN_ORT) 49 | target_include_directories(onnxsim PRIVATE third_party/onnxruntime/onnxruntime third_party/onnxruntime/include/onnxruntime) 50 | endif() 51 | target_include_directories(onnxsim PUBLIC onnxsim) 52 | if (NOT ONNXSIM_BUILTIN_ORT) 53 | target_compile_definitions(onnxsim PUBLIC NO_BUILTIN_ORT) 54 | endif() 55 | if (EMSCRIPTEN) 56 | target_link_libraries(onnxsim ${ORT_NAME} onnx_optimizer) 57 | else() 58 | target_link_libraries(onnxsim ${ORT_NAME} onnx_optimizer onnx) 59 | endif() 60 | 61 | add_executable(onnxsim_bin onnxsim/bin/onnxsim_bin.cpp onnxsim/bin/onnxsim_option.cpp) 62 | target_link_libraries(onnxsim_bin onnxsim) 63 | set_target_properties(onnxsim_bin PROPERTIES OUTPUT_NAME onnxsim) 64 | if (EMSCRIPTEN) 65 | if (ONNXSIM_WASM_NODE) 66 | set_target_properties(onnxsim_bin PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s ASSERTIONS=2 -s STACK_OVERFLOW_CHECK=1 -sNO_DISABLE_EXCEPTION_CATCHING") 67 | else() 68 | set_target_properties(onnxsim_bin PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s EXIT_RUNTIME=1 -s FORCE_FILESYSTEM=1 -s MODULARIZE=1 -s 'EXPORT_NAME=\"create_onnxsim\"' -s 'EXPORTED_RUNTIME_METHODS=[FS,ccall,cwrap,callMain]' -s EXPORTED_FUNCTIONS=[_main]") 69 | endif() 70 | endif() 71 | 72 | if (ONNXSIM_PYTHON) 73 | add_subdirectory(third_party/pybind11) 74 | pybind11_add_module(onnxsim_cpp2py_export onnxsim/cpp2py_export.cc) 75 | target_link_libraries(onnxsim_cpp2py_export PRIVATE onnxsim) 76 | if(NOT "${PY_EXT_SUFFIX}" STREQUAL "") 77 | set_target_properties(onnxsim_cpp2py_export PROPERTIES SUFFIX ${PY_EXT_SUFFIX}) 78 | endif() 79 | endif() 80 | -------------------------------------------------------------------------------- /onnxsim/model_info.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Callable, Any, Optional, Tuple, Dict 3 | 4 | import onnx 5 | from rich.table import Table 6 | from rich.text import Text 7 | from rich import print 8 | 9 | 10 | __all__ = ['ModelInfo', 'print_simplifying_info'] 11 | 12 | 13 | def human_readable_size(num, suffix="B"): 14 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 15 | if abs(num) < 1024.0: 16 | return f"{num:3.1f}{unit}{suffix}" 17 | num /= 1024.0 18 | return f"{num:.1f}Yi{suffix}" 19 | 20 | 21 | class ModelInfo: 22 | """ 23 | Model info contains: 24 | 1. Num of every op 25 | 2. Model size 26 | TODO: 27 | Based on onnx runtime, get 28 | 1、FLOPs 29 | 2、forward memory footprint 30 | 3、memory access 31 | 4、compute density 32 | """ 33 | 34 | def get_info(self, graph: onnx.GraphProto) -> Tuple[Dict[str, int], int]: 35 | op_nums = defaultdict(int) 36 | model_size = 0 37 | for node in graph.node: 38 | op_nums[node.op_type] += 1 39 | for attr in node.attribute: 40 | sub_graphs = [] 41 | if attr.g is not None: 42 | sub_graphs.append(attr.g) 43 | if attr.graphs is not None: 44 | sub_graphs.extend(attr.graphs) 45 | for sub_graph in sub_graphs: 46 | sub_op_nums, sub_model_size = self.get_info(sub_graph) 47 | op_nums = defaultdict(int, {k: op_nums[k] + sub_op_nums[k] for k in set(op_nums) | set(sub_op_nums)}) 48 | model_size += sub_model_size 49 | op_nums["Constant"] += len(graph.initializer) 50 | model_size += graph.ByteSize() 51 | return op_nums, model_size 52 | 53 | def __init__(self, model: onnx.ModelProto): 54 | self.op_nums, self.model_size = self.get_info(model.graph) 55 | 56 | 57 | def print_simplifying_info(model_ori: onnx.ModelProto, model_opt: onnx.ModelProto) -> None: 58 | """ 59 | -------------------------------------------------------- 60 | | | original model | simplified model | 61 | -------------------------------------------------------- 62 | | **** | **** | **** | 63 | -------------------------------------------------------- 64 | | Model Size | **** | **** | 65 | -------------------------------------------------------- 66 | """ 67 | ori_info = ModelInfo(model_ori) 68 | opt_info = ModelInfo(model_opt) 69 | table = Table() 70 | table.add_column('') 71 | table.add_column('Original Model') 72 | table.add_column('Simplified Model') 73 | 74 | def add_row(table: Table, key, ori_data, opt_data, is_better: Callable[[Any, Any], Any], postprocess: Optional[Callable[[Any], Any]] = None) -> None: 75 | if postprocess is None: 76 | postprocess = str 77 | if is_better(opt_data, ori_data): 78 | table.add_row(key, postprocess(ori_data), Text( 79 | postprocess(opt_data), style='bold green1')) 80 | else: 81 | table.add_row(key, postprocess(ori_data), postprocess(opt_data)) 82 | 83 | for key in sorted(list(set(ori_info.op_nums.keys()) | set(opt_info.op_nums.keys()))): 84 | add_row(table, key, ori_info.op_nums[key], 85 | opt_info.op_nums[key], lambda opt, ori: opt < ori) 86 | add_row( 87 | table, 'Model Size', ori_info.model_size, opt_info.model_size, lambda opt, ori: opt < ori, postprocess=human_readable_size) 88 | print(table) 89 | -------------------------------------------------------------------------------- /onnxsim/cpp2py_export.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include 6 | #include 7 | 8 | #include "onnx/py_utils.h" 9 | #include "onnxsim.h" 10 | 11 | namespace py = pybind11; 12 | using namespace pybind11::literals; 13 | 14 | struct PyModelExecutor : public ModelExecutor { 15 | using ModelExecutor::ModelExecutor; 16 | 17 | std::vector _Run( 18 | const onnx::ModelProto& model, 19 | const std::vector& inputs) const override { 20 | std::vector inputs_bytes; 21 | std::transform(inputs.begin(), inputs.end(), 22 | std::back_inserter(inputs_bytes), 23 | [](const onnx::TensorProto& x) { 24 | return py::bytes(x.SerializeAsString()); 25 | }); 26 | std::string model_str = model.SerializeAsString(); 27 | auto output_bytes = _PyRun(py::bytes(model_str), inputs_bytes); 28 | std::vector output_tps; 29 | std::transform(output_bytes.begin(), output_bytes.end(), 30 | std::back_inserter(output_tps), [](const py::bytes& x) { 31 | onnx::TensorProto tp; 32 | tp.ParseFromString(std::string(x)); 33 | return tp; 34 | }); 35 | return output_tps; 36 | } 37 | 38 | virtual std::vector _PyRun( 39 | const py::bytes& model_bytes, 40 | const std::vector& inputs_bytes) const = 0; 41 | }; 42 | 43 | struct PyModelExecutorTrampoline : public PyModelExecutor { 44 | /* Inherit the constructors */ 45 | using PyModelExecutor::PyModelExecutor; 46 | 47 | /* Trampoline (need one for each virtual function) */ 48 | std::vector _PyRun( 49 | const py::bytes& model_bytes, 50 | const std::vector& inputs_bytes) const override { 51 | PYBIND11_OVERRIDE_PURE_NAME( 52 | std::vector, /* Return type */ 53 | PyModelExecutor, /* Parent class */ 54 | "Run", _PyRun, /* Name of function in C++ (must match Python name) */ 55 | model_bytes, inputs_bytes /* Argument(s) */ 56 | ); 57 | } 58 | }; 59 | 60 | PYBIND11_MODULE(onnxsim_cpp2py_export, m) { 61 | m.doc() = "ONNX Simplifier"; 62 | 63 | m.def("simplify", 64 | [](const py::bytes& model_proto_bytes, 65 | std::optional> skip_optimizers, 66 | bool constant_folding, bool shape_inference, 67 | size_t tensor_size_threshold) -> py::bytes { 68 | // force env initialization to register opset 69 | InitEnv(); 70 | ONNX_NAMESPACE::ModelProto model; 71 | ParseProtoFromPyBytes(&model, model_proto_bytes); 72 | auto const result = Simplify(model, skip_optimizers, constant_folding, 73 | shape_inference, tensor_size_threshold); 74 | std::string out; 75 | result.SerializeToString(&out); 76 | return py::bytes(out); 77 | }) 78 | .def("simplify_path", 79 | [](const std::string& in_path, const std::string& out_path, 80 | std::optional> skip_optimizers, 81 | bool constant_folding, bool shape_inference, 82 | size_t tensor_size_threshold) -> bool { 83 | // force env initialization to register opset 84 | InitEnv(); 85 | SimplifyPath(in_path, out_path, skip_optimizers, constant_folding, 86 | shape_inference, tensor_size_threshold); 87 | return true; 88 | }) 89 | .def("_set_model_executor", 90 | [](std::shared_ptr executor) { 91 | ModelExecutor::set_instance(std::move(executor)); 92 | }); 93 | 94 | py::class_>(m, "ModelExecutor") 96 | .def(py::init<>()) 97 | .def("Run", &PyModelExecutor::_PyRun); 98 | } 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX Simplifier Prebuilt 2 | 3 | [![PyPI version](https://img.shields.io/pypi/v/onnxsim-prebuilt.svg)](https://pypi.python.org/pypi/onnxsim-prebuilt/) 4 | [![PyPI pyversions](https://img.shields.io/pypi/pyversions/onnxsim-prebuilt.svg)](https://pypi.python.org/pypi/onnxsim-prebuilt/) 5 | [![PyPI license](https://img.shields.io/pypi/l/onnxsim-prebuilt.svg)](https://pypi.python.org/pypi/onnxsim-prebuilt/) 6 | [![Build and Test](https://github.com/tsukumijima/onnx-simplifier-prebuilt/actions/workflows/build-and-test.yml/badge.svg)](https://github.com/tsukumijima/onnx-simplifier-prebuilt/actions/workflows/build-and-test.yml) 7 | 8 | onnxsim-prebuilt is a fork of [onnxsim](https://github.com/daquexian/onnx-simplifier) that aims to publish prebuilt wheels for Python 3.12 and later to PyPI. 9 | 10 | ## Changes in this fork 11 | 12 | - **Changed package name to `onnxsim-prebuilt`** 13 | - The library name remains unchanged from `onnxsim`, so you can import it as `import onnxsim` just like the original [onnxsim](https://github.com/daquexian/onnx-simplifier) 14 | - Can be used as a drop-in replacement for the original [onnxsim](https://github.com/daquexian/onnx-simplifier) 15 | - **Publish prebuilt wheels for all platforms (Windows, macOS x64/arm64, Linux x64/arm64) on PyPI** 16 | - onnx-simplifier depends on C++, CMake, and submodules, making the build environment setup relatively difficult and time-consuming 17 | - For over a year, onnxsim has not been updated, and prebuilt wheels for Python 3.12/3.13 are not available (ref: [onnxsim/issues/334](https://github.com/daquexian/onnx-simplifier/issues/334), [onnxsim/pull/359](https://github.com/daquexian/onnx-simplifier/pull/359)) 18 | - Various issues arise, such as the need to install build-essentials and CMake in Docker images just for installation, and long build times 19 | - By publishing prebuilt wheels on PyPI, we aim to enable easy installation even on PCs without a build environment 20 | - Incorporated the CI improvements proposed in the pull request [onnxsim/pull/359](https://github.com/daquexian/onnx-simplifier/pull/359), and further enhanced it to build and publish prebuilt wheels for Linux aarch64 21 | - **Explicitly added Python 3.12 / 3.13 to supported versions** 22 | - Changed CI target Python versions to Python 3.10 and above 23 | - This fork does not support Python 3.9 and below 24 | 25 | ## Installation 26 | 27 | You can install the library by running the following command: 28 | 29 | ```bash 30 | pip install onnxsim-prebuilt 31 | ``` 32 | 33 | The documentation below is inherited from the original [onnxsim](https://github.com/daquexian/onnx-simplifier) without any modifications. 34 | There is no guarantee that the content of this documentation applies to onnxsim-prebuilt. 35 | 36 | ------- 37 | 38 | # ONNX Simplifier 39 | 40 | [![PyPI version](https://img.shields.io/pypi/v/onnx-simplifier.svg)](https://pypi.python.org/pypi/onnx-simplifier/) 41 | [![PyPI pyversions](https://img.shields.io/pypi/pyversions/onnx-simplifier.svg)](https://pypi.python.org/pypi/onnx-simplifier/) 42 | [![PyPI license](https://img.shields.io/pypi/l/onnx-simplifier.svg)](https://pypi.python.org/pypi/onnx-simplifier/) 43 | [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/daquexian/onnx-simplifier/pulls) 44 | 45 | _ONNX is great, but sometimes too complicated._ 46 | 47 | ## Background 48 | 49 | One day I wanted to export the following simple reshape operation to ONNX: 50 | 51 | ```python 52 | import torch 53 | 54 | 55 | class JustReshape(torch.nn.Module): 56 | def __init__(self): 57 | super(JustReshape, self).__init__() 58 | 59 | def forward(self, x): 60 | return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2])) 61 | 62 | 63 | net = JustReshape() 64 | model_name = 'just_reshape.onnx' 65 | dummy_input = torch.randn(2, 3, 4, 5) 66 | torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output']) 67 | ``` 68 | 69 | The input shape in this model is static, so what I expected is 70 | 71 | ![simple_reshape](imgs/simple_reshape.png) 72 | 73 | However, I got the following complicated model instead: 74 | 75 | ![complicated_reshape](imgs/complicated_reshape.png) 76 | 77 | ## Our solution 78 | 79 | ONNX Simplifier is presented to simplify the ONNX model. It infers the whole computation graph 80 | and then replaces the redundant operators with their constant outputs (a.k.a. constant folding). 81 | 82 | ### Web version 83 | 84 | We have published ONNX Simplifier on [convertmodel.com](https://www.convertmodel.com/#input=onnx&output=onnx). It works out of the box and **doesn't need any installation**. Note that it runs in the browser locally and your model is completely safe. 85 | 86 | ### Python version 87 | 88 | 89 | ``` 90 | pip3 install -U pip && pip3 install onnxsim 91 | ``` 92 | 93 | Then 94 | 95 | ``` 96 | onnxsim input_onnx_model output_onnx_model 97 | ``` 98 | 99 | For more advanced features, try the following command for help message 100 | 101 | ``` 102 | onnxsim -h 103 | ``` 104 | 105 | ## Demonstration 106 | 107 | An overall comparison between 108 | [a complicated model](https://github.com/JDAI-CV/DNNLibrary/issues/17#issuecomment-455934190) 109 | and its simplified version: 110 | 111 | ![Comparison between old model and new model](imgs/comparison.png) 112 | 113 | ## In-script workflow 114 | 115 | If you would like to embed ONNX simplifier python package in another script, it is just that simple. 116 | 117 | ```python 118 | import onnx 119 | from onnxsim import simplify 120 | 121 | # load your predefined ONNX model 122 | model = onnx.load(filename) 123 | 124 | # convert model 125 | model_simp, check = simplify(model) 126 | 127 | assert check, "Simplified ONNX model could not be validated" 128 | 129 | # use model_simp as a standard ONNX model object 130 | ``` 131 | 132 | You can see more details of the API in [onnxsim/onnx_simplifier.py](onnxsim/onnx_simplifier.py) 133 | 134 | ## Projects Using ONNX Simplifier 135 | 136 | * [MXNet](https://mxnet.apache.org/versions/1.9.1/api/python/docs/tutorials/deploy/export/onnx.html#Simplify-the-exported-ONNX-model) 137 | * [MMDetection](https://github.com/open-mmlab/mmdetection) 138 | * [YOLOv5](https://github.com/ultralytics/yolov5) 139 | * [ncnn](https://github.com/Tencent/ncnn) 140 | * ... 141 | 142 | ## Chat 143 | 144 | We created a Chinese QQ group for ONNX! 145 | 146 | ONNX QQ Group (Chinese): 1021964010, verification code: nndab. Welcome to join! 147 | 148 | For English users, I'm active on the [ONNX Slack](https://github.com/onnx/onnx#discuss). You can find and chat with me (daquexian) there. 149 | -------------------------------------------------------------------------------- /onnxsim/model_checking.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict, Optional, Union 3 | from collections import OrderedDict 4 | 5 | import onnx 6 | import onnx.checker 7 | import numpy as np 8 | import onnxruntime as rt 9 | 10 | Tensors = Dict[str, np.ndarray] 11 | TensorShape = List[int] 12 | TensorShapes = Dict[Optional[str], TensorShape] 13 | 14 | 15 | def compare( 16 | model_opt: Union[str, onnx.ModelProto], 17 | model_ori: Union[str, onnx.ModelProto], 18 | n_times: int = 5, 19 | input_shapes: Optional[TensorShapes] = None, 20 | input_data: Optional[Tensors] = None, 21 | custom_lib: Optional[str] = None, 22 | verbose=True, 23 | ) -> bool: 24 | """ 25 | :param model_opt: The simplified ONNX model 26 | :param model_ori: The original ONNX model 27 | :param n_times: Generate n random inputs 28 | :param input_shapes: Shapes of generated random inputs 29 | :param input_data: User-given data instead of random generated data 30 | :param custom_lib: ONNX Runtime custom lib for custom ops 31 | """ 32 | 33 | def get_shape_from_value_info_proto(v: onnx.ValueInfoProto) -> List[int]: 34 | return [dim.dim_value for dim in v.type.tensor_type.shape.dim] 35 | 36 | def get_value_info_all( 37 | m: onnx.ModelProto, name: str 38 | ) -> Optional[onnx.ValueInfoProto]: 39 | for v in m.graph.value_info: 40 | if v.name == name: 41 | return v 42 | 43 | for v in m.graph.input: 44 | if v.name == name: 45 | return v 46 | 47 | for v in m.graph.output: 48 | if v.name == name: 49 | return v 50 | 51 | return None 52 | 53 | def get_shape(m: onnx.ModelProto, name: str) -> TensorShape: 54 | """ 55 | Note: This method relies on onnx shape inference, which is not reliable. So only use it on input or output tensors 56 | """ 57 | v = get_value_info_all(m, name) 58 | if v is not None: 59 | return get_shape_from_value_info_proto(v) 60 | raise RuntimeError('Cannot get shape of "{}"'.format(name)) 61 | 62 | def get_elem_type(m: onnx.ModelProto, name: str) -> Optional[int]: 63 | v = get_value_info_all(m, name) 64 | if v is not None: 65 | return v.type.tensor_type.elem_type 66 | return None 67 | 68 | def get_np_type_from_elem_type(elem_type: int) -> int: 69 | sizes = ( 70 | None, 71 | np.float32, 72 | np.uint8, 73 | np.int8, 74 | np.uint16, 75 | np.int16, 76 | np.int32, 77 | np.int64, 78 | str, 79 | bool, 80 | np.float16, 81 | np.double, 82 | np.uint32, 83 | np.uint64, 84 | np.complex64, 85 | np.complex128, 86 | np.float16, 87 | ) 88 | assert len(sizes) == 17 89 | size = sizes[elem_type] 90 | assert size is not None 91 | return size 92 | 93 | def get_input_names(model: onnx.ModelProto) -> List[str]: 94 | input_names = list( 95 | set([ipt.name for ipt in model.graph.input]) 96 | - set([x.name for x in model.graph.initializer]) 97 | ) 98 | return input_names 99 | 100 | def generate_rand_input( 101 | model: Union[str, onnx.ModelProto], 102 | input_shapes: Optional[TensorShapes] = None 103 | ): 104 | if input_shapes is None: 105 | input_shapes = {} 106 | if isinstance(model, str): 107 | model = onnx.load(model, load_external_data=False) 108 | input_names = get_input_names(model) 109 | full_input_shapes = {ipt: get_shape(model, ipt) for ipt in input_names} 110 | assert None not in input_shapes 111 | full_input_shapes.update(input_shapes) # type: ignore 112 | for name, shape in full_input_shapes.items(): 113 | if any([dim <= 0 for dim in shape[1:]]): 114 | raise RuntimeError( 115 | 'The shape of input "{}" has dynamic size, ' 116 | "please set an input shape manually with --test-input-shape".format(name) 117 | ) 118 | if len(shape) > 0 and shape[0] <= 0: 119 | print(f'shape[0] of input "{name}" is dynamic, we assume it presents batch size and set it as 1 when testing. If it is not wanted, please set the it manually by --test-input-shape (see `onnxsim -h` for the details).') 120 | shape[0] = 1 121 | 122 | inputs = { 123 | ipt: np.array( 124 | np.random.rand(*full_input_shapes[ipt]), 125 | dtype=get_np_type_from_elem_type(get_elem_type(model, ipt)), 126 | ) 127 | for ipt in input_names 128 | } 129 | return inputs 130 | 131 | def forward( 132 | model: Union[str, onnx.ModelProto], 133 | inputs: Tensors, 134 | custom_lib: Optional[str] = None 135 | ) -> Dict[str, np.ndarray]: 136 | sess_options = rt.SessionOptions() 137 | if custom_lib is not None: 138 | if os.path.exists(custom_lib): 139 | sess_options.register_custom_ops_library(custom_lib) 140 | else: 141 | raise ValueError("No such file '{}'".format(custom_lib)) 142 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0) 143 | sess_options.log_severity_level = 3 144 | if isinstance(model, onnx.ModelProto): 145 | model = model.SerializeToString() 146 | sess = rt.InferenceSession( 147 | model, 148 | sess_options=sess_options, 149 | providers=["CPUExecutionProvider"], 150 | ) 151 | outputs = [x.name for x in sess.get_outputs()] 152 | run_options = rt.RunOptions() 153 | run_options.log_severity_level = 3 154 | res = OrderedDict( 155 | zip(outputs, sess.run(outputs, inputs, run_options=run_options)) 156 | ) 157 | return res 158 | 159 | if input_shapes is None: 160 | input_shapes = {} 161 | onnx.checker.check_model(model_opt) 162 | for i in range(n_times): 163 | print(f'Checking {i}/{n_times}...') 164 | if input_data is None: 165 | inputs = generate_rand_input(model_opt, input_shapes=input_shapes) 166 | else: 167 | inputs = input_data 168 | res_ori = forward(model_ori, inputs, custom_lib) 169 | res_opt = forward(model_opt, inputs, custom_lib) 170 | 171 | for name in res_opt.keys(): 172 | if not np.allclose(res_opt[name], res_ori[name], rtol=1e-4, atol=1e-5): 173 | if verbose: 174 | print( 175 | "Tensor {} changes after optimization. The max diff is {}.".format( 176 | name, np.max(np.abs(res_opt[name] - res_ori[name])) 177 | ) 178 | ) 179 | print("After optimization:") 180 | print(res_opt[name]) 181 | print("Before optimization:") 182 | print(res_ori[name]) 183 | print("----------------") 184 | return False 185 | return True 186 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | tags: 9 | - v* 10 | pull_request: 11 | workflow_dispatch: 12 | 13 | env: 14 | # This is used to skip tests that are too heavy for CI 15 | ONNXSIM_CI: 1 16 | 17 | jobs: 18 | build_wheels: 19 | env: 20 | CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28 21 | CIBW_MANYLINUX_AARCH64_IMAGE: manylinux_2_28 22 | CIBW_ENVIRONMENT_PASS_LINUX: ONNXSIM_CI 23 | CIBW_BEFORE_ALL_LINUX: WD=`pwd` && /opt/python/cp38-cp38/bin/python -m pip install --target tmp_cmake cmake && cp tmp_cmake/bin/cmake /usr/local/bin/cmake && rm -rf tmp_cmake && /opt/python/cp38-cp38/bin/python -m pip install cmake && cmake --version && whereis cmake 24 | CIBW_BEFORE_ALL_MACOS: WD=`pwd` && pip install cmake 25 | CIBW_TEST_REQUIRES_LINUX: pytest flake8 onnxruntime 26 | CIBW_TEST_REQUIRES_MACOS: pytest onnxruntime 27 | CIBW_TEST_REQUIRES_WINDOWS: pytest onnxruntime 28 | CIBW_BEFORE_TEST_LINUX: pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 29 | CIBW_BEFORE_TEST_MACOS: pip install torch torchvision 30 | CIBW_BEFORE_TEST_WINDOWS: pip install torch torchvision 31 | # Skip universal2 x86_64 tests (as PyTorch no longer supports them) 32 | CIBW_TEST_SKIP: "*_universal2:x86_64" 33 | CIBW_TEST_COMMAND: pytest -v {project}/tests/test_python_api.py 34 | # Skip 32-bit or musl or pypy builds 35 | CIBW_SKIP: "*-win32 *-manylinux_i686 *-musllinux_* pp*" 36 | # Only build on Python 3.10 and above 37 | CIBW_PROJECT_REQUIRES_PYTHON: ">=3.10" 38 | name: Build wheel ${{ matrix.name }} on ${{ matrix.os }} 39 | runs-on: ${{ matrix.os }} 40 | strategy: 41 | fail-fast: false 42 | matrix: 43 | include: 44 | - os: windows-2025 45 | archs: AMD64 46 | name: onnxsim-prebuilt 47 | - os: macos-15 48 | # Only build universal2 package 49 | # Related issue: https://github.com/pypa/cibuildwheel/issues/1190 50 | archs: universal2 51 | name: onnxsim-prebuilt 52 | - os: ubuntu-24.04 53 | archs: x86_64 54 | name: onnxsim-prebuilt 55 | - os: ubuntu-24.04-arm 56 | archs: aarch64 57 | name: onnxsim-prebuilt 58 | steps: 59 | - uses: actions/checkout@v4 60 | with: 61 | submodules: recursive 62 | fetch-depth: 0 63 | - name: Build onnxsim wheels 64 | uses: pypa/cibuildwheel@v2.23.3 65 | env: 66 | CIBW_ARCHS_WINDOWS: ${{ contains(matrix.os, 'windows') && matrix.archs || 'AMD64' }} 67 | CIBW_ARCHS_MACOS: ${{ contains(matrix.os, 'macos') && matrix.archs || 'x86_64' }} 68 | CIBW_ARCHS_LINUX: ${{ contains(matrix.os, 'ubuntu') && matrix.archs || 'x86_64' }} 69 | CIBW_ENVIRONMENT: CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5" ONNXSIM_PKG_NAME=${{ matrix.name }} 70 | CIBW_ENVIRONMENT_WINDOWS: USE_MSVC_STATIC_RUNTIME=0 CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON" ONNXSIM_PKG_NAME=${{ matrix.name }} 71 | CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.15 ONNXSIM_PKG_NAME=${{ matrix.name }} 72 | - uses: actions/upload-artifact@v4 73 | with: 74 | name: python-dist-${{ matrix.os }}-${{ matrix.archs }}-${{ matrix.name }} 75 | path: ./wheelhouse/*.whl 76 | 77 | build_sdist: 78 | name: Build source distribution 79 | runs-on: ubuntu-latest 80 | steps: 81 | - uses: actions/checkout@v4 82 | with: 83 | submodules: recursive 84 | fetch-depth: 0 85 | 86 | - name: Update version 87 | if: startsWith(github.ref, 'refs/tags/v') 88 | run: | 89 | sed -i "s/0.0.0/${GITHUB_REF/refs\/tags\/v/}/" setup.py 90 | 91 | - name: Build sdist 92 | run: | 93 | export ONNXSIM_SDIST=ON 94 | export ONNXSIM_PKG_NAME=onnxsim-prebuilt 95 | pipx run build --sdist 96 | 97 | - name: Install and test sdist 98 | run: | 99 | # It's important to leave the project directory where a 'onnxsim' subdirectory exists 100 | cd dist 101 | python3 -m pip install onnxruntime 102 | python3 -m pip install onnxsim_prebuilt-*.tar.gz 103 | python3 -c "import onnxsim; print(dir(onnxsim))" 104 | python3 -m pip uninstall -y onnxsim-prebuilt 105 | 106 | - uses: actions/upload-artifact@v4 107 | with: 108 | name: python-dist-sdist 109 | path: dist/*.tar.gz 110 | 111 | upload_pypi: 112 | name: Upload to PyPI 113 | needs: [build_wheels, build_sdist] 114 | runs-on: ubuntu-latest 115 | steps: 116 | - uses: actions/download-artifact@v4 117 | with: 118 | # unpacks python-dist artifact into dist/ 119 | # if `name: python-dist` is omitted, the action will create extra parent dir 120 | pattern: python-dist-* 121 | path: dist 122 | merge-multiple: true 123 | - name: Publish distribution 📦 to Test PyPI 124 | if: ${{ github.ref == 'refs/heads/master' }} 125 | uses: pypa/gh-action-pypi-publish@release/v1 126 | with: 127 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 128 | repository-url: https://test.pypi.org/legacy/ 129 | skip-existing: true 130 | - name: Publish distribution 📦 to PyPI 131 | if: startsWith(github.ref, 'refs/tags/v') 132 | uses: pypa/gh-action-pypi-publish@release/v1 133 | with: 134 | password: ${{ secrets.PYPI_API_TOKEN }} 135 | 136 | build_wasm: 137 | name: Build WebAssembly 138 | runs-on: ubuntu-latest 139 | steps: 140 | - uses: actions/checkout@v4 141 | with: 142 | submodules: recursive 143 | - uses: mymindstorm/setup-emsdk@v14 144 | - name: Verify 145 | run: emcc -v 146 | 147 | - run: | 148 | sudo apt update 149 | sudo apt install protobuf-compiler 150 | 151 | - name: Build 152 | run: ./build_wasm.sh 153 | 154 | - uses: actions/upload-artifact@v4 155 | with: 156 | name: wasm 157 | path: | 158 | build-wasm-node-OFF/onnxsim.js 159 | build-wasm-node-OFF/onnxsim.wasm 160 | 161 | - name: Upload to convertmodel.com 162 | # if: ${{ github.ref == 'refs/heads/master' }} 163 | if: false 164 | run: | 165 | wget https://gosspublic.alicdn.com/ossutil/1.7.14/ossutil64 166 | chmod 755 ossutil64 167 | echo "[Credentials]" >> ~/.ossutilconfig 168 | echo "language=EN" >> ~/.ossutilconfig 169 | echo "endpoint=oss-cn-beijing.aliyuncs.com" >> ~/.ossutilconfig 170 | echo "accessKeyID=${{ secrets.OSS_ACCESS_KEY }}" >> ~/.ossutilconfig 171 | echo "accessKeySecret=${{ secrets.OSS_SECRET_KEY }}" >> ~/.ossutilconfig 172 | 173 | gzip -c -9 build-wasm-node-OFF/onnxsim.wasm > onnxsim_gz.wasm 174 | ./ossutil64 --config-file ~/.ossutilconfig cp -u onnxsim_gz.wasm oss://converter-web/onnxsim.wasm --meta=Content-Type:application/wasm#Content-Encoding:gzip 175 | ./ossutil64 --config-file ~/.ossutilconfig cp -u build-wasm-node-OFF/onnxsim.js oss://converter-web/ 176 | 177 | build_wasm_with_noderawfs: 178 | name: Build WebAssembly with NODERAWFS 179 | runs-on: ubuntu-latest 180 | steps: 181 | - uses: actions/checkout@v4 182 | with: 183 | submodules: recursive 184 | - uses: mymindstorm/setup-emsdk@v14 185 | - name: Verify 186 | run: emcc -v 187 | 188 | - run: | 189 | sudo apt update 190 | sudo apt install protobuf-compiler 191 | 192 | - name: Build 193 | run: ./build_wasm.sh ON 194 | 195 | - uses: actions/upload-artifact@v4 196 | with: 197 | name: wasm-with-noderawfs 198 | path: | 199 | build-wasm-node-ON/onnxsim.js 200 | build-wasm-node-ON/onnxsim.wasm 201 | 202 | build_native: 203 | name: Build Native 204 | runs-on: ubuntu-latest 205 | steps: 206 | - uses: actions/checkout@v4 207 | with: 208 | submodules: recursive 209 | 210 | - name: Build 211 | run: | 212 | mkdir build-native 213 | cd build-native 214 | cmake -GNinja .. 215 | ninja onnxsim_bin 216 | 217 | - uses: actions/upload-artifact@v4 218 | with: 219 | name: native 220 | path: | 221 | build-native/onnxsim 222 | -------------------------------------------------------------------------------- /tests/test_python_api.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | import os 3 | import tempfile 4 | 5 | import numpy as np 6 | import torch 7 | import onnx 8 | import onnxsim 9 | import torchvision as tv 10 | import pytest 11 | 12 | 13 | def export_simplify_and_check_by_python_api( 14 | m: torch.nn.Module, 15 | input: Any, 16 | *, 17 | is_model_valid: Optional[Callable[[Any], bool]] = None, 18 | export_kwargs: Optional[Dict[str, Any]] = None, 19 | simplify_kwargs: Optional[Dict[str, Any]] = None, 20 | ) -> onnx.ModelProto: 21 | if is_model_valid is None: 22 | is_model_valid = lambda _: True 23 | if export_kwargs is None: 24 | export_kwargs = {} 25 | if simplify_kwargs is None: 26 | simplify_kwargs = {} 27 | with tempfile.TemporaryDirectory() as tmpdirname: 28 | model_fn = os.path.join(tmpdirname, "tmp.onnx") 29 | torch.onnx.export(m, input, model_fn, **export_kwargs) 30 | model = onnx.load(model_fn) 31 | if not is_model_valid(model): 32 | raise AssertionError(f"model is invalid:\n{model}") 33 | # read the model from filesystem to support >2GB large model 34 | sim_model, check_ok = onnxsim.simplify(model_fn, check_n=3, **simplify_kwargs) 35 | assert check_ok 36 | return sim_model 37 | 38 | 39 | def str_is_logical_positive(x: str) -> bool: 40 | return x.lower() in ["1", "on", "true"] 41 | 42 | 43 | def skip_in_ci(): 44 | return pytest.mark.skipif( 45 | str_is_logical_positive(os.getenv("ONNXSIM_CI", "")), reason="memory limited" 46 | ) 47 | 48 | 49 | def test_just_reshape(): 50 | class JustReshape(torch.nn.Module): 51 | def __init__(self): 52 | super(JustReshape, self).__init__() 53 | 54 | def forward(self, x): 55 | return x.view((x.shape[0], x.shape[1], x.shape[3] * x.shape[2])) 56 | 57 | net = JustReshape() 58 | dummy_input = torch.randn(2, 3, 4, 5) 59 | sim_model = export_simplify_and_check_by_python_api( 60 | net, dummy_input, export_kwargs={"do_constant_folding": False} 61 | ) 62 | assert len(sim_model.graph.node) == 1 63 | 64 | 65 | def test_a_model_not_need_simplification(): 66 | class ModelNotNeedSimplification(torch.nn.Module): 67 | def __init__(self): 68 | super(ModelNotNeedSimplification, self).__init__() 69 | 70 | def forward(self, x): 71 | return x + 1 72 | 73 | net = ModelNotNeedSimplification() 74 | dummy_input = torch.randn(2, 3, 4, 5) 75 | sim_model = export_simplify_and_check_by_python_api(net, dummy_input) 76 | assert len(sim_model.graph.node) == 1 77 | 78 | 79 | def test_exprimental_simplify_subgraph(): 80 | class WithSubGraph(torch.nn.Module): 81 | def __init__(self): 82 | super(WithSubGraph, self).__init__() 83 | 84 | def forward(self, x): 85 | if x.sum() > 1.0: 86 | # NOTE: even onnxsim cannot simplify it, 87 | # a canonical pass in onnx-optimizer is needed for it. 88 | # so this test only tests that include_subgraph doesn't 89 | # result in invalid model in this case 90 | return 3 + x + 3 91 | else: 92 | return x + 4 93 | 94 | net = torch.jit.script(WithSubGraph()) 95 | dummy_input = torch.randn(2) 96 | sim_model = export_simplify_and_check_by_python_api( 97 | net, dummy_input, simplify_kwargs={"include_subgraph": True} 98 | ) 99 | assert len(sim_model.graph.node) == 3 100 | assert len(sim_model.graph.node[2].attribute[0].g.node) == 2 101 | assert len(sim_model.graph.node[2].attribute[1].g.node) == 1 102 | 103 | 104 | def test_dynamic_batch_size(): 105 | class SimpleModel(torch.nn.Module): 106 | def __init__(self): 107 | super(SimpleModel, self).__init__() 108 | 109 | def forward(self, x): 110 | return x + 2 111 | 112 | net = SimpleModel() 113 | dummy_input = torch.randn(2, 3, 4, 5) 114 | sim_model = export_simplify_and_check_by_python_api( 115 | net, 116 | dummy_input, 117 | export_kwargs={ 118 | "input_names": ["input"], 119 | "dynamic_axes": {"input": {0: "batch_size"}}, 120 | }, 121 | simplify_kwargs={"test_input_shapes": {"input": [2, 3, 4, 5]}}, 122 | ) 123 | assert len(sim_model.graph.node) == 1 124 | 125 | 126 | # NOTE: `include_subgraph` makes this test fail 127 | @skip_in_ci() 128 | def test_torchvision_fasterrcnn_fpn(): 129 | model = tv.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) 130 | x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 131 | export_simplify_and_check_by_python_api( 132 | model, x, export_kwargs={"opset_version": 11} 133 | ) 134 | 135 | 136 | # maskrcnn is only supported in opset 11 and higher 137 | @skip_in_ci() 138 | def test_torchvision_maskrcnn_fpn_opset11(): 139 | model = tv.models.detection.maskrcnn_resnet50_fpn(pretrained=False) 140 | x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 141 | export_simplify_and_check_by_python_api( 142 | model, x, export_kwargs={"opset_version": 11} 143 | ) 144 | 145 | 146 | # keypointrcnn is only supported in opset 11 and higher 147 | @skip_in_ci() 148 | def test_torchvision_keypointrcnn_fpn(): 149 | model = tv.models.detection.keypointrcnn_resnet50_fpn(pretrained=False) 150 | x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 151 | export_simplify_and_check_by_python_api( 152 | model, x, export_kwargs={"opset_version": 11} 153 | ) 154 | 155 | 156 | # shufflenet and mnasnet causes segfault in CI (perhaps because of memory limit) 157 | # but works locally 158 | @skip_in_ci() 159 | def test_torchvision_shufflenet_v2(): 160 | model = tv.models.shufflenet_v2_x1_0(pretrained=False) 161 | x = torch.rand(1, 3, 224, 224) 162 | export_simplify_and_check_by_python_api(model, x) 163 | 164 | 165 | @skip_in_ci() 166 | def test_torchvision_mnasnet(): 167 | model = tv.models.mnasnet1_0(pretrained=False) 168 | x = torch.rand(1, 3, 224, 224) 169 | export_simplify_and_check_by_python_api(model, x) 170 | 171 | 172 | @skip_in_ci() 173 | def test_torchvision_deeplabv3(): 174 | model = tv.models.segmentation.deeplabv3_resnet50(pretrained=False) 175 | x = torch.rand(1, 3, 224, 224) 176 | export_simplify_and_check_by_python_api(model, x) 177 | 178 | 179 | def test_unused_output(): 180 | class SimpleModel(torch.nn.Module): 181 | def __init__(self): 182 | super(SimpleModel, self).__init__() 183 | 184 | def forward(self, x): 185 | x1 = x + 2 186 | x1 = x1 - 2 187 | x1 = x1 * 2 188 | x1 = x1 / 2 189 | y1 = x1 190 | x2 = x + 2 191 | x2 = x2 - 2 192 | x2 = x2 * 2 193 | x2 = x2 / 2 194 | y2 = x2 195 | x3 = x + 2 196 | x3 = x3 - 2 197 | x3 = x3 * 2 198 | x3 = x3 / 2 199 | y3 = x3 200 | return y1, y2, y3 201 | 202 | net = SimpleModel() 203 | dummy_input = torch.randn(2, 3, 4, 5) 204 | sim_model = export_simplify_and_check_by_python_api( 205 | net, 206 | dummy_input, 207 | export_kwargs={ 208 | "input_names": ["input"], 209 | "output_names": ["output0", "output1", "output2"], 210 | }, 211 | simplify_kwargs={"unused_output": ["output1", "output2"]}, 212 | ) 213 | assert len(sim_model.graph.node) == 4 214 | 215 | 216 | def test_remove_unused_initializer(): 217 | class SimpleModel(torch.nn.Module): 218 | def __init__(self): 219 | super(SimpleModel, self).__init__() 220 | self.w = torch.nn.Parameter(torch.ones(5, 4)) 221 | 222 | def forward(self, x): 223 | return x + torch.transpose(self.w, 0, 1) 224 | 225 | net = SimpleModel() 226 | dummy_input = torch.randn(2, 3, 4, 5) 227 | sim_model = export_simplify_and_check_by_python_api( 228 | net, 229 | dummy_input, 230 | is_model_valid=lambda model: any( 231 | node.op_type == "Transpose" for node in model.graph.node 232 | ), 233 | export_kwargs={"do_constant_folding": False}, 234 | ) 235 | assert len(sim_model.graph.node) == 1 236 | assert len(sim_model.graph.initializer) == 1 237 | 238 | 239 | @skip_in_ci() 240 | def test_model_larger_than_2gb(): 241 | class SimpleModel(torch.nn.Module): 242 | def __init__(self): 243 | super(SimpleModel, self).__init__() 244 | # a parameter is 500MB 245 | self.w1 = torch.nn.Parameter(torch.ones(125 * 1024 * 1024)) 246 | self.w2 = torch.nn.Parameter(torch.ones(125 * 1024 * 1024)) 247 | self.w3 = torch.nn.Parameter(torch.ones(125 * 1024 * 1024)) 248 | self.w4 = torch.nn.Parameter(torch.ones(125 * 1024 * 1024)) 249 | self.w5 = torch.nn.Parameter(torch.ones(125 * 1024 * 1024)) 250 | 251 | def forward(self, x): 252 | return x + (self.w1 + self.w2 + self.w3 + self.w4 + self.w5) 253 | 254 | net = SimpleModel() 255 | dummy_input = torch.randn(125 * 1024 * 1024) 256 | sim_model = export_simplify_and_check_by_python_api( 257 | net, 258 | dummy_input, 259 | is_model_valid=lambda model: sum( 260 | node.op_type == "Add" for node in model.graph.node 261 | ) 262 | == 5, 263 | export_kwargs={"do_constant_folding": False}, 264 | ) 265 | assert len(sim_model.graph.node) == 1 266 | assert sim_model.graph.node[0].op_type == "Add" 267 | 268 | 269 | def test_unset_optional_input(): 270 | fmap = [] 271 | nodes = [] 272 | initializers = [] 273 | 274 | fmap.append(onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, shape=(1,3,4,4))) 275 | 276 | X = np.random.rand(1,3,2,2).astype(np.float32) 277 | initializers.append(onnx.helper.make_tensor('X', onnx.TensorProto.FLOAT, X.shape, X.copy().tobytes(), raw=True)) 278 | sizes = np.asarray([1,3,4,4]).astype(np.int64) 279 | initializers.append(onnx.helper.make_tensor('sizes', onnx.TensorProto.INT64, sizes.shape, sizes.copy().tobytes(), raw=True)) 280 | 281 | nodes.append(onnx.helper.make_node( 282 | 'Resize', 283 | inputs=['X', '', '', 'sizes'], 284 | outputs=['y'], 285 | mode='linear')) 286 | 287 | graph_def = onnx.helper.make_graph( 288 | nodes, 289 | 'test_unset_optional_input', 290 | [], 291 | [fmap[-1]], 292 | value_info=fmap, 293 | initializer=initializers 294 | ) 295 | 296 | opset_imports = [onnx.helper.make_opsetid("", 14)] 297 | 298 | model = onnx.helper.make_model(graph_def, opset_imports=opset_imports, ir_version=10) 299 | sim_model, check_ok = onnxsim.simplify(model, check_n=3) 300 | assert check_ok 301 | assert len(model.graph.node) == 1 302 | assert len(model.graph.initializer) == 2 303 | assert len(sim_model.graph.node) == 0 304 | assert len(sim_model.graph.initializer) == 1 305 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2019] [daquexian] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.spawn import find_executable 2 | from distutils import sysconfig, log 3 | import setuptools 4 | import setuptools.command.build_py 5 | import setuptools.command.develop 6 | import setuptools.command.build_ext 7 | 8 | from collections import namedtuple 9 | from contextlib import contextmanager 10 | from pathlib import Path 11 | import os 12 | import shlex 13 | import subprocess 14 | import sys 15 | import platform 16 | from textwrap import dedent 17 | import multiprocessing 18 | import re 19 | 20 | 21 | TOP_DIR = os.path.realpath(os.path.dirname(__file__)) 22 | SRC_DIR = os.path.join(TOP_DIR, 'onnxsim') 23 | CMAKE_BUILD_DIR = os.path.join(TOP_DIR, '.setuptools-cmake-build') 24 | 25 | WINDOWS = (os.name == 'nt') 26 | MACOS = sys.platform.startswith("darwin") 27 | 28 | CMAKE = find_executable('cmake') 29 | 30 | install_requires = [] 31 | setup_requires = [] 32 | 33 | USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '0') == '1') 34 | ONNX_ML = not bool(os.getenv('ONNX_ML') == '0') 35 | ONNX_VERIFY_PROTO3 = bool(os.getenv('ONNX_VERIFY_PROTO3') == '1') 36 | ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'onnx') 37 | ONNX_BUILD_TESTS = bool(os.getenv('ONNX_BUILD_TESTS') == '1') 38 | ONNX_OPT_USE_SYSTEM_PROTOBUF = bool(os.getenv('ONNX_OPT_USE_SYSTEM_PROTOBUF', '0') == '1') 39 | 40 | DEBUG = bool(os.getenv('DEBUG')) 41 | COVERAGE = bool(os.getenv('COVERAGE')) 42 | 43 | try: 44 | version = subprocess.check_output(['git', 'describe', '--tags', '--abbrev=0'], 45 | cwd=TOP_DIR).decode('ascii').strip() 46 | if version[0] == 'v': 47 | version = version[1:] 48 | except (OSError, subprocess.CalledProcessError): 49 | with open(os.path.join(TOP_DIR, 'VERSION')) as ver_file: 50 | version = ver_file.read().strip() 51 | 52 | try: 53 | git_version = subprocess.check_output(['git', 'rev-parse', 'HEAD'], 54 | cwd=TOP_DIR).decode('ascii').strip() 55 | except (OSError, subprocess.CalledProcessError): 56 | git_version = None 57 | 58 | if os.getenv('ONNXSIM_SDIST') is not None: 59 | version = '0.0.0' 60 | git_version = None 61 | 62 | VersionInfo = namedtuple('VersionInfo', ['version', 'git_version'])( 63 | version=version, 64 | git_version=git_version 65 | ) 66 | 67 | assert CMAKE, 'Could not find "cmake" executable!' 68 | 69 | @contextmanager 70 | def cd(path): 71 | if not os.path.isabs(path): 72 | raise RuntimeError('Can only cd to absolute path, got: {}'.format(path)) 73 | orig_path = os.getcwd() 74 | os.chdir(path) 75 | try: 76 | yield 77 | finally: 78 | os.chdir(orig_path) 79 | 80 | 81 | class ONNXCommand(setuptools.Command): 82 | user_options = [] 83 | 84 | def initialize_options(self): 85 | pass 86 | 87 | def finalize_options(self): 88 | pass 89 | 90 | 91 | class create_version(ONNXCommand): 92 | def run(self): 93 | with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f: 94 | f.write(dedent('''\ 95 | # This file is generated by setup.py. DO NOT EDIT! 96 | 97 | version = '{version}' 98 | git_version = '{git_version}' 99 | '''.format(**dict(VersionInfo._asdict())))) 100 | 101 | 102 | class cmake_build(setuptools.Command): 103 | """ 104 | Compiles everything when `python setupmnm.py build` is run using cmake. 105 | 106 | Custom args can be passed to cmake by specifying the `CMAKE_ARGS` 107 | environment variable. 108 | 109 | The number of CPUs used by `make` can be specified by passing `-j` 110 | to `setup.py build`. By default all CPUs are used. 111 | """ 112 | user_options = [ 113 | (str('jobs='), str('j'), str('Specifies the number of jobs to use with make')) 114 | ] 115 | 116 | built = False 117 | 118 | def initialize_options(self): 119 | self.jobs = None 120 | 121 | def finalize_options(self): 122 | self.set_undefined_options('build', ('parallel', 'jobs')) 123 | if self.jobs is None and os.getenv("MAX_JOBS") is not None: 124 | self.jobs = os.getenv("MAX_JOBS") 125 | self.jobs = multiprocessing.cpu_count() if self.jobs is None else int(self.jobs) 126 | 127 | def run(self): 128 | if cmake_build.built: 129 | return 130 | cmake_build.built = True 131 | if not os.path.exists(CMAKE_BUILD_DIR): 132 | os.makedirs(CMAKE_BUILD_DIR) 133 | 134 | with cd(CMAKE_BUILD_DIR): 135 | build_type = 'Release' 136 | # configure 137 | cmake_args = [ 138 | CMAKE, 139 | '-DPython_INCLUDE_DIR={}'.format(sysconfig.get_python_inc()), 140 | '-DPython_EXECUTABLE={}'.format(sys.executable), 141 | # For pybind11 142 | '-DPYTHON_EXECUTABLE={}'.format(sys.executable), 143 | '-DBUILD_ONNX_PYTHON=OFF', 144 | '-DONNXSIM_PYTHON=ON', 145 | '-DONNXSIM_BUILTIN_ORT=OFF', 146 | '-DONNX_USE_LITE_PROTO=OFF', 147 | '-DCMAKE_EXPORT_COMPILE_COMMANDS=ON', 148 | '-DONNX_NAMESPACE={}'.format(ONNX_NAMESPACE), 149 | '-DPY_EXT_SUFFIX={}'.format( 150 | sysconfig.get_config_var('EXT_SUFFIX') or ''), 151 | '-DONNX_OPT_USE_SYSTEM_PROTOBUF={}'.format( 152 | 'ON' if ONNX_OPT_USE_SYSTEM_PROTOBUF else 'OFF'), 153 | ] 154 | if COVERAGE: 155 | cmake_args.append('-DONNX_COVERAGE=ON') 156 | if COVERAGE or DEBUG: 157 | # in order to get accurate coverage information, the 158 | # build needs to turn off optimizations 159 | build_type = 'Debug' 160 | cmake_args.append('-DCMAKE_BUILD_TYPE=%s' % build_type) 161 | if WINDOWS: 162 | cmake_args.extend([ 163 | # we need to link with libpython on windows, so 164 | # passing python version to window in order to 165 | # find python in cmake 166 | '-DPY_VERSION={}'.format('{0}.{1}'.format(* \ 167 | sys.version_info[:2])), 168 | ]) 169 | if USE_MSVC_STATIC_RUNTIME: 170 | cmake_args.append('-DONNX_USE_MSVC_STATIC_RUNTIME=ON') 171 | if platform.architecture()[0] == '64bit': 172 | cmake_args.extend(['-A', 'x64', '-T', 'host=x64']) 173 | else: 174 | cmake_args.extend(['-A', 'Win32', '-T', 'host=x86']) 175 | if MACOS: 176 | # Cross-compile support for macOS - respect ARCHFLAGS if set 177 | archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) 178 | if archs: 179 | cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] 180 | if ONNX_ML: 181 | cmake_args.append('-DONNX_ML=1') 182 | if ONNX_VERIFY_PROTO3: 183 | cmake_args.append('-DONNX_VERIFY_PROTO3=1') 184 | if ONNX_BUILD_TESTS: 185 | cmake_args.append('-DONNX_BUILD_TESTS=ON') 186 | if 'CMAKE_ARGS' in os.environ: 187 | extra_cmake_args = shlex.split(os.environ['CMAKE_ARGS']) 188 | # prevent crossfire with downstream scripts 189 | del os.environ['CMAKE_ARGS'] 190 | log.info('Extra cmake args: {}'.format(extra_cmake_args)) 191 | cmake_args.extend(extra_cmake_args) 192 | cmake_args.append(TOP_DIR) 193 | print(f"Run command {cmake_args}") 194 | subprocess.check_call(cmake_args) 195 | 196 | build_args = [CMAKE, '--build', os.curdir, '--target onnxsim_cpp2py_export'] 197 | if WINDOWS: 198 | build_args.extend(['--config', build_type]) 199 | build_args.extend(['--', '/maxcpucount:{}'.format(self.jobs)]) 200 | else: 201 | build_args.extend(['--', '-j', str(self.jobs)]) 202 | print(f"Run command {build_args}") 203 | subprocess.check_call(build_args) 204 | 205 | 206 | class build_py(setuptools.command.build_py.build_py): 207 | def run(self): 208 | self.run_command('create_version') 209 | return setuptools.command.build_py.build_py.run(self) 210 | 211 | 212 | class develop(setuptools.command.develop.develop): 213 | def run(self): 214 | self.run_command('build_py') 215 | setuptools.command.develop.develop.run(self) 216 | 217 | 218 | class build_ext(setuptools.command.build_ext.build_ext): 219 | def run(self): 220 | self.run_command('cmake_build') 221 | setuptools.command.build_ext.build_ext.run(self) 222 | 223 | def build_extensions(self): 224 | for ext in self.extensions: 225 | fullname = self.get_ext_fullname(ext.name) 226 | filename = os.path.basename(self.get_ext_filename(fullname)) 227 | 228 | lib_path = CMAKE_BUILD_DIR 229 | if os.name == 'nt': 230 | debug_lib_dir = os.path.join(lib_path, "Debug") 231 | release_lib_dir = os.path.join(lib_path, "Release") 232 | if os.path.exists(debug_lib_dir): 233 | lib_path = debug_lib_dir 234 | elif os.path.exists(release_lib_dir): 235 | lib_path = release_lib_dir 236 | src = os.path.join(lib_path, filename) 237 | dst_dir = os.path.join(os.path.realpath( 238 | self.build_lib), "onnxsim") 239 | dst = os.path.join(dst_dir, filename) 240 | os.makedirs(dst_dir, exist_ok=True) 241 | self.copy_file(src, dst) 242 | 243 | 244 | cmdclass = { 245 | 'create_version': create_version, 246 | 'cmake_build': cmake_build, 247 | 'build_ext': build_ext, 248 | 'build_py': build_py, 249 | 'develop': develop, 250 | } 251 | 252 | ext_modules = [ 253 | setuptools.Extension( 254 | name=str('onnxsim.onnxsim_cpp2py_export'), 255 | sources=[]) 256 | ] 257 | 258 | # no need to do fancy stuff so far 259 | packages = setuptools.find_packages() 260 | 261 | # Though we depend on onnxruntime, it has three different packages: 262 | # onnxruntime, onnxruntime-gpu and onnxruntime-noopenmp. 263 | # The solution is, we publish two packages, a wheel named onnxsim-no-ort 264 | # and a sdist package named onnxsim, onnxsim depends on onnxsim-no-ort, 265 | # and also check if one of onnxruntime packages is installed, and depends 266 | # on onnxruntime when no existing installed packages. 267 | install_requires.extend([ 268 | 'onnx', 269 | 'rich', 270 | ]) 271 | 272 | setup_requires.append('pytest-runner') 273 | 274 | # read the contents of your README file 275 | this_directory = Path(__file__).parent 276 | long_description = (this_directory / "README.md").read_text() 277 | 278 | setuptools.setup( 279 | name=os.getenv("ONNXSIM_PKG_NAME", "onnxsim"), 280 | version=VersionInfo.version, 281 | description='Simplify your ONNX model', 282 | ext_modules=ext_modules, 283 | cmdclass=cmdclass, 284 | packages=packages, 285 | license='Apache License v2.0', 286 | include_package_data=True, 287 | install_requires=install_requires, 288 | setup_requires=setup_requires, 289 | author='ONNX Simplifier Authors', 290 | author_email='daquexian566@gmail.com', 291 | url='https://github.com/daquexian/onnx-simplifier', 292 | keywords='deep-learning ONNX', 293 | long_description=long_description, 294 | long_description_content_type='text/markdown', 295 | classifiers=[ 296 | 'Development Status :: 4 - Beta', 297 | 'Intended Audience :: Developers', 298 | 'License :: OSI Approved :: Apache Software License', 299 | 'Programming Language :: Python :: 3 :: Only', 300 | 'Programming Language :: Python :: 3.7', 301 | 'Programming Language :: Python :: 3.8', 302 | 'Programming Language :: Python :: 3.9', 303 | 'Programming Language :: Python :: 3.10', 304 | 'Programming Language :: Python :: 3.11', 305 | 'Programming Language :: Python :: 3.12', 306 | 'Programming Language :: Python :: 3.13', 307 | 'Topic :: Scientific/Engineering', 308 | 'Topic :: Software Development' 309 | ], 310 | python_requires='>=3.7', 311 | entry_points={ 312 | 'console_scripts': [ 313 | 'onnxsim=onnxsim:main', 314 | ], 315 | }, 316 | ) 317 | -------------------------------------------------------------------------------- /onnxsim/onnxsim.cpp: -------------------------------------------------------------------------------- 1 | #include "onnxsim.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #ifndef NO_BUILTIN_ORT 12 | #include "../third_party/onnxruntime/include/onnxruntime/core/framework/endian.h" 13 | #include "../third_party/onnxruntime/include/onnxruntime/core/session/onnxruntime_cxx_api.h" 14 | #endif 15 | #include "onnx/common/file_utils.h" 16 | #include "onnx/shape_inference/implementation.h" 17 | #include "onnxoptimizer/model_util.h" 18 | #include "onnxoptimizer/optimize.h" 19 | 20 | struct Config { 21 | std::vector optimizer_passes; 22 | // default value is max 23 | size_t tensor_size_threshold = -1; 24 | }; 25 | 26 | Config config; 27 | 28 | std::shared_ptr ModelExecutor::instance_ = nullptr; 29 | 30 | bool IsOfficialOp(const std::string& domain, const std::string& op) { 31 | if (domain != "ai.onnx" && domain != "ai.onnx.ml" && !domain.empty()) { 32 | return false; 33 | } 34 | // these experimental ops were in onnx default domain but are no 35 | // longer supported by onnx now. 36 | static std::set experimental_ops = {"ATen", 37 | "Affine", 38 | "ConstantFill", 39 | "Crop", 40 | "DynamicSlice", 41 | "GRUUnit", 42 | "GivenTensorFill", 43 | "ImageScaler", 44 | "ParametricSoftplus", 45 | "Scale", 46 | "ScaledTanh"}; 47 | return experimental_ops.find(op) == experimental_ops.end(); 48 | } 49 | 50 | bool IsDeterministic(const std::string& domain, const std::string& op) { 51 | // Copy from onnxruntime/core/optimizer/utils.cc 52 | constexpr std::array kOnnxDomainNonDeterministicOps{ 53 | "RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", 54 | "Multinomial"}; 55 | if (domain == "ai.onnx" || domain == "ai.onnx.ml" || domain.empty()) { 56 | auto iter = std::find(kOnnxDomainNonDeterministicOps.begin(), 57 | kOnnxDomainNonDeterministicOps.end(), op); 58 | return iter == kOnnxDomainNonDeterministicOps.end(); 59 | } 60 | // Unknown domain. Assume the op is not deterministic. 61 | return false; 62 | } 63 | 64 | bool IsQDQ(const std::string& domain, const std::string& op) { 65 | if (domain == "ai.onnx" || domain.empty()) { 66 | return op == "QuantizeLinear" || op == "DequantizeLinear"; 67 | } 68 | return false; 69 | } 70 | 71 | auto FindInitializerByName(const onnx::ModelProto& model, 72 | const std::string& name) { 73 | for (const auto& initializer : model.graph().initializer()) { 74 | if (initializer.name() == name) { 75 | return initializer; 76 | } 77 | } 78 | throw std::invalid_argument("no initializer " + name); 79 | } 80 | 81 | auto FindValueInfoProtoByName(const onnx::ModelProto& model, 82 | const std::string& name) { 83 | for (const auto& vi : model.graph().value_info()) { 84 | if (vi.name() == name) { 85 | return vi; 86 | } 87 | } 88 | for (const auto& initializer : model.graph().initializer()) { 89 | if (initializer.name() == name) { 90 | onnx::ValueInfoProto vi; 91 | for (const auto& dim : initializer.dims()) { 92 | vi.mutable_type() 93 | ->mutable_tensor_type() 94 | ->mutable_shape() 95 | ->add_dim() 96 | ->set_dim_value(dim); 97 | } 98 | vi.mutable_type()->mutable_tensor_type()->set_elem_type( 99 | initializer.data_type()); 100 | vi.set_name(name); 101 | return vi; 102 | } 103 | } 104 | throw std::invalid_argument("no value info " + name); 105 | } 106 | 107 | #ifndef NO_BUILTIN_ORT 108 | onnx::TensorProto TensorToTensorProto(const Ort::Value& tensor) { 109 | onnx::TensorProto tensor_proto; 110 | for (const auto& dim : tensor.GetTensorTypeAndShapeInfo().GetShape()) { 111 | tensor_proto.add_dims(dim); 112 | } 113 | onnx::TensorProto::DataType onnx_dtype = 114 | (onnx::TensorProto::DataType)tensor.GetTensorTypeAndShapeInfo() 115 | .GetElementType(); 116 | tensor_proto.set_data_type(onnx_dtype); 117 | 118 | switch (onnx_dtype) { 119 | #define CASE_DTYPE(onnx_dtype, storage_dtype, cpp_type) \ 120 | case onnx::TensorProto::onnx_dtype: { \ 121 | const auto* dptr = tensor.GetTensorData(); \ 122 | for (size_t i = 0; \ 123 | i < tensor.GetTensorTypeAndShapeInfo().GetElementCount(); i++) { \ 124 | tensor_proto.add_##storage_dtype##_data(dptr[i]); \ 125 | } \ 126 | break; \ 127 | } 128 | 129 | CASE_DTYPE(FLOAT, float, float) 130 | CASE_DTYPE(DOUBLE, double, double) 131 | CASE_DTYPE(INT64, int64, int64_t) 132 | CASE_DTYPE(UINT64, uint64, uint64_t) 133 | CASE_DTYPE(INT32, int32, int32_t) 134 | CASE_DTYPE(UINT8, int32, uint8_t) 135 | CASE_DTYPE(INT8, int32, int8_t) 136 | CASE_DTYPE(UINT16, int32, uint16_t) 137 | CASE_DTYPE(INT16, int32, int16_t) 138 | CASE_DTYPE(BOOL, int32, int8_t) 139 | #undef CASE_DTYPE 140 | default: 141 | throw std::invalid_argument("Unknown dtype " + 142 | std::to_string(tensor_proto.data_type())); 143 | } 144 | return tensor_proto; 145 | } 146 | 147 | Ort::Value TensorProtoToTensor(const onnx::TensorProto& tensor_proto) { 148 | Ort::AllocatorWithDefaultOptions allocator; 149 | auto tensor = Ort::Value::CreateTensor( 150 | allocator, tensor_proto.dims().data(), tensor_proto.dims_size(), 151 | (ONNXTensorElementDataType)tensor_proto.data_type()); 152 | if (tensor_proto.has_raw_data()) { 153 | if (onnxruntime::endian::native == onnxruntime::endian::big) { 154 | throw std::invalid_argument("only little endian is supported"); 155 | } 156 | memcpy(tensor.GetTensorMutableData(), tensor_proto.raw_data().data(), 157 | tensor_proto.raw_data().size()); 158 | } else { 159 | switch (tensor_proto.data_type()) { 160 | #define CASE_DTYPE(onnx_dtype, storage_dtype, cpp_type) \ 161 | case onnx::TensorProto::onnx_dtype: { \ 162 | std::vector vec; \ 163 | for (const auto& x : tensor_proto.storage_dtype##_data()) { \ 164 | vec.push_back(x); \ 165 | } \ 166 | memcpy(tensor.GetTensorMutableData(), vec.data(), \ 167 | vec.size() * sizeof(cpp_type)); \ 168 | break; \ 169 | } 170 | CASE_DTYPE(FLOAT, float, float) 171 | CASE_DTYPE(DOUBLE, double, double) 172 | CASE_DTYPE(INT64, int64, int64_t) 173 | CASE_DTYPE(UINT64, uint64, uint64_t) 174 | CASE_DTYPE(INT32, int32, int32_t) 175 | CASE_DTYPE(UINT8, int32, uint8_t) 176 | CASE_DTYPE(INT8, int32, int8_t) 177 | CASE_DTYPE(UINT16, int32, uint16_t) 178 | CASE_DTYPE(INT16, int32, int16_t) 179 | CASE_DTYPE(BOOL, int32, int8_t) 180 | #undef CASE_DTYPE 181 | default: 182 | throw std::invalid_argument("Unknown dtype " + 183 | std::to_string(tensor_proto.data_type())); 184 | } 185 | } 186 | return tensor; 187 | } 188 | 189 | std::shared_ptr GetEnv() { 190 | static std::shared_ptr env = std::make_shared(); 191 | return env; 192 | } 193 | 194 | struct CppModelExecutor : public ModelExecutor { 195 | std::vector _Run( 196 | const onnx::ModelProto& model, 197 | const std::vector& inputs) const override { 198 | std::vector input_name_ptrs; 199 | std::vector output_name_ptrs; 200 | std::transform( 201 | model.graph().input().begin(), model.graph().input().end(), 202 | std::back_inserter(input_name_ptrs), 203 | [](const onnx::ValueInfoProto& x) { return x.name().c_str(); }); 204 | std::transform( 205 | model.graph().output().begin(), model.graph().output().end(), 206 | std::back_inserter(output_name_ptrs), 207 | [](const onnx::ValueInfoProto& x) { return x.name().c_str(); }); 208 | Ort::SessionOptions sess_opts; 209 | sess_opts.SetLogSeverityLevel(3); 210 | sess_opts.SetGraphOptimizationLevel(ORT_DISABLE_ALL); 211 | std::string model_str = model.SerializeAsString(); 212 | Ort::Session session(*GetEnv(), model_str.data(), model_str.size(), 213 | sess_opts); 214 | Ort::RunOptions run_opts; 215 | run_opts.SetRunLogSeverityLevel(3); 216 | std::vector input_tensors; 217 | std::transform(inputs.begin(), inputs.end(), 218 | std::back_inserter(input_tensors), TensorProtoToTensor); 219 | auto output_tensors = session.Run( 220 | run_opts, input_name_ptrs.data(), input_tensors.data(), 221 | input_tensors.size(), output_name_ptrs.data(), output_name_ptrs.size()); 222 | 223 | std::vector output_tps; 224 | std::transform(output_tensors.begin(), output_tensors.end(), 225 | std::back_inserter(output_tps), TensorToTensorProto); 226 | return output_tps; 227 | } 228 | }; 229 | 230 | static int __register_cpp_model_executor __attribute__((unused)) = []() { 231 | ModelExecutor::set_instance(std::make_shared()); 232 | return 0; 233 | }(); 234 | 235 | void InitEnv() { GetEnv(); } 236 | #else 237 | void InitEnv() { 238 | // do nothing 239 | } 240 | #endif 241 | 242 | std::vector RunOp(onnx::ModelProto& model, 243 | const onnx::NodeProto& op) { 244 | std::vector input_names; 245 | std::vector input_tps; 246 | std::set initializer_names; 247 | 248 | onnx::ModelProto op_model; 249 | op_model.set_ir_version(model.ir_version()); 250 | for (const auto& x : model.opset_import()) { 251 | *op_model.add_opset_import() = x; 252 | } 253 | *op_model.mutable_graph()->add_node() = op; 254 | 255 | for (const auto& input : op.input()) { 256 | if (std::find(input_names.begin(), input_names.end(), input) != 257 | input_names.end()) { 258 | continue; 259 | } 260 | // skip "" which represents the unset optional input 261 | if (input.empty()) { 262 | continue; 263 | } 264 | if (initializer_names.find(input) != initializer_names.end()) { 265 | continue; 266 | } 267 | auto in_tp = FindInitializerByName(model, input); 268 | if (in_tp.dims().size() == 1 && in_tp.dims()[0] == 0) { 269 | initializer_names.insert(input); 270 | *op_model.mutable_graph()->add_initializer() = in_tp; 271 | continue; 272 | } 273 | input_names.push_back(input); 274 | input_tps.push_back(in_tp); 275 | } 276 | 277 | for (const auto& x : input_names) { 278 | // skip "" which represents the unset optional input 279 | if (x.empty()) { 280 | continue; 281 | } 282 | *op_model.mutable_graph()->add_input() = FindValueInfoProtoByName(model, x); 283 | } 284 | for (const auto& x : op.output()) { 285 | onnx::ValueInfoProto vi; 286 | // In principle output ValueInfoProto must have type. But it is not checked. 287 | vi.set_name(x); 288 | *op_model.mutable_graph()->add_output() = vi; 289 | } 290 | 291 | auto output_tps = ModelExecutor::Run(op_model, input_tps); 292 | for (size_t i = 0; i < op.output_size(); i++) { 293 | output_tps[i].set_name(op.output(i)); 294 | } 295 | return output_tps; 296 | } 297 | 298 | void RunOpAndAddInitializer(onnx::ModelProto& model, 299 | const onnx::NodeProto& op) { 300 | const auto output_tps = RunOp(model, op); 301 | for (const auto& output_tp : output_tps) { 302 | *model.mutable_graph()->add_initializer() = output_tp; 303 | } 304 | } 305 | 306 | bool HasSubgraph(const onnx::NodeProto& node) { 307 | for (const auto& attr : node.attribute()) { 308 | if (attr.type() == onnx::AttributeProto::GRAPH || 309 | attr.type() == onnx::AttributeProto::GRAPHS) { 310 | return true; 311 | } 312 | } 313 | return false; 314 | } 315 | 316 | size_t size_of_dtype(onnx::TensorProto::DataType dtype) { 317 | switch (dtype) { 318 | case onnx::TensorProto::DataType::TensorProto_DataType_BOOL: 319 | case onnx::TensorProto::DataType::TensorProto_DataType_INT8: 320 | case onnx::TensorProto::DataType::TensorProto_DataType_UINT8: 321 | return 1; 322 | case onnx::TensorProto::DataType::TensorProto_DataType_BFLOAT16: 323 | case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16: 324 | case onnx::TensorProto::DataType::TensorProto_DataType_INT16: 325 | case onnx::TensorProto::DataType::TensorProto_DataType_UINT16: 326 | return 2; 327 | case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: 328 | case onnx::TensorProto::DataType::TensorProto_DataType_INT32: 329 | case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: 330 | return 4; 331 | case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: 332 | case onnx::TensorProto::DataType::TensorProto_DataType_INT64: 333 | case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: 334 | case onnx::TensorProto::DataType::TensorProto_DataType_COMPLEX64: 335 | return 8; 336 | case onnx::TensorProto::DataType::TensorProto_DataType_COMPLEX128: 337 | return 16; 338 | // Don't know the size of string.. Just return 16. 339 | case onnx::TensorProto::DataType::TensorProto_DataType_STRING: 340 | return 16; 341 | case onnx::TensorProto::DataType::TensorProto_DataType_UNDEFINED: 342 | throw std::invalid_argument("Undefined datatype"); 343 | } 344 | throw std::invalid_argument("Unknown datatype " + std::to_string(dtype)); 345 | } 346 | 347 | bool ProduceLargeTensor(const onnx::ModelProto& model, 348 | const onnx::NodeProto& node, size_t threshold) { 349 | std::set large_tensor_ops{"Tile", "ConstantOfShape", "Expand"}; 350 | if (large_tensor_ops.find(node.op_type()) == large_tensor_ops.end()) { 351 | return false; 352 | } 353 | for (const auto& value_info : model.graph().value_info()) { 354 | if (value_info.name() == node.output(0)) { 355 | size_t size = size_of_dtype(static_cast( 356 | value_info.type().tensor_type().elem_type())); 357 | for (const auto& dim : value_info.type().tensor_type().shape().dim()) { 358 | size *= dim.dim_value(); 359 | } 360 | if (size <= threshold) { 361 | return false; 362 | } 363 | } 364 | } 365 | // If the output is not in value_info, we assume it is large. 366 | // There is a possibility that value_info is presented by the shape inference 367 | // later and `ProduceLargeTensor` is called again and returns false at that 368 | // time. 369 | return true; 370 | } 371 | 372 | std::pair, std::vector> 373 | GetConstantNodes(const onnx::ModelProto& model) { 374 | // tensor with empty name("") represents the empty value of an optional input 375 | // so "" should be treated as a name of a constant tensor. 376 | std::vector const_names{""}; 377 | std::vector const_nodes; 378 | std::vector non_const_nodes; 379 | std::transform( 380 | model.graph().initializer().begin(), model.graph().initializer().end(), 381 | std::back_inserter(const_names), [](const auto& x) { return x.name(); }); 382 | // node is already topo sorted 383 | for (const auto& node : model.graph().node()) { 384 | // clang-format off 385 | if (IsOfficialOp(node.domain(), node.op_type()) && 386 | IsDeterministic(node.domain(), node.op_type()) && 387 | !IsQDQ(node.domain(), node.op_type()) && 388 | !HasSubgraph(node) && 389 | !ProduceLargeTensor(model, node, config.tensor_size_threshold) && 390 | // clang-format on 391 | std::all_of(node.input().begin(), node.input().end(), 392 | [&const_names](const auto& x) { 393 | return std::find(const_names.begin(), const_names.end(), 394 | x) != const_names.end(); 395 | })) { 396 | const_names.insert(const_names.end(), node.output().begin(), 397 | node.output().end()); 398 | const_nodes.push_back(node); 399 | } else { 400 | non_const_nodes.push_back(node); 401 | } 402 | } 403 | return {const_nodes, non_const_nodes}; 404 | } 405 | 406 | onnx::ModelProto _InferShapes(const onnx::ModelProto& model) { 407 | onnx::ModelProto result; 408 | result.CopyFrom(model); 409 | onnx::shape_inference::InferShapes(result); 410 | return result; 411 | } 412 | 413 | onnx::ModelProto _FoldConstant(const onnx::ModelProto& model) { 414 | const auto& tmp = model; 415 | { 416 | onnx::ModelProto model; 417 | model.CopyFrom(tmp); 418 | auto [const_nodes, non_const_nodes] = GetConstantNodes(model); 419 | for (const auto& x : const_nodes) { 420 | try { 421 | RunOpAndAddInitializer(model, x); 422 | } catch (const std::exception& e) { 423 | std::cerr << "WARNING: failed to run \"" << x.op_type() << 424 | "\" op (name is \"" << x.name() << "\"), skip..." << std::endl; 425 | non_const_nodes.push_back(x); 426 | } 427 | } 428 | model.mutable_graph()->clear_node(); 429 | for (const auto& x : non_const_nodes) { 430 | *model.mutable_graph()->add_node() = x; 431 | } 432 | return model; 433 | } 434 | } 435 | 436 | onnx::ModelProto Optimize(const onnx::ModelProto& model) { 437 | return onnx::optimization::OptimizeFixed(model, config.optimizer_passes); 438 | } 439 | 440 | template 441 | std::function FixedPointFn(const std::function& f1, 442 | const std::function& f2, 443 | size_t max_iters, bool* converged) { 444 | return [f1, f2, max_iters, converged](const T& x) { 445 | size_t _max_iters = max_iters; 446 | T tmp1 = f1(x); 447 | T tmp2 = f2(tmp1); 448 | T& y1 = tmp1; 449 | T& y2 = tmp2; 450 | while (_max_iters-- > 0) { 451 | if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) { 452 | if (converged) { 453 | *converged = true; 454 | } 455 | return y2; 456 | } 457 | y1 = f1(y2); 458 | if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) { 459 | if (converged) { 460 | *converged = true; 461 | } 462 | return y1; 463 | } 464 | y2 = f2(y1); 465 | } 466 | 467 | if (converged) { 468 | *converged = false; 469 | } 470 | return y2; 471 | }; 472 | } 473 | 474 | template 475 | std::function FixedPointFn(const std::function& f1, 476 | const std::function& f2, 477 | size_t max_iters) { 478 | return FixedPointFn(f1, f2, max_iters, nullptr); 479 | } 480 | 481 | onnx::ModelProto Identity(const onnx::ModelProto& model) { return model; } 482 | 483 | void Check(const onnx::ModelProto& model) { onnx::checker::check_model(model); } 484 | 485 | onnx::ModelProto Simplify( 486 | const onnx::ModelProto& model, 487 | std::optional> skip_optimizers, 488 | bool constant_folding, bool shape_inference, size_t tensor_size_threshold) { 489 | Check(model); 490 | 491 | config.tensor_size_threshold = tensor_size_threshold; 492 | config.optimizer_passes.clear(); 493 | // skip_optimizers == nullopt means skiping all optimizers, so 494 | // config.optimizer_passes is empty 495 | if (skip_optimizers) { 496 | std::vector passes; 497 | const auto all_passes = onnx::optimization::GetFuseAndEliminationPass(); 498 | for (const auto& pass : all_passes) { 499 | if (std::find(skip_optimizers->begin(), skip_optimizers->end(), pass) == 500 | skip_optimizers->end()) { 501 | passes.push_back(pass); 502 | } 503 | } 504 | config.optimizer_passes = passes; 505 | } 506 | 507 | auto FoldConstant = constant_folding ? _FoldConstant : Identity; 508 | auto InferShapes = shape_inference ? _InferShapes : Identity; 509 | 510 | int fixed_point_iters = 511 | std::getenv("ONNXSIM_FIXED_POINT_ITERS") 512 | ? std::atoi(std::getenv("ONNXSIM_FIXED_POINT_ITERS")) 513 | : 50; 514 | 515 | auto OptAndShape = FixedPointFn(std::function{InferShapes}, 516 | std::function{Optimize}, fixed_point_iters); 517 | bool converged = false; 518 | auto OptAndShapeAndFold = 519 | FixedPointFn(std::function{OptAndShape}, std::function{FoldConstant}, 520 | fixed_point_iters, &converged); 521 | auto sim_model = OptAndShapeAndFold(model); 522 | Check(sim_model); 523 | if (!converged) { 524 | std::cout << "WARNING: the simplification stopped because of timeout. " 525 | "Please set environment variable `ONNXSIM_FIXED_POINT_ITERS` " 526 | "to a number higher than " 527 | << fixed_point_iters << "if you want further simplification." 528 | << std::endl; 529 | } 530 | return sim_model; 531 | } 532 | 533 | void SimplifyPath(const std::string& in_path, const std::string& out_path, 534 | std::optional> skip_optimizers, 535 | bool constant_folding, bool shape_inference, 536 | size_t tensor_size_threshold) { 537 | onnx::ModelProto model; 538 | onnx::optimization::loadModel(&model, in_path, true); 539 | 540 | model = Simplify(model, skip_optimizers, constant_folding, shape_inference, 541 | tensor_size_threshold); 542 | 543 | onnx::optimization::saveModel(&model, out_path, true, ""); 544 | } 545 | -------------------------------------------------------------------------------- /onnxsim/onnx_simplifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import copy 4 | import os 5 | import sys 6 | import re 7 | import tempfile 8 | from typing import List, Dict, Union, Optional, Tuple, Sequence 9 | from rich.text import Text 10 | from rich import print 11 | import numpy as np 12 | 13 | import onnx # type: ignore 14 | import onnx.checker # type: ignore 15 | import onnx.helper # type: ignore 16 | import onnx.shape_inference # type: ignore 17 | import onnx.numpy_helper # type: ignore 18 | try: 19 | import onnxruntime as rt # type: ignore 20 | except ImportError: 21 | command = [sys.executable, '-m', 'pip', 'install', 'onnxruntime'] 22 | print(Text(f"Installing onnxruntime by `{' '.join(command)}`, please wait for a moment..", style="bold magenta")) 23 | import subprocess 24 | subprocess.check_call(command) 25 | import onnxruntime as rt 26 | 27 | 28 | import onnxsim.onnxsim_cpp2py_export as C 29 | from . import model_info 30 | from . import model_checking 31 | from . import version 32 | 33 | 34 | TensorShape = List[int] 35 | TensorShapes = Dict[str, TensorShape] 36 | TensorShapesWithOptionalKey = Dict[Optional[str], TensorShape] 37 | 38 | 39 | def get_output_names(model: onnx.ModelProto) -> List[str]: 40 | output_names = [opt.name for opt in model.graph.output] 41 | return output_names 42 | 43 | 44 | def remove_unused_output( 45 | model: onnx.ModelProto, unused_output: Sequence[str] 46 | ) -> onnx.ModelProto: 47 | unused_output_names = unused_output 48 | output_names = get_output_names(model) 49 | for unused_output_name in unused_output_names: 50 | if unused_output_name not in output_names: 51 | raise RuntimeError( 52 | f'The model doesn\'t have output named "{unused_output_name}"' 53 | ) 54 | for graph_output in copy.deepcopy(model.graph.output): 55 | if graph_output.name in unused_output_names: 56 | model.graph.output.remove(graph_output) 57 | return model 58 | 59 | 60 | def remove_initializer_from_input(model: onnx.ModelProto) -> onnx.ModelProto: 61 | initializer_names = [x.name for x in model.graph.initializer] 62 | for graph_input in copy.deepcopy(model.graph.input): 63 | if graph_input.name in initializer_names: 64 | model.graph.input.remove(graph_input) 65 | return model 66 | 67 | 68 | def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: Optional[TensorShapesWithOptionalKey]) -> Optional[TensorShapes]: 69 | if input_shapes is None: 70 | return None 71 | 72 | def get_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]: 73 | initializer_names = [x.name for x in model.graph.initializer] 74 | return [ipt for ipt in model.graph.input if ipt.name not in initializer_names] 75 | 76 | def get_input_names(model: onnx.ModelProto) -> List[str]: 77 | input_names = [ipt.name for ipt in get_inputs(model)] 78 | return input_names 79 | 80 | input_names = get_input_names(model) 81 | if None in input_shapes: 82 | if len(input_names) == 1: 83 | input_shapes[input_names[0]] = input_shapes[None] 84 | del input_shapes[None] 85 | else: 86 | raise RuntimeError( 87 | 'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape') 88 | for x in input_shapes: 89 | if x not in input_names: 90 | raise RuntimeError( 91 | 'The model doesn\'t have input named "{}"'.format(x)) 92 | 93 | return input_shapes # type: ignore 94 | 95 | 96 | # A very very large threshold 97 | DEFAULT_TENSOR_SIZE_THRESHOLDHOLD = '1.5GB' 98 | 99 | 100 | def simplify( 101 | model: Union[str, onnx.ModelProto], 102 | check_n: int = 0, 103 | perform_optimization: bool = True, 104 | skip_fuse_bn: bool = False, 105 | overwrite_input_shapes=None, 106 | test_input_shapes=None, 107 | skipped_optimizers: Optional[List[str]] = None, 108 | skip_constant_folding=False, 109 | skip_shape_inference=False, 110 | input_data=None, 111 | dynamic_input_shape: bool = False, 112 | custom_lib: Optional[str] = None, 113 | include_subgraph: bool = False, 114 | unused_output: Optional[Sequence[str]] = None, 115 | tensor_size_threshold: str = DEFAULT_TENSOR_SIZE_THRESHOLDHOLD, 116 | mutable_initializer: bool = False, 117 | *, 118 | input_shapes=None, 119 | ) -> Tuple[onnx.ModelProto, bool]: 120 | """ 121 | :param model: onnx ModelProto object or file path 122 | :param check_n: The simplified model will be checked for `check_n` times by random inputs 123 | :param perform_optimization: Whether to run onnx optimizer on the model 124 | :param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer 125 | :param overwrite_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape 126 | for generating random inputs and checking equality. 127 | :param test_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape 128 | for generating random inputs and checking equality. 129 | :param skipped_optimizers: Skip some specific onnx optimizers 130 | :param skip_constant_folding: Skip constant folding 131 | :param skip_shape_inference: Skip shape inference (sometimes shape inference will crash) 132 | :param input_data: Feed custom input data for checking if needed 133 | :param dynamic_input_shape: Deprecated. Not needed anymore. 134 | :param custom_lib: onnxruntime custom ops's shared library 135 | :param include_subgraph: Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph 136 | :param unused_output: name of unused outputs that will be eliminated from the model 137 | :param input_shapes: Deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead. 138 | :return: A tuple (simplified model, success(True) or failed(False)) 139 | """ 140 | if dynamic_input_shape: 141 | print( 142 | Text( 143 | "WARNING: The argument `dynamic_input_shape=True` is not needed any more, onnxsim can now support dynamic input shapes natively, please refer to the latest documentation. An error will be raised in the future.", 144 | style="bold red", 145 | ) 146 | ) 147 | if input_shapes is not None: 148 | print( 149 | Text( 150 | "WARNING: The argument `input_shapes` is deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead. An error will be raised in the future.", 151 | style="bold red", 152 | ) 153 | ) 154 | overwrite_input_shapes = input_shapes 155 | test_input_shapes = input_shapes 156 | 157 | if not perform_optimization: 158 | # None means skip all optimizers 159 | skipped_optimizers = None 160 | elif skipped_optimizers is None: 161 | skipped_optimizers = [] 162 | 163 | if skip_fuse_bn and skipped_optimizers is not None: 164 | skipped_optimizers.append("fuse_bn_into_conv") 165 | if isinstance(model, str): 166 | model = onnx.load(model) 167 | if overwrite_input_shapes is None: 168 | overwrite_input_shapes = {} 169 | overwrite_input_shapes = check_and_update_input_shapes( 170 | model, overwrite_input_shapes) 171 | test_input_shapes = check_and_update_input_shapes( 172 | model, test_input_shapes) 173 | 174 | for name, input_shape in overwrite_input_shapes.items(): 175 | for ipt in model.graph.input: 176 | if ipt.name == name: 177 | for i, dim in enumerate(ipt.type.tensor_type.shape.dim): 178 | dim.dim_value = input_shape[i] 179 | if unused_output is not None: 180 | model = remove_unused_output(model, unused_output) 181 | if not mutable_initializer and model.ir_version >= 4: 182 | model = remove_initializer_from_input(model) 183 | 184 | # https://stackoverflow.com/a/60708339 185 | def parse_size(size: str) -> int: 186 | units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40} 187 | size = size.upper() 188 | if not re.match(r' ', size): 189 | size = re.sub(r'([KMGT]?B)', r' \1', size) 190 | number, unit = [string.strip() for string in size.split()] 191 | return int(float(number)*units[unit]) 192 | 193 | tensor_size_threshold = parse_size(tensor_size_threshold) 194 | if tensor_size_threshold > 2**31 - 9999: 195 | raise ValueError("tensor_size_threshold should be less than 2GB") 196 | 197 | try: 198 | model_bytes = model.SerializeToString() 199 | model_opt_bytes = C.simplify( 200 | model_bytes, 201 | skipped_optimizers, 202 | not skip_constant_folding, 203 | not skip_shape_inference, 204 | tensor_size_threshold, 205 | ) 206 | if len(model_opt_bytes) == 0: 207 | raise ValueError("Simplified model larger than 2GB") 208 | model_opt = onnx.load_from_string(model_opt_bytes) 209 | check_ok = model_checking.compare( 210 | model_opt, model, check_n, test_input_shapes, input_data, custom_lib 211 | ) 212 | except (ValueError, onnx.onnx_cpp2py_export.checker.ValidationError): 213 | print("[bold magenta]Simplified model larger than 2GB. Trying to save as external data...[/bold magenta]") 214 | # large models try to convert through a temporary file 215 | with tempfile.TemporaryDirectory() as tmpdirname: 216 | onnx.save( 217 | copy.deepcopy(model), 218 | os.path.join(tmpdirname, 'model.onnx'), 219 | save_as_external_data=True, 220 | ) 221 | check_ok = C.simplify_path( 222 | os.path.join(tmpdirname, 'model.onnx'), 223 | os.path.join(tmpdirname, 'opt.onnx'), 224 | skipped_optimizers, 225 | not skip_constant_folding, 226 | not skip_shape_inference, 227 | tensor_size_threshold, 228 | ) 229 | check_ok = model_checking.compare( 230 | os.path.join(tmpdirname, 'opt.onnx'), 231 | os.path.join(tmpdirname, 'model.onnx'), 232 | check_n, test_input_shapes, input_data, custom_lib 233 | ) 234 | model_opt = onnx.load(os.path.join(tmpdirname, 'opt.onnx')) 235 | return model_opt, check_ok 236 | 237 | 238 | class PyModelExecutor(C.ModelExecutor): 239 | def Run(self, model_str: str, inputs_str: List[str]): 240 | model = onnx.ModelProto() 241 | model.ParseFromString(model_str) 242 | 243 | def deserialize_tp(tp_str): 244 | tp = onnx.TensorProto() 245 | tp.ParseFromString(tp_str) 246 | return tp 247 | 248 | input_tps = map(deserialize_tp, inputs_str) 249 | input_arrs = map(onnx.numpy_helper.to_array, input_tps) 250 | input_names = [x.name for x in model.graph.input] 251 | inputs = dict(zip(input_names, input_arrs)) 252 | sess_options = rt.SessionOptions() 253 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0) 254 | sess_options.log_severity_level = 3 255 | sess = rt.InferenceSession( 256 | model.SerializeToString(), 257 | sess_options=sess_options, 258 | providers=["CPUExecutionProvider"], 259 | ) 260 | output_names = [x.name for x in sess.get_outputs()] 261 | run_options = rt.RunOptions() 262 | run_options.log_severity_level = 3 263 | output_arrs = sess.run(output_names, inputs, run_options=run_options) 264 | return [ 265 | onnx.numpy_helper.from_array(x).SerializeToString() for x in output_arrs 266 | ] 267 | 268 | 269 | def main(): 270 | parser = argparse.ArgumentParser() 271 | parser.add_argument("input_model", help="Input ONNX model") 272 | parser.add_argument("output_model", help="Output ONNX model") 273 | parser.add_argument( 274 | "check_n", 275 | help="Check whether the output is correct with n random inputs", 276 | nargs="?", 277 | type=int, 278 | default=0, 279 | ) 280 | parser.add_argument( 281 | "--enable-fuse-bn", 282 | help="This option is deprecated. Fusing bn into conv is enabled by default.", 283 | action="store_true", 284 | ) 285 | parser.add_argument( 286 | "--skip-fuse-bn", help="Skip fusing batchnorm into conv.", action="store_true" 287 | ) 288 | parser.add_argument( 289 | "--skip-optimization", 290 | help="Skip all ONNX optimizers or some of them. To skip all optimizers, use `onnxsim a.onnx b.onnx --skip-optimization`. To skip some of optimizers, use something like `onnxsim a.onnx b.onnx --skip-optimization fuse_bn_into_conv fuse_pad_into_pool`.", 291 | type=str, 292 | nargs="*", 293 | ) 294 | parser.add_argument("--skip-constant-folding", help="Skip constant folding", action="store_true") 295 | parser.add_argument( 296 | "--input-shape", 297 | help="This argument has been renamed to --overwrite-input-shape, please refer to it", 298 | type=str, 299 | nargs="+", 300 | ) 301 | parser.add_argument( 302 | "--overwrite-input-shape", 303 | help='Overwrite the input shape. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.', 304 | type=str, 305 | nargs="+", 306 | ) 307 | parser.add_argument( 308 | "--test-input-shape", 309 | help='The input shape to generated random inputs for test, useful when the input shape is dynamic. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.', 310 | type=str, 311 | nargs="+", 312 | ) 313 | parser.add_argument( 314 | "--skip-optimizer", 315 | help="Deprecated. Refer to --skip-optimization", 316 | type=str, 317 | nargs="+", 318 | ) 319 | parser.add_argument( 320 | "--skip-shape-inference", help="Skip shape inference", action="store_true" 321 | ) 322 | parser.add_argument( 323 | "--enable-onnxruntime-optimization", 324 | help="Enable ONNX Runtime's ORT_ENABLE_BASIC level optimization.", 325 | action="store_true", 326 | ) 327 | parser.add_argument( 328 | "--dynamic-input-shape", 329 | help="Deprecated. Not needed any more.", 330 | action="store_true", 331 | ) 332 | parser.add_argument( 333 | "--input-data-path", 334 | help='input data, The value should be "input_name1:xxx1.bin" "input_name2:xxx2.bin ...", input data should be a binary data file.', 335 | type=str, 336 | nargs="+", 337 | ) 338 | parser.add_argument( 339 | "--custom-lib", help="Deprecated. Not needed any more.", type=str 340 | ) 341 | parser.add_argument( 342 | "--include-subgraph", 343 | help='Experimental feature. Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph', 344 | action="store_true", 345 | ) 346 | parser.add_argument( 347 | "--unused-output", 348 | help="Name of unused outputs that will be eliminated from the model", 349 | type=str, 350 | nargs="+", 351 | ) 352 | parser.add_argument( 353 | "--no-large-tensor", 354 | help="Some ops like Tile and ConstantOfShape can produce large tensor and make the model size much larger. Specifying this flag to skip folding these ops, with loss of some optimization chances. It can be followed with a threshold, for example, --no-large-tensor 1M or --no-large-tensor 100KB. A simple '--no-large-tensor' means '--no-large-tensor 1KB'.", 355 | type=str, 356 | const='1KB', 357 | default=DEFAULT_TENSOR_SIZE_THRESHOLDHOLD, 358 | nargs="?", 359 | dest="tensor_size_threshold", 360 | ) 361 | parser.add_argument( 362 | "--mutable-initializer", 363 | help="By ONNX specification, initializers can also serve as inputs. This allows users to overwrite their values during runtime, but some useful optimizations like fuse-conv-and-bn will not be applicable anymore. In almost all cases, having an initializer that is also an input is unintended (usually caused by a out-dated PyTorch). So onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, you can specify '--mutable-initializer' to disable this behavior.", 364 | action="store_true", 365 | ) 366 | parser.add_argument( 367 | "--save-as-external-data", 368 | help="Save parameters as external data. This will make the .onnx file much smaller, but the .onnx file will depend on the external data file (.data).", 369 | action="store_true", 370 | ) 371 | parser.add_argument('-v', '--version', action='version', version='onnxsim ' + version.version) 372 | 373 | args = parser.parse_args() 374 | 375 | if args.enable_fuse_bn: 376 | print( 377 | Text( 378 | 'WARNING: "--enable-fuse-bn" is not needed any more, because fuse bn is enabled by default. "--enable-fuse-bn" flag is ignored now and will raise an error in the future.', 379 | style="bold red", 380 | ) 381 | ) 382 | if args.dynamic_input_shape: 383 | print( 384 | Text( 385 | 'WARNING: "--dynamic-input-shape" is not needed any more, onnxsim v0.4 now handles dynamic input shapes automatically. "--dynamic-input-shape" flag is ignored now and will raise an error in the future.', 386 | style="bold red", 387 | ) 388 | ) 389 | assert not (args.input_shape is not None and args.overwrite_input_shape is not None) 390 | if args.input_shape: 391 | print( 392 | Text( 393 | 'WARNING: "--input-shape" is renamed to "--overwrite-input-shape". Please use it instead.', 394 | style="bold red", 395 | ) 396 | ) 397 | args.overwrite_input_shape = args.input_shape 398 | if args.include_subgraph: 399 | print( 400 | Text( 401 | "WARNING: subgraph optimization is not supported in v0.4 for now.", 402 | style="bold red", 403 | ) 404 | ) 405 | assert not (args.skip_optimizer is not None and args.skip_optimization is not None) 406 | if args.skip_optimizer: 407 | print( 408 | Text( 409 | 'WARNING: "--skip-optimizer" is renamed to "--skip-optimization". Please use it instead.', 410 | style="bold red", 411 | ) 412 | ) 413 | args.skip_optimization = args.skip_optimizer 414 | if args.skip_optimization is None: 415 | # user doesn't specify --skip-optimization 416 | args.skip_optimization = [] 417 | elif len(args.skip_optimization) == 0: 418 | # user specify --skip-optimization without any certain optimizer name 419 | # set it to None means skip all optimizations 420 | args.skip_optimization = None 421 | if args.skip_fuse_bn and args.skip_optimization is not None: 422 | args.skip_optimization.append("fuse_bn_into_conv") 423 | 424 | perform_optimization = False if args.skip_optimization is None else True 425 | 426 | def parse_shapes(shapes_arg): 427 | shapes = {} 428 | if shapes_arg is not None: 429 | for x in shapes_arg: 430 | if ':' not in x: 431 | shapes[None] = list(map(int, x.split(','))) 432 | else: 433 | pieces = x.split(':') 434 | # for the input name like input:0 435 | name, shape = ':'.join( 436 | pieces[:-1]), list(map(int, pieces[-1].split(','))) 437 | shapes.update({name: shape}) 438 | return shapes 439 | 440 | test_input_shapes = parse_shapes(args.test_input_shape) 441 | overwrite_input_shapes = parse_shapes(args.overwrite_input_shape) 442 | 443 | if args.enable_onnxruntime_optimization: 444 | 445 | tmp_file = tempfile.NamedTemporaryFile() 446 | sess_options = rt.SessionOptions() 447 | # Set graph optimization level 448 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_BASIC 449 | # To enable model serialization after graph optimization 450 | sess_options.optimized_model_filepath = tmp_file.name 451 | _ = rt.InferenceSession(args.input_model, sess_options, providers=["CPUExecutionProvider"]) 452 | 453 | model = onnx.load(tmp_file.name) 454 | else: 455 | model = onnx.load(args.input_model) 456 | 457 | if args.tensor_size_threshold == DEFAULT_TENSOR_SIZE_THRESHOLDHOLD: 458 | for node in model.graph.node: 459 | if node.op_type in ["Tile", "ConstantOfShape"]: 460 | print( 461 | Text( 462 | 'Your model contains "Tile" ops or/and "ConstantOfShape" ops. Folding these ops can make the simplified model much larger. If it is not expected, please specify "--no-large-tensor" (which will lose some optimization chances)', 463 | style="bold magenta", 464 | ) 465 | ) 466 | break 467 | 468 | if not args.mutable_initializer: 469 | initializer_names = set([x.name for x in model.graph.initializer]) 470 | input_names = set([x.name for x in model.graph.input]) 471 | if len(initializer_names.intersection(input_names)) > 0: 472 | print( 473 | Text( 474 | 'Your model contains initializers that are also inputs. This is usually caused by an out-dated PyTorch. onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, please specify "--mutable-initializer" to disable this behavior.', 475 | style="bold magenta", 476 | ) 477 | ) 478 | 479 | input_tensors = None 480 | if args.input_data_path is not None: 481 | input_tensors = {} 482 | for x in args.input_data_path: 483 | pieces = x.split(':') 484 | name, data = ':'.join(pieces[:-1]), pieces[-1] 485 | input_tensors.update({name: np.load(data)}) 486 | 487 | print("Simplifying...") 488 | 489 | model_opt, check_ok = simplify( 490 | model, 491 | args.check_n, 492 | perform_optimization, 493 | False, 494 | overwrite_input_shapes, 495 | test_input_shapes, 496 | args.skip_optimization, 497 | args.skip_constant_folding, 498 | args.skip_shape_inference, 499 | input_tensors, 500 | False, 501 | args.custom_lib, 502 | args.include_subgraph, 503 | args.unused_output, 504 | args.tensor_size_threshold, 505 | args.mutable_initializer, 506 | ) 507 | 508 | try: 509 | if not args.save_as_external_data: 510 | onnx.save(model_opt, args.output_model) 511 | else: 512 | raise ValueError("save_as_external_data") 513 | except ValueError: 514 | # large models (>2GB) which onnx.save doesn't support, 515 | # or explicitly specified --save-as-external-data 516 | external_data_path = os.path.basename(args.output_model) + '.data' 517 | if os.path.exists(external_data_path): 518 | os.remove(external_data_path) 519 | onnx.save( 520 | copy.deepcopy(model_opt), 521 | args.output_model, 522 | save_as_external_data=True, 523 | all_tensors_to_one_file=True, 524 | location=external_data_path, 525 | ) 526 | 527 | if check_ok: 528 | print("Finish! Here is the difference:") 529 | model_info.print_simplifying_info(model, model_opt) 530 | else: 531 | print( 532 | 'Check failed. Please be careful to use the simplified model, or try specifying "--skip-fuse-bn" or "--skip-optimization" (run "onnxsim -h" for details).' 533 | ) 534 | print("Here is the difference after simplification:") 535 | model_info.print_simplifying_info(model, model_opt) 536 | sys.exit(1) 537 | -------------------------------------------------------------------------------- /onnxsim/cxxopts.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2014, 2015, 2016, 2017 Jarryd Beck 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in 10 | all copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 17 | THE SOFTWARE. 18 | */ 19 | 20 | #ifndef CXXOPTS_HPP_INCLUDED 21 | #define CXXOPTS_HPP_INCLUDED 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | #if defined(__GNUC__) && !defined(__clang__) 40 | # if (__GNUC__ * 10 + __GNUC_MINOR__) < 49 41 | # define CXXOPTS_NO_REGEX true 42 | # endif 43 | #endif 44 | 45 | #ifndef CXXOPTS_NO_REGEX 46 | # include 47 | #endif // CXXOPTS_NO_REGEX 48 | 49 | // Nonstandard before C++17, which is coincidentally what we also need for 50 | #ifdef __has_include 51 | # if __has_include() 52 | # include 53 | # ifdef __cpp_lib_optional 54 | # define CXXOPTS_HAS_OPTIONAL 55 | # endif 56 | # endif 57 | #endif 58 | 59 | #if __cplusplus >= 201603L 60 | #define CXXOPTS_NODISCARD [[nodiscard]] 61 | #else 62 | #define CXXOPTS_NODISCARD 63 | #endif 64 | 65 | #ifndef CXXOPTS_VECTOR_DELIMITER 66 | #define CXXOPTS_VECTOR_DELIMITER ',' 67 | #endif 68 | 69 | #define CXXOPTS__VERSION_MAJOR 3 70 | #define CXXOPTS__VERSION_MINOR 0 71 | #define CXXOPTS__VERSION_PATCH 0 72 | 73 | #if (__GNUC__ < 10 || (__GNUC__ == 10 && __GNUC_MINOR__ < 1)) && __GNUC__ >= 6 74 | #define CXXOPTS_NULL_DEREF_IGNORE 75 | #endif 76 | 77 | namespace cxxopts 78 | { 79 | static constexpr struct { 80 | uint8_t major, minor, patch; 81 | } version = { 82 | CXXOPTS__VERSION_MAJOR, 83 | CXXOPTS__VERSION_MINOR, 84 | CXXOPTS__VERSION_PATCH 85 | }; 86 | } // namespace cxxopts 87 | 88 | //when we ask cxxopts to use Unicode, help strings are processed using ICU, 89 | //which results in the correct lengths being computed for strings when they 90 | //are formatted for the help output 91 | //it is necessary to make sure that can be found by the 92 | //compiler, and that icu-uc is linked in to the binary. 93 | 94 | #ifdef CXXOPTS_USE_UNICODE 95 | #include 96 | 97 | namespace cxxopts 98 | { 99 | using String = icu::UnicodeString; 100 | 101 | inline 102 | String 103 | toLocalString(std::string s) 104 | { 105 | return icu::UnicodeString::fromUTF8(std::move(s)); 106 | } 107 | 108 | #if defined(__GNUC__) 109 | // GNU GCC with -Weffc++ will issue a warning regarding the upcoming class, we want to silence it: 110 | // warning: base class 'class std::enable_shared_from_this' has accessible non-virtual destructor 111 | #pragma GCC diagnostic push 112 | #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" 113 | #pragma GCC diagnostic ignored "-Weffc++" 114 | // This will be ignored under other compilers like LLVM clang. 115 | #endif 116 | class UnicodeStringIterator : public 117 | std::iterator 118 | { 119 | public: 120 | 121 | UnicodeStringIterator(const icu::UnicodeString* string, int32_t pos) 122 | : s(string) 123 | , i(pos) 124 | { 125 | } 126 | 127 | value_type 128 | operator*() const 129 | { 130 | return s->char32At(i); 131 | } 132 | 133 | bool 134 | operator==(const UnicodeStringIterator& rhs) const 135 | { 136 | return s == rhs.s && i == rhs.i; 137 | } 138 | 139 | bool 140 | operator!=(const UnicodeStringIterator& rhs) const 141 | { 142 | return !(*this == rhs); 143 | } 144 | 145 | UnicodeStringIterator& 146 | operator++() 147 | { 148 | ++i; 149 | return *this; 150 | } 151 | 152 | UnicodeStringIterator 153 | operator+(int32_t v) 154 | { 155 | return UnicodeStringIterator(s, i + v); 156 | } 157 | 158 | private: 159 | const icu::UnicodeString* s; 160 | int32_t i; 161 | }; 162 | #if defined(__GNUC__) 163 | #pragma GCC diagnostic pop 164 | #endif 165 | 166 | inline 167 | String& 168 | stringAppend(String&s, String a) 169 | { 170 | return s.append(std::move(a)); 171 | } 172 | 173 | inline 174 | String& 175 | stringAppend(String& s, size_t n, UChar32 c) 176 | { 177 | for (size_t i = 0; i != n; ++i) 178 | { 179 | s.append(c); 180 | } 181 | 182 | return s; 183 | } 184 | 185 | template 186 | String& 187 | stringAppend(String& s, Iterator begin, Iterator end) 188 | { 189 | while (begin != end) 190 | { 191 | s.append(*begin); 192 | ++begin; 193 | } 194 | 195 | return s; 196 | } 197 | 198 | inline 199 | size_t 200 | stringLength(const String& s) 201 | { 202 | return s.length(); 203 | } 204 | 205 | inline 206 | std::string 207 | toUTF8String(const String& s) 208 | { 209 | std::string result; 210 | s.toUTF8String(result); 211 | 212 | return result; 213 | } 214 | 215 | inline 216 | bool 217 | empty(const String& s) 218 | { 219 | return s.isEmpty(); 220 | } 221 | } 222 | 223 | namespace std 224 | { 225 | inline 226 | cxxopts::UnicodeStringIterator 227 | begin(const icu::UnicodeString& s) 228 | { 229 | return cxxopts::UnicodeStringIterator(&s, 0); 230 | } 231 | 232 | inline 233 | cxxopts::UnicodeStringIterator 234 | end(const icu::UnicodeString& s) 235 | { 236 | return cxxopts::UnicodeStringIterator(&s, s.length()); 237 | } 238 | } 239 | 240 | //ifdef CXXOPTS_USE_UNICODE 241 | #else 242 | 243 | namespace cxxopts 244 | { 245 | using String = std::string; 246 | 247 | template 248 | T 249 | toLocalString(T&& t) 250 | { 251 | return std::forward(t); 252 | } 253 | 254 | inline 255 | size_t 256 | stringLength(const String& s) 257 | { 258 | return s.length(); 259 | } 260 | 261 | inline 262 | String& 263 | stringAppend(String&s, const String& a) 264 | { 265 | return s.append(a); 266 | } 267 | 268 | inline 269 | String& 270 | stringAppend(String& s, size_t n, char c) 271 | { 272 | return s.append(n, c); 273 | } 274 | 275 | template 276 | String& 277 | stringAppend(String& s, Iterator begin, Iterator end) 278 | { 279 | return s.append(begin, end); 280 | } 281 | 282 | template 283 | std::string 284 | toUTF8String(T&& t) 285 | { 286 | return std::forward(t); 287 | } 288 | 289 | inline 290 | bool 291 | empty(const std::string& s) 292 | { 293 | return s.empty(); 294 | } 295 | } // namespace cxxopts 296 | 297 | //ifdef CXXOPTS_USE_UNICODE 298 | #endif 299 | 300 | namespace cxxopts 301 | { 302 | namespace 303 | { 304 | #ifdef _WIN32 305 | const std::string LQUOTE("\'"); 306 | const std::string RQUOTE("\'"); 307 | #else 308 | const std::string LQUOTE("‘"); 309 | const std::string RQUOTE("’"); 310 | #endif 311 | } // namespace 312 | 313 | #if defined(__GNUC__) 314 | // GNU GCC with -Weffc++ will issue a warning regarding the upcoming class, we want to silence it: 315 | // warning: base class 'class std::enable_shared_from_this' has accessible non-virtual destructor 316 | #pragma GCC diagnostic push 317 | #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" 318 | #pragma GCC diagnostic ignored "-Weffc++" 319 | // This will be ignored under other compilers like LLVM clang. 320 | #endif 321 | class Value : public std::enable_shared_from_this 322 | { 323 | public: 324 | 325 | virtual ~Value() = default; 326 | 327 | virtual 328 | std::shared_ptr 329 | clone() const = 0; 330 | 331 | virtual void 332 | parse(const std::string& text) const = 0; 333 | 334 | virtual void 335 | parse() const = 0; 336 | 337 | virtual bool 338 | has_default() const = 0; 339 | 340 | virtual bool 341 | is_container() const = 0; 342 | 343 | virtual bool 344 | has_implicit() const = 0; 345 | 346 | virtual std::string 347 | get_default_value() const = 0; 348 | 349 | virtual std::string 350 | get_implicit_value() const = 0; 351 | 352 | virtual std::shared_ptr 353 | default_value(const std::string& value) = 0; 354 | 355 | virtual std::shared_ptr 356 | implicit_value(const std::string& value) = 0; 357 | 358 | virtual std::shared_ptr 359 | no_implicit_value() = 0; 360 | 361 | virtual bool 362 | is_boolean() const = 0; 363 | }; 364 | #if defined(__GNUC__) 365 | #pragma GCC diagnostic pop 366 | #endif 367 | class OptionException : public std::exception 368 | { 369 | public: 370 | explicit OptionException(std::string message) 371 | : m_message(std::move(message)) 372 | { 373 | } 374 | 375 | CXXOPTS_NODISCARD 376 | const char* 377 | what() const noexcept override 378 | { 379 | return m_message.c_str(); 380 | } 381 | 382 | private: 383 | std::string m_message; 384 | }; 385 | 386 | class OptionSpecException : public OptionException 387 | { 388 | public: 389 | 390 | explicit OptionSpecException(const std::string& message) 391 | : OptionException(message) 392 | { 393 | } 394 | }; 395 | 396 | class OptionParseException : public OptionException 397 | { 398 | public: 399 | explicit OptionParseException(const std::string& message) 400 | : OptionException(message) 401 | { 402 | } 403 | }; 404 | 405 | class option_exists_error : public OptionSpecException 406 | { 407 | public: 408 | explicit option_exists_error(const std::string& option) 409 | : OptionSpecException("Option " + LQUOTE + option + RQUOTE + " already exists") 410 | { 411 | } 412 | }; 413 | 414 | class invalid_option_format_error : public OptionSpecException 415 | { 416 | public: 417 | explicit invalid_option_format_error(const std::string& format) 418 | : OptionSpecException("Invalid option format " + LQUOTE + format + RQUOTE) 419 | { 420 | } 421 | }; 422 | 423 | class option_syntax_exception : public OptionParseException { 424 | public: 425 | explicit option_syntax_exception(const std::string& text) 426 | : OptionParseException("Argument " + LQUOTE + text + RQUOTE + 427 | " starts with a - but has incorrect syntax") 428 | { 429 | } 430 | }; 431 | 432 | class option_not_exists_exception : public OptionParseException 433 | { 434 | public: 435 | explicit option_not_exists_exception(const std::string& option) 436 | : OptionParseException("Option " + LQUOTE + option + RQUOTE + " does not exist") 437 | { 438 | } 439 | }; 440 | 441 | class missing_argument_exception : public OptionParseException 442 | { 443 | public: 444 | explicit missing_argument_exception(const std::string& option) 445 | : OptionParseException( 446 | "Option " + LQUOTE + option + RQUOTE + " is missing an argument" 447 | ) 448 | { 449 | } 450 | }; 451 | 452 | class option_requires_argument_exception : public OptionParseException 453 | { 454 | public: 455 | explicit option_requires_argument_exception(const std::string& option) 456 | : OptionParseException( 457 | "Option " + LQUOTE + option + RQUOTE + " requires an argument" 458 | ) 459 | { 460 | } 461 | }; 462 | 463 | class option_not_has_argument_exception : public OptionParseException 464 | { 465 | public: 466 | option_not_has_argument_exception 467 | ( 468 | const std::string& option, 469 | const std::string& arg 470 | ) 471 | : OptionParseException( 472 | "Option " + LQUOTE + option + RQUOTE + 473 | " does not take an argument, but argument " + 474 | LQUOTE + arg + RQUOTE + " given" 475 | ) 476 | { 477 | } 478 | }; 479 | 480 | class option_not_present_exception : public OptionParseException 481 | { 482 | public: 483 | explicit option_not_present_exception(const std::string& option) 484 | : OptionParseException("Option " + LQUOTE + option + RQUOTE + " not present") 485 | { 486 | } 487 | }; 488 | 489 | class option_has_no_value_exception : public OptionException 490 | { 491 | public: 492 | explicit option_has_no_value_exception(const std::string& option) 493 | : OptionException( 494 | !option.empty() ? 495 | ("Option " + LQUOTE + option + RQUOTE + " has no value") : 496 | "Option has no value") 497 | { 498 | } 499 | }; 500 | 501 | class argument_incorrect_type : public OptionParseException 502 | { 503 | public: 504 | explicit argument_incorrect_type 505 | ( 506 | const std::string& arg 507 | ) 508 | : OptionParseException( 509 | "Argument " + LQUOTE + arg + RQUOTE + " failed to parse" 510 | ) 511 | { 512 | } 513 | }; 514 | 515 | class option_required_exception : public OptionParseException 516 | { 517 | public: 518 | explicit option_required_exception(const std::string& option) 519 | : OptionParseException( 520 | "Option " + LQUOTE + option + RQUOTE + " is required but not present" 521 | ) 522 | { 523 | } 524 | }; 525 | 526 | template 527 | void throw_or_mimic(const std::string& text) 528 | { 529 | static_assert(std::is_base_of::value, 530 | "throw_or_mimic only works on std::exception and " 531 | "deriving classes"); 532 | 533 | #ifndef CXXOPTS_NO_EXCEPTIONS 534 | // If CXXOPTS_NO_EXCEPTIONS is not defined, just throw 535 | throw T{text}; 536 | #else 537 | // Otherwise manually instantiate the exception, print what() to stderr, 538 | // and exit 539 | T exception{text}; 540 | std::cerr << exception.what() << std::endl; 541 | std::exit(EXIT_FAILURE); 542 | #endif 543 | } 544 | 545 | namespace values 546 | { 547 | namespace parser_tool 548 | { 549 | struct IntegerDesc 550 | { 551 | std::string negative = ""; 552 | std::string base = ""; 553 | std::string value = ""; 554 | }; 555 | struct ArguDesc { 556 | std::string arg_name = ""; 557 | bool grouping = false; 558 | bool set_value = false; 559 | std::string value = ""; 560 | }; 561 | #ifdef CXXOPTS_NO_REGEX 562 | inline IntegerDesc SplitInteger(const std::string &text) 563 | { 564 | if (text.empty()) 565 | { 566 | throw_or_mimic(text); 567 | } 568 | IntegerDesc desc; 569 | const char *pdata = text.c_str(); 570 | if (*pdata == '-') 571 | { 572 | pdata += 1; 573 | desc.negative = "-"; 574 | } 575 | if (strncmp(pdata, "0x", 2) == 0) 576 | { 577 | pdata += 2; 578 | desc.base = "0x"; 579 | } 580 | if (*pdata != '\0') 581 | { 582 | desc.value = std::string(pdata); 583 | } 584 | else 585 | { 586 | throw_or_mimic(text); 587 | } 588 | return desc; 589 | } 590 | 591 | inline bool IsTrueText(const std::string &text) 592 | { 593 | const char *pdata = text.c_str(); 594 | if (*pdata == 't' || *pdata == 'T') 595 | { 596 | pdata += 1; 597 | if (strncmp(pdata, "rue\0", 4) == 0) 598 | { 599 | return true; 600 | } 601 | } 602 | else if (strncmp(pdata, "1\0", 2) == 0) 603 | { 604 | return true; 605 | } 606 | return false; 607 | } 608 | 609 | inline bool IsFalseText(const std::string &text) 610 | { 611 | const char *pdata = text.c_str(); 612 | if (*pdata == 'f' || *pdata == 'F') 613 | { 614 | pdata += 1; 615 | if (strncmp(pdata, "alse\0", 5) == 0) 616 | { 617 | return true; 618 | } 619 | } 620 | else if (strncmp(pdata, "0\0", 2) == 0) 621 | { 622 | return true; 623 | } 624 | return false; 625 | } 626 | 627 | inline std::pair SplitSwitchDef(const std::string &text) 628 | { 629 | std::string short_sw, long_sw; 630 | const char *pdata = text.c_str(); 631 | if (isalnum(*pdata) && *(pdata + 1) == ',') { 632 | short_sw = std::string(1, *pdata); 633 | pdata += 2; 634 | } 635 | while (*pdata == ' ') { pdata += 1; } 636 | if (isalnum(*pdata)) { 637 | const char *store = pdata; 638 | pdata += 1; 639 | while (isalnum(*pdata) || *pdata == '-' || *pdata == '_') { 640 | pdata += 1; 641 | } 642 | if (*pdata == '\0') { 643 | long_sw = std::string(store, pdata - store); 644 | } else { 645 | throw_or_mimic(text); 646 | } 647 | } 648 | return std::pair(short_sw, long_sw); 649 | } 650 | 651 | inline ArguDesc ParseArgument(const char *arg, bool &matched) 652 | { 653 | ArguDesc argu_desc; 654 | const char *pdata = arg; 655 | matched = false; 656 | if (strncmp(pdata, "--", 2) == 0) 657 | { 658 | pdata += 2; 659 | if (isalnum(*pdata)) 660 | { 661 | argu_desc.arg_name.push_back(*pdata); 662 | pdata += 1; 663 | while (isalnum(*pdata) || *pdata == '-' || *pdata == '_') 664 | { 665 | argu_desc.arg_name.push_back(*pdata); 666 | pdata += 1; 667 | } 668 | if (argu_desc.arg_name.length() > 1) 669 | { 670 | if (*pdata == '=') 671 | { 672 | argu_desc.set_value = true; 673 | pdata += 1; 674 | if (*pdata != '\0') 675 | { 676 | argu_desc.value = std::string(pdata); 677 | } 678 | matched = true; 679 | } 680 | else if (*pdata == '\0') 681 | { 682 | matched = true; 683 | } 684 | } 685 | } 686 | } 687 | else if (strncmp(pdata, "-", 1) == 0) 688 | { 689 | pdata += 1; 690 | argu_desc.grouping = true; 691 | while (isalnum(*pdata)) 692 | { 693 | argu_desc.arg_name.push_back(*pdata); 694 | pdata += 1; 695 | } 696 | matched = !argu_desc.arg_name.empty() && *pdata == '\0'; 697 | } 698 | return argu_desc; 699 | } 700 | 701 | #else // CXXOPTS_NO_REGEX 702 | 703 | namespace 704 | { 705 | 706 | std::basic_regex integer_pattern 707 | ("(-)?(0x)?([0-9a-zA-Z]+)|((0x)?0)"); 708 | std::basic_regex truthy_pattern 709 | ("(t|T)(rue)?|1"); 710 | std::basic_regex falsy_pattern 711 | ("(f|F)(alse)?|0"); 712 | 713 | std::basic_regex option_matcher 714 | ("--([[:alnum:]][-_[:alnum:]]+)(=(.*))?|-([[:alnum:]]+)"); 715 | std::basic_regex option_specifier 716 | ("(([[:alnum:]]),)?[ ]*([[:alnum:]][-_[:alnum:]]*)?"); 717 | 718 | } // namespace 719 | 720 | inline IntegerDesc SplitInteger(const std::string &text) 721 | { 722 | std::smatch match; 723 | std::regex_match(text, match, integer_pattern); 724 | 725 | if (match.length() == 0) 726 | { 727 | throw_or_mimic(text); 728 | } 729 | 730 | IntegerDesc desc; 731 | desc.negative = match[1]; 732 | desc.base = match[2]; 733 | desc.value = match[3]; 734 | 735 | if (match.length(4) > 0) 736 | { 737 | desc.base = match[5]; 738 | desc.value = "0"; 739 | return desc; 740 | } 741 | 742 | return desc; 743 | } 744 | 745 | inline bool IsTrueText(const std::string &text) 746 | { 747 | std::smatch result; 748 | std::regex_match(text, result, truthy_pattern); 749 | return !result.empty(); 750 | } 751 | 752 | inline bool IsFalseText(const std::string &text) 753 | { 754 | std::smatch result; 755 | std::regex_match(text, result, falsy_pattern); 756 | return !result.empty(); 757 | } 758 | 759 | inline std::pair SplitSwitchDef(const std::string &text) 760 | { 761 | std::match_results result; 762 | std::regex_match(text.c_str(), result, option_specifier); 763 | if (result.empty()) 764 | { 765 | throw_or_mimic(text); 766 | } 767 | 768 | const std::string& short_sw = result[2]; 769 | const std::string& long_sw = result[3]; 770 | 771 | return std::pair(short_sw, long_sw); 772 | } 773 | 774 | inline ArguDesc ParseArgument(const char *arg, bool &matched) 775 | { 776 | std::match_results result; 777 | std::regex_match(arg, result, option_matcher); 778 | matched = !result.empty(); 779 | 780 | ArguDesc argu_desc; 781 | if (matched) { 782 | argu_desc.arg_name = result[1].str(); 783 | argu_desc.set_value = result[2].length() > 0; 784 | argu_desc.value = result[3].str(); 785 | if (result[4].length() > 0) 786 | { 787 | argu_desc.grouping = true; 788 | argu_desc.arg_name = result[4].str(); 789 | } 790 | } 791 | 792 | return argu_desc; 793 | } 794 | 795 | #endif // CXXOPTS_NO_REGEX 796 | #undef CXXOPTS_NO_REGEX 797 | } 798 | 799 | namespace detail 800 | { 801 | template 802 | struct SignedCheck; 803 | 804 | template 805 | struct SignedCheck 806 | { 807 | template 808 | void 809 | operator()(bool negative, U u, const std::string& text) 810 | { 811 | if (negative) 812 | { 813 | if (u > static_cast((std::numeric_limits::min)())) 814 | { 815 | throw_or_mimic(text); 816 | } 817 | } 818 | else 819 | { 820 | if (u > static_cast((std::numeric_limits::max)())) 821 | { 822 | throw_or_mimic(text); 823 | } 824 | } 825 | } 826 | }; 827 | 828 | template 829 | struct SignedCheck 830 | { 831 | template 832 | void 833 | operator()(bool, U, const std::string&) const {} 834 | }; 835 | 836 | template 837 | void 838 | check_signed_range(bool negative, U value, const std::string& text) 839 | { 840 | SignedCheck::is_signed>()(negative, value, text); 841 | } 842 | } // namespace detail 843 | 844 | template 845 | void 846 | checked_negate(R& r, T&& t, const std::string&, std::true_type) 847 | { 848 | // if we got to here, then `t` is a positive number that fits into 849 | // `R`. So to avoid MSVC C4146, we first cast it to `R`. 850 | // See https://github.com/jarro2783/cxxopts/issues/62 for more details. 851 | r = static_cast(-static_cast(t-1)-1); 852 | } 853 | 854 | template 855 | void 856 | checked_negate(R&, T&&, const std::string& text, std::false_type) 857 | { 858 | throw_or_mimic(text); 859 | } 860 | 861 | template 862 | void 863 | integer_parser(const std::string& text, T& value) 864 | { 865 | parser_tool::IntegerDesc int_desc = parser_tool::SplitInteger(text); 866 | 867 | using US = typename std::make_unsigned::type; 868 | constexpr bool is_signed = std::numeric_limits::is_signed; 869 | 870 | const bool negative = int_desc.negative.length() > 0; 871 | const uint8_t base = int_desc.base.length() > 0 ? 16 : 10; 872 | const std::string & value_match = int_desc.value; 873 | 874 | US result = 0; 875 | 876 | for (char ch : value_match) 877 | { 878 | US digit = 0; 879 | 880 | if (ch >= '0' && ch <= '9') 881 | { 882 | digit = static_cast(ch - '0'); 883 | } 884 | else if (base == 16 && ch >= 'a' && ch <= 'f') 885 | { 886 | digit = static_cast(ch - 'a' + 10); 887 | } 888 | else if (base == 16 && ch >= 'A' && ch <= 'F') 889 | { 890 | digit = static_cast(ch - 'A' + 10); 891 | } 892 | else 893 | { 894 | throw_or_mimic(text); 895 | } 896 | 897 | const US next = static_cast(result * base + digit); 898 | if (result > next) 899 | { 900 | throw_or_mimic(text); 901 | } 902 | 903 | result = next; 904 | } 905 | 906 | detail::check_signed_range(negative, result, text); 907 | 908 | if (negative) 909 | { 910 | checked_negate(value, result, text, std::integral_constant()); 911 | } 912 | else 913 | { 914 | value = static_cast(result); 915 | } 916 | } 917 | 918 | template 919 | void stringstream_parser(const std::string& text, T& value) 920 | { 921 | std::stringstream in(text); 922 | in >> value; 923 | if (!in) { 924 | throw_or_mimic(text); 925 | } 926 | } 927 | 928 | template ::value>::type* = nullptr 930 | > 931 | void parse_value(const std::string& text, T& value) 932 | { 933 | integer_parser(text, value); 934 | } 935 | 936 | inline 937 | void 938 | parse_value(const std::string& text, bool& value) 939 | { 940 | if (parser_tool::IsTrueText(text)) 941 | { 942 | value = true; 943 | return; 944 | } 945 | 946 | if (parser_tool::IsFalseText(text)) 947 | { 948 | value = false; 949 | return; 950 | } 951 | 952 | throw_or_mimic(text); 953 | } 954 | 955 | inline 956 | void 957 | parse_value(const std::string& text, std::string& value) 958 | { 959 | value = text; 960 | } 961 | 962 | // The fallback parser. It uses the stringstream parser to parse all types 963 | // that have not been overloaded explicitly. It has to be placed in the 964 | // source code before all other more specialized templates. 965 | template ::value>::type* = nullptr 967 | > 968 | void 969 | parse_value(const std::string& text, T& value) { 970 | stringstream_parser(text, value); 971 | } 972 | 973 | template 974 | void 975 | parse_value(const std::string& text, std::vector& value) 976 | { 977 | if (text.empty()) { 978 | T v; 979 | parse_value(text, v); 980 | value.emplace_back(std::move(v)); 981 | return; 982 | } 983 | std::stringstream in(text); 984 | std::string token; 985 | while(!in.eof() && std::getline(in, token, CXXOPTS_VECTOR_DELIMITER)) { 986 | T v; 987 | parse_value(token, v); 988 | value.emplace_back(std::move(v)); 989 | } 990 | } 991 | 992 | #ifdef CXXOPTS_HAS_OPTIONAL 993 | template 994 | void 995 | parse_value(const std::string& text, std::optional& value) 996 | { 997 | T result; 998 | parse_value(text, result); 999 | value = std::move(result); 1000 | } 1001 | #endif 1002 | 1003 | inline 1004 | void parse_value(const std::string& text, char& c) 1005 | { 1006 | if (text.length() != 1) 1007 | { 1008 | throw_or_mimic(text); 1009 | } 1010 | 1011 | c = text[0]; 1012 | } 1013 | 1014 | template 1015 | struct type_is_container 1016 | { 1017 | static constexpr bool value = false; 1018 | }; 1019 | 1020 | template 1021 | struct type_is_container> 1022 | { 1023 | static constexpr bool value = true; 1024 | }; 1025 | 1026 | template 1027 | class abstract_value : public Value 1028 | { 1029 | using Self = abstract_value; 1030 | 1031 | public: 1032 | abstract_value() 1033 | : m_result(std::make_shared()) 1034 | , m_store(m_result.get()) 1035 | { 1036 | } 1037 | 1038 | explicit abstract_value(T* t) 1039 | : m_store(t) 1040 | { 1041 | } 1042 | 1043 | ~abstract_value() override = default; 1044 | 1045 | abstract_value& operator=(const abstract_value&) = default; 1046 | 1047 | abstract_value(const abstract_value& rhs) 1048 | { 1049 | if (rhs.m_result) 1050 | { 1051 | m_result = std::make_shared(); 1052 | m_store = m_result.get(); 1053 | } 1054 | else 1055 | { 1056 | m_store = rhs.m_store; 1057 | } 1058 | 1059 | m_default = rhs.m_default; 1060 | m_implicit = rhs.m_implicit; 1061 | m_default_value = rhs.m_default_value; 1062 | m_implicit_value = rhs.m_implicit_value; 1063 | } 1064 | 1065 | void 1066 | parse(const std::string& text) const override 1067 | { 1068 | parse_value(text, *m_store); 1069 | } 1070 | 1071 | bool 1072 | is_container() const override 1073 | { 1074 | return type_is_container::value; 1075 | } 1076 | 1077 | void 1078 | parse() const override 1079 | { 1080 | parse_value(m_default_value, *m_store); 1081 | } 1082 | 1083 | bool 1084 | has_default() const override 1085 | { 1086 | return m_default; 1087 | } 1088 | 1089 | bool 1090 | has_implicit() const override 1091 | { 1092 | return m_implicit; 1093 | } 1094 | 1095 | std::shared_ptr 1096 | default_value(const std::string& value) override 1097 | { 1098 | m_default = true; 1099 | m_default_value = value; 1100 | return shared_from_this(); 1101 | } 1102 | 1103 | std::shared_ptr 1104 | implicit_value(const std::string& value) override 1105 | { 1106 | m_implicit = true; 1107 | m_implicit_value = value; 1108 | return shared_from_this(); 1109 | } 1110 | 1111 | std::shared_ptr 1112 | no_implicit_value() override 1113 | { 1114 | m_implicit = false; 1115 | return shared_from_this(); 1116 | } 1117 | 1118 | std::string 1119 | get_default_value() const override 1120 | { 1121 | return m_default_value; 1122 | } 1123 | 1124 | std::string 1125 | get_implicit_value() const override 1126 | { 1127 | return m_implicit_value; 1128 | } 1129 | 1130 | bool 1131 | is_boolean() const override 1132 | { 1133 | return std::is_same::value; 1134 | } 1135 | 1136 | const T& 1137 | get() const 1138 | { 1139 | if (m_store == nullptr) 1140 | { 1141 | return *m_result; 1142 | } 1143 | return *m_store; 1144 | } 1145 | 1146 | protected: 1147 | std::shared_ptr m_result{}; 1148 | T* m_store{}; 1149 | 1150 | bool m_default = false; 1151 | bool m_implicit = false; 1152 | 1153 | std::string m_default_value{}; 1154 | std::string m_implicit_value{}; 1155 | }; 1156 | 1157 | template 1158 | class standard_value : public abstract_value 1159 | { 1160 | public: 1161 | using abstract_value::abstract_value; 1162 | 1163 | CXXOPTS_NODISCARD 1164 | std::shared_ptr 1165 | clone() const override 1166 | { 1167 | return std::make_shared>(*this); 1168 | } 1169 | }; 1170 | 1171 | template <> 1172 | class standard_value : public abstract_value 1173 | { 1174 | public: 1175 | ~standard_value() override = default; 1176 | 1177 | standard_value() 1178 | { 1179 | set_default_and_implicit(); 1180 | } 1181 | 1182 | explicit standard_value(bool* b) 1183 | : abstract_value(b) 1184 | { 1185 | set_default_and_implicit(); 1186 | } 1187 | 1188 | std::shared_ptr 1189 | clone() const override 1190 | { 1191 | return std::make_shared>(*this); 1192 | } 1193 | 1194 | private: 1195 | 1196 | void 1197 | set_default_and_implicit() 1198 | { 1199 | m_default = true; 1200 | m_default_value = "false"; 1201 | m_implicit = true; 1202 | m_implicit_value = "true"; 1203 | } 1204 | }; 1205 | } // namespace values 1206 | 1207 | template 1208 | std::shared_ptr 1209 | value() 1210 | { 1211 | return std::make_shared>(); 1212 | } 1213 | 1214 | template 1215 | std::shared_ptr 1216 | value(T& t) 1217 | { 1218 | return std::make_shared>(&t); 1219 | } 1220 | 1221 | class OptionAdder; 1222 | 1223 | class OptionDetails 1224 | { 1225 | public: 1226 | OptionDetails 1227 | ( 1228 | std::string short_, 1229 | std::string long_, 1230 | String desc, 1231 | std::shared_ptr val 1232 | ) 1233 | : m_short(std::move(short_)) 1234 | , m_long(std::move(long_)) 1235 | , m_desc(std::move(desc)) 1236 | , m_value(std::move(val)) 1237 | , m_count(0) 1238 | { 1239 | m_hash = std::hash{}(m_long + m_short); 1240 | } 1241 | 1242 | OptionDetails(const OptionDetails& rhs) 1243 | : m_desc(rhs.m_desc) 1244 | , m_value(rhs.m_value->clone()) 1245 | , m_count(rhs.m_count) 1246 | { 1247 | } 1248 | 1249 | OptionDetails(OptionDetails&& rhs) = default; 1250 | 1251 | CXXOPTS_NODISCARD 1252 | const String& 1253 | description() const 1254 | { 1255 | return m_desc; 1256 | } 1257 | 1258 | CXXOPTS_NODISCARD 1259 | const Value& 1260 | value() const { 1261 | return *m_value; 1262 | } 1263 | 1264 | CXXOPTS_NODISCARD 1265 | std::shared_ptr 1266 | make_storage() const 1267 | { 1268 | return m_value->clone(); 1269 | } 1270 | 1271 | CXXOPTS_NODISCARD 1272 | const std::string& 1273 | short_name() const 1274 | { 1275 | return m_short; 1276 | } 1277 | 1278 | CXXOPTS_NODISCARD 1279 | const std::string& 1280 | long_name() const 1281 | { 1282 | return m_long; 1283 | } 1284 | 1285 | CXXOPTS_NODISCARD 1286 | const std::string& 1287 | essential_name() const 1288 | { 1289 | return m_long.empty() ? m_short : m_long; 1290 | } 1291 | 1292 | size_t 1293 | hash() const 1294 | { 1295 | return m_hash; 1296 | } 1297 | 1298 | private: 1299 | std::string m_short{}; 1300 | std::string m_long{}; 1301 | String m_desc{}; 1302 | std::shared_ptr m_value{}; 1303 | int m_count; 1304 | 1305 | size_t m_hash{}; 1306 | }; 1307 | 1308 | struct HelpOptionDetails 1309 | { 1310 | std::string s; 1311 | std::string l; 1312 | String desc; 1313 | bool has_default; 1314 | std::string default_value; 1315 | bool has_implicit; 1316 | std::string implicit_value; 1317 | std::string arg_help; 1318 | bool is_container; 1319 | bool is_boolean; 1320 | }; 1321 | 1322 | struct HelpGroupDetails 1323 | { 1324 | std::string name{}; 1325 | std::string description{}; 1326 | std::vector options{}; 1327 | }; 1328 | 1329 | class OptionValue 1330 | { 1331 | public: 1332 | void 1333 | parse 1334 | ( 1335 | const std::shared_ptr& details, 1336 | const std::string& text 1337 | ) 1338 | { 1339 | ensure_value(details); 1340 | ++m_count; 1341 | m_value->parse(text); 1342 | m_long_name = &details->long_name(); 1343 | } 1344 | 1345 | void 1346 | parse_default(const std::shared_ptr& details) 1347 | { 1348 | ensure_value(details); 1349 | m_default = true; 1350 | m_long_name = &details->long_name(); 1351 | m_value->parse(); 1352 | } 1353 | 1354 | void 1355 | parse_no_value(const std::shared_ptr& details) 1356 | { 1357 | m_long_name = &details->long_name(); 1358 | } 1359 | 1360 | #if defined(CXXOPTS_NULL_DEREF_IGNORE) 1361 | #pragma GCC diagnostic push 1362 | #pragma GCC diagnostic ignored "-Wnull-dereference" 1363 | #endif 1364 | 1365 | CXXOPTS_NODISCARD 1366 | size_t 1367 | count() const noexcept 1368 | { 1369 | return m_count; 1370 | } 1371 | 1372 | #if defined(CXXOPTS_NULL_DEREF_IGNORE) 1373 | #pragma GCC diagnostic pop 1374 | #endif 1375 | 1376 | // TODO: maybe default options should count towards the number of arguments 1377 | CXXOPTS_NODISCARD 1378 | bool 1379 | has_default() const noexcept 1380 | { 1381 | return m_default; 1382 | } 1383 | 1384 | template 1385 | const T& 1386 | as() const 1387 | { 1388 | if (m_value == nullptr) { 1389 | throw_or_mimic( 1390 | m_long_name == nullptr ? "" : *m_long_name); 1391 | } 1392 | 1393 | #ifdef CXXOPTS_NO_RTTI 1394 | return static_cast&>(*m_value).get(); 1395 | #else 1396 | return dynamic_cast&>(*m_value).get(); 1397 | #endif 1398 | } 1399 | 1400 | private: 1401 | void 1402 | ensure_value(const std::shared_ptr& details) 1403 | { 1404 | if (m_value == nullptr) 1405 | { 1406 | m_value = details->make_storage(); 1407 | } 1408 | } 1409 | 1410 | 1411 | const std::string* m_long_name = nullptr; 1412 | // Holding this pointer is safe, since OptionValue's only exist in key-value pairs, 1413 | // where the key has the string we point to. 1414 | std::shared_ptr m_value{}; 1415 | size_t m_count = 0; 1416 | bool m_default = false; 1417 | }; 1418 | 1419 | class KeyValue 1420 | { 1421 | public: 1422 | KeyValue(std::string key_, std::string value_) 1423 | : m_key(std::move(key_)) 1424 | , m_value(std::move(value_)) 1425 | { 1426 | } 1427 | 1428 | CXXOPTS_NODISCARD 1429 | const std::string& 1430 | key() const 1431 | { 1432 | return m_key; 1433 | } 1434 | 1435 | CXXOPTS_NODISCARD 1436 | const std::string& 1437 | value() const 1438 | { 1439 | return m_value; 1440 | } 1441 | 1442 | template 1443 | T 1444 | as() const 1445 | { 1446 | T result; 1447 | values::parse_value(m_value, result); 1448 | return result; 1449 | } 1450 | 1451 | private: 1452 | std::string m_key; 1453 | std::string m_value; 1454 | }; 1455 | 1456 | using ParsedHashMap = std::unordered_map; 1457 | using NameHashMap = std::unordered_map; 1458 | 1459 | class ParseResult 1460 | { 1461 | public: 1462 | class Iterator 1463 | { 1464 | public: 1465 | using iterator_category = std::forward_iterator_tag; 1466 | using value_type = KeyValue; 1467 | using difference_type = void; 1468 | using pointer = const KeyValue*; 1469 | using reference = const KeyValue&; 1470 | 1471 | Iterator() = default; 1472 | Iterator(const Iterator&) = default; 1473 | 1474 | Iterator(const ParseResult *pr, bool end=false) 1475 | : m_pr(pr) 1476 | , m_iter(end? pr->m_defaults.end(): pr->m_sequential.begin()) 1477 | { 1478 | } 1479 | 1480 | Iterator& operator++() 1481 | { 1482 | ++m_iter; 1483 | if(m_iter == m_pr->m_sequential.end()) 1484 | { 1485 | m_iter = m_pr->m_defaults.begin(); 1486 | return *this; 1487 | } 1488 | return *this; 1489 | } 1490 | 1491 | Iterator operator++(int) 1492 | { 1493 | Iterator retval = *this; 1494 | ++(*this); 1495 | return retval; 1496 | } 1497 | 1498 | bool operator==(const Iterator& other) const 1499 | { 1500 | return m_iter == other.m_iter; 1501 | } 1502 | 1503 | bool operator!=(const Iterator& other) const 1504 | { 1505 | return !(*this == other); 1506 | } 1507 | 1508 | const KeyValue& operator*() 1509 | { 1510 | return *m_iter; 1511 | } 1512 | 1513 | const KeyValue* operator->() 1514 | { 1515 | return m_iter.operator->(); 1516 | } 1517 | 1518 | private: 1519 | const ParseResult* m_pr; 1520 | std::vector::const_iterator m_iter; 1521 | }; 1522 | 1523 | ParseResult() = default; 1524 | ParseResult(const ParseResult&) = default; 1525 | 1526 | ParseResult(NameHashMap&& keys, ParsedHashMap&& values, std::vector sequential, 1527 | std::vector default_opts, std::vector&& unmatched_args) 1528 | : m_keys(std::move(keys)) 1529 | , m_values(std::move(values)) 1530 | , m_sequential(std::move(sequential)) 1531 | , m_defaults(std::move(default_opts)) 1532 | , m_unmatched(std::move(unmatched_args)) 1533 | { 1534 | } 1535 | 1536 | ParseResult& operator=(ParseResult&&) = default; 1537 | ParseResult& operator=(const ParseResult&) = default; 1538 | 1539 | Iterator 1540 | begin() const 1541 | { 1542 | return Iterator(this); 1543 | } 1544 | 1545 | Iterator 1546 | end() const 1547 | { 1548 | return Iterator(this, true); 1549 | } 1550 | 1551 | size_t 1552 | count(const std::string& o) const 1553 | { 1554 | auto iter = m_keys.find(o); 1555 | if (iter == m_keys.end()) 1556 | { 1557 | return 0; 1558 | } 1559 | 1560 | auto viter = m_values.find(iter->second); 1561 | 1562 | if (viter == m_values.end()) 1563 | { 1564 | return 0; 1565 | } 1566 | 1567 | return viter->second.count(); 1568 | } 1569 | 1570 | const OptionValue& 1571 | operator[](const std::string& option) const 1572 | { 1573 | auto iter = m_keys.find(option); 1574 | 1575 | if (iter == m_keys.end()) 1576 | { 1577 | throw_or_mimic(option); 1578 | } 1579 | 1580 | auto viter = m_values.find(iter->second); 1581 | 1582 | if (viter == m_values.end()) 1583 | { 1584 | throw_or_mimic(option); 1585 | } 1586 | 1587 | return viter->second; 1588 | } 1589 | 1590 | const std::vector& 1591 | arguments() const 1592 | { 1593 | return m_sequential; 1594 | } 1595 | 1596 | const std::vector& 1597 | unmatched() const 1598 | { 1599 | return m_unmatched; 1600 | } 1601 | 1602 | const std::vector& 1603 | defaults() const 1604 | { 1605 | return m_defaults; 1606 | } 1607 | 1608 | const std::string 1609 | arguments_string() const 1610 | { 1611 | std::string result; 1612 | for(const auto& kv: m_sequential) 1613 | { 1614 | result += kv.key() + " = " + kv.value() + "\n"; 1615 | } 1616 | for(const auto& kv: m_defaults) 1617 | { 1618 | result += kv.key() + " = " + kv.value() + " " + "(default)" + "\n"; 1619 | } 1620 | return result; 1621 | } 1622 | 1623 | private: 1624 | NameHashMap m_keys{}; 1625 | ParsedHashMap m_values{}; 1626 | std::vector m_sequential{}; 1627 | std::vector m_defaults{}; 1628 | std::vector m_unmatched{}; 1629 | }; 1630 | 1631 | struct Option 1632 | { 1633 | Option 1634 | ( 1635 | std::string opts, 1636 | std::string desc, 1637 | std::shared_ptr value = ::cxxopts::value(), 1638 | std::string arg_help = "" 1639 | ) 1640 | : opts_(std::move(opts)) 1641 | , desc_(std::move(desc)) 1642 | , value_(std::move(value)) 1643 | , arg_help_(std::move(arg_help)) 1644 | { 1645 | } 1646 | 1647 | std::string opts_; 1648 | std::string desc_; 1649 | std::shared_ptr value_; 1650 | std::string arg_help_; 1651 | }; 1652 | 1653 | using OptionMap = std::unordered_map>; 1654 | using PositionalList = std::vector; 1655 | using PositionalListIterator = PositionalList::const_iterator; 1656 | 1657 | class OptionParser 1658 | { 1659 | public: 1660 | OptionParser(const OptionMap& options, const PositionalList& positional, bool allow_unrecognised) 1661 | : m_options(options) 1662 | , m_positional(positional) 1663 | , m_allow_unrecognised(allow_unrecognised) 1664 | { 1665 | } 1666 | 1667 | ParseResult 1668 | parse(int argc, const char* const* argv); 1669 | 1670 | bool 1671 | consume_positional(const std::string& a, PositionalListIterator& next); 1672 | 1673 | void 1674 | checked_parse_arg 1675 | ( 1676 | int argc, 1677 | const char* const* argv, 1678 | int& current, 1679 | const std::shared_ptr& value, 1680 | const std::string& name 1681 | ); 1682 | 1683 | void 1684 | add_to_option(OptionMap::const_iterator iter, const std::string& option, const std::string& arg); 1685 | 1686 | void 1687 | parse_option 1688 | ( 1689 | const std::shared_ptr& value, 1690 | const std::string& name, 1691 | const std::string& arg = "" 1692 | ); 1693 | 1694 | void 1695 | parse_default(const std::shared_ptr& details); 1696 | 1697 | void 1698 | parse_no_value(const std::shared_ptr& details); 1699 | 1700 | private: 1701 | 1702 | void finalise_aliases(); 1703 | 1704 | const OptionMap& m_options; 1705 | const PositionalList& m_positional; 1706 | 1707 | std::vector m_sequential{}; 1708 | std::vector m_defaults{}; 1709 | bool m_allow_unrecognised; 1710 | 1711 | ParsedHashMap m_parsed{}; 1712 | NameHashMap m_keys{}; 1713 | }; 1714 | 1715 | class Options 1716 | { 1717 | public: 1718 | 1719 | explicit Options(std::string program, std::string help_string = "") 1720 | : m_program(std::move(program)) 1721 | , m_help_string(toLocalString(std::move(help_string))) 1722 | , m_custom_help("[OPTION...]") 1723 | , m_positional_help("positional parameters") 1724 | , m_show_positional(false) 1725 | , m_allow_unrecognised(false) 1726 | , m_width(76) 1727 | , m_tab_expansion(false) 1728 | , m_options(std::make_shared()) 1729 | { 1730 | } 1731 | 1732 | Options& 1733 | positional_help(std::string help_text) 1734 | { 1735 | m_positional_help = std::move(help_text); 1736 | return *this; 1737 | } 1738 | 1739 | Options& 1740 | custom_help(std::string help_text) 1741 | { 1742 | m_custom_help = std::move(help_text); 1743 | return *this; 1744 | } 1745 | 1746 | Options& 1747 | show_positional_help() 1748 | { 1749 | m_show_positional = true; 1750 | return *this; 1751 | } 1752 | 1753 | Options& 1754 | allow_unrecognised_options() 1755 | { 1756 | m_allow_unrecognised = true; 1757 | return *this; 1758 | } 1759 | 1760 | Options& 1761 | set_width(size_t width) 1762 | { 1763 | m_width = width; 1764 | return *this; 1765 | } 1766 | 1767 | Options& 1768 | set_tab_expansion(bool expansion=true) 1769 | { 1770 | m_tab_expansion = expansion; 1771 | return *this; 1772 | } 1773 | 1774 | ParseResult 1775 | parse(int argc, const char* const* argv); 1776 | 1777 | OptionAdder 1778 | add_options(std::string group = ""); 1779 | 1780 | void 1781 | add_options 1782 | ( 1783 | const std::string& group, 1784 | std::initializer_list