├── sio4onnx ├── __main__.py ├── __init__.py └── onnx_input_output_variable_changer.py ├── .gitignore ├── .github └── workflows │ ├── python-publish.yml │ └── codeql-analysis.yml ├── setup.py ├── LICENSE └── README.md /sio4onnx/__main__.py: -------------------------------------------------------------------------------- 1 | from . import main 2 | 3 | if __name__ == '__main__': 4 | main() -------------------------------------------------------------------------------- /sio4onnx/__init__.py: -------------------------------------------------------------------------------- 1 | from sio4onnx.onnx_input_output_variable_changer import io_change, main 2 | 3 | __version__ = '1.0.3' 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | saved_model/ 3 | build/ 4 | dist/ 5 | sio4onnx.egg-info/ 6 | sio4onnx/debug/ 7 | sio4onnx/saved_model/ 8 | __pycache__/ 9 | test.py 10 | 11 | *.onnx 12 | *.npy -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python 15 | uses: actions/setup-python@v3 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel pipenv 22 | - name: Build 23 | run: | 24 | python setup.py sdist bdist_wheel 25 | - name: Publish a Python distribution to PyPI 26 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 27 | with: 28 | user: __token__ 29 | password: ${{ secrets.PYPI_API_TOKEN }} 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup, find_packages 3 | from os import path 4 | import re 5 | 6 | package_name="sio4onnx" 7 | root_dir = path.abspath(path.dirname(__file__)) 8 | 9 | with open("README.md") as f: 10 | long_description = f.read() 11 | 12 | with open(path.join(root_dir, package_name, '__init__.py')) as f: 13 | init_text = f.read() 14 | version = re.search(r'__version__\s*=\s*[\'\"](.+?)[\'\"]', init_text).group(1) 15 | 16 | setup( 17 | name=package_name, 18 | version=version, 19 | description=\ 20 | "Simple tool to change the INPUT and OUTPUT shape of ONNX.", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | author="Katsuya Hyodo", 24 | author_email="rmsdh122@yahoo.co.jp", 25 | url="https://github.com/PINTO0309/sio4onnx", 26 | license="MIT License", 27 | packages=find_packages(), 28 | platforms=["linux", "unix"], 29 | python_requires=">=3.6", 30 | entry_points={ 31 | 'console_scripts': [ 32 | "sio4onnx=sio4onnx:main" 33 | ] 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Katsuya Hyodo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '27 0 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sio4onnx 2 | Simple tool to change the INPUT and OUTPUT shape of ONNX. 3 | 4 | https://github.com/PINTO0309/simple-onnx-processing-tools 5 | 6 | [![Downloads](https://static.pepy.tech/personalized-badge/sio4onnx?period=total&units=none&left_color=grey&right_color=brightgreen&left_text=Downloads)](https://pepy.tech/project/sio4onnx) ![GitHub](https://img.shields.io/github/license/PINTO0309/sio4onnx?color=2BAF2B) [![PyPI](https://img.shields.io/pypi/v/sio4onnx?color=2BAF2B)](https://pypi.org/project/sio4onnx/) [![CodeQL](https://github.com/PINTO0309/sio4onnx/workflows/CodeQL/badge.svg)](https://github.com/PINTO0309/sio4onnx/actions?query=workflow%3ACodeQL) 7 | 8 |

9 | 10 |

11 | 12 | ## 1. Setup 13 | ### 1-1. HostPC 14 | ```bash 15 | ### option 16 | $ echo export PATH="~/.local/bin:$PATH" >> ~/.bashrc \ 17 | && source ~/.bashrc 18 | 19 | ### run 20 | $ pip install -U onnx \ 21 | && pip install -U sio4onnx 22 | ``` 23 | ### 1-2. Docker 24 | https://github.com/PINTO0309/simple-onnx-processing-tools#docker 25 | 26 | ## 2. CLI Usage 27 | ```bash 28 | $ sio4onnx -h 29 | 30 | usage: 31 | sio4onnx [-h] 32 | -if INPUT_ONNX_FILE_PATH 33 | -of OUTPUT_ONNX_FILE_PATH 34 | -i INPUT_NAMES 35 | -is INPUT_SHAPES [INPUT_SHAPES ...] 36 | -o OUTPUT_NAMES 37 | -os OUTPUT_SHAPES [OUTPUT_SHAPES ...] 38 | [-n] 39 | 40 | optional arguments: 41 | -h, --help 42 | Show this help message and exit. 43 | 44 | -if INPUT_ONNX_FILE_PATH, --input_onnx_file_path INPUT_ONNX_FILE_PATH 45 | INPUT ONNX file path 46 | 47 | -of OUTPUT_ONNX_FILE_PATH, --output_onnx_file_path OUTPUT_ONNX_FILE_PATH 48 | OUTPUT ONNX file path 49 | 50 | -i INPUT_NAMES, --input_names INPUT_NAMES 51 | List of input OP names. All input OPs of the model must be specified. 52 | The order is unspecified, but must match the order specified for input_shapes. 53 | e.g. 54 | --input_names "input.A" \ 55 | --input_names "input.B" \ 56 | --input_names "input.C" 57 | 58 | -is INPUT_SHAPES [INPUT_SHAPES ...], --input_shapes INPUT_SHAPES [INPUT_SHAPES ...] 59 | List of input OP shapes. All input OPs of the model must be specified. 60 | The order is unspecified, but must match the order specified for input_names. 61 | e.g. 62 | --input_shapes 1 3 "H" "W" \ 63 | --input_shapes "N" 3 "H" "W" \ 64 | --input_shapes "-1" 3 480 640 65 | 66 | -o OUTPUT_NAMES, --output_names OUTPUT_NAMES 67 | List of output OP names. All output OPs of the model must be specified. 68 | The order is unspecified, but must match the order specified for output_shapes. 69 | e.g. 70 | --output_names "output.a" \ 71 | --output_names "output.b" \ 72 | --output_names "output.c" 73 | 74 | -os OUTPUT_SHAPES [OUTPUT_SHAPES ...], --output_shapes OUTPUT_SHAPES [OUTPUT_SHAPES ...] 75 | List of input OP shapes. All output OPs of the model must be specified. 76 | The order is unspecified, but must match the order specified for output_shapes. 77 | e.g. 78 | --output_shapes 1 3 "H" "W" \ 79 | --output_shapes "N", 3, "H", "W" \ 80 | --output_shapes "-1" 3 480 640 81 | 82 | -n, --non_verbose 83 | Do not show all information logs. Only error logs are displayed. 84 | ``` 85 | 86 | ## 3. In-script Usage 87 | ```python 88 | >>> from sio4onnx import io_change 89 | >>> help(io_change) 90 | 91 | Help on function io_change in module sio4onnx.onnx_input_output_variable_changer: 92 | 93 | io_change( 94 | input_onnx_file_path: Union[str, NoneType] = '', 95 | onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None, 96 | output_onnx_file_path: Union[str, NoneType] = '', 97 | input_names: Union[List[str], NoneType] = [], 98 | input_shapes: Union[List[Union[int, str]], NoneType] = [], 99 | output_names: Union[List[str], NoneType] = [], 100 | output_shapes: Union[List[Union[int, str]], NoneType] = [], 101 | non_verbose: Union[bool, NoneType] = False, 102 | ) -> onnx.onnx_ml_pb2.ModelProto 103 | 104 | Parameters 105 | ---------- 106 | input_onnx_file_path: Optional[str] 107 | Input onnx file path. 108 | Either input_onnx_file_path or onnx_graph must be specified. 109 | Default: '' 110 | 111 | onnx_graph: Optional[onnx.ModelProto] 112 | onnx.ModelProto. 113 | Either input_onnx_file_path or onnx_graph must be specified. 114 | onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph. 115 | 116 | output_onnx_file_path: Optional[str] 117 | Output onnx file path. If not specified, no ONNX file is output. 118 | Default: '' 119 | 120 | input_names: Optional[List[str]] 121 | List of input OP names. All input OPs of the model must be specified. 122 | The order is unspecified, but must match the order specified for input_shapes. 123 | e.g. ['input.A', 'input.B', 'input.C'] 124 | 125 | input_shapes: Optional[List[Union[int, str]]] 126 | List of input OP shapes. All input OPs of the model must be specified. 127 | The order is unspecified, but must match the order specified for input_names. 128 | e.g. 129 | [ 130 | [1, 3, 'H', 'W'], 131 | ['N', 3, 'H', 'W'], 132 | ['-1', 3, 480, 640], 133 | ] 134 | 135 | output_names: Optional[List[str]] 136 | List of output OP names. All output OPs of the model must be specified. 137 | The order is unspecified, but must match the order specified for output_shapes. 138 | e.g. ['output.a', 'output.b', 'output.c'] 139 | 140 | output_shapes: Optional[List[Union[int, str]]] 141 | List of input OP shapes. All output OPs of the model must be specified. 142 | The order is unspecified, but must match the order specified for output_shapes. 143 | e.g. 144 | [ 145 | [1, 3, 'H', 'W'], 146 | ['N', 3, 'H', 'W'], 147 | ['-1', 3, 480, 640], 148 | ] 149 | 150 | non_verbose: Optional[bool] 151 | Do not show all information logs. Only error logs are displayed. 152 | Default: False 153 | 154 | Returns 155 | ------- 156 | io_changed_graph: onnx.ModelProto 157 | onnx ModelProto with modified INPUT and OUTPUT shapes. 158 | ``` 159 | 160 | ## 4. CLI Execution 161 | ```bash 162 | $ sio4onnx \ 163 | --input_onnx_file_path yolov3-10.onnx \ 164 | --output_onnx_file_path yolov3-10_upd.onnx \ 165 | --input_names "input_1" \ 166 | --input_names "image_shape" \ 167 | --input_shapes "batch" 3 "H" "W" \ 168 | --input_shapes "batch" 2 \ 169 | --output_names "yolonms_layer_1/ExpandDims_1:0" \ 170 | --output_names "yolonms_layer_1/ExpandDims_3:0" \ 171 | --output_names "yolonms_layer_1/concat_2:0" \ 172 | --output_shapes 1 "boxes" 4 \ 173 | --output_shapes 1 "classes" "boxes" \ 174 | --output_shapes "boxes" 3 175 | ``` 176 | 177 | ## 5. In-script Execution 178 | ```python 179 | from sio4onnx import io_change 180 | 181 | io_changed_graph = io_change( 182 | input_onnx_file_path="yolov3-10.onnx", 183 | output_onnx_file_path="yolov3-10_upd.onnx", 184 | input_names=[ 185 | "input_1", 186 | "image_shape", 187 | ], 188 | input_shapes=[ 189 | ["batch", 3, "H", "W"], 190 | ["batch", 2], 191 | ], 192 | output_names=[ 193 | "yolonms_layer_1/ExpandDims_1:0", 194 | "yolonms_layer_1/ExpandDims_3:0", 195 | "yolonms_layer_1/concat_2:0", 196 | ], 197 | output_shapes=[ 198 | [1, "boxes", 4], 199 | [1, "classes", "boxes"], 200 | ["boxes", 3], 201 | ], 202 | ) 203 | ``` 204 | ## 6. Sample 205 | ### Before 206 | ![image](https://user-images.githubusercontent.com/33194443/178515405-42d2bd01-f5fa-41be-95e3-3a229b0c8ae9.png) 207 | ### After 208 | ![image](https://user-images.githubusercontent.com/33194443/178515314-ecbf7f85-5c1d-4626-ac8b-3558432f6e9b.png) 209 | -------------------------------------------------------------------------------- /sio4onnx/onnx_input_output_variable_changer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import sys 4 | import onnx 5 | from onnx import ModelProto, ValueInfoProto, TensorShapeProto 6 | from typing import Optional, List, Union, Any 7 | from argparse import ArgumentParser 8 | from ast import literal_eval 9 | 10 | class Color: 11 | BLACK = '\033[30m' 12 | RED = '\033[31m' 13 | GREEN = '\033[32m' 14 | YELLOW = '\033[33m' 15 | BLUE = '\033[34m' 16 | MAGENTA = '\033[35m' 17 | CYAN = '\033[36m' 18 | WHITE = '\033[37m' 19 | COLOR_DEFAULT = '\033[39m' 20 | BOLD = '\033[1m' 21 | UNDERLINE = '\033[4m' 22 | INVISIBLE = '\033[08m' 23 | REVERCE = '\033[07m' 24 | BG_BLACK = '\033[40m' 25 | BG_RED = '\033[41m' 26 | BG_GREEN = '\033[42m' 27 | BG_YELLOW = '\033[43m' 28 | BG_BLUE = '\033[44m' 29 | BG_MAGENTA = '\033[45m' 30 | BG_CYAN = '\033[46m' 31 | BG_WHITE = '\033[47m' 32 | BG_DEFAULT = '\033[49m' 33 | RESET = '\033[0m' 34 | 35 | 36 | def update_inputs_outputs_dims( 37 | model: ModelProto, 38 | input_dims: dict[str, list[Any]], 39 | output_dims: dict[str, list[Any]], 40 | ) -> ModelProto: 41 | """This function updates the dimension sizes of the model's inputs and outputs to the values 42 | provided in input_dims and output_dims. if the dim value provided is negative, a unique dim_param 43 | will be set for that dimension. 44 | 45 | Example. if we have the following shape for inputs and outputs: 46 | 47 | * shape(input_1) = ('b', 3, 'w', 'h') 48 | * shape(input_2) = ('b', 4) 49 | * shape(output) = ('b', 'd', 5) 50 | 51 | The parameters can be provided as: 52 | 53 | :: 54 | 55 | input_dims = { 56 | "input_1": ['b', 3, 'w', 'h'], 57 | "input_2": ['b', 4], 58 | } 59 | output_dims = { 60 | "output": ['b', -1, 5] 61 | } 62 | 63 | Putting it together: 64 | 65 | :: 66 | 67 | model = onnx.load('model.onnx') 68 | updated_model = update_inputs_outputs_dims(model, input_dims, output_dims) 69 | onnx.save(updated_model, 'model.onnx') 70 | """ 71 | dim_param_set: set[str] = set() 72 | 73 | def init_dim_param_set( 74 | dim_param_set: set[str], value_infos: list[ValueInfoProto] 75 | ) -> None: 76 | for info in value_infos: 77 | shape = info.type.tensor_type.shape 78 | for dim in shape.dim: 79 | if dim.HasField("dim_param"): 80 | dim_param_set.add(dim.dim_param) 81 | 82 | init_dim_param_set(dim_param_set, model.graph.input) 83 | init_dim_param_set(dim_param_set, model.graph.output) 84 | init_dim_param_set(dim_param_set, model.graph.value_info) 85 | 86 | def update_dim(tensor: ValueInfoProto, dim: Any, j: int, name: str) -> None: 87 | dim_proto = tensor.type.tensor_type.shape.dim[j] 88 | if isinstance(dim, int): 89 | if dim >= 0: 90 | if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim: 91 | raise ValueError( 92 | f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}." 93 | ) 94 | dim_proto.dim_value = dim 95 | else: 96 | generated_dim_param = name + "_" + str(j) 97 | if generated_dim_param in dim_param_set: 98 | raise ValueError( 99 | f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value." 100 | ) 101 | dim_proto.dim_param = generated_dim_param 102 | elif isinstance(dim, str): 103 | dim_proto.dim_param = dim 104 | else: 105 | raise ValueError( 106 | f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}" 107 | ) 108 | 109 | def make_dim(tensor: ValueInfoProto, output_dim_arr: Any, name: str) -> None: 110 | make_dim_list = [] 111 | for j, dim in enumerate(output_dim_arr): 112 | if isinstance(dim, int): 113 | if dim >= 0: 114 | make_dim = TensorShapeProto.Dimension(dim_value=dim) 115 | make_dim_list.append(make_dim) 116 | else: 117 | make_dim = TensorShapeProto.Dimension(dim_param=str(dim)) 118 | make_dim_list.append(make_dim) 119 | elif isinstance(dim, str): 120 | make_dim = TensorShapeProto.Dimension(dim_param=dim) 121 | make_dim_list.append(make_dim) 122 | else: 123 | raise ValueError( 124 | f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}" 125 | ) 126 | make_tensor_shape_proto = TensorShapeProto(dim=make_dim_list) 127 | tensor.type.tensor_type.shape.MergeFrom(make_tensor_shape_proto) 128 | 129 | for input_ in model.graph.input: 130 | input_name = input_.name 131 | input_dim_arr = input_dims[input_name] 132 | 133 | if input_.type.tensor_type.shape.dim != []: 134 | for j, dim in enumerate(input_dim_arr): 135 | update_dim(input_, dim, j, input_name) 136 | else: 137 | make_dim(input_, input_dim_arr, input_name) 138 | 139 | for output in model.graph.output: 140 | output_name = output.name 141 | output_dim_arr = output_dims[output_name] 142 | 143 | if output.type.tensor_type.shape.dim != []: 144 | for j, dim in enumerate(output_dim_arr): 145 | update_dim(output, dim, j, output_name) 146 | else: 147 | make_dim(output, output_dim_arr, output_name) 148 | 149 | onnx.checker.check_model(model) 150 | return model 151 | 152 | 153 | def io_change( 154 | input_onnx_file_path: Optional[str] = '', 155 | onnx_graph: Optional[onnx.ModelProto] = None, 156 | output_onnx_file_path: Optional[str] = '', 157 | input_names: Optional[List[str]] = [], 158 | input_shapes: Optional[List[Union[int, str]]] = [], 159 | output_names: Optional[List[str]] = [], 160 | output_shapes: Optional[List[Union[int, str]]] = [], 161 | non_verbose: Optional[bool] = False, 162 | ) -> onnx.ModelProto: 163 | """ 164 | 165 | Parameters 166 | ---------- 167 | input_onnx_file_path: Optional[str] 168 | Input onnx file path.\n\ 169 | Either input_onnx_file_path or onnx_graph must be specified.\n\ 170 | Default: '' 171 | 172 | onnx_graph: Optional[onnx.ModelProto] 173 | onnx.ModelProto.\n\ 174 | Either input_onnx_file_path or onnx_graph must be specified.\n\ 175 | onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph. 176 | 177 | output_onnx_file_path: Optional[str] 178 | Output onnx file path. If not specified, no ONNX file is output.\n\ 179 | Default: '' 180 | 181 | input_names: Optional[List[str]] 182 | List of input OP names. All input OPs of the model must be specified.\n\ 183 | The order is unspecified, but must match the order specified for input_shapes.\n\ 184 | e.g. ['input.A', 'input.B', 'input.C'] 185 | 186 | input_shapes: Optional[List[Union[int, str]]] 187 | List of input OP shapes. All input OPs of the model must be specified.\n\ 188 | The order is unspecified, but must match the order specified for input_names.\n\ 189 | e.g.\n\ 190 | [\n\ 191 | [1, 3, 'H', 'W'],\n\ 192 | ['N', 3, 'H', 'W'],\n\ 193 | ['-1', 3, 480, 640],\n\ 194 | ] 195 | 196 | output_names: Optional[List[str]] 197 | List of output OP names. All output OPs of the model must be specified.\n\ 198 | The order is unspecified, but must match the order specified for output_shapes.\n\ 199 | e.g. ['output.a', 'output.b', 'output.c'] 200 | 201 | output_shapes: Optional[List[Union[int, str]]] 202 | List of input OP shapes. All output OPs of the model must be specified.\n\ 203 | The order is unspecified, but must match the order specified for output_shapes.\n\ 204 | e.g.\n\ 205 | [\n\ 206 | [1, 3, 'H', 'W'],\n\ 207 | ['N', 3, 'H', 'W'],\n\ 208 | ['-1', 3, 480, 640],\n\ 209 | ] 210 | 211 | non_verbose: Optional[bool] 212 | Do not show all information logs. Only error logs are displayed.\n\ 213 | Default: False 214 | 215 | Returns 216 | ------- 217 | io_changed_graph: onnx.ModelProto 218 | onnx ModelProto with modified INPUT and OUTPUT shapes. 219 | """ 220 | 221 | # Unspecified check for input_onnx_file_path and onnx_graph 222 | if not input_onnx_file_path and not onnx_graph: 223 | print( 224 | f'{Color.RED}ERROR:{Color.RESET} '+ 225 | f'One of input_onnx_file_path or onnx_graph must be specified.' 226 | ) 227 | sys.exit(1) 228 | 229 | # Other check 230 | if input_names is None or len(input_names) == 0: 231 | print( 232 | f'{Color.RED}ERROR:{Color.RESET} '+ 233 | f'At least one input_names must be specified.' 234 | ) 235 | sys.exit(1) 236 | 237 | if input_shapes is None or len(input_shapes) == 0: 238 | print( 239 | f'{Color.RED}ERROR:{Color.RESET} '+ 240 | f'At least one input_shapes must be specified.' 241 | ) 242 | sys.exit(1) 243 | 244 | if len(input_names) != len(input_shapes): 245 | print( 246 | f'{Color.RED}ERROR:{Color.RESET} '+ 247 | f'The number of input_names and input_shapes must match.' 248 | ) 249 | sys.exit(1) 250 | 251 | if output_names is None or len(output_names) == 0: 252 | print( 253 | f'{Color.RED}ERROR:{Color.RESET} '+ 254 | f'At least one output_names must be specified.' 255 | ) 256 | sys.exit(1) 257 | 258 | if output_shapes is None or len(output_shapes) == 0: 259 | print( 260 | f'{Color.RED}ERROR:{Color.RESET} '+ 261 | f'At least one output_shapes must be specified.' 262 | ) 263 | sys.exit(1) 264 | 265 | if len(output_names) != len(output_shapes): 266 | print( 267 | f'{Color.RED}ERROR:{Color.RESET} '+ 268 | f'The number of output_names and output_shapes must match.' 269 | ) 270 | sys.exit(1) 271 | 272 | # Loading Graphs 273 | # onnx_graph If specified, onnx_graph is processed first 274 | if not onnx_graph: 275 | onnx_graph = onnx.load(input_onnx_file_path) 276 | 277 | input_dicts = {name:shape for (name, shape) in zip(input_names, input_shapes)} 278 | output_dicts = {name:shape for (name, shape) in zip(output_names, output_shapes)} 279 | 280 | updated_model = update_inputs_outputs_dims( 281 | model=onnx_graph, 282 | input_dims=input_dicts, 283 | output_dims=output_dicts, 284 | ) 285 | 286 | # Shape Estimation 287 | io_changed_graph = None 288 | try: 289 | io_changed_graph = onnx.shape_inference.infer_shapes(updated_model) 290 | except: 291 | if not non_verbose: 292 | print( 293 | f'{Color.YELLOW}WARNING:{Color.RESET} '+ 294 | 'The input shape of the next OP does not match the output shape. '+ 295 | 'Be sure to open the .onnx file to verify the certainty of the geometry.' 296 | ) 297 | 298 | # Save 299 | if output_onnx_file_path: 300 | onnx.save(io_changed_graph, output_onnx_file_path) 301 | 302 | if not non_verbose: 303 | print(f'{Color.GREEN}INFO:{Color.RESET} Finish!') 304 | 305 | # Return 306 | return io_changed_graph 307 | 308 | 309 | def main(): 310 | parser = ArgumentParser() 311 | parser.add_argument( 312 | '-if', 313 | '--input_onnx_file_path', 314 | type=str, 315 | required=True, 316 | help='INPUT ONNX file path' 317 | ) 318 | parser.add_argument( 319 | '-of', 320 | '--output_onnx_file_path', 321 | type=str, 322 | required=True, 323 | help='OUTPUT ONNX file path' 324 | ) 325 | parser.add_argument( 326 | '-i', 327 | '--input_names', 328 | type=str, 329 | action='append', 330 | required=True, 331 | help='\ 332 | List of input OP names. All input OPs of the model must be specified.\ 333 | The order is unspecified, but must match the order specified for input_shapes. \ 334 | e.g.\ 335 | --input_names "input.A" \ 336 | --input_names "input.B" \ 337 | --input_names "input.C"' 338 | ) 339 | parser.add_argument( 340 | '-is', 341 | '--input_shapes', 342 | type=str, 343 | nargs='+', 344 | action='append', 345 | required=True, 346 | help='\ 347 | List of input OP shapes. All input OPs of the model must be specified. \ 348 | The order is unspecified, but must match the order specified for input_names. \ 349 | e.g. \ 350 | --input_shapes 1 3 "H" "W" \ 351 | --input_shapes "N" 3 "H" "W" \ 352 | --input_shapes "-1" 3 480 640' 353 | ) 354 | parser.add_argument( 355 | '-o', 356 | '--output_names', 357 | type=str, 358 | action='append', 359 | required=True, 360 | help='\ 361 | List of output OP names. All output OPs of the model must be specified. \ 362 | The order is unspecified, but must match the order specified for output_shapes. \ 363 | e.g. \ 364 | --output_names "output.a" \ 365 | --output_names "output.b" \ 366 | --output_names "output.c"' 367 | ) 368 | parser.add_argument( 369 | '-os', 370 | '--output_shapes', 371 | type=str, 372 | nargs='+', 373 | action='append', 374 | required=True, 375 | help='\ 376 | List of input OP shapes. All output OPs of the model must be specified. \ 377 | The order is unspecified, but must match the order specified for output_shapes. \ 378 | e.g. \ 379 | --output_shapes 1 3 "H" "W" \ 380 | --output_shapes "N", 3, "H", "W" \ 381 | --output_shapes "-1" 3 480 640' 382 | ) 383 | parser.add_argument( 384 | '-n', 385 | '--non_verbose', 386 | action='store_true', 387 | help='Do not show all information logs. Only error logs are displayed.' 388 | ) 389 | args = parser.parse_args() 390 | 391 | input_onnx_file_path = args.input_onnx_file_path 392 | output_onnx_file_path = args.output_onnx_file_path 393 | input_names = args.input_names 394 | output_names = args.output_names 395 | non_verbose = args.non_verbose 396 | 397 | input_shapes = [] 398 | for src in args.input_shapes: 399 | input_shape = [] 400 | for s in src: 401 | try: 402 | val = literal_eval(s) 403 | if isinstance(val, int) and val >= 0: 404 | input_shape.append(val) 405 | else: 406 | input_shape.append(s) 407 | except: 408 | input_shape.append(s) 409 | input_shapes.append(input_shape) 410 | 411 | output_shapes = [] 412 | for src in args.output_shapes: 413 | output_shape = [] 414 | for s in src: 415 | try: 416 | val = literal_eval(s) 417 | if isinstance(val, int) and val >= 0: 418 | output_shape.append(val) 419 | else: 420 | output_shape.append(s) 421 | except: 422 | output_shape.append(s) 423 | output_shapes.append(output_shape) 424 | 425 | input_name_list = [name for name in input_names] 426 | input_shape_list = [name for name in input_shapes] 427 | output_name_list = [name for name in output_names] 428 | output_shape_list = [name for name in output_shapes] 429 | 430 | if not output_onnx_file_path: 431 | output_onnx_file_path = input_onnx_file_path 432 | 433 | # Load 434 | onnx_graph = onnx.load(input_onnx_file_path) 435 | 436 | # change 437 | io_changed_graph = io_change( 438 | input_onnx_file_path=None, 439 | onnx_graph=onnx_graph, 440 | output_onnx_file_path=output_onnx_file_path, 441 | input_names=input_name_list, 442 | input_shapes=input_shape_list, 443 | output_names=output_name_list, 444 | output_shapes=output_shape_list, 445 | non_verbose=non_verbose, 446 | ) 447 | 448 | 449 | if __name__ == '__main__': 450 | main() --------------------------------------------------------------------------------