├── .gitignore
├── LICENSE
├── onnx2tflite
├── __init__.py
├── __main__.py
├── components
│ ├── __init__.py
│ ├── builder.py
│ ├── dataloader.py
│ ├── onnx_loader.py
│ └── output_check.py
├── converter.py
├── layers
│ ├── __init__.py
│ ├── activations_layers.py
│ ├── common_layers.py
│ ├── conv_layers.py
│ ├── deformation_layers.py
│ └── mathematics_layers.py
└── utils
│ ├── __init__.py
│ ├── definitions.py
│ ├── dimension_utils.py
│ ├── graph_tools.py
│ └── op_registry.py
├── readme.md
├── requirements.txt
├── setup.py
└── test
├── test_concat.py
├── test_reshape_transpose.py
├── test_squeeze.py
└── test_torchvison.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.onnx
2 | *.tflite
3 | __pycache__/
4 | .ipynb_checkpoints/
5 | test.py
6 | gen_model.py
7 | models/
8 | unit_test/
9 | onnx2tflite.egg-info/
10 | build/
11 | dist/
12 | main.py
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [MPolaris] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/onnx2tflite/__init__.py:
--------------------------------------------------------------------------------
1 | __VERSION__ = "2.0"
2 |
3 | from .converter import onnx_converter
--------------------------------------------------------------------------------
/onnx2tflite/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .converter import onnx_converter
3 |
4 | def parse_opt():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--weights', type=str, required=True, help='onnx model path')
7 | parser.add_argument('--outpath', type=str, default=None, help='tflite model save path')
8 | parser.add_argument('--input-node-names', nargs="+", default=None, help='which inputs is you want, support middle layers, None will using onnx orignal inputs')
9 | parser.add_argument('--output-node-names', nargs="+", default=None, help='which outputs is you want, support middle layers, None will using onnx orignal outputs')
10 | parser.add_argument('--nosimplify', default=False, action='store_true', help='do not simplify model')
11 | parser.add_argument("--native-groupconv", default=False, action='store_true', help='using native method for groupconv, only support for tflite version >= 2.9')
12 | parser.add_argument('--weigthquant', default=False, action='store_true', help='weight only int8 quant')
13 | parser.add_argument('--fp16', default=False, action='store_true', help='fp16 quant, include input output')
14 | parser.add_argument('--int8', default=False, action='store_true', help='int8 quant, include input output')
15 | parser.add_argument('--imgroot', type=str, default=None, help='when int8=True, imgroot should give for calculating running_mean and running_norm')
16 | parser.add_argument('--int8mean', type=float, nargs='+', default=[123.675, 116.28, 103.53], help='int8 image preprocesses mean, float or list')
17 | parser.add_argument('--int8std', type=float, nargs='+', default=[58.395, 57.12, 57.375], help='int8 image preprocesses std, float or list')
18 | parser.add_argument('--formats', nargs='+', default=['keras', 'tflite'], help='available formats are (h5, tflite)')
19 | opt = parser.parse_args()
20 | return opt
21 |
22 | def run():
23 | opt = parse_opt()
24 | onnx_converter(
25 | onnx_model_path = opt.weights,
26 | need_simplify = not opt.nosimplify,
27 | input_node_names = opt.input_node_names,
28 | output_node_names = opt.output_node_names,
29 | output_path = opt.outpath,
30 | target_formats = opt.formats,
31 | native_groupconv = opt.native_groupconv,
32 | weight_quant=opt.weigthquant,
33 | fp16_model=opt.fp16,
34 | int8_model=opt.int8,
35 | int8_mean=opt.int8mean,
36 | int8_std=opt.int8std,
37 | image_root=opt.imgroot
38 | )
39 |
40 | if __name__ == "__main__":
41 | run()
--------------------------------------------------------------------------------
/onnx2tflite/components/__init__.py:
--------------------------------------------------------------------------------
1 | from .output_check import get_elements_error
2 | from .onnx_loader import load_onnx_modelproto
3 | from .builder import keras_builder, tflite_builder
4 |
5 | __all__ = ['load_onnx_modelproto', 'keras_builder', 'tflite_builder', 'get_elements_error']
--------------------------------------------------------------------------------
/onnx2tflite/components/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import tensorflow as tf
5 | from tensorflow import keras
6 | from onnx import numpy_helper
7 | from .dataloader import RandomLoader, ImageLoader
8 |
9 | from onnx2tflite.utils import OPERATOR
10 | from onnx2tflite.layers import conv_layers
11 | from onnx2tflite.utils.definitions import *
12 | from onnx2tflite.utils.graph_tools import build_tf_inputs, decode_node_attribute
13 |
14 | def keras_builder(onnx_model, native_groupconv:bool=False):
15 |
16 | conv_layers.USE_NATIVE_GROUP_CONV = native_groupconv
17 |
18 | model_graph = onnx_model.graph
19 | layout_dict, tf_tensor = {}, {}
20 |
21 | '''
22 | init onnx model's build-in tensors
23 | '''
24 | onnx_weights = dict()
25 | for initializer in model_graph.initializer:
26 | onnx_weights[initializer.name] = numpy_helper.to_array(initializer)
27 |
28 | '''
29 | build input nodes
30 | '''
31 | input_nodes = build_tf_inputs(model_graph, layout_dict)
32 | tf_tensor.update(input_nodes)
33 |
34 | '''
35 | build model inline node by iterate onnx nodes.
36 | '''
37 | for node in model_graph.node:
38 | op_name, node_inputs, node_outputs = node.op_type, node.input, node.output
39 | op_attr = decode_node_attribute(node)
40 |
41 | tf_operator = OPERATOR.get(op_name)
42 | if tf_operator is None:
43 | raise KeyError(f"{op_name} not implemented yet")
44 |
45 | _inputs = None
46 | if len(node_inputs) > 0:
47 | _inputs = tf_tensor[node_inputs[0]] if node_inputs[0] in tf_tensor else onnx_weights[node_inputs[0]]
48 |
49 | # init layout
50 | for index in range(len(node_outputs)):
51 | layout_dict[node_outputs[index]] = layout_dict.get(node_inputs[0], Layout.Default)
52 |
53 | res = tf_operator(tf_tensor, onnx_weights, node_inputs, op_attr, node_outputs, layout_dict)(_inputs)
54 | if isinstance(res, list):
55 | for index in range(len(node_outputs)):
56 | tf_tensor[node_outputs[index]] = res[index]
57 | else:
58 | tf_tensor[node_outputs[0]] = res
59 |
60 | '''
61 | build keras model
62 | '''
63 | input_nodes = [tf_tensor[x.name] for x in model_graph.input]
64 | outputs_nodes = [tf_tensor[x.name] for x in model_graph.output]
65 | keras_model = keras.Model(inputs=input_nodes, outputs=outputs_nodes)
66 | keras_model.trainable = False
67 | # keras_model.summary()
68 | # print(layout_dict)
69 | input_layout, output_layout = {}, {}
70 | for inp in model_graph.input:
71 | input_layout[inp.name] = layout_dict[inp.name]
72 | for oup in model_graph.output:
73 | output_layout[oup.name] = layout_dict[oup.name]
74 | return keras_model, input_layout, output_layout
75 |
76 | def tflite_builder(keras_model, weight_quant:bool=False, fp16_model=False, int8_model:bool=False, image_root:str=None,
77 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375]):
78 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
79 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
80 | if weight_quant or int8_model or fp16_model:
81 | converter.experimental_new_converter = True
82 | converter.optimizations = [tf.lite.Optimize.DEFAULT]
83 |
84 | if fp16_model:
85 | converter.target_spec.supported_types = [tf.float16]
86 | converter.inference_input_type = tf.float32
87 | converter.inference_output_type = tf.float32
88 | elif int8_model:
89 | assert len(keras_model.inputs) == 1, f"help want, only support single input model."
90 | shape = list(keras_model.inputs[0].shape)
91 | dataset = RandomLoader(shape) if image_root is None else ImageLoader(image_root, shape, int8_mean, int8_std)
92 | converter.representative_dataset = lambda: dataset
93 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.SELECT_TF_OPS]
94 | converter.target_spec.supported_types = []
95 | converter.inference_input_type = tf.uint8
96 | converter.inference_output_type = tf.uint8
97 | converter.experimental_new_converter = True
98 |
99 | tflite_model = converter.convert()
100 | return tflite_model
--------------------------------------------------------------------------------
/onnx2tflite/components/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import logging
4 | import numpy as np
5 |
6 | LOG = logging.getLogger("Quantization DataLoder :")
7 |
8 | class RandomLoader(object):
9 | def __init__(self, target_size):
10 | self.target_size = target_size
11 | LOG.warning(f"Generate quantization data from random, it's will lead to accuracy problem!")
12 |
13 | def __iter__(self):
14 | self.index = 0
15 | return self
16 |
17 | def __next__(self):
18 | if self.index > 5:
19 | raise StopIteration()
20 | self.index += 1
21 | return [np.random.randn(*self.target_size).astype(np.float32)]
22 |
23 | class ImageLoader(object):
24 | '''
25 | generate data for quantization from image datas.
26 | img_quan_data = (img - mean)/std, it's important for accuracy of model.
27 | '''
28 | VALID_FORMAT = ['.jpg', '.png', '.jpeg']
29 |
30 | def __init__(self, img_root, target_size, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) -> None:
31 | assert os.path.exists(img_root), F"{img_root} is not exists, please check!"
32 | self.fns = os.listdir(img_root)
33 | self.fns = list(filter(lambda fn: os.path.splitext(fn)[-1].lower() in self.VALID_FORMAT, self.fns))
34 | self.nums = len(self.fns)
35 | assert self.nums > 0, f"No images detected in {img_root}."
36 | if self.nums > 100:
37 | LOG.warning(f"{self.nums} images detected, the number of recommended images is less than 100.")
38 | else:
39 | LOG.info(f"{self.nums} images detected.")
40 | self.fns = [os.path.join(img_root, fn) for fn in self.fns]
41 |
42 | self.batch, self.size = target_size[0], target_size[1:-1]
43 | if isinstance(mean, list):
44 | mean = np.array(mean, dtype=np.float32)
45 | if isinstance(std, list):
46 | std = np.array(std, dtype=np.float32)
47 | self.mean, self.std = mean, std
48 |
49 | def __iter__(self):
50 | self.index = 0
51 | return self
52 |
53 | def __next__(self):
54 | if self.index >= self.nums:
55 | raise StopIteration()
56 |
57 | _input = cv2.imread(self.fns[self.index])
58 | _input = cv2.resize(_input, self.size)[:, :, ::-1]#BGR->RGB
59 | _input = _input.astype(np.float32)
60 |
61 | if self.mean is not None:
62 | _input = (_input - self.mean)
63 | if self.std is not None:
64 | _input = _input/self.std
65 |
66 | _input = np.expand_dims(_input, axis=0)
67 | if self.batch > 1:
68 | _input = np.repeat(_input, self.batch, axis=0).astype(np.float32)
69 |
70 | self.index += 1
71 | return [_input]
72 |
--------------------------------------------------------------------------------
/onnx2tflite/components/onnx_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import onnx
3 | import logging
4 | from onnxsim import simplify
5 |
6 | LOG = logging.getLogger("onnx_loader running:")
7 | LOG.setLevel(logging.INFO)
8 |
9 | def clean_model_input(model_proto):
10 | inputs = model_proto.graph.input
11 | name_to_input = {}
12 | for input in inputs:
13 | name_to_input[input.name] = input
14 |
15 | names = []
16 | for initializer in model_proto.graph.initializer:
17 | if initializer.name in name_to_input:
18 | inputs.remove(name_to_input[initializer.name])
19 | names.append(initializer.name)
20 |
21 | if len(names) > 0:
22 | LOG.warning(f"[{len(names)}] redundant input nodes are removed.\n \
23 | nodes name : {','.join(names)}")
24 |
25 | def get_onnx_submodel(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None):
26 | '''
27 | cutoff onnx model
28 | '''
29 | model_proto = onnx.load(onnx_model_path)
30 | if input_node_names is None:
31 | input_node_names = []
32 | for inp in model_proto.graph.input:
33 | input_node_names.append(inp.name)
34 |
35 | if output_node_names is None:
36 | output_node_names = []
37 | for oup in model_proto.graph.output:
38 | output_node_names.append(oup.name)
39 | del model_proto
40 |
41 | new_model_path = os.path.splitext(onnx_model_path)[0] + "_sub.onnx"
42 | onnx.utils.extract_model(onnx_model_path, new_model_path, input_node_names, output_node_names)
43 | model_proto = onnx.load(new_model_path)
44 | return model_proto
45 |
46 | def get_proto(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None):
47 | if input_node_names is None and output_node_names is None:
48 | return onnx.load(onnx_model_path)
49 | else:
50 | return get_onnx_submodel(onnx_model_path, input_node_names, output_node_names)
51 |
52 | def load_onnx_modelproto(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None, need_simplify:bool=True):
53 | if not os.path.exists(onnx_model_path):
54 | LOG.error(f"{onnx_model_path} is not exists.")
55 | raise FileExistsError(f"{onnx_model_path} is not exists.")
56 | model_proto = get_proto(onnx_model_path, input_node_names, output_node_names)
57 | dynamic_input = False
58 | for inp in model_proto.graph.input:
59 | for x in inp.type.tensor_type.shape.dim:
60 | if x.dim_value <= 0:
61 | dynamic_input = True
62 | break
63 | if need_simplify:
64 | success = False
65 | try:
66 | model_proto, success = simplify(model_proto, check_n=1, dynamic_input_shape=dynamic_input)
67 | except:
68 | success = False
69 | if not success:
70 | LOG.warning(f"onnxsim is failed, maybe make convert fails.")
71 | model_proto = onnx.load(onnx_model_path)
72 | clean_model_input(model_proto)
73 | return model_proto
--------------------------------------------------------------------------------
/onnx2tflite/components/output_check.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import onnxruntime as ort
5 | from onnx2tflite.utils.definitions import Layout
6 | from onnx2tflite.utils.dimension_utils import tensor_NDC_to_NCD_format
7 |
8 | def tflite_run(model_path:str) -> np.ndarray:
9 | '''
10 | tflite runtime
11 | '''
12 | tflite_runtime = tf.lite.Interpreter(model_path, num_threads=4)
13 | tflite_runtime.allocate_tensors()
14 | input_details, output_details = tflite_runtime.get_input_details(), tflite_runtime.get_output_details()
15 | for i in range(len(input_details)):
16 | tflite_runtime.set_tensor(input_details[i]['index'], np.ones(input_details[i]['shape'], dtype=np.float32))
17 | tflite_runtime.invoke()
18 |
19 | # only compare one output is ok.
20 | tflite_output = tflite_runtime.get_tensor(output_details[0]['index'])
21 | return tflite_output
22 |
23 | def keras_run(model_path:str) -> np.ndarray:
24 | '''
25 | keras runtime
26 | '''
27 | keras_runtime = tf.keras.models.load_model(model_path)
28 | _input = []
29 | for inp in keras_runtime.inputs:
30 | _input.append(np.ones(list(inp.shape), dtype=np.float32))
31 |
32 | keras_output = keras_runtime.predict(_input)
33 | # only compare one output is ok.
34 | if isinstance(keras_output, list):
35 | keras_output = keras_output[0]
36 | return keras_output
37 |
38 |
39 | def get_elements_error(onnx_proto, keras_model_path:str, tflite_model_path:str, input_layout:dict, output_layout:dict) -> dict:
40 | '''
41 | use ones input arr to check model.
42 | more carefully check is up to youself custom code.
43 | '''
44 | result = {}
45 | # test onnx
46 | onnx_runtime = ort.InferenceSession(onnx_proto.SerializeToString())
47 | onnx_inputs = {}
48 | for inp in onnx_runtime.get_inputs():
49 | shape = inp.shape
50 | if isinstance(shape[0], str) or shape[0] < 1:
51 | shape[0] = 1
52 | onnx_inputs[inp.name] = np.ones(shape, dtype=np.float32)
53 | if len(shape) > 2:
54 | _transpose_index = [i for i in range(len(shape))]
55 | _transpose_index = _transpose_index[0:1] + _transpose_index[2:] + _transpose_index[1:2]
56 | onnx_outputs = onnx_runtime.run([], onnx_inputs)
57 |
58 | channel_last = False
59 | for oup in onnx_proto.graph.output:
60 | channel_last = output_layout[oup.name] == Layout.Channel_Last
61 | break
62 |
63 | if keras_model_path is not None:
64 | # test keras model
65 | keras_output = keras_run(keras_model_path)
66 | if channel_last:
67 | keras_output = tensor_NDC_to_NCD_format(keras_output)
68 | # get max error
69 | keras_max_error = 1000
70 | for onnx_output in onnx_outputs:
71 | if onnx_output.shape != keras_output.shape:
72 | continue
73 | diff = np.abs(onnx_output - keras_output)
74 | max_diff = np.max(diff)
75 | keras_max_error = min(keras_max_error, max_diff)
76 | result['keras'] = keras_max_error
77 |
78 | if tflite_model_path is not None:
79 | # test tflite
80 | tflite_output = tflite_run(tflite_model_path)
81 | if channel_last:
82 | tflite_output = tensor_NDC_to_NCD_format(tflite_output)
83 | # get max error
84 | tflite_max_error = 1000
85 | for onnx_output in onnx_outputs:
86 | if onnx_output.shape != tflite_output.shape:
87 | continue
88 | diff = np.abs(onnx_output - tflite_output)
89 | max_diff = np.max(diff)
90 | tflite_max_error = min(tflite_max_error, max_diff)
91 | result['tflite'] = tflite_max_error
92 |
93 | return result
--------------------------------------------------------------------------------
/onnx2tflite/converter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from .components import load_onnx_modelproto, keras_builder, tflite_builder, get_elements_error
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | LOG = logging.getLogger("converter running:")
7 |
8 | def onnx_converter(onnx_model_path:str, output_path:str=None,
9 | input_node_names:list=None, output_node_names:list=None,
10 | need_simplify:bool=True, target_formats:list = ['keras', 'tflite'],
11 | native_groupconv:bool=False,
12 | weight_quant:bool=False, fp16_model:bool=False, int8_model:bool=False, image_root:str=None,
13 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375])->float:
14 | """
15 | Converts an ONNX model to various target formats with optional optimizations.
16 |
17 | Parameters:
18 | onnx_model_path (str): Path to the input ONNX model file.
19 | output_path (str, optional): Path to save the converted model(s). If None, the converted model(s) will be saved in the same directory as the input model.
20 | input_node_names (list, optional): List of input node names. If None, the default input nodes of the ONNX model are used.
21 | output_node_names (list, optional): List of output node names. If None, the default output nodes of the ONNX model are used.
22 | need_simplify (bool, optional): If True, the ONNX model will be simplified before conversion. Default is True.
23 | target_formats (list, optional): List of target formats to convert the ONNX model to. Default is ['keras', 'tflite'].
24 | native_groupconv (bool, optional): If True, retains native group convolution operations during conversion. Default is False.
25 | weight_quant (bool, optional): If True, applies weight quantization to the converted model. Default is False.
26 | fp16_model (bool, optional): If True, converts the model to use FP16 precision. Default is False.
27 | int8_model (bool, optional): If True, converts the model to use INT8 precision. Default is False.
28 | image_root (str, optional): Path to the root directory of images for calibration if INT8 quantization is enabled. Default is None.
29 | int8_mean (list or float, optional): Mean values for INT8 quantization. Default is [123.675, 116.28, 103.53].
30 | int8_std (list or float, optional): Standard deviation values for INT8 quantization. Default is [58.395, 57.12, 57.375].
31 |
32 | Returns:
33 | float: Error value.
34 |
35 | Note:
36 | - The function supports multiple target formats for conversion and allows for various optimizations such as simplification, quantization, and precision reduction.
37 | - When INT8 quantization is enabled, 'image_root', 'int8_mean', and 'int8_std' parameters are used for calibration.
38 | """
39 | if not isinstance(target_formats, list) and 'keras' not in target_formats and 'tflite' not in target_formats:
40 | raise KeyError("'keras' or 'tflite' should in list")
41 |
42 | model_proto = load_onnx_modelproto(onnx_model_path, input_node_names, output_node_names, need_simplify)
43 |
44 | keras_model, input_layout, output_layout = keras_builder(model_proto, native_groupconv)
45 |
46 | if 'tflite' in target_formats:
47 | tflite_model = tflite_builder(keras_model, weight_quant, fp16_model, int8_model, image_root, int8_mean, int8_std)
48 |
49 | onnx_path, model_name = os.path.split(onnx_model_path)
50 | if output_path is None:
51 | output_path = onnx_path
52 | output_path = os.path.join(output_path, model_name.split('.')[0])
53 |
54 | if fp16_model:
55 | output_path = output_path + "_fp16"
56 | elif int8_model:
57 | output_path = output_path + "_int8"
58 |
59 | keras_model_path = None
60 | if 'keras' in target_formats:
61 | keras_model_path = output_path + ".h5"
62 | keras_model.save(keras_model_path)
63 | LOG.info(f"keras model saved in {keras_model_path}")
64 |
65 | tflite_model_path = None
66 | if 'tflite' in target_formats:
67 | tflite_model_path = output_path + ".tflite"
68 | with open(tflite_model_path, "wb") as fp:
69 | fp.write(tflite_model)
70 |
71 | convert_result = {"keras":keras_model_path, "tflite":tflite_model_path, "keras_error":0, "tflite_error":0}
72 | # ignore quantization model
73 | if int8_model:
74 | return convert_result
75 |
76 | error_dict = {}
77 | try:
78 | error_dict = get_elements_error(model_proto, keras_model_path, tflite_model_path, input_layout, output_layout)
79 | keras_error, tflite_error = error_dict.get("keras", None), error_dict.get("tflite", None)
80 | if keras_error:
81 | if keras_error > 1e-2:
82 | LOG.error("h5 model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(keras_error, keras_model_path))
83 | elif keras_error > 1e-4:
84 | LOG.warning("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path))
85 | else:
86 | LOG.info("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path))
87 | if tflite_error:
88 | if tflite_error > 1e-2:
89 | LOG.error("tflite model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(tflite_error, tflite_model_path))
90 | elif tflite_error > 1e-4:
91 | LOG.warning("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path))
92 | else:
93 | LOG.info("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path))
94 | except:
95 | LOG.warning("convert is successed, but model running is failed, please check carefully!")
96 |
97 | convert_result["keras_error"] = error_dict.get("keras", None)
98 | convert_result["tflite_error"] = error_dict.get("tflite", None)
99 | return convert_result
--------------------------------------------------------------------------------
/onnx2tflite/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv_layers import *
2 | from .common_layers import *
3 | from .activations_layers import *
4 | from .mathematics_layers import *
5 | from .deformation_layers import *
--------------------------------------------------------------------------------
/onnx2tflite/layers/activations_layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow import keras
4 |
5 | from onnx2tflite.utils.definitions import Layout
6 | from onnx2tflite.utils import OPERATOR, channel_to_last_dimension, tensor_NCD_to_NDC_format
7 |
8 | @OPERATOR.register_operator("Relu")
9 | class TFRelu():
10 | def __init__(self, *args, **kwargs) -> None:
11 | super().__init__()
12 |
13 | def __call__(self, inputs):
14 | return keras.activations.relu(inputs)
15 |
16 | @OPERATOR.register_operator("HardSigmoid")
17 | class TFHardSigmoid():
18 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
19 | super().__init__()
20 | self.alpha = node_attribute.get("alpha", 0.2)
21 | self.beta = node_attribute.get("beta", 0.5)
22 |
23 | def __call__(self, inputs):
24 | return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1)
25 |
26 | @OPERATOR.register_operator("HardSwish")
27 | class TFHardSwish():
28 | def __init__(self, *args, **kwargs) -> None:
29 | super().__init__()
30 |
31 | def __call__(self, inputs):
32 | return inputs*tf.clip_by_value(inputs/6+0.5, 0, 1)
33 |
34 | @OPERATOR.register_operator("Mish")
35 | class TFMish():
36 | def __init__(self, *args, **kwargs) -> None:
37 | super().__init__()
38 |
39 | def __call__(self, inputs):
40 | return inputs*tf.tanh(tf.math.log(tf.math.exp(inputs)+1))
41 |
42 | @OPERATOR.register_operator("Sigmoid")
43 | class TFSigmoid():
44 | def __init__(self, *args, **kwargs) -> None:
45 | super().__init__()
46 |
47 | def __call__(self, inputs):
48 | return keras.activations.sigmoid(inputs)
49 |
50 | @OPERATOR.register_operator("LeakyRelu")
51 | class TFLeakyRelu():
52 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
53 | super().__init__()
54 | self.alpha = node_attribute.get('alpha', 0.01)
55 |
56 | def __call__(self, inputs):
57 | return keras.activations.relu(inputs, alpha=self.alpha)
58 |
59 | @OPERATOR.register_operator("PRelu")
60 | class TFPRelu():
61 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
62 | super().__init__()
63 | if 'slope' in node_attribute:
64 | self.slope = node_attribute['slope']
65 | elif node_inputs[1] in node_weights:
66 | self.slope = node_weights[node_inputs[1]]
67 | else:
68 | self.slope = tensor_grap[node_inputs[1]]
69 | input_tensor_shape = tensor_grap[node_inputs[0]].shape
70 | channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
71 | if isinstance(self.slope, np.ndarray):
72 | while self.slope.ndim < input_tensor_shape.ndims:
73 | self.slope = self.slope[np.newaxis, :]
74 | if channel_last:
75 | self.slope = tensor_NCD_to_NDC_format(self.slope)
76 | if self.slope.ndim > 1:
77 | # remove batchsize
78 | self.slope = self.slope[0]
79 | axes = [i for i in range(1, input_tensor_shape.ndims-1)] if channel_last else [i for i in range(2, input_tensor_shape.ndims)]
80 | self.PRelu = tf.keras.layers.PReLU(weights=[self.slope], shared_axes = axes)
81 |
82 | def __call__(self, inputs):
83 | return self.PRelu(inputs)
84 |
85 | @OPERATOR.register_operator("Sin")
86 | class TFSin():
87 | def __init__(self, *args, **kwargs) -> None:
88 | super().__init__()
89 |
90 | def __call__(self, inputs):
91 | return tf.sin(inputs)
92 |
93 | @OPERATOR.register_operator("Sinh")
94 | class TFSinh():
95 | def __init__(self, *args, **kwargs) -> None:
96 | super().__init__()
97 |
98 | def __call__(self, inputs):
99 | return tf.sinh(inputs)
100 |
101 | @OPERATOR.register_operator("Cos")
102 | class TFCos():
103 | def __init__(self, *args, **kwargs) -> None:
104 | super().__init__()
105 |
106 | def __call__(self, inputs):
107 | return tf.cos(inputs)
108 |
109 | @OPERATOR.register_operator("Cosh")
110 | class TFCosh():
111 | def __init__(self, *args, **kwargs) -> None:
112 | super().__init__()
113 |
114 | def __call__(self, inputs):
115 | return tf.cosh(inputs)
116 |
117 | @OPERATOR.register_operator("Tan")
118 | class TFTan():
119 | def __init__(self, *args, **kwargs) -> None:
120 | super().__init__()
121 |
122 | def __call__(self, inputs):
123 | return tf.tan(inputs)
124 |
125 | @OPERATOR.register_operator("Tanh")
126 | class TFTanh():
127 | def __init__(self, *args, **kwargs) -> None:
128 | super().__init__()
129 |
130 | def __call__(self, inputs):
131 | return tf.tanh(inputs)
132 |
133 | @OPERATOR.register_operator("Softmax")
134 | class TFSoftmax():
135 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
136 | super().__init__()
137 | self.axis = node_attribute.get('axis', -1)
138 | if self.axis == -1:
139 | self.axis = len(tensor_grap[node_inputs[0]].shape.as_list()) - 1
140 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
141 | self.axis = channel_to_last_dimension(self.axis)
142 |
143 | def __call__(self, inputs):
144 | return keras.activations.softmax(inputs, axis=self.axis)
145 |
146 | @OPERATOR.register_operator("Softplus")
147 | class TFSoftplus():
148 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
149 | super().__init__()
150 |
151 | def __call__(self, inputs):
152 | return keras.activations.softplus(inputs)
153 |
154 | @OPERATOR.register_operator("Softsign")
155 | class TFSoftsign():
156 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
157 | super().__init__()
158 |
159 | def __call__(self, inputs):
160 | return keras.activations.softsign(inputs)
161 |
162 | @OPERATOR.register_operator("Selu")
163 | class TFSelu():
164 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
165 | super().__init__()
166 |
167 | def __call__(self, inputs):
168 | return keras.activations.selu(inputs)
169 |
170 | @OPERATOR.register_operator("Elu")
171 | class TFElu():
172 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
173 | super().__init__()
174 |
175 | def __call__(self, inputs):
176 | return keras.activations.elu(inputs)
177 |
178 | @OPERATOR.register_operator("Celu")
179 | class TFCelu():
180 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
181 | super().__init__()
182 | self.alpha = node_attribute.get("alpha", 1.0)
183 |
184 | def __call__(self, inputs):
185 | return tf.maximum(inputs, 0) + tf.minimum(0, self.alpha*(tf.exp(inputs/self.alpha)-1))
--------------------------------------------------------------------------------
/onnx2tflite/layers/common_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow import keras
6 |
7 | from onnx2tflite.utils.definitions import Layout
8 | from onnx2tflite.utils import OPERATOR, intfloat_to_list, dimension_utils
9 |
10 | LOG = logging.getLogger("common_layers :")
11 |
12 | @OPERATOR.register_operator("BatchNormalization")
13 | class TFBatchNormalization():
14 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
15 | super().__init__()
16 | epsilon = node_attribute.get("epsilon", 1e-5)
17 | momentum = node_attribute.get("momentum", 0.9)
18 |
19 | self.bn = keras.layers.BatchNormalization(
20 | gamma_initializer=keras.initializers.Constant(node_weights[node_inputs[1]]),
21 | beta_initializer=keras.initializers.Constant(node_weights[node_inputs[2]]),
22 | moving_mean_initializer=keras.initializers.Constant(node_weights[node_inputs[3]]),
23 | moving_variance_initializer=keras.initializers.Constant(node_weights[node_inputs[4]]),
24 | epsilon=epsilon,
25 | momentum=momentum)
26 |
27 | def __call__(self, inputs):
28 | return self.bn(inputs)
29 |
30 | @OPERATOR.register_operator("InstanceNormalization")
31 | class TFInstanceNormalization():
32 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
33 | super().__init__()
34 | self.epsilon = node_attribute.get("epsilon", 1e-5)
35 | self.scale = node_weights[node_inputs[1]]
36 | self.bias = node_weights[node_inputs[2]]
37 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
38 |
39 | def __call__(self, inputs):
40 | axes = tuple(range(1, len(inputs.shape)-1)) if self.channel_last else tuple(range(2, len(inputs.shape)))
41 | mean = tf.reduce_mean(inputs, axis=axes, keepdims=True)
42 | var = tf.math.reduce_variance(inputs, axis= axes, keepdims=True)
43 | return self.scale*(inputs - mean)/tf.sqrt(var + self.epsilon) + self.bias
44 |
45 | @OPERATOR.register_operator("Pad")
46 | class TFPad():
47 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
48 | super().__init__()
49 | if node_attribute.get("pads") is not None:
50 | pads = node_attribute['pads']
51 | elif node_inputs[1] in node_weights:
52 | pads = node_weights[node_inputs[1]]
53 | else:
54 | pads = tensor_grap[node_inputs[1]]
55 | self.pad = [[pads[0], pads[4]], [pads[2], pads[6]], [pads[3], pads[7]], [pads[1], pads[5]]]
56 | self.model = node_attribute.get("mode", "constant").upper()
57 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
58 | layout_dict[node_outputs[0]] = Layout.Channel_Last
59 |
60 | def __call__(self, inputs):
61 | if not self.channel_last:
62 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
63 | return tf.pad(inputs, self.pad, mode=self.model)
64 |
65 | @OPERATOR.register_operator("Clip")
66 | class TFClip():
67 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
68 | super().__init__()
69 | if "min" in node_attribute:
70 | self.min = node_attribute.get("min")
71 | else:
72 | self.min = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
73 | if "max" in node_attribute:
74 | self.max = node_attribute.get("max")
75 | else:
76 | self.max = tensor_grap[node_inputs[2]] if node_inputs[2] in tensor_grap else node_weights[node_inputs[2]]
77 |
78 | def __call__(self, inputs):
79 | if float(self.min) == 0 and float(self.max) == 6:
80 | return tf.nn.relu6(inputs)
81 | return tf.clip_by_value(inputs, self.min, self.max)
82 |
83 | @OPERATOR.register_operator("TFGlobalMaxPool")
84 | class TFGlobalMaxPool():
85 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
86 | super().__init__()
87 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
88 |
89 | def __call__(self, inputs):
90 | if self.channel_last:
91 | return tf.reduce_max(inputs, axis=[i for i in range(1, len(inputs.shape)-1)], keepdims=True)
92 | else:
93 | return tf.reduce_max(inputs, axis=[i for i in range(2, len(inputs.shape))], keepdims=True)
94 |
95 | @OPERATOR.register_operator("GlobalAveragePool")
96 | class TFGlobalAveragePool():
97 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
98 | super().__init__()
99 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
100 |
101 | def __call__(self, inputs):
102 | if self.channel_last:
103 | return tf.reduce_mean(inputs, axis=[i for i in range(1, len(inputs.shape)-1)], keepdims=True)
104 | else:
105 | return tf.reduce_mean(inputs, axis=[i for i in range(2, len(inputs.shape))], keepdims=True)
106 |
107 | @OPERATOR.register_operator("AveragePool")
108 | class TFAveragePool():
109 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
110 | super().__init__()
111 | kernel_shape = intfloat_to_list(node_attribute.get("kernel_shape", [2, 2]), 2)
112 | strides = intfloat_to_list(node_attribute.get("strides", [1, 1]), 2)
113 | dilations = intfloat_to_list(node_attribute.get("dilations", [1, 1]), 2)
114 | ceil_mode = node_attribute.get("ceil_mode", 0)
115 | pads = intfloat_to_list(node_attribute.get("pads", [0, 0, 0, 0]), 4)
116 |
117 | func = math.floor if ceil_mode == 0 else math.ceil
118 |
119 | pad_mode = "SAME"
120 | input_shape = tensor_grap[node_inputs[0]].shape
121 | for i in range(len(input_shape)-2):
122 | pad_shape = pads[i] + pads[i+2]
123 | onnx_output_shape = func((input_shape[1+i]+pad_shape-((kernel_shape[i]-1)*dilations[i]+1))/strides[i]+1)
124 | tf_output_shape = math.floor((input_shape[1+i] - kernel_shape[i]) / strides[i]) + 1
125 | pads[2+i] = max(onnx_output_shape-tf_output_shape, pads[2+i]) # right_down pad
126 | if pad_mode == "SAME" and onnx_output_shape != input_shape[1+i]:
127 | pad_mode = "VALID"
128 | self.avg_pool = keras.layers.AveragePooling2D(pool_size=kernel_shape, strides=strides, padding=pad_mode)
129 |
130 | self.pad = None
131 | if pad_mode == "VALID" and pads is not None and np.sum(pads) > 0:
132 | if np.sum(pads) > 0:
133 | self.pad = keras.layers.ZeroPadding2D(padding=((pads[0], pads[2]), (pads[1], pads[3])))
134 |
135 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
136 | layout_dict[node_outputs[0]] = Layout.Channel_Last
137 |
138 | def __call__(self, inputs):
139 | if not self.channel_last:
140 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
141 | if self.pad:
142 | inputs = self.pad(inputs)
143 | return self.avg_pool(inputs)
144 |
145 | @OPERATOR.register_operator("MaxPool")
146 | class TFMaxPool():
147 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
148 | super().__init__()
149 | kernel_shape = intfloat_to_list(node_attribute.get("kernel_shape", [2, 2]), 2)
150 | strides = intfloat_to_list(node_attribute.get("strides", [1, 1]), 2)
151 | dilations = intfloat_to_list(node_attribute.get("dilations", [1, 1]), 2)
152 | ceil_mode = node_attribute.get("ceil_mode", 0)
153 | pads = intfloat_to_list(node_attribute.get("pads", [0, 0, 0, 0]), 4)
154 |
155 | func = math.floor if ceil_mode == 0 else math.ceil
156 |
157 | pad_mode = "SAME"
158 | input_shape = tensor_grap[node_inputs[0]].shape
159 | for i in range(len(input_shape)-2):
160 | pad_shape = pads[i] + pads[i+2]
161 | onnx_output_shape = func((input_shape[1+i]+pad_shape-((kernel_shape[i]-1)*dilations[i]+1))/strides[i]+1)
162 | tf_output_shape = math.floor((input_shape[1+i] - kernel_shape[i]) / strides[i]) + 1
163 | pads[2+i] = max(onnx_output_shape-tf_output_shape, pads[2+i]) # right_down pad
164 | if pad_mode == "SAME" and onnx_output_shape != input_shape[1+i]:
165 | pad_mode = "VALID"
166 | self.max_pool = keras.layers.MaxPool2D(pool_size=kernel_shape, strides=strides, padding=pad_mode)
167 |
168 | self.pad = None
169 | if pad_mode == "VALID" and pads is not None and np.sum(pads) > 0:
170 | if np.sum(pads) > 0:
171 | self.pad = keras.layers.ZeroPadding2D(padding=((pads[0], pads[2]), (pads[1], pads[3])))
172 |
173 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
174 | layout_dict[node_outputs[0]] = Layout.Channel_Last
175 |
176 | def __call__(self, inputs):
177 | if not self.channel_last:
178 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
179 | if self.pad:
180 | inputs = self.pad(inputs)
181 | return self.max_pool(inputs)
182 |
183 | @OPERATOR.register_operator("Upsample")
184 | class TFUpsample():
185 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
186 | super().__init__()
187 | _, h, w, _ = tensor_grap[node_inputs[0]].shape
188 | scale = node_weights[node_inputs[1]]
189 |
190 | self.scale = (int(h*scale[2]), int(w*scale[3]))
191 | if node_attribute.get("mode", "nearest").lower() == 'nearest':
192 | self.method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
193 | else:
194 | self.method = tf.image.ResizeMethod.BILINEAR
195 |
196 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
197 | layout_dict[node_outputs[0]] = Layout.Channel_Last
198 |
199 | def __call__(self, inputs):
200 | if not self.channel_last:
201 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
202 | return tf.image.resize(inputs, self.scale, method=self.method)
203 |
204 | @OPERATOR.register_operator("Constant")
205 | class TFConstant():
206 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
207 | super().__init__()
208 | self.val = node_attribute['value']
209 |
210 | def __call__(self, *args, **kwargs):
211 | return self.val
212 |
213 | @OPERATOR.register_operator("ScatterND")
214 | class TFScatterND():
215 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
216 | super().__init__()
217 | self.indices = node_weights[node_inputs[1]]
218 | self.channle_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
219 | if node_inputs[2] in tensor_grap:
220 | self.updates = tensor_grap[node_inputs[2]]
221 | if self.channle_last:
222 | self.updates = dimension_utils.tensor_NDC_to_NCD_format(self.updates)
223 | else:
224 | self.updates = node_weights[node_inputs[2]]
225 |
226 | layout_dict[node_outputs[0]] = Layout.Channel_First
227 |
228 | def __call__(self, inputs):
229 | if self.channle_last:
230 | inputs = dimension_utils.tensor_NDC_to_NCD_format(inputs)
231 | inputs = tf.tensor_scatter_nd_update(inputs, self.indices, self.updates)
232 | return inputs
233 |
234 | @OPERATOR.register_operator("Resize")
235 | class TFResize():
236 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
237 | super().__init__()
238 | if node_inputs[-1] in node_weights:
239 | _, _, nh, nw = node_weights[node_inputs[-1]]
240 | if len(node_inputs) != 4:
241 | _, h, w, _ = tensor_grap[node_inputs[0]].shape
242 | nh, nw = int(h*nh), int(w*nw)
243 | self.scale = (nh, nw)
244 | else:
245 | scales = tensor_grap[node_inputs[0]].shape[1:3]*tensor_grap[node_inputs[2]][2:3]
246 | self.scale = scales
247 |
248 | if node_attribute.get("mode", "nearest").lower() == 'nearest':
249 | self.method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
250 | else:
251 | self.method = tf.image.ResizeMethod.BILINEAR
252 |
253 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
254 | layout_dict[node_outputs[0]] = Layout.Channel_Last
255 |
256 | def __call__(self, inputs):
257 | if not self.channel_last:
258 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
259 | return tf.image.resize(inputs, self.scale, method=self.method)
260 |
261 | @OPERATOR.register_operator("Gemm")
262 | class TFGemm():
263 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
264 | super().__init__()
265 | if len(node_inputs) > 2:
266 | weights = [node_weights[node_inputs[1]].T, node_weights[node_inputs[2]]]
267 | else:
268 | weights = [node_weights[node_inputs[1]].T]
269 |
270 | self.dense = keras.layers.Dense(weights[0].shape[1],
271 | weights=weights,
272 | use_bias=len(weights)==2)
273 |
274 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
275 | layout_dict[node_outputs[0]] = Layout.Channel_Last
276 |
277 | def __call__(self, inputs):
278 | if not self.channel_last:
279 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs)
280 | return self.dense(inputs)
281 |
282 | @OPERATOR.register_operator("Identity")
283 | class TFIdentity():
284 | def __init__(self, *args, **kwargs):
285 | super().__init__()
286 |
287 | def __call__(self, inputs):
288 | return inputs
289 |
290 | @OPERATOR.register_operator("Dropout")
291 | class TFDropout():
292 | '''
293 | Dropout will be ignored in deployment.
294 | '''
295 | def __init__(self, *args, **kwargs):
296 | super().__init__()
297 |
298 | def __call__(self, inputs):
299 | return inputs
300 |
301 | @OPERATOR.register_operator("TopK")
302 | class TFTopK():
303 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
304 |
305 | self.axis = node_attribute.get("axis", -1)
306 | self.largest = node_attribute.get("largest", 1)
307 | self.sorted = bool(node_attribute.get("sorted", 1))
308 | self.K = node_attribute.get('K') if len(node_inputs)==1 else node_weights[node_inputs[1]][0]
309 |
310 | def __call__(self, inputs):
311 | res = tf.math.top_k(inputs, k=self.K, sorted=self.sorted)
312 | return [res[0], res[1]]
313 |
314 | @OPERATOR.register_operator("Cast")
315 | class TFCast():
316 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
317 | super().__init__()
318 | self.cast_to = int(node_attribute.get("to", 1))
319 | assert self.cast_to > 0 and self.cast_to < 12, f"Unknown cast type [{self.cast_to}]"
320 | self.np_cast_map = {
321 | 1: np.float32,
322 | 2: np.uint8,
323 | 3: np.int8,
324 | 5: np.int16,
325 | 6: np.int32,
326 | 7: np.int64,
327 | 9: np.bool_,
328 | 10: np.float16,
329 | 11: np.double,
330 | }
331 | self.tf_cast_map = {
332 | 1: tf.float32,
333 | 2: tf.uint8,
334 | 3: tf.int8,
335 | 5: tf.int16,
336 | 6: tf.int32,
337 | 7: tf.int64,
338 | 9: tf.bool,
339 | 10: tf.float16,
340 | 11: tf.double,
341 | }
342 |
343 | def __call__(self, inputs):
344 | if isinstance(inputs, list):
345 | for i in range(len(inputs)):
346 | if isinstance(inputs[i], np.ndarray) or isinstance(inputs[i], np.generic):
347 | inputs[i] = self.np_cast_map[self.cast_to](inputs[i])
348 | else:
349 | inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to])
350 | else:
351 | if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic):
352 | inputs = self.np_cast_map[self.cast_to](inputs)
353 | else:
354 | inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])
355 |
356 | return inputs
357 |
--------------------------------------------------------------------------------
/onnx2tflite/layers/conv_layers.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: MPolaris && yutaka329 && lkdci
3 |
4 | Thanks for yutaka329 with your pad tricks.
5 | https://github.com/MPolaris/onnx2tflite/issues/5
6 |
7 | Thanks for lkdci with your native method of group conv
8 | https://github.com/MPolaris/onnx2tflite/issues/19
9 | '''
10 | import logging
11 | import tensorflow as tf
12 | from tensorflow import keras
13 | from onnx2tflite.utils.op_registry import OPERATOR
14 | from onnx2tflite.utils.definitions import Layout
15 | from onnx2tflite.utils.dimension_utils import tensor_NCD_to_NDC_format as NCD2NDC
16 |
17 | LOG = logging.getLogger("convolution_layers :")
18 |
19 | # Whether to implement grouped convolution using the native `keras.layers.Conv2D` class with groups !=1 argument.
20 | # This implementation is supported only with tflite version >= 2.9.
21 | # If set to `False`, the grouped convolution is built using regular conv per group then concatenated as a workaround
22 | # to support older version of tflite.
23 | # Using the native keras implementation results in a simplified tflite graph and supposed to run faster.
24 | # See https://github.com/MPolaris/onnx2tflite/issues/19 for more details.
25 | USE_NATIVE_GROUP_CONV = False
26 |
27 | @OPERATOR.register_operator("ConvTranspose")
28 | class TFConvTranspose():
29 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
30 | super().__init__()
31 | # out_channel, in_channel = node_weights[node_inputs[1]].shape[:2]
32 | dilations, group = node_attribute.get('dilations', 1), node_attribute.get('group', 1)
33 | pads = node_attribute['pads'] if "pads" in node_attribute else None
34 | kernel_shape, strides = node_attribute.get('kernel_shape', 1), node_attribute.get('strides', 1)
35 |
36 | weights = node_weights[node_inputs[1]].transpose(2,3,1,0)
37 | bias = node_weights[node_inputs[2]] if len(node_inputs) == 3 else None
38 | height, width, n_filters, channels = weights.shape
39 | self.pad = None
40 | self.conv = keras.layers.Conv2DTranspose(filters=n_filters, kernel_size=(height, width), strides=strides, padding='VALID', use_bias=False if bias is None else True,
41 | weights=[weights] if bias is None else [weights, bias],
42 | dilation_rate=dilations)
43 | if pads is not None and max(pads) != 0:
44 | padding = None
45 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0):
46 | padding = (pads[0], pads[1])
47 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0):
48 | padding = ((pads[0], pads[2]), (pads[1], pads[3]))
49 | self.pad = keras.layers.Cropping2D(padding)
50 |
51 | for nop in node_outputs:
52 | layout_dict[nop] = Layout.Channel_Last
53 |
54 | self.need_trans = layout_dict[node_inputs[0]] != Layout.Channel_Last
55 |
56 | def __call__(self, inputs):
57 | if self.need_trans:
58 | inputs = NCD2NDC(inputs)
59 | inputs = self.conv(inputs)
60 | if self.pad:
61 | inputs = self.pad(inputs)
62 | return inputs
63 |
64 | @OPERATOR.register_operator("Conv")
65 | class Convlution():
66 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
67 | super().__init__()
68 | out_channel, in_channel = node_weights[node_inputs[1]].shape[:2]
69 | dilations, group = node_attribute.get('dilations', 1), node_attribute.get('group', 1)
70 | pads = node_attribute['pads'] if "pads" in node_attribute else None
71 | kernel_shape, strides = node_attribute.get('kernel_shape', 1), node_attribute.get('strides', 1)
72 |
73 | weights = node_weights[node_inputs[1]] if node_inputs[1] in node_weights else tensor_grap[node_inputs[1]]
74 | out_channel, in_channel = weights.shape[:2]
75 |
76 | channel_sequence = [2+i for i in range(len(weights.shape)-2)] + [1, 0]
77 | weights = weights.transpose(*channel_sequence)
78 |
79 | bias = None
80 | if len(node_inputs) == 3:
81 | bias = node_weights[node_inputs[2]] if node_inputs[2] in node_weights else tensor_grap[node_inputs[2]]
82 |
83 | if group == 1:
84 | self.conv = TFConv(in_channel, out_channel, kernel_shape, strides, dilations, pads, weights, bias)
85 | elif group == out_channel:
86 | self.conv = TFDepthwiseConv(kernel_shape, strides, dilations, pads, weights, bias)
87 | else:
88 | if USE_NATIVE_GROUP_CONV:
89 | self.conv = TFConv(in_channel, out_channel, kernel_shape, strides, dilations, pads, weights, bias, group=group)
90 | LOG.warning(f"Group Convolution is detected, using native method, only supported tflite version >= 2.9, \
91 | if compatibility error occurs and please make USE_NATIVE_GROUP_CONV=False!")
92 | else:
93 | self.conv = TFGroupConv(in_channel, out_channel, kernel_shape, strides, dilations, pads, weights, bias, group=group)
94 |
95 | for nop in node_outputs:
96 | layout_dict[nop] = Layout.Channel_Last
97 |
98 | self.need_trans = layout_dict[node_inputs[0]] != Layout.Channel_Last
99 |
100 | def __call__(self, inputs):
101 | if self.need_trans:
102 | inputs = NCD2NDC(inputs)
103 | return self.conv(inputs)
104 |
105 | class TFConv():
106 | # Standard convolution
107 | def __init__(self, in_channel_num, out_channel_num, kernel_size=1,
108 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
109 | super().__init__()
110 |
111 | if len(weights.shape) == 3:
112 | self.conv1d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group)
113 | elif len(weights.shape) == 4:
114 | self.conv2d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group)
115 | elif len(weights.shape) == 5:
116 | self.conv3d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group)
117 | else:
118 | raise NotImplementedError(f"Conv{len(weights.shape)-2}d is not implemented")
119 |
120 | def conv1d_init(self, in_channel_num, out_channel_num, kernel_size=1,
121 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
122 | self.pad =None
123 | if pads is not None and max(pads) == 1 and max(strides) == 1:
124 | self.conv = keras.layers.Conv1D(
125 | out_channel_num, kernel_size, strides, "SAME", use_bias=False if bias is None else True,
126 | weights=[weights] if bias is None else [weights, bias],
127 | dilation_rate=dilations, groups=group)
128 | else:
129 | self.conv = keras.layers.Conv1D(
130 | out_channel_num, kernel_size, strides, "VALID", use_bias=False if bias is None else True,
131 | weights=[weights] if bias is None else [weights, bias],
132 | dilation_rate=dilations, groups=group)
133 | if pads is not None and max(pads) != 0:
134 | self.pad = keras.layers.ZeroPadding1D(padding=pads)
135 |
136 | def conv2d_init(self, in_channel_num, out_channel_num, kernel_size=1,
137 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
138 | if isinstance(dilations, int):
139 | dilations = (dilations, dilations)
140 | if isinstance(strides, int):
141 | strides = (strides, strides)
142 | if dilations[0] != 1 and strides[0] != 1:
143 | raise Exception("Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.")
144 |
145 | self.pad =None
146 | if pads is not None and max(pads) == 1 and max(strides) == 1:
147 | self.conv = keras.layers.Conv2D(
148 | out_channel_num, kernel_size, strides, "SAME", use_bias=False if bias is None else True,
149 | weights=[weights] if bias is None else [weights, bias],
150 | dilation_rate=dilations, groups=group)
151 | else:
152 | self.conv = keras.layers.Conv2D(
153 | out_channel_num, kernel_size, strides, "VALID", use_bias=False if bias is None else True,
154 | weights=[weights] if bias is None else [weights, bias],
155 | dilation_rate=dilations, groups=group)
156 | if pads is not None and max(pads) != 0:
157 | padding = None
158 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0):
159 | padding = (pads[0], pads[1])
160 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0):
161 | padding = ((pads[0], pads[2]), (pads[1], pads[3]))
162 | self.pad = keras.layers.ZeroPadding2D(padding=padding)
163 |
164 | def conv3d_init(self, in_channel_num, out_channel_num, kernel_size=1,
165 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
166 | raise NotImplementedError("Conv3d is not implemented")
167 |
168 | def __call__(self, inputs):
169 | if self.pad:
170 | inputs = self.pad(inputs)
171 | return self.conv(inputs)
172 |
173 | class TFGroupConv():
174 | '''
175 | Group Convolution, using split method to implement, not native.
176 | '''
177 | def __init__(self, in_channel_num, out_channel_num, kernel_size=1,
178 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
179 | super().__init__()
180 |
181 | if len(weights.shape) == 3:
182 | self.groupconv1d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group)
183 | elif len(weights.shape) == 4:
184 | self.groupconv2d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group)
185 | else:
186 | raise NotImplementedError(f"GroupConv{len(weights.shape)-2}d is not implemented")
187 |
188 | def groupconv1d_init(self, in_channel_num, out_channel_num, kernel_size=1,
189 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
190 | self.cin = in_channel_num
191 | self.groups = group
192 | out_channel_num = int(out_channel_num//group)
193 | self.convs = []
194 | for i in range(group):
195 | if pads is not None and max(pads) == 1 and max(strides) == 1:
196 | self.convs.append(keras.layers.Conv1D(
197 | out_channel_num, kernel_size, strides, 'SAME', use_bias=False if bias is None else True,
198 | dilation_rate=dilations,
199 | weights=[weights[:, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]]))
200 | else:
201 | self.convs.append(keras.layers.Conv1D(
202 | out_channel_num, kernel_size, strides, 'VALID', use_bias=False if bias is None else True,
203 | dilation_rate=dilations,
204 | weights=[weights[:, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]]))
205 | self.pad =None
206 | if pads is not None and (max(pads) != 0 and not (max(pads) == 1 and max(strides) == 1)):
207 | self.pad = keras.layers.ZeroPadding1D(padding=pads)
208 |
209 | def groupconv2d_init(self, in_channel_num, out_channel_num, kernel_size=1,
210 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1):
211 | if isinstance(dilations, int):
212 | dilations = (dilations, dilations)
213 | if isinstance(strides, int):
214 | strides = (strides, strides)
215 | if dilations[0] != 1 and strides[0] != 1:
216 | raise Exception("Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.")
217 | self.cin = in_channel_num
218 | self.groups = group
219 | out_channel_num = int(out_channel_num//group)
220 |
221 | self.convs = []
222 | for i in range(group):
223 | if pads is not None and max(pads) == 1 and max(strides) == 1:
224 | self.convs.append(keras.layers.Conv2D(
225 | out_channel_num, kernel_size, strides, 'SAME', use_bias=False if bias is None else True,
226 | dilation_rate=dilations,
227 | weights=[weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]]))
228 | else:
229 | self.convs.append(keras.layers.Conv2D(
230 | out_channel_num, kernel_size, strides, 'VALID', use_bias=False if bias is None else True,
231 | dilation_rate=dilations,
232 | weights=[weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]]))
233 | self.pad =None
234 | if pads is not None and (max(pads) != 0 and not (max(pads) == 1 and max(strides) == 1)):
235 | padding = None
236 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0):
237 | padding = (pads[0], pads[1])
238 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0):
239 | padding = ((pads[0], pads[2]), (pads[1], pads[3]))
240 | self.pad = keras.layers.ZeroPadding2D(padding=padding)
241 |
242 | def __call__(self, inputs):
243 | if self.pad is not None:
244 | inputs = self.pad(inputs)
245 | outs = []
246 | in_s = tf.split(inputs, num_or_size_splits=self.groups, axis=-1)
247 | for i in range(self.groups):
248 | outs.append(self.convs[i](in_s[i]))
249 | outs = tf.concat(outs, axis=-1)
250 | return outs
251 |
252 | class TFDepthwiseConv():
253 | # Depthwise Convolution, group = 1
254 | def __init__(self, kernel_size=1, strides=1, dilations=1, pads=None, weights=None, bias=None) -> None:
255 | super().__init__()
256 | if len(weights.shape) == 3:
257 | weights = weights.transpose(0, 2, 1)
258 | self.dwconv1d_init(kernel_size, strides, dilations, pads, weights, bias)
259 | elif len(weights.shape) == 4:
260 | weights = weights.transpose(0, 1, 3, 2)
261 | self.dwconv2d_init(kernel_size, strides, dilations, pads, weights, bias)
262 | else:
263 | raise NotImplementedError(f"DepthwiseConv{len(weights.shape)-2}d is not implemented")
264 |
265 | def dwconv1d_init(self, kernel_size=1, strides=1, dilations=1, pads=None, weights=None, bias=None):
266 | self.pad =None
267 | if pads is not None and max(pads) == 1 and max(strides) == 1:
268 | self.conv = keras.layers.DepthwiseConv1D(
269 | kernel_size, strides, "SAME", use_bias=False if bias is None else True,
270 | weights=[weights] if bias is None else [weights, bias],
271 | dilation_rate=dilations,
272 | activation=None,
273 | kernel_initializer='zeros',
274 | bias_initializer='zeros'
275 | )
276 | else:
277 | self.conv = keras.layers.DepthwiseConv1D(
278 | kernel_size, strides, "VALID", use_bias=False if bias is None else True,
279 | weights=[weights] if bias is None else [weights, bias],
280 | dilation_rate=dilations,
281 | activation=None,
282 | kernel_initializer='zeros',
283 | bias_initializer='zeros'
284 | )
285 | if pads is not None and max(pads) != 0:
286 | self.pad = keras.layers.ZeroPadding1D(padding=pads)
287 |
288 | def dwconv2d_init(self, kernel_size=1, strides=1, dilations=1, pads=None, weights=None, bias=None):
289 | if isinstance(dilations, int):
290 | dilations = (dilations, dilations)
291 | if isinstance(strides, int):
292 | strides = (strides, strides)
293 |
294 | self.pad =None
295 | if pads is not None and max(pads) == 1 and max(strides) == 1:
296 | self.conv = keras.layers.DepthwiseConv2D(
297 | kernel_size, strides, "SAME", use_bias=False if bias is None else True,
298 | weights=[weights] if bias is None else [weights, bias],
299 | dilation_rate=dilations,
300 | activation=None,
301 | kernel_initializer='zeros',
302 | bias_initializer='zeros'
303 | )
304 | else:
305 | self.conv = keras.layers.DepthwiseConv2D(
306 | kernel_size, strides, "VALID", use_bias=False if bias is None else True,
307 | weights=[weights] if bias is None else [weights, bias],
308 | dilation_rate=dilations,
309 | activation=None,
310 | kernel_initializer='zeros',
311 | bias_initializer='zeros'
312 | )
313 | if pads is not None and max(pads) != 0:
314 | padding = None
315 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0):
316 | padding = (pads[0], pads[1])
317 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0):
318 | padding = ((pads[0], pads[2]), (pads[1], pads[3]))
319 | self.pad = keras.layers.ZeroPadding2D(padding=padding)
320 |
321 | def __call__(self, inputs):
322 | if self.pad:
323 | inputs = self.pad(inputs)
324 | return self.conv(inputs)
--------------------------------------------------------------------------------
/onnx2tflite/layers/deformation_layers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tensorflow as tf
3 |
4 | from onnx2tflite.utils.definitions import Layout
5 | from onnx2tflite.utils import OPERATOR, dimension_utils
6 |
7 | LOG = logging.getLogger("deformation_layers :")
8 |
9 | @OPERATOR.register_operator("Transpose")
10 | class TFTranspose():
11 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
12 | super().__init__()
13 | for nop in node_outputs:
14 | layout_dict[nop] = Layout.Channel_First
15 | if kwargs.get("perm_list"):
16 | self.perm_list = kwargs.get("perm_list")
17 | return
18 | self.trans_in = None
19 | self.perm_list = [i for i in node_attribute['perm']]
20 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
21 | # LOG.info("Transpose will process tensor after change back to NCHW format.")
22 | shape_len = len(tensor_grap[node_inputs[0]].shape)
23 | self.trans_in = [0, shape_len-1] + [n for n in range(1, shape_len-1)]
24 |
25 | def __call__(self, inputs):
26 | if self.trans_in:
27 | inputs = tf.transpose(inputs, perm=self.trans_in)
28 | return tf.transpose(inputs, perm=self.perm_list)
29 |
30 | @OPERATOR.register_operator("Slice")
31 | class TFSlice():
32 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
33 | super().__init__()
34 | if len(node_inputs) == 1:
35 | self.starts = node_attribute['starts'][0]
36 | self.ends = node_attribute['ends'][0]
37 | self.axis = node_attribute['axes'][0]
38 | self.steps = 1
39 | else:
40 | self.starts = node_weights[node_inputs[1]][0] if node_inputs[1] in node_weights else tensor_grap[node_inputs[1]][0]
41 | self.axis = node_weights[node_inputs[3]][0] if node_inputs[3] in node_weights else tensor_grap[node_inputs[3]][0]
42 | self.ends = node_weights[node_inputs[2]][0] if node_inputs[2] in node_weights else tensor_grap[node_inputs[2]][0]
43 | self.ends = min(self.ends, tensor_grap[node_inputs[0]].shape[self.axis])
44 | if len(node_inputs) < 5:
45 | self.steps = 1
46 | else:
47 | self.steps = node_weights[node_inputs[4]][0] if node_inputs[4] in node_weights else tensor_grap[node_inputs[4]][0]
48 |
49 | shape = tensor_grap[node_inputs[0]].shape.as_list()
50 | if self.starts < 0:
51 | self.starts = shape[self.axis] + self.starts
52 | if self.ends < 0:
53 | self.ends = shape[self.axis] + self.ends
54 |
55 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
56 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
57 |
58 | def __call__(self, inputs):
59 | indices = tf.keras.backend.arange(self.starts, self.ends, step=self.steps)
60 | return tf.gather(inputs, indices, axis=self.axis)
61 |
62 | @OPERATOR.register_operator("Gather")
63 | class TFGather():
64 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
65 | super().__init__()
66 | self.axis = node_attribute.get('axis', 0)
67 | self.indices = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]
68 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
69 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
70 |
71 | def __call__(self, inputs):
72 | return tf.gather(inputs, self.indices, axis=self.axis)
73 |
74 | @OPERATOR.register_operator("Concat")
75 | class TFConcat():
76 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
77 | super().__init__()
78 | #TODO can be optimzer by watch after node, if conv to be channel last.
79 | self._axis = node_attribute['axis']
80 | # use `count` to count how much more for channel-last to channel-first
81 | count = 0
82 | for inp in node_inputs:
83 | if inp in node_weights:
84 | count -= 1
85 | elif layout_dict[inp] == Layout.Channel_Last:
86 | count += 1
87 | else:
88 | count -= 1
89 |
90 | self._gather = []
91 | if count < 0:
92 | # align to Channel_First
93 | layout_dict[node_outputs[0]] = Layout.Channel_First
94 | for inp in node_inputs:
95 | if inp in tensor_grap:
96 | if layout_dict[inp] == Layout.Channel_Last:
97 | tensor_grap[inp] = dimension_utils.tensor_NDC_to_NCD_format(tensor_grap[inp])
98 | self._gather.append(tensor_grap[inp])
99 | else:
100 | self._gather.append(node_weights[inp])
101 | else:
102 | # align to Channel_Last
103 | layout_dict[node_outputs[0]] = Layout.Channel_Last
104 | self._axis = dimension_utils.channel_to_last_dimension(self._axis)
105 | for inp in node_inputs:
106 | if inp in tensor_grap:
107 | if layout_dict[inp] != Layout.Channel_Last:
108 | tensor_grap[inp] = dimension_utils.tensor_NCD_to_NDC_format(tensor_grap[inp])
109 | self._gather.append(tensor_grap[inp])
110 | else:
111 | self._gather.append(dimension_utils.tensor_NCD_to_NDC_format(node_weights[inp]))
112 |
113 | def __call__(self, *args, **kwargs):
114 | return tf.concat(self._gather, axis=self._axis)
115 |
116 | @OPERATOR.register_operator("Reshape")
117 | class TFReshape():
118 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
119 | super().__init__()
120 | self.out_shape = node_weights[node_inputs[1]]
121 | self.trans_in = None
122 | # LOG.info("Reshape will process tensor after change back to NCHW format.")
123 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
124 | shape_len = len(tensor_grap[node_inputs[0]].shape)
125 | self.trans_in = [0, shape_len-1] + [n for n in range(1, shape_len-1)]
126 | for nop in node_outputs:
127 | layout_dict[nop] = Layout.Channel_First
128 |
129 | def __call__(self, inputs):
130 | if self.trans_in:
131 | inputs = tf.transpose(inputs, perm=self.trans_in)
132 | inputs = tf.reshape(inputs, shape=self.out_shape)
133 | return inputs
134 |
135 | @OPERATOR.register_operator("Flatten")
136 | class TFFlatten():
137 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
138 | super().__init__()
139 | num_elements = int(tensor_grap[node_inputs[0]].shape.num_elements()/tensor_grap[node_inputs[0]].shape[0])
140 | input_shape = tensor_grap[node_inputs[0]].shape
141 | self.flat = tf.keras.layers.Flatten()
142 | '''
143 | ensure memory order match, for example:
144 | onnx = (B, 2, 3, 4).reshape(B, -1)
145 | tflite = (B, 3, 4, 2).reshape(B, -1)
146 | we can observe that:
147 | onnx.shape == tflite.shape, but np.sum(onnx-tflite) != 0
148 | it's cause memory order of two vars is different, we must make tflite back to onnx by transpose.
149 | generally, this situation is general one, below is just special situation and most appear in cnn.
150 | onnx = (B, 512, 1, 1)
151 | tflite = (B, 1, 1, 512)
152 | or = (B, 1, 512, 1)
153 | these memory order are all same.
154 | '''
155 | self.perm = None
156 | if layout_dict[node_inputs[0]] == Layout.Channel_Last and num_elements != max(input_shape[1:]):
157 | self.perm = [0, len(input_shape)-1]
158 | for i in range(len(input_shape)-2):
159 | self.perm.append(i+1)
160 |
161 | def __call__(self, inputs):
162 | if self.perm:
163 | inputs = tf.transpose(inputs, perm=self.perm)
164 | return self.flat(inputs)
165 |
166 | @OPERATOR.register_operator("Split")
167 | class TFSplit():
168 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
169 | super().__init__()
170 | self.outputs_nums = len(node_outputs)
171 | self.axis = node_attribute.get("axis", 0)
172 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
173 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
174 | split_args = None
175 | if 'split' in node_attribute:
176 | split_args = node_attribute['split']
177 | else:
178 | assert len(node_inputs) == 2 and node_inputs[1] in node_weights
179 | split_args = node_weights[node_inputs[1]]
180 |
181 | self.indices = []
182 | start, end = 0, 0
183 | for i in range(self.outputs_nums):
184 | end = start + int(split_args[i])
185 | self.indices.append(tf.keras.backend.arange(start, end, 1))
186 | start = end
187 |
188 | def __call__(self, inputs):
189 | return [tf.gather(inputs, indices=indice, axis=self.axis) for indice in self.indices]
190 |
191 | @OPERATOR.register_operator("Expand")
192 | class TFExpand():
193 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
194 | super().__init__()
195 | self.shape = node_weights[node_inputs[1]]
196 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
197 | self.shape = dimension_utils.shape_NCD_to_NDC_format(self.shape)
198 | def __call__(self, inputs):
199 | for i in range(len(self.shape)):
200 | if int(self.shape[i]//inputs.shape[i]) > 1:
201 | inputs = tf.repeat(inputs, repeats=int(self.shape[i]//inputs.shape[i]), axis=i)
202 | elif self.shape[i] < inputs.shape[i] and self.shape[i] != 1:
203 | inputs = tf.repeat(inputs, repeats=int(self.shape[i]), axis=i)
204 | return inputs
205 |
206 | @OPERATOR.register_operator("GatherElements")
207 | class TFGatherElements():
208 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
209 | super().__init__()
210 | self.axis = node_attribute.get("axis", 1)
211 | self.indices = None
212 | if 'indices' in node_attribute:
213 | self.indices = node_attribute['indices']
214 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices)
215 | elif node_inputs[1] in node_weights:
216 | self.indices = node_weights[node_inputs[1]]
217 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices)
218 | else:
219 | self.indices = tensor_grap[node_inputs[1]]
220 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
221 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
222 | if len(node_inputs) == 1 or layout_dict[node_inputs[1]] != Layout.Channel_Last:
223 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices)
224 |
225 | def gather_elements(self, input_tensor, indices, axis):
226 | # Get the shape of the input tensor and the indices tensor
227 | input_shape = tf.shape(input_tensor)
228 | indices_shape = tf.shape(indices)
229 |
230 | # Create indices for all dimensions
231 | idx = tf.meshgrid(*[tf.range(s) for s in indices_shape], indexing='ij')
232 | idx = [tf.cast(i, tf.int64) for i in idx]
233 |
234 | # Replace the axis index with the provided indices
235 | idx[axis] = tf.cast(indices, tf.int64)
236 |
237 | # Stack indices to form the final gather indices
238 | gather_indices = tf.stack(idx, axis=-1)
239 |
240 | # Use tf.gather_nd to gather elements
241 | output_tensor = tf.gather_nd(input_tensor, gather_indices)
242 |
243 | return output_tensor
244 |
245 | def __call__(self, inputs):
246 | return self.gather_elements(inputs, self.indices, self.axis)
247 |
248 | @OPERATOR.register_operator("Tile")
249 | class TFTile():
250 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
251 | super().__init__()
252 | self.repeats = node_attribute['repeats'] if 'repeats' in node_attribute else node_weights[node_inputs[1]]
253 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
254 | self.repeats = dimension_utils.shape_NCD_to_NDC_format(self.repeats)
255 |
256 | def __call__(self, inputs):
257 | for i in range(len(self.repeats)):
258 | if self.repeats[i] > 1:
259 | inputs = tf.repeat(inputs, self.repeats[i], axis=i)
260 | return inputs
261 |
262 | @OPERATOR.register_operator("Unsqueeze")
263 | class TFUnsqueeze():
264 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
265 | super().__init__()
266 | self.axis = node_attribute['axes'] if 'axes' in node_attribute else node_weights[node_inputs[1]]
267 | if not isinstance(self.axis, int):
268 | self.axis = int(self.axis[0])
269 | input_shape = tensor_grap[node_inputs[0]].shape
270 | if len(input_shape) == 1:
271 | layout_dict[node_outputs[0]] = Layout.Channel_None
272 | elif len(input_shape) == 2:
273 | layout_dict[node_outputs[0]] = Layout.Channel_First
274 | else:
275 | layout_dict[node_outputs[0]] = layout_dict[node_inputs[0]]
276 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
277 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
278 |
279 | def __call__(self, inputs):
280 | return tf.expand_dims(inputs, self.axis)
281 |
282 | @OPERATOR.register_operator("Squeeze")
283 | class TFSqueeze():
284 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
285 | super().__init__()
286 | self.axis = node_attribute['axes'] if 'axes' in node_attribute else node_weights[node_inputs[1]]
287 | if not isinstance(self.axis, int):
288 | self.axis = int(self.axis[0])
289 | input_shape = tensor_grap[node_inputs[0]].shape
290 | if len(input_shape) <= 3:
291 | layout_dict[node_outputs[0]] = Layout.Channel_None
292 | if len(input_shape) > 2 and layout_dict[node_inputs[0]] == Layout.Channel_Last:
293 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
294 |
295 | def __call__(self, inputs):
296 | return tf.squeeze(inputs, self.axis)
297 |
298 | @OPERATOR.register_operator("DepthToSpace")
299 | class TFDepthToSpace():
300 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None:
301 | super().__init__()
302 | self.block_size = node_attribute.get("blocksize", 2)
303 | self.mode = node_attribute.get("mode", "DCR")
304 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last
305 |
306 | def __call__(self, inputs):
307 | if not self.channel_last:
308 | inputs = dimension_utils.tensor_NDC_to_NCD_format(inputs)
309 | if self.mode == "DCR":
310 | return tf.nn.depth_to_space(inputs, self.block_size)
311 | elif self.mode == "CRD":
312 | # help want, native tensorflow is not support CRD mode, this way will generate 5 dims op.
313 | b, h, w, c = inputs.shape
314 | inputs = tf.reshape(inputs, [b, h, w, c//(self.block_size * self.block_size), self.block_size, self.block_size])
315 | inputs = tf.transpose(inputs, perm=[0, 1, 4, 2, 5, 3])
316 | inputs = tf.reshape(inputs, [b, h*self.block_size, w*self.block_size, c//(self.block_size * self.block_size)])
317 | return inputs
318 | else:
319 | raise KeyError(f"For DepthToSpace, mode must be [DCR, CRD], not {self.mode}")
--------------------------------------------------------------------------------
/onnx2tflite/layers/mathematics_layers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from onnx2tflite.utils.definitions import Layout
6 | from onnx2tflite.utils import OPERATOR, dimension_utils, np2tf_type
7 |
8 | LOG = logging.getLogger("calculations_layers :")
9 |
10 | def np2tf(x):
11 | if isinstance(x, np.ndarray):
12 | x = tf.convert_to_tensor(x, dtype=np2tf_type[x.dtype.name])
13 | return x, False
14 | return x, True
15 |
16 | def match_tensor(x1:tf.Tensor or np.ndarray, x2:tf.Tensor or np.ndarray, x1_layout:Layout, x2_layout:Layout):
17 |
18 | x1, f1 = np2tf(x1)
19 | x2, f2 = np2tf(x2)
20 |
21 | # no need to transpose if all var are tensor, we assume tensor are computed by gragh.
22 | if f1 and f2:
23 | if x1_layout != x2_layout:
24 | if x1_layout == Layout.Channel_Last:
25 | x1 = dimension_utils.tensor_NDC_to_NCD_format(x1)
26 | elif x2_layout == Layout.Channel_Last:
27 | x2 = dimension_utils.tensor_NDC_to_NCD_format(x2)
28 | return x1, x2, Layout.Channel_First
29 |
30 | # ensure tensor is set to x1, const weights set to x2
31 | out_layout = x1_layout
32 | if f2:
33 | x1, x2 = x2, x1
34 | out_layout = x2_layout
35 |
36 |
37 | if out_layout == Layout.Channel_Last:
38 | if x1.shape.ndims != x2.shape.ndims:
39 | while x2.shape.ndims < x1.shape.ndims:
40 | x2 = tf.expand_dims(x2, axis=0)
41 | x2 = dimension_utils.tensor_NCD_to_NDC_format(x2)
42 |
43 | x2 = tf.cast(x2, x1.dtype)
44 | return (x2, x1, out_layout) if f2 else (x1, x2, out_layout)
45 |
46 | '''
47 | tensor(NDC) + const
48 | tensor(NCD) + const
49 | tensor(NDC) + tensor(NDC)
50 | tensor(NCD) + tensor(NCD)
51 | '''
52 |
53 | class BaseArithmetic:
54 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
55 | self.left_val, self.right_val = None, None
56 | left_layout, right_layout = Layout.Default, Layout.Default
57 |
58 | if node_inputs[0] in tensor_grap:
59 | self.left_val = tensor_grap[node_inputs[0]]
60 | left_layout = layout_dict[node_inputs[0]]
61 | else:
62 | self.left_val = node_weights[node_inputs[0]]
63 |
64 | if node_inputs[1] in tensor_grap:
65 | self.right_val = tensor_grap[node_inputs[1]]
66 | right_layout = layout_dict[node_inputs[1]]
67 | else:
68 | self.right_val = node_weights[node_inputs[1]]
69 |
70 | if left_layout == right_layout:
71 | return
72 |
73 | self.left_val, self.right_val, out_layout = match_tensor(self.left_val, self.right_val, left_layout, right_layout)
74 | layout_dict[node_outputs[0]] = out_layout
75 |
76 | @OPERATOR.register_operator("Add")
77 | class TFAdd(BaseArithmetic):
78 | def __init__(self, *args, **kwargs):
79 | super().__init__(*args, **kwargs)
80 |
81 | def __call__(self, *args, **kwargs):
82 | return self.left_val + self.right_val
83 |
84 | @OPERATOR.register_operator("Sub")
85 | class TFSub(BaseArithmetic):
86 | def __init__(self, *args, **kwargs):
87 | super().__init__(*args, **kwargs)
88 |
89 | def __call__(self, *args, **kwargs):
90 | return self.left_val - self.right_val
91 |
92 | @OPERATOR.register_operator("Mul")
93 | class TFMul(BaseArithmetic):
94 | def __init__(self,*args, **kwargs):
95 | super().__init__(*args, **kwargs)
96 |
97 | def __call__(self, *args, **kwargs):
98 | return self.left_val * self.right_val
99 |
100 | @OPERATOR.register_operator("Div")
101 | class TFDiv(BaseArithmetic):
102 | def __init__(self,*args, **kwargs):
103 | super().__init__(*args, **kwargs)
104 |
105 | def __call__(self, *args, **kwargs):
106 | return self.left_val / self.right_val
107 |
108 | @OPERATOR.register_operator("MatMul")
109 | class TFMatMul():
110 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
111 | super().__init__()
112 | if node_inputs[0] in tensor_grap:
113 | self.A = tensor_grap[node_inputs[0]]
114 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
115 | self.A = dimension_utils.tensor_NDC_to_NCD_format(self.A)
116 | else:
117 | self.A = node_weights[node_inputs[0]]
118 |
119 | if node_inputs[1] in tensor_grap:
120 | self.B = tensor_grap[node_inputs[1]]
121 | if layout_dict[node_inputs[1]] == Layout.Channel_Last:
122 | self.B = dimension_utils.tensor_NDC_to_NCD_format(self.B)
123 | else:
124 | self.B = node_weights[node_inputs[1]]
125 |
126 | self.dense = tf.keras.layers.Dense(self.B.shape[-1],
127 | weights=[self.B],
128 | use_bias=False)
129 |
130 | layout_dict[node_outputs[0]] = Layout.Channel_First
131 |
132 | def __call__(self, *args, **kwargs):
133 | # out = tf.matmul(self.A, self.B)
134 | try:
135 | out = self.dense(self.A)
136 | except Exception:
137 | out = tf.matmul(self.A, self.B)
138 | return out
139 |
140 | @OPERATOR.register_operator("Mod")
141 | class TFMod():
142 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
143 | super().__init__()
144 | self.fmod = bool(node_attribute.get("fmod", 0))
145 | self.mod_value = None
146 | if node_inputs[1] in node_weights:
147 | self.mod_value = node_weights[node_inputs[1]]
148 | else:
149 | self.mod_value = tensor_grap[node_inputs[1]]
150 |
151 | def __call__(self, inputs):
152 | if self.fmod:
153 | return tf.math.floormod(inputs, tf.cast(self.mod_value, inputs.dtype))
154 | else:
155 | return tf.math.mod(inputs, tf.cast(self.mod_value, inputs.dtype))
156 |
157 | @OPERATOR.register_operator("Pow")
158 | class TFPow():
159 | def __init__(self, tensor_grap, node_weights, node_inputs, *args, **kwargs):
160 | super().__init__()
161 | self.power_index = node_weights[node_inputs[1]]
162 |
163 | def __call__(self, inputs, *args, **kwargs):
164 | return tf.pow(inputs, self.power_index)
165 |
166 | @OPERATOR.register_operator("Reciprocal")
167 | class TFReciprocal():
168 | def __init__(self, *args, **kwargs):
169 | super().__init__()
170 |
171 | def __call__(self, inputs, *args, **kwargs):
172 | return 1/inputs
173 |
174 | @OPERATOR.register_operator("Sqrt")
175 | class TFSqrt():
176 | def __init__(self, *args, **kwargs):
177 | super().__init__()
178 |
179 | def __call__(self, inputs, *args, **kwargs):
180 | return tf.sqrt(inputs)
181 |
182 | @OPERATOR.register_operator("Exp")
183 | class TFSqrt():
184 | def __init__(self, *args, **kwargs):
185 | super().__init__()
186 |
187 | def __call__(self, inputs, *args, **kwargs):
188 | return tf.exp(inputs)
189 |
190 | @OPERATOR.register_operator("Log")
191 | class TFLog():
192 | def __init__(self, *args, **kwargs):
193 | super().__init__()
194 |
195 | def __call__(self, inputs, *args, **kwargs):
196 | return tf.log(inputs)
197 |
198 | class ReduceBase:
199 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
200 | self.keep_dims = node_attribute.get("keepdims", 1) == 1
201 | input_shape_len = len(tensor_grap[node_inputs[0]].shape)
202 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
203 | self.axes = [dimension_utils.channel_to_last_dimension(i) if i >=0 else dimension_utils.channel_to_last_dimension(input_shape_len + i) for i in node_attribute.get("axes", [-1])]
204 | else:
205 | self.axes = [i if i >=0 else input_shape_len + i for i in node_attribute.get("axes", [-1])]
206 |
207 | @OPERATOR.register_operator("ReduceSum")
208 | class TFReduceSum(ReduceBase):
209 | def __init__(self, *args, **kwargs):
210 | super().__init__(*args, **kwargs)
211 |
212 | def __call__(self, inputs, *args, **kwargs):
213 | return tf.math.reduce_sum(inputs, axis=self.axes, keepdims=self.keep_dims)
214 |
215 | @OPERATOR.register_operator("ReduceMean")
216 | class TFReduceMean(ReduceBase):
217 | def __init__(self, *args, **kwargs):
218 | super().__init__(*args, **kwargs)
219 |
220 | def __call__(self, inputs, *args, **kwargs):
221 | return tf.math.reduce_mean(inputs, axis=self.axes, keepdims=self.keep_dims)
222 |
223 | @OPERATOR.register_operator("ReduceMax")
224 | class TFReduceMax(ReduceBase):
225 | def __init__(self, *args, **kwargs):
226 | super().__init__(*args, **kwargs)
227 |
228 | def __call__(self, inputs, *args, **kwargs):
229 | return tf.math.reduce_max(inputs, axis=self.axes, keepdims=self.keep_dims)
230 |
231 | @OPERATOR.register_operator("ReduceMin")
232 | class TFReduceMin(ReduceBase):
233 | def __init__(self, *args, **kwargs):
234 | super().__init__(*args, **kwargs)
235 |
236 | def __call__(self, inputs, *args, **kwargs):
237 | return tf.math.reduce_min(inputs, axis=self.axes, keepdims=self.keep_dims)
238 |
239 | @OPERATOR.register_operator("ArgMax")
240 | class TFArgMax():
241 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
242 | super().__init__()
243 | self.axis = node_attribute.get('axis', 0)
244 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
245 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
246 | self.keepdims = node_attribute.get("keepdims", 1) == 1
247 |
248 | def __call__(self, inputs, *args, **kwargs):
249 | _inputs = tf.argmax(inputs, axis=self.axis)
250 | if self.keepdims:
251 | _inputs = tf.expand_dims(_inputs, axis=self.axis)
252 | return _inputs
253 |
254 | @OPERATOR.register_operator("ArgMin")
255 | class TFArgMin():
256 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs):
257 | super().__init__()
258 | self.axis = node_attribute.get('axis', 0)
259 | if layout_dict[node_inputs[0]] == Layout.Channel_Last:
260 | self.axis = dimension_utils.channel_to_last_dimension(self.axis)
261 | self.keepdims = node_attribute.get("keepdims", 1) == 1
262 |
263 | def __call__(self, inputs, *args, **kwargs):
264 | _inputs = tf.argmax(inputs, axis=self.axis)
265 | if self.keepdims:
266 | _inputs = tf.expand_dims(_inputs, axis=self.axis)
267 | return _inputs
268 |
269 | @OPERATOR.register_operator("Erf")
270 | class TFErf():
271 | def __init__(self, *args, **kwargs) -> None:
272 | pass
273 |
274 | def __call__(self, inputs):
275 | inputs = tf.math.erf(inputs)
276 | return inputs
--------------------------------------------------------------------------------
/onnx2tflite/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dimension_utils import *
2 | from .op_registry import OPERATOR
3 | from .definitions import *
--------------------------------------------------------------------------------
/onnx2tflite/utils/definitions.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from abc import ABC
3 | from enum import Enum, unique
4 |
5 | @unique
6 | class Layout(Enum):
7 | Default = 0
8 | Channel_First = 1 << 0# for onnx format
9 | Channel_Last = 1 << 1 # for tensorflow format
10 | Channel_None = 1 << 2 # no channel
11 |
12 | class Node_Layout:
13 | def __init__(self, name:str, pre:list=[], nxt:list=[]) -> None:
14 | self.name = name
15 | self.pre = pre
16 | self.nxt = nxt
17 | self.layout = Layout.Default
18 |
19 | class BaseOP(ABC):
20 | def __init__(self, tensor_graph, const_weights, node_attributes, node_inputs, node_outputs, layout_dict) -> None:
21 | pass
22 |
23 | onnx2tf_type = {
24 | 1: tf.float32, # ONNX_FLOAT
25 | 2: tf.uint8, # ONNX_UINT8
26 | 3: tf.int8, # ONNX_INT8
27 | 4: tf.uint16, # ONNX_UINT16
28 | 5: tf.int16, # ONNX_INT16
29 | 6: tf.int32, # ONNX_INT32
30 | 7: tf.int64, # ONNX_INT64
31 | 8: tf.string, # ONNX_STRING
32 | 9: tf.bool, # ONNX_BOOL
33 | 10: tf.float16, # ONNX_FLOAT16
34 | 11: tf.float64, # ONNX_DOUBLE
35 | 12: tf.uint32, # ONNX_UINT32
36 | 13: tf.uint64, # ONNX_UINT64
37 | 14: tf.complex64, # ONNX_COMPLEX64
38 | 15: tf.complex128 # ONNX_COMPLEX128
39 | }
40 |
41 | np2tf_type = {
42 | "int32": tf.int32,
43 | "int64": tf.int64,
44 | "float32": tf.float32,
45 | "float64": tf.float64,
46 | "bool": tf.bool,
47 | "uint8": tf.uint8,
48 | "int8": tf.int8,
49 | "int16": tf.int16,
50 | "uint16": tf.uint16,
51 | "uint32": tf.uint32,
52 | "uint64": tf.uint64,
53 | "complex64": tf.complex64,
54 | "complex128": tf.complex128
55 | }
56 |
57 | FORCE_CHANNEL_LAST_OP = ["Conv", "ConvTranspose", "DepthToSpace", "Pad", "AveragePool", "MaxPool", "Upsample", "Resize", "Gemm"]
58 | FORCE_CHANNEL_FIRST_OP = ["Reshape", "Transpose", "ScatterND", "MatMul"]
59 |
60 |
--------------------------------------------------------------------------------
/onnx2tflite/utils/dimension_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | '''
3 | shape and axis transform utils func.
4 | '''
5 | def channel_to_last_dimension(axis):
6 | '''
7 | make channel first to channel last
8 | '''
9 | if axis == 0:
10 | axis = 0
11 | elif axis == 1:
12 | axis = -1
13 | else:
14 | axis -= 1
15 | return axis
16 |
17 | def shape_NCD_to_NDC_format(shape):
18 | '''
19 | make shape format from channel first to channel last
20 | '''
21 | if len(shape) <= 2:
22 | return tuple(shape)
23 | new_shape = [shape[0], *shape[2:], shape[1]]
24 | return tuple(new_shape)
25 |
26 | def shape_NDC_to_NCD_format(shape):
27 | '''
28 | make shape format from channel last to channel first
29 | '''
30 | if len(shape) <= 2:
31 | return tuple(shape)
32 | new_shape = [shape[0], shape[-1], *shape[1:-1]]
33 | return tuple(new_shape)
34 |
35 | def tensor_NCD_to_NDC_format(tensor):
36 | '''
37 | make tensor format from channel first to channel last
38 | '''
39 | if(len(tensor.shape) > 2):
40 | shape = [i for i in range(len(tensor.shape))]
41 | shape = shape_NCD_to_NDC_format(shape)
42 | tensor = tf.transpose(tensor, perm=shape)
43 | return tensor
44 |
45 | def tensor_NDC_to_NCD_format(tensor):
46 | '''
47 | make tensor format from channel last to channel first
48 | '''
49 | if(len(tensor.shape) > 2):
50 | shape = [i for i in range(len(tensor.shape))]
51 | shape = shape_NDC_to_NCD_format(shape)
52 | tensor = tf.transpose(tensor, perm=shape)
53 | return tensor
54 |
55 | def intfloat_to_list(x:int or float, lens:int):
56 | if isinstance(x, (int, float)):
57 | return [x]*lens
58 | else:
59 | return x
--------------------------------------------------------------------------------
/onnx2tflite/utils/graph_tools.py:
--------------------------------------------------------------------------------
1 | from onnx import numpy_helper
2 | import tensorflow as tf
3 | from tensorflow import keras
4 | from .definitions import *
5 |
6 | # copy from https://github.com/gmalivenko/onnx2keras
7 | def decode_node_attribute(node)->dict:
8 | """
9 | Parse ONNX attributes to Python dictionary
10 | :param args: ONNX attributes object
11 | :return: Python dictionary
12 | """
13 | def onnx_attribute_to_dict(onnx_attr):
14 | """
15 | Parse ONNX attribute
16 | :param onnx_attr: ONNX attribute
17 | :return: Python data type
18 | """
19 | if onnx_attr.HasField('t'):
20 | return numpy_helper.to_array(getattr(onnx_attr, 't'))
21 |
22 | for attr_type in ['f', 'i']:
23 | if onnx_attr.HasField(attr_type):
24 | return getattr(onnx_attr, attr_type)
25 |
26 | # s need to be decode, bytes to string
27 | if onnx_attr.HasField('s'):
28 | return getattr(onnx_attr, 's').decode()
29 |
30 | for attr_type in ['floats', 'ints', 'strings']:
31 | if getattr(onnx_attr, attr_type):
32 | return list(getattr(onnx_attr, attr_type))
33 | return {arg.name: onnx_attribute_to_dict(arg) for arg in node.attribute}
34 |
35 | def build_tf_inputs(model_graph, layout_dict:dict):
36 | inputs_name = []
37 | for inp in model_graph.input:
38 | input_shape = [x.dim_value for x in inp.type.tensor_type.shape.dim]
39 | if input_shape == []:
40 | continue
41 | inputs_name.append(inp.name)
42 | layout_dict[inp.name] = Layout.Default
43 | if len(input_shape) < 3:
44 | layout_dict[inp.name] = Layout.Channel_None
45 |
46 | _inputs_name = inputs_name.copy()
47 | for node in model_graph.node:
48 | op_name, node_inputs = node.op_type, node.input
49 | # output_layout = Layout.Default
50 | for ninp in node_inputs:
51 | if ninp in _inputs_name and op_name in FORCE_CHANNEL_LAST_OP and layout_dict[ninp] == Layout.Default:
52 | layout_dict[ninp] = Layout.Channel_Last
53 | _inputs_name.remove(ninp)
54 | if ninp in _inputs_name and op_name in FORCE_CHANNEL_FIRST_OP and layout_dict[ninp] == Layout.Default:
55 | layout_dict[ninp] = Layout.Channel_First
56 | _inputs_name.remove(ninp)
57 | # output_layout = output_layout | node_dict[ninp]
58 |
59 | if len(_inputs_name) == 0:
60 | break
61 |
62 | input_nodes = {}
63 | for inp in model_graph.input:
64 | input_shape = [x.dim_value for x in inp.type.tensor_type.shape.dim]
65 | if input_shape == []:
66 | continue
67 | batch_size = 1 if input_shape[0] <= 0 else input_shape[0]
68 | input_shape = input_shape[1:]
69 | if layout_dict[inp.name] == Layout.Channel_Last:
70 | input_shape = input_shape[1:] + input_shape[0:1]
71 |
72 | input_nodes[inp.name] = keras.Input(shape=input_shape, batch_size=batch_size, dtype=onnx2tf_type.get(inp.type.tensor_type.elem_type))
73 |
74 | return input_nodes
75 |
--------------------------------------------------------------------------------
/onnx2tflite/utils/op_registry.py:
--------------------------------------------------------------------------------
1 | class Registry(object):
2 | def __init__(self, name) -> None:
3 | self._name = name
4 | self._operator_dict = dict()
5 |
6 | def __len__(self):
7 | return len(self._operator_dict)
8 |
9 | @property
10 | def name(self):
11 | return self._name
12 |
13 | @property
14 | def operator_dict(self):
15 | return self._operator_dict
16 |
17 | def get(self, key):
18 | return self._operator_dict.get(key, None)
19 |
20 | def _register_operator(self, op_class, op_name=None):
21 | if (not isinstance(op_name, str)) or op_name is None:
22 | op_name = op_class.__name__
23 |
24 | if self._operator_dict.get(op_name, None):
25 | raise KeyError(f'{op_name} is already registered in {self._name}')
26 |
27 | self._operator_dict[op_name] = op_class
28 |
29 | def register_operator(self, name=None, op_class=None):
30 | if op_class is not None:
31 | self._register_operator(op_class, name)
32 | return op_class
33 |
34 | def _register(cls):
35 | self._register_operator(cls, name)
36 | return cls
37 |
38 | return _register
39 |
40 | OPERATOR = Registry("TensorflowOP")
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # ONNX->Keras and ONNX->TFLite tools
2 | ## Welcome
3 | If you have some good ideas, welcome to discuss or give project PRs.
4 |
5 | ## Install
6 | ```cmd
7 | git clone https://github.com/MPolaris/onnx2tflite.git
8 | cd onnx2tflite
9 | python setup.py install
10 | ```
11 | ```python
12 | from onnx2tflite import onnx_converter
13 | res = onnx_converter(
14 | onnx_model_path = "./model.onnx",
15 | need_simplify = True,
16 | output_path = "./models/",
17 | target_formats = ['tflite'],
18 | )
19 | ```
20 | ---
21 | ```cmd
22 | # base
23 | python -m onnx2tflite --weights "./your_model.onnx"
24 |
25 | # give save path
26 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path"
27 |
28 | # save tflite model
29 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite"
30 |
31 | # save keras and tflite model
32 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" "keras"
33 |
34 | # cutoff model, redefine inputs and outputs, support middle layers
35 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" --input-node-names "layer_inputname" --output-node-names "layer_outname1" "layer_outname2"
36 |
37 | # quantify model weight, only weight
38 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --weigthquant
39 |
40 | # quantify model weight, include input and output
41 | ## fp16
42 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --fp16
43 | ## recommend
44 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 255 255 255
45 | ## generate random data, instead of read from image file
46 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --int8
47 | ```
48 | ---
49 | ## Features
50 | - High Consistency. Compare to ONNX outputs, average error less than 1e-5 per elements.
51 | - More Faster. Output tensorflow-lite model 30% faster than [onnx_tf](https://github.com/onnx/onnx-tensorflow).
52 | - Auto Channel Align. Auto convert pytorch format(NCWH) to tensorflow format(NWHC).
53 | - Deployment Support. Support output quantitative model, include fp16 quantization and uint8 quantization.
54 | - Code Friendly. I've been trying to keep the code structure simple and clear.
55 | ---
56 |
57 | ## Pytorch -> ONNX -> Tensorflow-Keras -> Tensorflow-Lite
58 |
59 | - ### From torchvision to tensorflow-lite
60 | ```python
61 | import torch
62 | import torchvision
63 | _input = torch.randn(1, 3, 224, 224)
64 | model = torchvision.models.mobilenet_v2(True)
65 | # use default settings is ok
66 | torch.onnx.export(model, _input, './mobilenetV2.onnx', opset_version=11)# or opset_version=13
67 |
68 | from converter import onnx_converter
69 | onnx_converter(
70 | onnx_model_path = "./mobilenetV2.onnx",
71 | need_simplify = True,
72 | output_path = "./",
73 | target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
74 | weight_quant = False,
75 | fp16_model=False,
76 | int8_model = False,
77 | int8_mean = None,
78 | int8_std = None,
79 | image_root = None
80 | )
81 | ```
82 | - ### From custom pytorch model to tensorflow-lite-int8
83 | ```python
84 | import torch
85 | import torch.nn as nn
86 | import torch.nn.functional as F
87 |
88 | class MyModel(nn.Module):
89 | def __init__(self):
90 | self.conv = nn.Sequential(
91 | nn.Conv2d(3, 64, kernel_size=3, padding=1),
92 | nn.BatchNorm2d(64),
93 | nn.ReLU(inplace=True),
94 | )
95 |
96 | def forward(self, x):
97 | return self.conv(x)
98 |
99 | model = MyModel()
100 | model.load_state_dict(torch.load("model_checkpoint.pth", map_location="cpu"))
101 |
102 | _input = torch.randn(1, 3, 224, 224)
103 | torch.onnx.export(model, _input, './mymodel.onnx', opset_version=11)# or opset_version=13
104 |
105 | from converter import onnx_converter
106 | onnx_converter(
107 | onnx_model_path = "./mymodel.onnx",
108 | need_simplify = True,
109 | output_path = "./",
110 | target_formats = ['tflite'], #or ['keras'], ['keras', 'tflite']
111 | weight_quant = False,
112 | int8_model = True, # do quantification
113 | int8_mean = [123.675, 116.28, 103.53], # give mean of image preprocessing
114 | int8_std = [58.395, 57.12, 57.375], # give std of image preprocessing
115 | image_root = "./dataset/train" # give image folder of train
116 | )
117 | ```
118 | ---
119 | ## Validated models
120 | - [SSD](https://github.com/qfgaohao/pytorch-ssd)
121 | - [HRNet](HRNet-Facial-Landmark-Detection)
122 | - [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)
123 | - [YOLOV3](https://github.com/ultralytics/yolov3)
124 | - [YOLOV4](https://github.com/Tianxiaomo/pytorch-YOLOv4)
125 | - [YOLOV5](https://github.com/ultralytics/yolov5)
126 | - [YOLOV6](https://github.com/meituan/YOLOv6)
127 | - [YOLOV7](https://github.com/WongKinYiu/yolov7)
128 | - [YOLOV10](https://github.com/THU-MIG/yolov10)
129 | - [MoveNet](https://github.com/fire717/movenet.pytorch)
130 | - [UNet\FPN](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets)
131 | - ViT(torchvision)
132 | - [SwinTransformerV1](https://github.com/microsoft/Swin-Transformer)
133 | - MLP(custom)
134 | - DCGAN(custom)
135 | - [AutoEncoder/VAE](https://github.com/AntixK/PyTorch-VAE)
136 | - all torchvision classification models
137 | - some segmation models in torchvision
138 | - 1D or 2D CNN without special operators(custom)
139 | ---
140 | ## Add operator by yourself
141 | When you counter unspported operator, you can choose to add it by yourself or make an issue.
142 | It's very simple to implement a new operator parser by following these steps below.
143 | Step 0: Select a corresponding layer code file in [layers folder](./onnx2tflite/layers/), such as activations_layers.py for 'HardSigmoid'.
144 | Step 1: Open it, and edit it:
145 | ```python
146 | # all operators regist through OPERATOR register.
147 | # regist operator's name is onnx operator name.
148 | @OPERATOR.register_operator("HardSigmoid")
149 | class TFHardSigmoid():
150 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None:
151 | '''
152 | :param tensor_grap: dict, key is node name, value is tensorflow-keras node output tensor.
153 | :param node_weights: dict, key is node name, value is static data, such as weight/bias/constant, weight should be transfom by dimension_utils.tensor_NCD_to_NDC_format at most time.
154 | :param node_inputs: List[str], stored node input names, indicates which nodes the input comes from, tensor_grap and node_weights are possible.
155 | :param node_attribute: dict, key is attribute name, such as 'axis' or 'perm'. value type is indeterminate, such as List[int] or int or float. notice that type of 'axis' value should be adjusted form NCHW to NHWC by dimension_utils.channel_to_last_dimension or dimension_utils.shape_NCD_to_NDC_format.
156 | :param node_inputs: List[str], stored node output names.
157 | :param layout_dict: List[Layout], stored all before node's layout.
158 | '''
159 | super().__init__()
160 | self.alpha = node_attribute.get("alpha", 0.2)
161 | self.beta = node_attribute.get("beta", 0.5)
162 |
163 | def __call__(self, inputs):
164 | return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1)
165 | ```
166 | Step 2: Make it work without error.
167 | Step 3: Convert model to tflite without any quantification.
168 |
169 | ---
170 |
171 | # License
172 | This software is covered by Apache-2.0 license.
173 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | onnx
2 | onnxruntime
3 | onnx-simplifier
4 | numpy <= 1.24
5 | tensorflow>=2.5,<2.13
6 | opencv-python
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup, find_packages
3 | abs_path = os.path.dirname(os.path.abspath(__file__))
4 |
5 | setup(
6 | name="onnx2tflite",
7 | version="2.0",
8 | author="MPolaris",
9 | description="onnx to keras/tensorflow lite",
10 | long_description=open(os.path.join(abs_path, "readme.md")).read(),
11 | long_description_content_type='text/markdown',
12 | packages=find_packages(include=['onnx2tflite']),
13 | license="Apache-2.0",
14 | platforms=["Windows", "linux"],
15 | install_requires=open(os.path.join(abs_path, "requirements.txt")).read().splitlines()
16 | )
--------------------------------------------------------------------------------
/test/test_concat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | import torch
5 | import torch.nn as nn
6 | from onnx2tflite import onnx_converter
7 |
8 | MODEL_ROOT = "./unit_test"
9 | os.makedirs(MODEL_ROOT, exist_ok=True)
10 |
11 | @pytest.mark.filterwarnings('ignore::UserWarning')
12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
13 | def test_concat():
14 | class Concat(nn.Module):
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 | self.conv1 = nn.Conv2d(3, 3, 3, 2, 1)
18 | # self.conv2 = nn.Conv2d(3, 3, 3, 2, 1)
19 | self._const = torch.randn(1,2,16,8)
20 |
21 | def forward(self, x1, x2, x3):
22 | x1 = torch.reshape(x1, (1, 3, 16, 8))
23 | # x = torch.transpose(x, (0, 1, 3, 2))
24 | x2 = torch.transpose(x2, 3, 2)
25 | x3 = self.conv1(x3)
26 | x = torch.concat([x1,x2,x3,self._const], dim=1)
27 | return x
28 |
29 | model = Concat()
30 | x1 = torch.randn(1,3*16*8)
31 | x2 = torch.randn(1,3,8,16)
32 | x3 = torch.randn(1,3,32,16)
33 |
34 | onnx_model_path = os.path.join(MODEL_ROOT, "test_concat.onnx")
35 | torch.onnx.export(model, (x1,x2,x3), onnx_model_path, opset_version=11)
36 |
37 | res = onnx_converter(
38 | onnx_model_path = onnx_model_path,
39 | need_simplify = True,
40 | output_path = MODEL_ROOT,
41 | target_formats = ['tflite'],
42 | native_groupconv=False,
43 | fp16_model=False,
44 | int8_model=False,
45 | )
46 |
47 | assert res['tflite_error'] < 1e-3
--------------------------------------------------------------------------------
/test/test_reshape_transpose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | import torch
5 | import torch.nn as nn
6 | from onnx2tflite import onnx_converter
7 |
8 | MODEL_ROOT = "./unit_test"
9 | os.makedirs(MODEL_ROOT, exist_ok=True)
10 |
11 | @pytest.mark.filterwarnings('ignore::UserWarning')
12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
13 | def test_reshape_trans():
14 | class test1(nn.Module):
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 | self.conv1 = nn.Conv2d(3, 3, 3, 2, 1)
18 | self.conv2 = nn.Conv2d(3, 3, 3, 2, 1)
19 |
20 | def forward(self, x):
21 | x = torch.reshape(x, (1, 3, 32, 16))
22 | # x = torch.transpose(x, (0, 1, 3, 2))
23 | x = torch.transpose(x, 3, 2)
24 | x = self.conv1(x)
25 | x = self.conv2(x)
26 | return x
27 |
28 | model = test1()
29 | x = torch.randn(1, 3*32*16)
30 |
31 | onnx_model_path = os.path.join(MODEL_ROOT, "test_reshape_trans.onnx")
32 | torch.onnx.export(model, x, onnx_model_path, opset_version=11)
33 |
34 | res = onnx_converter(
35 | onnx_model_path = onnx_model_path,
36 | need_simplify = True,
37 | output_path = MODEL_ROOT,
38 | target_formats = ['tflite'],
39 | native_groupconv=False,
40 | fp16_model=False,
41 | int8_model = False,
42 | )
43 |
44 | assert res['tflite_error'] < 1e-3
--------------------------------------------------------------------------------
/test/test_squeeze.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | import torch
5 | import torch.nn as nn
6 | from onnx2tflite import onnx_converter
7 |
8 | MODEL_ROOT = "./unit_test"
9 | os.makedirs(MODEL_ROOT, exist_ok=True)
10 |
11 | @pytest.mark.filterwarnings('ignore::UserWarning')
12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
13 | def test_squeeze():
14 | class Squeeze(nn.Module):
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 |
18 | def forward(self, x):
19 | x = torch.unsqueeze(x, dim=1)
20 | # x = torch.tile(x, dims=(2,1,1))
21 | x = torch.squeeze(x, dim=1)
22 |
23 | return x
24 |
25 | model = Squeeze()
26 | x = torch.randn(1,1,1,2)
27 |
28 | onnx_model_path = os.path.join(MODEL_ROOT, "test_squeeze.onnx")
29 | torch.onnx.export(model, x, onnx_model_path, opset_version=11)
30 |
31 | res = onnx_converter(
32 | onnx_model_path = onnx_model_path,
33 | need_simplify = True,
34 | output_path = MODEL_ROOT,
35 | target_formats = ['tflite'],
36 | native_groupconv=False,
37 | fp16_model=False,
38 | int8_model=False,
39 | )
40 |
41 | assert res['tflite_error'] < 1e-3
--------------------------------------------------------------------------------
/test/test_torchvison.py:
--------------------------------------------------------------------------------
1 | '''
2 | unit test for torchvision models
3 | '''
4 | import os
5 | import pytest
6 |
7 | import torch
8 | import torchvision
9 | from onnx2tflite import onnx_converter
10 |
11 | MODEL_ROOT = "./unit_test"
12 | os.makedirs(MODEL_ROOT, exist_ok=True)
13 |
14 | @pytest.mark.filterwarnings('ignore::UserWarning')
15 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
16 | def test_resnet():
17 | model = torchvision.models.resnet18(False)
18 | onnx_model_path = os.path.join(MODEL_ROOT, "resnet18.onnx")
19 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13)
20 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
21 | assert error < 1e-3
22 |
23 | @pytest.mark.filterwarnings('ignore::UserWarning')
24 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
25 | def test_mobilenet():
26 | model = torchvision.models.mobilenet_v2(False)
27 | onnx_model_path = os.path.join(MODEL_ROOT, "mobilenet_v2.onnx")
28 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13)
29 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
30 | assert error < 1e-3
31 |
32 | @pytest.mark.filterwarnings('ignore::UserWarning')
33 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
34 | def test_deeplabv3():
35 | model = torchvision.models.segmentation.deeplabv3_resnet50(False)
36 | onnx_model_path = os.path.join(MODEL_ROOT, "deeplabv3_resnet50.onnx")
37 | torch.onnx.export(model, torch.randn(1, 3, 512, 1024), onnx_model_path, opset_version=13)
38 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
39 | assert error < 1e-3
40 |
41 | @pytest.mark.filterwarnings('ignore::UserWarning')
42 | @pytest.mark.filterwarnings('ignore::DeprecationWarning')
43 | def test_vit():
44 | model = torchvision.models.vit_b_16(False)
45 | onnx_model_path = os.path.join(MODEL_ROOT, "vit_b_16.onnx")
46 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13)
47 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error']
48 | assert error < 1e-3
--------------------------------------------------------------------------------