├── sio4onnx
├── __main__.py
├── __init__.py
└── onnx_input_output_variable_changer.py
├── .gitignore
├── .github
└── workflows
│ ├── python-publish.yml
│ └── codeql-analysis.yml
├── setup.py
├── LICENSE
└── README.md
/sio4onnx/__main__.py:
--------------------------------------------------------------------------------
1 | from . import main
2 |
3 | if __name__ == '__main__':
4 | main()
--------------------------------------------------------------------------------
/sio4onnx/__init__.py:
--------------------------------------------------------------------------------
1 | from sio4onnx.onnx_input_output_variable_changer import io_change, main
2 |
3 | __version__ = '1.0.3'
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | saved_model/
3 | build/
4 | dist/
5 | sio4onnx.egg-info/
6 | sio4onnx/debug/
7 | sio4onnx/saved_model/
8 | __pycache__/
9 | test.py
10 |
11 | *.onnx
12 | *.npy
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | deploy:
9 |
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v3
14 | - name: Set up Python
15 | uses: actions/setup-python@v3
16 | with:
17 | python-version: '3.x'
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install setuptools wheel pipenv
22 | - name: Build
23 | run: |
24 | python setup.py sdist bdist_wheel
25 | - name: Publish a Python distribution to PyPI
26 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
27 | with:
28 | user: __token__
29 | password: ${{ secrets.PYPI_API_TOKEN }}
30 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 |
2 | from setuptools import setup, find_packages
3 | from os import path
4 | import re
5 |
6 | package_name="sio4onnx"
7 | root_dir = path.abspath(path.dirname(__file__))
8 |
9 | with open("README.md") as f:
10 | long_description = f.read()
11 |
12 | with open(path.join(root_dir, package_name, '__init__.py')) as f:
13 | init_text = f.read()
14 | version = re.search(r'__version__\s*=\s*[\'\"](.+?)[\'\"]', init_text).group(1)
15 |
16 | setup(
17 | name=package_name,
18 | version=version,
19 | description=\
20 | "Simple tool to change the INPUT and OUTPUT shape of ONNX.",
21 | long_description=long_description,
22 | long_description_content_type="text/markdown",
23 | author="Katsuya Hyodo",
24 | author_email="rmsdh122@yahoo.co.jp",
25 | url="https://github.com/PINTO0309/sio4onnx",
26 | license="MIT License",
27 | packages=find_packages(),
28 | platforms=["linux", "unix"],
29 | python_requires=">=3.6",
30 | entry_points={
31 | 'console_scripts': [
32 | "sio4onnx=sio4onnx:main"
33 | ]
34 | }
35 | )
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Katsuya Hyodo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ "main" ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ "main" ]
20 | schedule:
21 | - cron: '27 0 * * 2'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
38 |
39 | steps:
40 | - name: Checkout repository
41 | uses: actions/checkout@v3
42 |
43 | # Initializes the CodeQL tools for scanning.
44 | - name: Initialize CodeQL
45 | uses: github/codeql-action/init@v2
46 | with:
47 | languages: ${{ matrix.language }}
48 | # If you wish to specify custom queries, you can do so here or in a config file.
49 | # By default, queries listed here will override any specified in a config file.
50 | # Prefix the list here with "+" to use these queries and those in the config file.
51 |
52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
53 | # queries: security-extended,security-and-quality
54 |
55 |
56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
57 | # If this step fails, then you should remove it and run the build manually (see below)
58 | - name: Autobuild
59 | uses: github/codeql-action/autobuild@v2
60 |
61 | # ℹ️ Command-line programs to run using the OS shell.
62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
63 |
64 | # If the Autobuild fails above, remove it and uncomment the following three lines.
65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
66 |
67 | # - run: |
68 | # echo "Run, Build Application using script"
69 | # ./location_of_script_within_repo/buildscript.sh
70 |
71 | - name: Perform CodeQL Analysis
72 | uses: github/codeql-action/analyze@v2
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sio4onnx
2 | Simple tool to change the INPUT and OUTPUT shape of ONNX.
3 |
4 | https://github.com/PINTO0309/simple-onnx-processing-tools
5 |
6 | [](https://pepy.tech/project/sio4onnx)  [](https://pypi.org/project/sio4onnx/) [](https://github.com/PINTO0309/sio4onnx/actions?query=workflow%3ACodeQL)
7 |
8 |
9 |
10 |
11 |
12 | ## 1. Setup
13 | ### 1-1. HostPC
14 | ```bash
15 | ### option
16 | $ echo export PATH="~/.local/bin:$PATH" >> ~/.bashrc \
17 | && source ~/.bashrc
18 |
19 | ### run
20 | $ pip install -U onnx \
21 | && pip install -U sio4onnx
22 | ```
23 | ### 1-2. Docker
24 | https://github.com/PINTO0309/simple-onnx-processing-tools#docker
25 |
26 | ## 2. CLI Usage
27 | ```bash
28 | $ sio4onnx -h
29 |
30 | usage:
31 | sio4onnx [-h]
32 | -if INPUT_ONNX_FILE_PATH
33 | -of OUTPUT_ONNX_FILE_PATH
34 | -i INPUT_NAMES
35 | -is INPUT_SHAPES [INPUT_SHAPES ...]
36 | -o OUTPUT_NAMES
37 | -os OUTPUT_SHAPES [OUTPUT_SHAPES ...]
38 | [-n]
39 |
40 | optional arguments:
41 | -h, --help
42 | Show this help message and exit.
43 |
44 | -if INPUT_ONNX_FILE_PATH, --input_onnx_file_path INPUT_ONNX_FILE_PATH
45 | INPUT ONNX file path
46 |
47 | -of OUTPUT_ONNX_FILE_PATH, --output_onnx_file_path OUTPUT_ONNX_FILE_PATH
48 | OUTPUT ONNX file path
49 |
50 | -i INPUT_NAMES, --input_names INPUT_NAMES
51 | List of input OP names. All input OPs of the model must be specified.
52 | The order is unspecified, but must match the order specified for input_shapes.
53 | e.g.
54 | --input_names "input.A" \
55 | --input_names "input.B" \
56 | --input_names "input.C"
57 |
58 | -is INPUT_SHAPES [INPUT_SHAPES ...], --input_shapes INPUT_SHAPES [INPUT_SHAPES ...]
59 | List of input OP shapes. All input OPs of the model must be specified.
60 | The order is unspecified, but must match the order specified for input_names.
61 | e.g.
62 | --input_shapes 1 3 "H" "W" \
63 | --input_shapes "N" 3 "H" "W" \
64 | --input_shapes "-1" 3 480 640
65 |
66 | -o OUTPUT_NAMES, --output_names OUTPUT_NAMES
67 | List of output OP names. All output OPs of the model must be specified.
68 | The order is unspecified, but must match the order specified for output_shapes.
69 | e.g.
70 | --output_names "output.a" \
71 | --output_names "output.b" \
72 | --output_names "output.c"
73 |
74 | -os OUTPUT_SHAPES [OUTPUT_SHAPES ...], --output_shapes OUTPUT_SHAPES [OUTPUT_SHAPES ...]
75 | List of input OP shapes. All output OPs of the model must be specified.
76 | The order is unspecified, but must match the order specified for output_shapes.
77 | e.g.
78 | --output_shapes 1 3 "H" "W" \
79 | --output_shapes "N", 3, "H", "W" \
80 | --output_shapes "-1" 3 480 640
81 |
82 | -n, --non_verbose
83 | Do not show all information logs. Only error logs are displayed.
84 | ```
85 |
86 | ## 3. In-script Usage
87 | ```python
88 | >>> from sio4onnx import io_change
89 | >>> help(io_change)
90 |
91 | Help on function io_change in module sio4onnx.onnx_input_output_variable_changer:
92 |
93 | io_change(
94 | input_onnx_file_path: Union[str, NoneType] = '',
95 | onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None,
96 | output_onnx_file_path: Union[str, NoneType] = '',
97 | input_names: Union[List[str], NoneType] = [],
98 | input_shapes: Union[List[Union[int, str]], NoneType] = [],
99 | output_names: Union[List[str], NoneType] = [],
100 | output_shapes: Union[List[Union[int, str]], NoneType] = [],
101 | non_verbose: Union[bool, NoneType] = False,
102 | ) -> onnx.onnx_ml_pb2.ModelProto
103 |
104 | Parameters
105 | ----------
106 | input_onnx_file_path: Optional[str]
107 | Input onnx file path.
108 | Either input_onnx_file_path or onnx_graph must be specified.
109 | Default: ''
110 |
111 | onnx_graph: Optional[onnx.ModelProto]
112 | onnx.ModelProto.
113 | Either input_onnx_file_path or onnx_graph must be specified.
114 | onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
115 |
116 | output_onnx_file_path: Optional[str]
117 | Output onnx file path. If not specified, no ONNX file is output.
118 | Default: ''
119 |
120 | input_names: Optional[List[str]]
121 | List of input OP names. All input OPs of the model must be specified.
122 | The order is unspecified, but must match the order specified for input_shapes.
123 | e.g. ['input.A', 'input.B', 'input.C']
124 |
125 | input_shapes: Optional[List[Union[int, str]]]
126 | List of input OP shapes. All input OPs of the model must be specified.
127 | The order is unspecified, but must match the order specified for input_names.
128 | e.g.
129 | [
130 | [1, 3, 'H', 'W'],
131 | ['N', 3, 'H', 'W'],
132 | ['-1', 3, 480, 640],
133 | ]
134 |
135 | output_names: Optional[List[str]]
136 | List of output OP names. All output OPs of the model must be specified.
137 | The order is unspecified, but must match the order specified for output_shapes.
138 | e.g. ['output.a', 'output.b', 'output.c']
139 |
140 | output_shapes: Optional[List[Union[int, str]]]
141 | List of input OP shapes. All output OPs of the model must be specified.
142 | The order is unspecified, but must match the order specified for output_shapes.
143 | e.g.
144 | [
145 | [1, 3, 'H', 'W'],
146 | ['N', 3, 'H', 'W'],
147 | ['-1', 3, 480, 640],
148 | ]
149 |
150 | non_verbose: Optional[bool]
151 | Do not show all information logs. Only error logs are displayed.
152 | Default: False
153 |
154 | Returns
155 | -------
156 | io_changed_graph: onnx.ModelProto
157 | onnx ModelProto with modified INPUT and OUTPUT shapes.
158 | ```
159 |
160 | ## 4. CLI Execution
161 | ```bash
162 | $ sio4onnx \
163 | --input_onnx_file_path yolov3-10.onnx \
164 | --output_onnx_file_path yolov3-10_upd.onnx \
165 | --input_names "input_1" \
166 | --input_names "image_shape" \
167 | --input_shapes "batch" 3 "H" "W" \
168 | --input_shapes "batch" 2 \
169 | --output_names "yolonms_layer_1/ExpandDims_1:0" \
170 | --output_names "yolonms_layer_1/ExpandDims_3:0" \
171 | --output_names "yolonms_layer_1/concat_2:0" \
172 | --output_shapes 1 "boxes" 4 \
173 | --output_shapes 1 "classes" "boxes" \
174 | --output_shapes "boxes" 3
175 | ```
176 |
177 | ## 5. In-script Execution
178 | ```python
179 | from sio4onnx import io_change
180 |
181 | io_changed_graph = io_change(
182 | input_onnx_file_path="yolov3-10.onnx",
183 | output_onnx_file_path="yolov3-10_upd.onnx",
184 | input_names=[
185 | "input_1",
186 | "image_shape",
187 | ],
188 | input_shapes=[
189 | ["batch", 3, "H", "W"],
190 | ["batch", 2],
191 | ],
192 | output_names=[
193 | "yolonms_layer_1/ExpandDims_1:0",
194 | "yolonms_layer_1/ExpandDims_3:0",
195 | "yolonms_layer_1/concat_2:0",
196 | ],
197 | output_shapes=[
198 | [1, "boxes", 4],
199 | [1, "classes", "boxes"],
200 | ["boxes", 3],
201 | ],
202 | )
203 | ```
204 | ## 6. Sample
205 | ### Before
206 | 
207 | ### After
208 | 
209 |
--------------------------------------------------------------------------------
/sio4onnx/onnx_input_output_variable_changer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | import sys
4 | import onnx
5 | from onnx import ModelProto, ValueInfoProto, TensorShapeProto
6 | from typing import Optional, List, Union, Any
7 | from argparse import ArgumentParser
8 | from ast import literal_eval
9 |
10 | class Color:
11 | BLACK = '\033[30m'
12 | RED = '\033[31m'
13 | GREEN = '\033[32m'
14 | YELLOW = '\033[33m'
15 | BLUE = '\033[34m'
16 | MAGENTA = '\033[35m'
17 | CYAN = '\033[36m'
18 | WHITE = '\033[37m'
19 | COLOR_DEFAULT = '\033[39m'
20 | BOLD = '\033[1m'
21 | UNDERLINE = '\033[4m'
22 | INVISIBLE = '\033[08m'
23 | REVERCE = '\033[07m'
24 | BG_BLACK = '\033[40m'
25 | BG_RED = '\033[41m'
26 | BG_GREEN = '\033[42m'
27 | BG_YELLOW = '\033[43m'
28 | BG_BLUE = '\033[44m'
29 | BG_MAGENTA = '\033[45m'
30 | BG_CYAN = '\033[46m'
31 | BG_WHITE = '\033[47m'
32 | BG_DEFAULT = '\033[49m'
33 | RESET = '\033[0m'
34 |
35 |
36 | def update_inputs_outputs_dims(
37 | model: ModelProto,
38 | input_dims: dict[str, list[Any]],
39 | output_dims: dict[str, list[Any]],
40 | ) -> ModelProto:
41 | """This function updates the dimension sizes of the model's inputs and outputs to the values
42 | provided in input_dims and output_dims. if the dim value provided is negative, a unique dim_param
43 | will be set for that dimension.
44 |
45 | Example. if we have the following shape for inputs and outputs:
46 |
47 | * shape(input_1) = ('b', 3, 'w', 'h')
48 | * shape(input_2) = ('b', 4)
49 | * shape(output) = ('b', 'd', 5)
50 |
51 | The parameters can be provided as:
52 |
53 | ::
54 |
55 | input_dims = {
56 | "input_1": ['b', 3, 'w', 'h'],
57 | "input_2": ['b', 4],
58 | }
59 | output_dims = {
60 | "output": ['b', -1, 5]
61 | }
62 |
63 | Putting it together:
64 |
65 | ::
66 |
67 | model = onnx.load('model.onnx')
68 | updated_model = update_inputs_outputs_dims(model, input_dims, output_dims)
69 | onnx.save(updated_model, 'model.onnx')
70 | """
71 | dim_param_set: set[str] = set()
72 |
73 | def init_dim_param_set(
74 | dim_param_set: set[str], value_infos: list[ValueInfoProto]
75 | ) -> None:
76 | for info in value_infos:
77 | shape = info.type.tensor_type.shape
78 | for dim in shape.dim:
79 | if dim.HasField("dim_param"):
80 | dim_param_set.add(dim.dim_param)
81 |
82 | init_dim_param_set(dim_param_set, model.graph.input)
83 | init_dim_param_set(dim_param_set, model.graph.output)
84 | init_dim_param_set(dim_param_set, model.graph.value_info)
85 |
86 | def update_dim(tensor: ValueInfoProto, dim: Any, j: int, name: str) -> None:
87 | dim_proto = tensor.type.tensor_type.shape.dim[j]
88 | if isinstance(dim, int):
89 | if dim >= 0:
90 | if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim:
91 | raise ValueError(
92 | f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}."
93 | )
94 | dim_proto.dim_value = dim
95 | else:
96 | generated_dim_param = name + "_" + str(j)
97 | if generated_dim_param in dim_param_set:
98 | raise ValueError(
99 | f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value."
100 | )
101 | dim_proto.dim_param = generated_dim_param
102 | elif isinstance(dim, str):
103 | dim_proto.dim_param = dim
104 | else:
105 | raise ValueError(
106 | f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}"
107 | )
108 |
109 | def make_dim(tensor: ValueInfoProto, output_dim_arr: Any, name: str) -> None:
110 | make_dim_list = []
111 | for j, dim in enumerate(output_dim_arr):
112 | if isinstance(dim, int):
113 | if dim >= 0:
114 | make_dim = TensorShapeProto.Dimension(dim_value=dim)
115 | make_dim_list.append(make_dim)
116 | else:
117 | make_dim = TensorShapeProto.Dimension(dim_param=str(dim))
118 | make_dim_list.append(make_dim)
119 | elif isinstance(dim, str):
120 | make_dim = TensorShapeProto.Dimension(dim_param=dim)
121 | make_dim_list.append(make_dim)
122 | else:
123 | raise ValueError(
124 | f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}"
125 | )
126 | make_tensor_shape_proto = TensorShapeProto(dim=make_dim_list)
127 | tensor.type.tensor_type.shape.MergeFrom(make_tensor_shape_proto)
128 |
129 | for input_ in model.graph.input:
130 | input_name = input_.name
131 | input_dim_arr = input_dims[input_name]
132 |
133 | if input_.type.tensor_type.shape.dim != []:
134 | for j, dim in enumerate(input_dim_arr):
135 | update_dim(input_, dim, j, input_name)
136 | else:
137 | make_dim(input_, input_dim_arr, input_name)
138 |
139 | for output in model.graph.output:
140 | output_name = output.name
141 | output_dim_arr = output_dims[output_name]
142 |
143 | if output.type.tensor_type.shape.dim != []:
144 | for j, dim in enumerate(output_dim_arr):
145 | update_dim(output, dim, j, output_name)
146 | else:
147 | make_dim(output, output_dim_arr, output_name)
148 |
149 | onnx.checker.check_model(model)
150 | return model
151 |
152 |
153 | def io_change(
154 | input_onnx_file_path: Optional[str] = '',
155 | onnx_graph: Optional[onnx.ModelProto] = None,
156 | output_onnx_file_path: Optional[str] = '',
157 | input_names: Optional[List[str]] = [],
158 | input_shapes: Optional[List[Union[int, str]]] = [],
159 | output_names: Optional[List[str]] = [],
160 | output_shapes: Optional[List[Union[int, str]]] = [],
161 | non_verbose: Optional[bool] = False,
162 | ) -> onnx.ModelProto:
163 | """
164 |
165 | Parameters
166 | ----------
167 | input_onnx_file_path: Optional[str]
168 | Input onnx file path.\n\
169 | Either input_onnx_file_path or onnx_graph must be specified.\n\
170 | Default: ''
171 |
172 | onnx_graph: Optional[onnx.ModelProto]
173 | onnx.ModelProto.\n\
174 | Either input_onnx_file_path or onnx_graph must be specified.\n\
175 | onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
176 |
177 | output_onnx_file_path: Optional[str]
178 | Output onnx file path. If not specified, no ONNX file is output.\n\
179 | Default: ''
180 |
181 | input_names: Optional[List[str]]
182 | List of input OP names. All input OPs of the model must be specified.\n\
183 | The order is unspecified, but must match the order specified for input_shapes.\n\
184 | e.g. ['input.A', 'input.B', 'input.C']
185 |
186 | input_shapes: Optional[List[Union[int, str]]]
187 | List of input OP shapes. All input OPs of the model must be specified.\n\
188 | The order is unspecified, but must match the order specified for input_names.\n\
189 | e.g.\n\
190 | [\n\
191 | [1, 3, 'H', 'W'],\n\
192 | ['N', 3, 'H', 'W'],\n\
193 | ['-1', 3, 480, 640],\n\
194 | ]
195 |
196 | output_names: Optional[List[str]]
197 | List of output OP names. All output OPs of the model must be specified.\n\
198 | The order is unspecified, but must match the order specified for output_shapes.\n\
199 | e.g. ['output.a', 'output.b', 'output.c']
200 |
201 | output_shapes: Optional[List[Union[int, str]]]
202 | List of input OP shapes. All output OPs of the model must be specified.\n\
203 | The order is unspecified, but must match the order specified for output_shapes.\n\
204 | e.g.\n\
205 | [\n\
206 | [1, 3, 'H', 'W'],\n\
207 | ['N', 3, 'H', 'W'],\n\
208 | ['-1', 3, 480, 640],\n\
209 | ]
210 |
211 | non_verbose: Optional[bool]
212 | Do not show all information logs. Only error logs are displayed.\n\
213 | Default: False
214 |
215 | Returns
216 | -------
217 | io_changed_graph: onnx.ModelProto
218 | onnx ModelProto with modified INPUT and OUTPUT shapes.
219 | """
220 |
221 | # Unspecified check for input_onnx_file_path and onnx_graph
222 | if not input_onnx_file_path and not onnx_graph:
223 | print(
224 | f'{Color.RED}ERROR:{Color.RESET} '+
225 | f'One of input_onnx_file_path or onnx_graph must be specified.'
226 | )
227 | sys.exit(1)
228 |
229 | # Other check
230 | if input_names is None or len(input_names) == 0:
231 | print(
232 | f'{Color.RED}ERROR:{Color.RESET} '+
233 | f'At least one input_names must be specified.'
234 | )
235 | sys.exit(1)
236 |
237 | if input_shapes is None or len(input_shapes) == 0:
238 | print(
239 | f'{Color.RED}ERROR:{Color.RESET} '+
240 | f'At least one input_shapes must be specified.'
241 | )
242 | sys.exit(1)
243 |
244 | if len(input_names) != len(input_shapes):
245 | print(
246 | f'{Color.RED}ERROR:{Color.RESET} '+
247 | f'The number of input_names and input_shapes must match.'
248 | )
249 | sys.exit(1)
250 |
251 | if output_names is None or len(output_names) == 0:
252 | print(
253 | f'{Color.RED}ERROR:{Color.RESET} '+
254 | f'At least one output_names must be specified.'
255 | )
256 | sys.exit(1)
257 |
258 | if output_shapes is None or len(output_shapes) == 0:
259 | print(
260 | f'{Color.RED}ERROR:{Color.RESET} '+
261 | f'At least one output_shapes must be specified.'
262 | )
263 | sys.exit(1)
264 |
265 | if len(output_names) != len(output_shapes):
266 | print(
267 | f'{Color.RED}ERROR:{Color.RESET} '+
268 | f'The number of output_names and output_shapes must match.'
269 | )
270 | sys.exit(1)
271 |
272 | # Loading Graphs
273 | # onnx_graph If specified, onnx_graph is processed first
274 | if not onnx_graph:
275 | onnx_graph = onnx.load(input_onnx_file_path)
276 |
277 | input_dicts = {name:shape for (name, shape) in zip(input_names, input_shapes)}
278 | output_dicts = {name:shape for (name, shape) in zip(output_names, output_shapes)}
279 |
280 | updated_model = update_inputs_outputs_dims(
281 | model=onnx_graph,
282 | input_dims=input_dicts,
283 | output_dims=output_dicts,
284 | )
285 |
286 | # Shape Estimation
287 | io_changed_graph = None
288 | try:
289 | io_changed_graph = onnx.shape_inference.infer_shapes(updated_model)
290 | except:
291 | if not non_verbose:
292 | print(
293 | f'{Color.YELLOW}WARNING:{Color.RESET} '+
294 | 'The input shape of the next OP does not match the output shape. '+
295 | 'Be sure to open the .onnx file to verify the certainty of the geometry.'
296 | )
297 |
298 | # Save
299 | if output_onnx_file_path:
300 | onnx.save(io_changed_graph, output_onnx_file_path)
301 |
302 | if not non_verbose:
303 | print(f'{Color.GREEN}INFO:{Color.RESET} Finish!')
304 |
305 | # Return
306 | return io_changed_graph
307 |
308 |
309 | def main():
310 | parser = ArgumentParser()
311 | parser.add_argument(
312 | '-if',
313 | '--input_onnx_file_path',
314 | type=str,
315 | required=True,
316 | help='INPUT ONNX file path'
317 | )
318 | parser.add_argument(
319 | '-of',
320 | '--output_onnx_file_path',
321 | type=str,
322 | required=True,
323 | help='OUTPUT ONNX file path'
324 | )
325 | parser.add_argument(
326 | '-i',
327 | '--input_names',
328 | type=str,
329 | action='append',
330 | required=True,
331 | help='\
332 | List of input OP names. All input OPs of the model must be specified.\
333 | The order is unspecified, but must match the order specified for input_shapes. \
334 | e.g.\
335 | --input_names "input.A" \
336 | --input_names "input.B" \
337 | --input_names "input.C"'
338 | )
339 | parser.add_argument(
340 | '-is',
341 | '--input_shapes',
342 | type=str,
343 | nargs='+',
344 | action='append',
345 | required=True,
346 | help='\
347 | List of input OP shapes. All input OPs of the model must be specified. \
348 | The order is unspecified, but must match the order specified for input_names. \
349 | e.g. \
350 | --input_shapes 1 3 "H" "W" \
351 | --input_shapes "N" 3 "H" "W" \
352 | --input_shapes "-1" 3 480 640'
353 | )
354 | parser.add_argument(
355 | '-o',
356 | '--output_names',
357 | type=str,
358 | action='append',
359 | required=True,
360 | help='\
361 | List of output OP names. All output OPs of the model must be specified. \
362 | The order is unspecified, but must match the order specified for output_shapes. \
363 | e.g. \
364 | --output_names "output.a" \
365 | --output_names "output.b" \
366 | --output_names "output.c"'
367 | )
368 | parser.add_argument(
369 | '-os',
370 | '--output_shapes',
371 | type=str,
372 | nargs='+',
373 | action='append',
374 | required=True,
375 | help='\
376 | List of input OP shapes. All output OPs of the model must be specified. \
377 | The order is unspecified, but must match the order specified for output_shapes. \
378 | e.g. \
379 | --output_shapes 1 3 "H" "W" \
380 | --output_shapes "N", 3, "H", "W" \
381 | --output_shapes "-1" 3 480 640'
382 | )
383 | parser.add_argument(
384 | '-n',
385 | '--non_verbose',
386 | action='store_true',
387 | help='Do not show all information logs. Only error logs are displayed.'
388 | )
389 | args = parser.parse_args()
390 |
391 | input_onnx_file_path = args.input_onnx_file_path
392 | output_onnx_file_path = args.output_onnx_file_path
393 | input_names = args.input_names
394 | output_names = args.output_names
395 | non_verbose = args.non_verbose
396 |
397 | input_shapes = []
398 | for src in args.input_shapes:
399 | input_shape = []
400 | for s in src:
401 | try:
402 | val = literal_eval(s)
403 | if isinstance(val, int) and val >= 0:
404 | input_shape.append(val)
405 | else:
406 | input_shape.append(s)
407 | except:
408 | input_shape.append(s)
409 | input_shapes.append(input_shape)
410 |
411 | output_shapes = []
412 | for src in args.output_shapes:
413 | output_shape = []
414 | for s in src:
415 | try:
416 | val = literal_eval(s)
417 | if isinstance(val, int) and val >= 0:
418 | output_shape.append(val)
419 | else:
420 | output_shape.append(s)
421 | except:
422 | output_shape.append(s)
423 | output_shapes.append(output_shape)
424 |
425 | input_name_list = [name for name in input_names]
426 | input_shape_list = [name for name in input_shapes]
427 | output_name_list = [name for name in output_names]
428 | output_shape_list = [name for name in output_shapes]
429 |
430 | if not output_onnx_file_path:
431 | output_onnx_file_path = input_onnx_file_path
432 |
433 | # Load
434 | onnx_graph = onnx.load(input_onnx_file_path)
435 |
436 | # change
437 | io_changed_graph = io_change(
438 | input_onnx_file_path=None,
439 | onnx_graph=onnx_graph,
440 | output_onnx_file_path=output_onnx_file_path,
441 | input_names=input_name_list,
442 | input_shapes=input_shape_list,
443 | output_names=output_name_list,
444 | output_shapes=output_shape_list,
445 | non_verbose=non_verbose,
446 | )
447 |
448 |
449 | if __name__ == '__main__':
450 | main()
--------------------------------------------------------------------------------