├── .gitignore ├── LICENSE ├── onnx2tflite ├── __init__.py ├── __main__.py ├── components │ ├── __init__.py │ ├── builder.py │ ├── dataloader.py │ ├── onnx_loader.py │ └── output_check.py ├── converter.py ├── layers │ ├── __init__.py │ ├── activations_layers.py │ ├── common_layers.py │ ├── conv_layers.py │ ├── deformation_layers.py │ └── mathematics_layers.py └── utils │ ├── __init__.py │ ├── definitions.py │ ├── dimension_utils.py │ ├── graph_tools.py │ └── op_registry.py ├── readme.md ├── requirements.txt ├── setup.py └── test ├── test_concat.py ├── test_reshape_transpose.py ├── test_squeeze.py └── test_torchvison.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | *.tflite 3 | __pycache__/ 4 | .ipynb_checkpoints/ 5 | test.py 6 | gen_model.py 7 | models/ 8 | unit_test/ 9 | onnx2tflite.egg-info/ 10 | build/ 11 | dist/ 12 | main.py 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [MPolaris] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /onnx2tflite/__init__.py: -------------------------------------------------------------------------------- 1 | __VERSION__ = "2.0" 2 | 3 | from .converter import onnx_converter -------------------------------------------------------------------------------- /onnx2tflite/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .converter import onnx_converter 3 | 4 | def parse_opt(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--weights', type=str, required=True, help='onnx model path') 7 | parser.add_argument('--outpath', type=str, default=None, help='tflite model save path') 8 | parser.add_argument('--input-node-names', nargs="+", default=None, help='which inputs is you want, support middle layers, None will using onnx orignal inputs') 9 | parser.add_argument('--output-node-names', nargs="+", default=None, help='which outputs is you want, support middle layers, None will using onnx orignal outputs') 10 | parser.add_argument('--nosimplify', default=False, action='store_true', help='do not simplify model') 11 | parser.add_argument("--native-groupconv", default=False, action='store_true', help='using native method for groupconv, only support for tflite version >= 2.9') 12 | parser.add_argument('--weigthquant', default=False, action='store_true', help='weight only int8 quant') 13 | parser.add_argument('--fp16', default=False, action='store_true', help='fp16 quant, include input output') 14 | parser.add_argument('--int8', default=False, action='store_true', help='int8 quant, include input output') 15 | parser.add_argument('--imgroot', type=str, default=None, help='when int8=True, imgroot should give for calculating running_mean and running_norm') 16 | parser.add_argument('--int8mean', type=float, nargs='+', default=[123.675, 116.28, 103.53], help='int8 image preprocesses mean, float or list') 17 | parser.add_argument('--int8std', type=float, nargs='+', default=[58.395, 57.12, 57.375], help='int8 image preprocesses std, float or list') 18 | parser.add_argument('--formats', nargs='+', default=['keras', 'tflite'], help='available formats are (h5, tflite)') 19 | opt = parser.parse_args() 20 | return opt 21 | 22 | def run(): 23 | opt = parse_opt() 24 | onnx_converter( 25 | onnx_model_path = opt.weights, 26 | need_simplify = not opt.nosimplify, 27 | input_node_names = opt.input_node_names, 28 | output_node_names = opt.output_node_names, 29 | output_path = opt.outpath, 30 | target_formats = opt.formats, 31 | native_groupconv = opt.native_groupconv, 32 | weight_quant=opt.weigthquant, 33 | fp16_model=opt.fp16, 34 | int8_model=opt.int8, 35 | int8_mean=opt.int8mean, 36 | int8_std=opt.int8std, 37 | image_root=opt.imgroot 38 | ) 39 | 40 | if __name__ == "__main__": 41 | run() -------------------------------------------------------------------------------- /onnx2tflite/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .output_check import get_elements_error 2 | from .onnx_loader import load_onnx_modelproto 3 | from .builder import keras_builder, tflite_builder 4 | 5 | __all__ = ['load_onnx_modelproto', 'keras_builder', 'tflite_builder', 'get_elements_error'] -------------------------------------------------------------------------------- /onnx2tflite/components/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | from onnx import numpy_helper 7 | from .dataloader import RandomLoader, ImageLoader 8 | 9 | from onnx2tflite.utils import OPERATOR 10 | from onnx2tflite.layers import conv_layers 11 | from onnx2tflite.utils.definitions import * 12 | from onnx2tflite.utils.graph_tools import build_tf_inputs, decode_node_attribute 13 | 14 | def keras_builder(onnx_model, native_groupconv:bool=False): 15 | 16 | conv_layers.USE_NATIVE_GROUP_CONV = native_groupconv 17 | 18 | model_graph = onnx_model.graph 19 | layout_dict, tf_tensor = {}, {} 20 | 21 | ''' 22 | init onnx model's build-in tensors 23 | ''' 24 | onnx_weights = dict() 25 | for initializer in model_graph.initializer: 26 | onnx_weights[initializer.name] = numpy_helper.to_array(initializer) 27 | 28 | ''' 29 | build input nodes 30 | ''' 31 | input_nodes = build_tf_inputs(model_graph, layout_dict) 32 | tf_tensor.update(input_nodes) 33 | 34 | ''' 35 | build model inline node by iterate onnx nodes. 36 | ''' 37 | for node in model_graph.node: 38 | op_name, node_inputs, node_outputs = node.op_type, node.input, node.output 39 | op_attr = decode_node_attribute(node) 40 | 41 | tf_operator = OPERATOR.get(op_name) 42 | if tf_operator is None: 43 | raise KeyError(f"{op_name} not implemented yet") 44 | 45 | _inputs = None 46 | if len(node_inputs) > 0: 47 | _inputs = tf_tensor[node_inputs[0]] if node_inputs[0] in tf_tensor else onnx_weights[node_inputs[0]] 48 | 49 | # init layout 50 | for index in range(len(node_outputs)): 51 | layout_dict[node_outputs[index]] = layout_dict.get(node_inputs[0], Layout.Default) 52 | 53 | res = tf_operator(tf_tensor, onnx_weights, node_inputs, op_attr, node_outputs, layout_dict)(_inputs) 54 | if isinstance(res, list): 55 | for index in range(len(node_outputs)): 56 | tf_tensor[node_outputs[index]] = res[index] 57 | else: 58 | tf_tensor[node_outputs[0]] = res 59 | 60 | ''' 61 | build keras model 62 | ''' 63 | input_nodes = [tf_tensor[x.name] for x in model_graph.input] 64 | outputs_nodes = [tf_tensor[x.name] for x in model_graph.output] 65 | keras_model = keras.Model(inputs=input_nodes, outputs=outputs_nodes) 66 | keras_model.trainable = False 67 | # keras_model.summary() 68 | # print(layout_dict) 69 | input_layout, output_layout = {}, {} 70 | for inp in model_graph.input: 71 | input_layout[inp.name] = layout_dict[inp.name] 72 | for oup in model_graph.output: 73 | output_layout[oup.name] = layout_dict[oup.name] 74 | return keras_model, input_layout, output_layout 75 | 76 | def tflite_builder(keras_model, weight_quant:bool=False, fp16_model=False, int8_model:bool=False, image_root:str=None, 77 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375]): 78 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) 79 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] 80 | if weight_quant or int8_model or fp16_model: 81 | converter.experimental_new_converter = True 82 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 83 | 84 | if fp16_model: 85 | converter.target_spec.supported_types = [tf.float16] 86 | converter.inference_input_type = tf.float32 87 | converter.inference_output_type = tf.float32 88 | elif int8_model: 89 | assert len(keras_model.inputs) == 1, f"help want, only support single input model." 90 | shape = list(keras_model.inputs[0].shape) 91 | dataset = RandomLoader(shape) if image_root is None else ImageLoader(image_root, shape, int8_mean, int8_std) 92 | converter.representative_dataset = lambda: dataset 93 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.SELECT_TF_OPS] 94 | converter.target_spec.supported_types = [] 95 | converter.inference_input_type = tf.uint8 96 | converter.inference_output_type = tf.uint8 97 | converter.experimental_new_converter = True 98 | 99 | tflite_model = converter.convert() 100 | return tflite_model -------------------------------------------------------------------------------- /onnx2tflite/components/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import logging 4 | import numpy as np 5 | 6 | LOG = logging.getLogger("Quantization DataLoder :") 7 | 8 | class RandomLoader(object): 9 | def __init__(self, target_size): 10 | self.target_size = target_size 11 | LOG.warning(f"Generate quantization data from random, it's will lead to accuracy problem!") 12 | 13 | def __iter__(self): 14 | self.index = 0 15 | return self 16 | 17 | def __next__(self): 18 | if self.index > 5: 19 | raise StopIteration() 20 | self.index += 1 21 | return [np.random.randn(*self.target_size).astype(np.float32)] 22 | 23 | class ImageLoader(object): 24 | ''' 25 | generate data for quantization from image datas. 26 | img_quan_data = (img - mean)/std, it's important for accuracy of model. 27 | ''' 28 | VALID_FORMAT = ['.jpg', '.png', '.jpeg'] 29 | 30 | def __init__(self, img_root, target_size, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) -> None: 31 | assert os.path.exists(img_root), F"{img_root} is not exists, please check!" 32 | self.fns = os.listdir(img_root) 33 | self.fns = list(filter(lambda fn: os.path.splitext(fn)[-1].lower() in self.VALID_FORMAT, self.fns)) 34 | self.nums = len(self.fns) 35 | assert self.nums > 0, f"No images detected in {img_root}." 36 | if self.nums > 100: 37 | LOG.warning(f"{self.nums} images detected, the number of recommended images is less than 100.") 38 | else: 39 | LOG.info(f"{self.nums} images detected.") 40 | self.fns = [os.path.join(img_root, fn) for fn in self.fns] 41 | 42 | self.batch, self.size = target_size[0], target_size[1:-1] 43 | if isinstance(mean, list): 44 | mean = np.array(mean, dtype=np.float32) 45 | if isinstance(std, list): 46 | std = np.array(std, dtype=np.float32) 47 | self.mean, self.std = mean, std 48 | 49 | def __iter__(self): 50 | self.index = 0 51 | return self 52 | 53 | def __next__(self): 54 | if self.index >= self.nums: 55 | raise StopIteration() 56 | 57 | _input = cv2.imread(self.fns[self.index]) 58 | _input = cv2.resize(_input, self.size)[:, :, ::-1]#BGR->RGB 59 | _input = _input.astype(np.float32) 60 | 61 | if self.mean is not None: 62 | _input = (_input - self.mean) 63 | if self.std is not None: 64 | _input = _input/self.std 65 | 66 | _input = np.expand_dims(_input, axis=0) 67 | if self.batch > 1: 68 | _input = np.repeat(_input, self.batch, axis=0).astype(np.float32) 69 | 70 | self.index += 1 71 | return [_input] 72 | -------------------------------------------------------------------------------- /onnx2tflite/components/onnx_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import onnx 3 | import logging 4 | from onnxsim import simplify 5 | 6 | LOG = logging.getLogger("onnx_loader running:") 7 | LOG.setLevel(logging.INFO) 8 | 9 | def clean_model_input(model_proto): 10 | inputs = model_proto.graph.input 11 | name_to_input = {} 12 | for input in inputs: 13 | name_to_input[input.name] = input 14 | 15 | names = [] 16 | for initializer in model_proto.graph.initializer: 17 | if initializer.name in name_to_input: 18 | inputs.remove(name_to_input[initializer.name]) 19 | names.append(initializer.name) 20 | 21 | if len(names) > 0: 22 | LOG.warning(f"[{len(names)}] redundant input nodes are removed.\n \ 23 | nodes name : {','.join(names)}") 24 | 25 | def get_onnx_submodel(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None): 26 | ''' 27 | cutoff onnx model 28 | ''' 29 | model_proto = onnx.load(onnx_model_path) 30 | if input_node_names is None: 31 | input_node_names = [] 32 | for inp in model_proto.graph.input: 33 | input_node_names.append(inp.name) 34 | 35 | if output_node_names is None: 36 | output_node_names = [] 37 | for oup in model_proto.graph.output: 38 | output_node_names.append(oup.name) 39 | del model_proto 40 | 41 | new_model_path = os.path.splitext(onnx_model_path)[0] + "_sub.onnx" 42 | onnx.utils.extract_model(onnx_model_path, new_model_path, input_node_names, output_node_names) 43 | model_proto = onnx.load(new_model_path) 44 | return model_proto 45 | 46 | def get_proto(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None): 47 | if input_node_names is None and output_node_names is None: 48 | return onnx.load(onnx_model_path) 49 | else: 50 | return get_onnx_submodel(onnx_model_path, input_node_names, output_node_names) 51 | 52 | def load_onnx_modelproto(onnx_model_path:str, input_node_names:list=None, output_node_names:list=None, need_simplify:bool=True): 53 | if not os.path.exists(onnx_model_path): 54 | LOG.error(f"{onnx_model_path} is not exists.") 55 | raise FileExistsError(f"{onnx_model_path} is not exists.") 56 | model_proto = get_proto(onnx_model_path, input_node_names, output_node_names) 57 | dynamic_input = False 58 | for inp in model_proto.graph.input: 59 | for x in inp.type.tensor_type.shape.dim: 60 | if x.dim_value <= 0: 61 | dynamic_input = True 62 | break 63 | if need_simplify: 64 | success = False 65 | try: 66 | model_proto, success = simplify(model_proto, check_n=1, dynamic_input_shape=dynamic_input) 67 | except: 68 | success = False 69 | if not success: 70 | LOG.warning(f"onnxsim is failed, maybe make convert fails.") 71 | model_proto = onnx.load(onnx_model_path) 72 | clean_model_input(model_proto) 73 | return model_proto -------------------------------------------------------------------------------- /onnx2tflite/components/output_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import onnxruntime as ort 5 | from onnx2tflite.utils.definitions import Layout 6 | from onnx2tflite.utils.dimension_utils import tensor_NDC_to_NCD_format 7 | 8 | def tflite_run(model_path:str) -> np.ndarray: 9 | ''' 10 | tflite runtime 11 | ''' 12 | tflite_runtime = tf.lite.Interpreter(model_path, num_threads=4) 13 | tflite_runtime.allocate_tensors() 14 | input_details, output_details = tflite_runtime.get_input_details(), tflite_runtime.get_output_details() 15 | for i in range(len(input_details)): 16 | tflite_runtime.set_tensor(input_details[i]['index'], np.ones(input_details[i]['shape'], dtype=np.float32)) 17 | tflite_runtime.invoke() 18 | 19 | # only compare one output is ok. 20 | tflite_output = tflite_runtime.get_tensor(output_details[0]['index']) 21 | return tflite_output 22 | 23 | def keras_run(model_path:str) -> np.ndarray: 24 | ''' 25 | keras runtime 26 | ''' 27 | keras_runtime = tf.keras.models.load_model(model_path) 28 | _input = [] 29 | for inp in keras_runtime.inputs: 30 | _input.append(np.ones(list(inp.shape), dtype=np.float32)) 31 | 32 | keras_output = keras_runtime.predict(_input) 33 | # only compare one output is ok. 34 | if isinstance(keras_output, list): 35 | keras_output = keras_output[0] 36 | return keras_output 37 | 38 | 39 | def get_elements_error(onnx_proto, keras_model_path:str, tflite_model_path:str, input_layout:dict, output_layout:dict) -> dict: 40 | ''' 41 | use ones input arr to check model. 42 | more carefully check is up to youself custom code. 43 | ''' 44 | result = {} 45 | # test onnx 46 | onnx_runtime = ort.InferenceSession(onnx_proto.SerializeToString()) 47 | onnx_inputs = {} 48 | for inp in onnx_runtime.get_inputs(): 49 | shape = inp.shape 50 | if isinstance(shape[0], str) or shape[0] < 1: 51 | shape[0] = 1 52 | onnx_inputs[inp.name] = np.ones(shape, dtype=np.float32) 53 | if len(shape) > 2: 54 | _transpose_index = [i for i in range(len(shape))] 55 | _transpose_index = _transpose_index[0:1] + _transpose_index[2:] + _transpose_index[1:2] 56 | onnx_outputs = onnx_runtime.run([], onnx_inputs) 57 | 58 | channel_last = False 59 | for oup in onnx_proto.graph.output: 60 | channel_last = output_layout[oup.name] == Layout.Channel_Last 61 | break 62 | 63 | if keras_model_path is not None: 64 | # test keras model 65 | keras_output = keras_run(keras_model_path) 66 | if channel_last: 67 | keras_output = tensor_NDC_to_NCD_format(keras_output) 68 | # get max error 69 | keras_max_error = 1000 70 | for onnx_output in onnx_outputs: 71 | if onnx_output.shape != keras_output.shape: 72 | continue 73 | diff = np.abs(onnx_output - keras_output) 74 | max_diff = np.max(diff) 75 | keras_max_error = min(keras_max_error, max_diff) 76 | result['keras'] = keras_max_error 77 | 78 | if tflite_model_path is not None: 79 | # test tflite 80 | tflite_output = tflite_run(tflite_model_path) 81 | if channel_last: 82 | tflite_output = tensor_NDC_to_NCD_format(tflite_output) 83 | # get max error 84 | tflite_max_error = 1000 85 | for onnx_output in onnx_outputs: 86 | if onnx_output.shape != tflite_output.shape: 87 | continue 88 | diff = np.abs(onnx_output - tflite_output) 89 | max_diff = np.max(diff) 90 | tflite_max_error = min(tflite_max_error, max_diff) 91 | result['tflite'] = tflite_max_error 92 | 93 | return result -------------------------------------------------------------------------------- /onnx2tflite/converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from .components import load_onnx_modelproto, keras_builder, tflite_builder, get_elements_error 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | LOG = logging.getLogger("converter running:") 7 | 8 | def onnx_converter(onnx_model_path:str, output_path:str=None, 9 | input_node_names:list=None, output_node_names:list=None, 10 | need_simplify:bool=True, target_formats:list = ['keras', 'tflite'], 11 | native_groupconv:bool=False, 12 | weight_quant:bool=False, fp16_model:bool=False, int8_model:bool=False, image_root:str=None, 13 | int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375])->float: 14 | """ 15 | Converts an ONNX model to various target formats with optional optimizations. 16 | 17 | Parameters: 18 | onnx_model_path (str): Path to the input ONNX model file. 19 | output_path (str, optional): Path to save the converted model(s). If None, the converted model(s) will be saved in the same directory as the input model. 20 | input_node_names (list, optional): List of input node names. If None, the default input nodes of the ONNX model are used. 21 | output_node_names (list, optional): List of output node names. If None, the default output nodes of the ONNX model are used. 22 | need_simplify (bool, optional): If True, the ONNX model will be simplified before conversion. Default is True. 23 | target_formats (list, optional): List of target formats to convert the ONNX model to. Default is ['keras', 'tflite']. 24 | native_groupconv (bool, optional): If True, retains native group convolution operations during conversion. Default is False. 25 | weight_quant (bool, optional): If True, applies weight quantization to the converted model. Default is False. 26 | fp16_model (bool, optional): If True, converts the model to use FP16 precision. Default is False. 27 | int8_model (bool, optional): If True, converts the model to use INT8 precision. Default is False. 28 | image_root (str, optional): Path to the root directory of images for calibration if INT8 quantization is enabled. Default is None. 29 | int8_mean (list or float, optional): Mean values for INT8 quantization. Default is [123.675, 116.28, 103.53]. 30 | int8_std (list or float, optional): Standard deviation values for INT8 quantization. Default is [58.395, 57.12, 57.375]. 31 | 32 | Returns: 33 | float: Error value. 34 | 35 | Note: 36 | - The function supports multiple target formats for conversion and allows for various optimizations such as simplification, quantization, and precision reduction. 37 | - When INT8 quantization is enabled, 'image_root', 'int8_mean', and 'int8_std' parameters are used for calibration. 38 | """ 39 | if not isinstance(target_formats, list) and 'keras' not in target_formats and 'tflite' not in target_formats: 40 | raise KeyError("'keras' or 'tflite' should in list") 41 | 42 | model_proto = load_onnx_modelproto(onnx_model_path, input_node_names, output_node_names, need_simplify) 43 | 44 | keras_model, input_layout, output_layout = keras_builder(model_proto, native_groupconv) 45 | 46 | if 'tflite' in target_formats: 47 | tflite_model = tflite_builder(keras_model, weight_quant, fp16_model, int8_model, image_root, int8_mean, int8_std) 48 | 49 | onnx_path, model_name = os.path.split(onnx_model_path) 50 | if output_path is None: 51 | output_path = onnx_path 52 | output_path = os.path.join(output_path, model_name.split('.')[0]) 53 | 54 | if fp16_model: 55 | output_path = output_path + "_fp16" 56 | elif int8_model: 57 | output_path = output_path + "_int8" 58 | 59 | keras_model_path = None 60 | if 'keras' in target_formats: 61 | keras_model_path = output_path + ".h5" 62 | keras_model.save(keras_model_path) 63 | LOG.info(f"keras model saved in {keras_model_path}") 64 | 65 | tflite_model_path = None 66 | if 'tflite' in target_formats: 67 | tflite_model_path = output_path + ".tflite" 68 | with open(tflite_model_path, "wb") as fp: 69 | fp.write(tflite_model) 70 | 71 | convert_result = {"keras":keras_model_path, "tflite":tflite_model_path, "keras_error":0, "tflite_error":0} 72 | # ignore quantization model 73 | if int8_model: 74 | return convert_result 75 | 76 | error_dict = {} 77 | try: 78 | error_dict = get_elements_error(model_proto, keras_model_path, tflite_model_path, input_layout, output_layout) 79 | keras_error, tflite_error = error_dict.get("keras", None), error_dict.get("tflite", None) 80 | if keras_error: 81 | if keras_error > 1e-2: 82 | LOG.error("h5 model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(keras_error, keras_model_path)) 83 | elif keras_error > 1e-4: 84 | LOG.warning("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path)) 85 | else: 86 | LOG.info("h5 model elements' max error is {:^.4E}, pass, h5 saved in {}".format(keras_error, keras_model_path)) 87 | if tflite_error: 88 | if tflite_error > 1e-2: 89 | LOG.error("tflite model elements' max error has reached {:^.4E}, but convert is done, please check {} carefully!".format(tflite_error, tflite_model_path)) 90 | elif tflite_error > 1e-4: 91 | LOG.warning("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path)) 92 | else: 93 | LOG.info("tflite model elements' max error is {:^.4E}, pass, tflite saved in {}".format(tflite_error, tflite_model_path)) 94 | except: 95 | LOG.warning("convert is successed, but model running is failed, please check carefully!") 96 | 97 | convert_result["keras_error"] = error_dict.get("keras", None) 98 | convert_result["tflite_error"] = error_dict.get("tflite", None) 99 | return convert_result -------------------------------------------------------------------------------- /onnx2tflite/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_layers import * 2 | from .common_layers import * 3 | from .activations_layers import * 4 | from .mathematics_layers import * 5 | from .deformation_layers import * -------------------------------------------------------------------------------- /onnx2tflite/layers/activations_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow import keras 4 | 5 | from onnx2tflite.utils.definitions import Layout 6 | from onnx2tflite.utils import OPERATOR, channel_to_last_dimension, tensor_NCD_to_NDC_format 7 | 8 | @OPERATOR.register_operator("Relu") 9 | class TFRelu(): 10 | def __init__(self, *args, **kwargs) -> None: 11 | super().__init__() 12 | 13 | def __call__(self, inputs): 14 | return keras.activations.relu(inputs) 15 | 16 | @OPERATOR.register_operator("HardSigmoid") 17 | class TFHardSigmoid(): 18 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 19 | super().__init__() 20 | self.alpha = node_attribute.get("alpha", 0.2) 21 | self.beta = node_attribute.get("beta", 0.5) 22 | 23 | def __call__(self, inputs): 24 | return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1) 25 | 26 | @OPERATOR.register_operator("HardSwish") 27 | class TFHardSwish(): 28 | def __init__(self, *args, **kwargs) -> None: 29 | super().__init__() 30 | 31 | def __call__(self, inputs): 32 | return inputs*tf.clip_by_value(inputs/6+0.5, 0, 1) 33 | 34 | @OPERATOR.register_operator("Mish") 35 | class TFMish(): 36 | def __init__(self, *args, **kwargs) -> None: 37 | super().__init__() 38 | 39 | def __call__(self, inputs): 40 | return inputs*tf.tanh(tf.math.log(tf.math.exp(inputs)+1)) 41 | 42 | @OPERATOR.register_operator("Sigmoid") 43 | class TFSigmoid(): 44 | def __init__(self, *args, **kwargs) -> None: 45 | super().__init__() 46 | 47 | def __call__(self, inputs): 48 | return keras.activations.sigmoid(inputs) 49 | 50 | @OPERATOR.register_operator("LeakyRelu") 51 | class TFLeakyRelu(): 52 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 53 | super().__init__() 54 | self.alpha = node_attribute.get('alpha', 0.01) 55 | 56 | def __call__(self, inputs): 57 | return keras.activations.relu(inputs, alpha=self.alpha) 58 | 59 | @OPERATOR.register_operator("PRelu") 60 | class TFPRelu(): 61 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 62 | super().__init__() 63 | if 'slope' in node_attribute: 64 | self.slope = node_attribute['slope'] 65 | elif node_inputs[1] in node_weights: 66 | self.slope = node_weights[node_inputs[1]] 67 | else: 68 | self.slope = tensor_grap[node_inputs[1]] 69 | input_tensor_shape = tensor_grap[node_inputs[0]].shape 70 | channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 71 | if isinstance(self.slope, np.ndarray): 72 | while self.slope.ndim < input_tensor_shape.ndims: 73 | self.slope = self.slope[np.newaxis, :] 74 | if channel_last: 75 | self.slope = tensor_NCD_to_NDC_format(self.slope) 76 | if self.slope.ndim > 1: 77 | # remove batchsize 78 | self.slope = self.slope[0] 79 | axes = [i for i in range(1, input_tensor_shape.ndims-1)] if channel_last else [i for i in range(2, input_tensor_shape.ndims)] 80 | self.PRelu = tf.keras.layers.PReLU(weights=[self.slope], shared_axes = axes) 81 | 82 | def __call__(self, inputs): 83 | return self.PRelu(inputs) 84 | 85 | @OPERATOR.register_operator("Sin") 86 | class TFSin(): 87 | def __init__(self, *args, **kwargs) -> None: 88 | super().__init__() 89 | 90 | def __call__(self, inputs): 91 | return tf.sin(inputs) 92 | 93 | @OPERATOR.register_operator("Sinh") 94 | class TFSinh(): 95 | def __init__(self, *args, **kwargs) -> None: 96 | super().__init__() 97 | 98 | def __call__(self, inputs): 99 | return tf.sinh(inputs) 100 | 101 | @OPERATOR.register_operator("Cos") 102 | class TFCos(): 103 | def __init__(self, *args, **kwargs) -> None: 104 | super().__init__() 105 | 106 | def __call__(self, inputs): 107 | return tf.cos(inputs) 108 | 109 | @OPERATOR.register_operator("Cosh") 110 | class TFCosh(): 111 | def __init__(self, *args, **kwargs) -> None: 112 | super().__init__() 113 | 114 | def __call__(self, inputs): 115 | return tf.cosh(inputs) 116 | 117 | @OPERATOR.register_operator("Tan") 118 | class TFTan(): 119 | def __init__(self, *args, **kwargs) -> None: 120 | super().__init__() 121 | 122 | def __call__(self, inputs): 123 | return tf.tan(inputs) 124 | 125 | @OPERATOR.register_operator("Tanh") 126 | class TFTanh(): 127 | def __init__(self, *args, **kwargs) -> None: 128 | super().__init__() 129 | 130 | def __call__(self, inputs): 131 | return tf.tanh(inputs) 132 | 133 | @OPERATOR.register_operator("Softmax") 134 | class TFSoftmax(): 135 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 136 | super().__init__() 137 | self.axis = node_attribute.get('axis', -1) 138 | if self.axis == -1: 139 | self.axis = len(tensor_grap[node_inputs[0]].shape.as_list()) - 1 140 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 141 | self.axis = channel_to_last_dimension(self.axis) 142 | 143 | def __call__(self, inputs): 144 | return keras.activations.softmax(inputs, axis=self.axis) 145 | 146 | @OPERATOR.register_operator("Softplus") 147 | class TFSoftplus(): 148 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 149 | super().__init__() 150 | 151 | def __call__(self, inputs): 152 | return keras.activations.softplus(inputs) 153 | 154 | @OPERATOR.register_operator("Softsign") 155 | class TFSoftsign(): 156 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 157 | super().__init__() 158 | 159 | def __call__(self, inputs): 160 | return keras.activations.softsign(inputs) 161 | 162 | @OPERATOR.register_operator("Selu") 163 | class TFSelu(): 164 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 165 | super().__init__() 166 | 167 | def __call__(self, inputs): 168 | return keras.activations.selu(inputs) 169 | 170 | @OPERATOR.register_operator("Elu") 171 | class TFElu(): 172 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 173 | super().__init__() 174 | 175 | def __call__(self, inputs): 176 | return keras.activations.elu(inputs) 177 | 178 | @OPERATOR.register_operator("Celu") 179 | class TFCelu(): 180 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 181 | super().__init__() 182 | self.alpha = node_attribute.get("alpha", 1.0) 183 | 184 | def __call__(self, inputs): 185 | return tf.maximum(inputs, 0) + tf.minimum(0, self.alpha*(tf.exp(inputs/self.alpha)-1)) -------------------------------------------------------------------------------- /onnx2tflite/layers/common_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | 7 | from onnx2tflite.utils.definitions import Layout 8 | from onnx2tflite.utils import OPERATOR, intfloat_to_list, dimension_utils 9 | 10 | LOG = logging.getLogger("common_layers :") 11 | 12 | @OPERATOR.register_operator("BatchNormalization") 13 | class TFBatchNormalization(): 14 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs): 15 | super().__init__() 16 | epsilon = node_attribute.get("epsilon", 1e-5) 17 | momentum = node_attribute.get("momentum", 0.9) 18 | 19 | self.bn = keras.layers.BatchNormalization( 20 | gamma_initializer=keras.initializers.Constant(node_weights[node_inputs[1]]), 21 | beta_initializer=keras.initializers.Constant(node_weights[node_inputs[2]]), 22 | moving_mean_initializer=keras.initializers.Constant(node_weights[node_inputs[3]]), 23 | moving_variance_initializer=keras.initializers.Constant(node_weights[node_inputs[4]]), 24 | epsilon=epsilon, 25 | momentum=momentum) 26 | 27 | def __call__(self, inputs): 28 | return self.bn(inputs) 29 | 30 | @OPERATOR.register_operator("InstanceNormalization") 31 | class TFInstanceNormalization(): 32 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 33 | super().__init__() 34 | self.epsilon = node_attribute.get("epsilon", 1e-5) 35 | self.scale = node_weights[node_inputs[1]] 36 | self.bias = node_weights[node_inputs[2]] 37 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 38 | 39 | def __call__(self, inputs): 40 | axes = tuple(range(1, len(inputs.shape)-1)) if self.channel_last else tuple(range(2, len(inputs.shape))) 41 | mean = tf.reduce_mean(inputs, axis=axes, keepdims=True) 42 | var = tf.math.reduce_variance(inputs, axis= axes, keepdims=True) 43 | return self.scale*(inputs - mean)/tf.sqrt(var + self.epsilon) + self.bias 44 | 45 | @OPERATOR.register_operator("Pad") 46 | class TFPad(): 47 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 48 | super().__init__() 49 | if node_attribute.get("pads") is not None: 50 | pads = node_attribute['pads'] 51 | elif node_inputs[1] in node_weights: 52 | pads = node_weights[node_inputs[1]] 53 | else: 54 | pads = tensor_grap[node_inputs[1]] 55 | self.pad = [[pads[0], pads[4]], [pads[2], pads[6]], [pads[3], pads[7]], [pads[1], pads[5]]] 56 | self.model = node_attribute.get("mode", "constant").upper() 57 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 58 | layout_dict[node_outputs[0]] = Layout.Channel_Last 59 | 60 | def __call__(self, inputs): 61 | if not self.channel_last: 62 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs) 63 | return tf.pad(inputs, self.pad, mode=self.model) 64 | 65 | @OPERATOR.register_operator("Clip") 66 | class TFClip(): 67 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs): 68 | super().__init__() 69 | if "min" in node_attribute: 70 | self.min = node_attribute.get("min") 71 | else: 72 | self.min = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]] 73 | if "max" in node_attribute: 74 | self.max = node_attribute.get("max") 75 | else: 76 | self.max = tensor_grap[node_inputs[2]] if node_inputs[2] in tensor_grap else node_weights[node_inputs[2]] 77 | 78 | def __call__(self, inputs): 79 | if float(self.min) == 0 and float(self.max) == 6: 80 | return tf.nn.relu6(inputs) 81 | return tf.clip_by_value(inputs, self.min, self.max) 82 | 83 | @OPERATOR.register_operator("TFGlobalMaxPool") 84 | class TFGlobalMaxPool(): 85 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 86 | super().__init__() 87 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 88 | 89 | def __call__(self, inputs): 90 | if self.channel_last: 91 | return tf.reduce_max(inputs, axis=[i for i in range(1, len(inputs.shape)-1)], keepdims=True) 92 | else: 93 | return tf.reduce_max(inputs, axis=[i for i in range(2, len(inputs.shape))], keepdims=True) 94 | 95 | @OPERATOR.register_operator("GlobalAveragePool") 96 | class TFGlobalAveragePool(): 97 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 98 | super().__init__() 99 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 100 | 101 | def __call__(self, inputs): 102 | if self.channel_last: 103 | return tf.reduce_mean(inputs, axis=[i for i in range(1, len(inputs.shape)-1)], keepdims=True) 104 | else: 105 | return tf.reduce_mean(inputs, axis=[i for i in range(2, len(inputs.shape))], keepdims=True) 106 | 107 | @OPERATOR.register_operator("AveragePool") 108 | class TFAveragePool(): 109 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 110 | super().__init__() 111 | kernel_shape = intfloat_to_list(node_attribute.get("kernel_shape", [2, 2]), 2) 112 | strides = intfloat_to_list(node_attribute.get("strides", [1, 1]), 2) 113 | dilations = intfloat_to_list(node_attribute.get("dilations", [1, 1]), 2) 114 | ceil_mode = node_attribute.get("ceil_mode", 0) 115 | pads = intfloat_to_list(node_attribute.get("pads", [0, 0, 0, 0]), 4) 116 | 117 | func = math.floor if ceil_mode == 0 else math.ceil 118 | 119 | pad_mode = "SAME" 120 | input_shape = tensor_grap[node_inputs[0]].shape 121 | for i in range(len(input_shape)-2): 122 | pad_shape = pads[i] + pads[i+2] 123 | onnx_output_shape = func((input_shape[1+i]+pad_shape-((kernel_shape[i]-1)*dilations[i]+1))/strides[i]+1) 124 | tf_output_shape = math.floor((input_shape[1+i] - kernel_shape[i]) / strides[i]) + 1 125 | pads[2+i] = max(onnx_output_shape-tf_output_shape, pads[2+i]) # right_down pad 126 | if pad_mode == "SAME" and onnx_output_shape != input_shape[1+i]: 127 | pad_mode = "VALID" 128 | self.avg_pool = keras.layers.AveragePooling2D(pool_size=kernel_shape, strides=strides, padding=pad_mode) 129 | 130 | self.pad = None 131 | if pad_mode == "VALID" and pads is not None and np.sum(pads) > 0: 132 | if np.sum(pads) > 0: 133 | self.pad = keras.layers.ZeroPadding2D(padding=((pads[0], pads[2]), (pads[1], pads[3]))) 134 | 135 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 136 | layout_dict[node_outputs[0]] = Layout.Channel_Last 137 | 138 | def __call__(self, inputs): 139 | if not self.channel_last: 140 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs) 141 | if self.pad: 142 | inputs = self.pad(inputs) 143 | return self.avg_pool(inputs) 144 | 145 | @OPERATOR.register_operator("MaxPool") 146 | class TFMaxPool(): 147 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 148 | super().__init__() 149 | kernel_shape = intfloat_to_list(node_attribute.get("kernel_shape", [2, 2]), 2) 150 | strides = intfloat_to_list(node_attribute.get("strides", [1, 1]), 2) 151 | dilations = intfloat_to_list(node_attribute.get("dilations", [1, 1]), 2) 152 | ceil_mode = node_attribute.get("ceil_mode", 0) 153 | pads = intfloat_to_list(node_attribute.get("pads", [0, 0, 0, 0]), 4) 154 | 155 | func = math.floor if ceil_mode == 0 else math.ceil 156 | 157 | pad_mode = "SAME" 158 | input_shape = tensor_grap[node_inputs[0]].shape 159 | for i in range(len(input_shape)-2): 160 | pad_shape = pads[i] + pads[i+2] 161 | onnx_output_shape = func((input_shape[1+i]+pad_shape-((kernel_shape[i]-1)*dilations[i]+1))/strides[i]+1) 162 | tf_output_shape = math.floor((input_shape[1+i] - kernel_shape[i]) / strides[i]) + 1 163 | pads[2+i] = max(onnx_output_shape-tf_output_shape, pads[2+i]) # right_down pad 164 | if pad_mode == "SAME" and onnx_output_shape != input_shape[1+i]: 165 | pad_mode = "VALID" 166 | self.max_pool = keras.layers.MaxPool2D(pool_size=kernel_shape, strides=strides, padding=pad_mode) 167 | 168 | self.pad = None 169 | if pad_mode == "VALID" and pads is not None and np.sum(pads) > 0: 170 | if np.sum(pads) > 0: 171 | self.pad = keras.layers.ZeroPadding2D(padding=((pads[0], pads[2]), (pads[1], pads[3]))) 172 | 173 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 174 | layout_dict[node_outputs[0]] = Layout.Channel_Last 175 | 176 | def __call__(self, inputs): 177 | if not self.channel_last: 178 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs) 179 | if self.pad: 180 | inputs = self.pad(inputs) 181 | return self.max_pool(inputs) 182 | 183 | @OPERATOR.register_operator("Upsample") 184 | class TFUpsample(): 185 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 186 | super().__init__() 187 | _, h, w, _ = tensor_grap[node_inputs[0]].shape 188 | scale = node_weights[node_inputs[1]] 189 | 190 | self.scale = (int(h*scale[2]), int(w*scale[3])) 191 | if node_attribute.get("mode", "nearest").lower() == 'nearest': 192 | self.method = tf.image.ResizeMethod.NEAREST_NEIGHBOR 193 | else: 194 | self.method = tf.image.ResizeMethod.BILINEAR 195 | 196 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 197 | layout_dict[node_outputs[0]] = Layout.Channel_Last 198 | 199 | def __call__(self, inputs): 200 | if not self.channel_last: 201 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs) 202 | return tf.image.resize(inputs, self.scale, method=self.method) 203 | 204 | @OPERATOR.register_operator("Constant") 205 | class TFConstant(): 206 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs): 207 | super().__init__() 208 | self.val = node_attribute['value'] 209 | 210 | def __call__(self, *args, **kwargs): 211 | return self.val 212 | 213 | @OPERATOR.register_operator("ScatterND") 214 | class TFScatterND(): 215 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 216 | super().__init__() 217 | self.indices = node_weights[node_inputs[1]] 218 | self.channle_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 219 | if node_inputs[2] in tensor_grap: 220 | self.updates = tensor_grap[node_inputs[2]] 221 | if self.channle_last: 222 | self.updates = dimension_utils.tensor_NDC_to_NCD_format(self.updates) 223 | else: 224 | self.updates = node_weights[node_inputs[2]] 225 | 226 | layout_dict[node_outputs[0]] = Layout.Channel_First 227 | 228 | def __call__(self, inputs): 229 | if self.channle_last: 230 | inputs = dimension_utils.tensor_NDC_to_NCD_format(inputs) 231 | inputs = tf.tensor_scatter_nd_update(inputs, self.indices, self.updates) 232 | return inputs 233 | 234 | @OPERATOR.register_operator("Resize") 235 | class TFResize(): 236 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 237 | super().__init__() 238 | if node_inputs[-1] in node_weights: 239 | _, _, nh, nw = node_weights[node_inputs[-1]] 240 | if len(node_inputs) != 4: 241 | _, h, w, _ = tensor_grap[node_inputs[0]].shape 242 | nh, nw = int(h*nh), int(w*nw) 243 | self.scale = (nh, nw) 244 | else: 245 | scales = tensor_grap[node_inputs[0]].shape[1:3]*tensor_grap[node_inputs[2]][2:3] 246 | self.scale = scales 247 | 248 | if node_attribute.get("mode", "nearest").lower() == 'nearest': 249 | self.method = tf.image.ResizeMethod.NEAREST_NEIGHBOR 250 | else: 251 | self.method = tf.image.ResizeMethod.BILINEAR 252 | 253 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 254 | layout_dict[node_outputs[0]] = Layout.Channel_Last 255 | 256 | def __call__(self, inputs): 257 | if not self.channel_last: 258 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs) 259 | return tf.image.resize(inputs, self.scale, method=self.method) 260 | 261 | @OPERATOR.register_operator("Gemm") 262 | class TFGemm(): 263 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 264 | super().__init__() 265 | if len(node_inputs) > 2: 266 | weights = [node_weights[node_inputs[1]].T, node_weights[node_inputs[2]]] 267 | else: 268 | weights = [node_weights[node_inputs[1]].T] 269 | 270 | self.dense = keras.layers.Dense(weights[0].shape[1], 271 | weights=weights, 272 | use_bias=len(weights)==2) 273 | 274 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 275 | layout_dict[node_outputs[0]] = Layout.Channel_Last 276 | 277 | def __call__(self, inputs): 278 | if not self.channel_last: 279 | inputs = dimension_utils.tensor_NCD_to_NDC_format(inputs) 280 | return self.dense(inputs) 281 | 282 | @OPERATOR.register_operator("Identity") 283 | class TFIdentity(): 284 | def __init__(self, *args, **kwargs): 285 | super().__init__() 286 | 287 | def __call__(self, inputs): 288 | return inputs 289 | 290 | @OPERATOR.register_operator("Dropout") 291 | class TFDropout(): 292 | ''' 293 | Dropout will be ignored in deployment. 294 | ''' 295 | def __init__(self, *args, **kwargs): 296 | super().__init__() 297 | 298 | def __call__(self, inputs): 299 | return inputs 300 | 301 | @OPERATOR.register_operator("TopK") 302 | class TFTopK(): 303 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None: 304 | 305 | self.axis = node_attribute.get("axis", -1) 306 | self.largest = node_attribute.get("largest", 1) 307 | self.sorted = bool(node_attribute.get("sorted", 1)) 308 | self.K = node_attribute.get('K') if len(node_inputs)==1 else node_weights[node_inputs[1]][0] 309 | 310 | def __call__(self, inputs): 311 | res = tf.math.top_k(inputs, k=self.K, sorted=self.sorted) 312 | return [res[0], res[1]] 313 | 314 | @OPERATOR.register_operator("Cast") 315 | class TFCast(): 316 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs): 317 | super().__init__() 318 | self.cast_to = int(node_attribute.get("to", 1)) 319 | assert self.cast_to > 0 and self.cast_to < 12, f"Unknown cast type [{self.cast_to}]" 320 | self.np_cast_map = { 321 | 1: np.float32, 322 | 2: np.uint8, 323 | 3: np.int8, 324 | 5: np.int16, 325 | 6: np.int32, 326 | 7: np.int64, 327 | 9: np.bool_, 328 | 10: np.float16, 329 | 11: np.double, 330 | } 331 | self.tf_cast_map = { 332 | 1: tf.float32, 333 | 2: tf.uint8, 334 | 3: tf.int8, 335 | 5: tf.int16, 336 | 6: tf.int32, 337 | 7: tf.int64, 338 | 9: tf.bool, 339 | 10: tf.float16, 340 | 11: tf.double, 341 | } 342 | 343 | def __call__(self, inputs): 344 | if isinstance(inputs, list): 345 | for i in range(len(inputs)): 346 | if isinstance(inputs[i], np.ndarray) or isinstance(inputs[i], np.generic): 347 | inputs[i] = self.np_cast_map[self.cast_to](inputs[i]) 348 | else: 349 | inputs[i] = tf.cast(input[i], dtype=self.tf_cast_map[self.cast_to]) 350 | else: 351 | if isinstance(inputs, np.ndarray) or isinstance(inputs, np.generic): 352 | inputs = self.np_cast_map[self.cast_to](inputs) 353 | else: 354 | inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to]) 355 | 356 | return inputs 357 | -------------------------------------------------------------------------------- /onnx2tflite/layers/conv_layers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: MPolaris && yutaka329 && lkdci 3 | 4 | Thanks for yutaka329 with your pad tricks. 5 | https://github.com/MPolaris/onnx2tflite/issues/5 6 | 7 | Thanks for lkdci with your native method of group conv 8 | https://github.com/MPolaris/onnx2tflite/issues/19 9 | ''' 10 | import logging 11 | import tensorflow as tf 12 | from tensorflow import keras 13 | from onnx2tflite.utils.op_registry import OPERATOR 14 | from onnx2tflite.utils.definitions import Layout 15 | from onnx2tflite.utils.dimension_utils import tensor_NCD_to_NDC_format as NCD2NDC 16 | 17 | LOG = logging.getLogger("convolution_layers :") 18 | 19 | # Whether to implement grouped convolution using the native `keras.layers.Conv2D` class with groups !=1 argument. 20 | # This implementation is supported only with tflite version >= 2.9. 21 | # If set to `False`, the grouped convolution is built using regular conv per group then concatenated as a workaround 22 | # to support older version of tflite. 23 | # Using the native keras implementation results in a simplified tflite graph and supposed to run faster. 24 | # See https://github.com/MPolaris/onnx2tflite/issues/19 for more details. 25 | USE_NATIVE_GROUP_CONV = False 26 | 27 | @OPERATOR.register_operator("ConvTranspose") 28 | class TFConvTranspose(): 29 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 30 | super().__init__() 31 | # out_channel, in_channel = node_weights[node_inputs[1]].shape[:2] 32 | dilations, group = node_attribute.get('dilations', 1), node_attribute.get('group', 1) 33 | pads = node_attribute['pads'] if "pads" in node_attribute else None 34 | kernel_shape, strides = node_attribute.get('kernel_shape', 1), node_attribute.get('strides', 1) 35 | 36 | weights = node_weights[node_inputs[1]].transpose(2,3,1,0) 37 | bias = node_weights[node_inputs[2]] if len(node_inputs) == 3 else None 38 | height, width, n_filters, channels = weights.shape 39 | self.pad = None 40 | self.conv = keras.layers.Conv2DTranspose(filters=n_filters, kernel_size=(height, width), strides=strides, padding='VALID', use_bias=False if bias is None else True, 41 | weights=[weights] if bias is None else [weights, bias], 42 | dilation_rate=dilations) 43 | if pads is not None and max(pads) != 0: 44 | padding = None 45 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0): 46 | padding = (pads[0], pads[1]) 47 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0): 48 | padding = ((pads[0], pads[2]), (pads[1], pads[3])) 49 | self.pad = keras.layers.Cropping2D(padding) 50 | 51 | for nop in node_outputs: 52 | layout_dict[nop] = Layout.Channel_Last 53 | 54 | self.need_trans = layout_dict[node_inputs[0]] != Layout.Channel_Last 55 | 56 | def __call__(self, inputs): 57 | if self.need_trans: 58 | inputs = NCD2NDC(inputs) 59 | inputs = self.conv(inputs) 60 | if self.pad: 61 | inputs = self.pad(inputs) 62 | return inputs 63 | 64 | @OPERATOR.register_operator("Conv") 65 | class Convlution(): 66 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 67 | super().__init__() 68 | out_channel, in_channel = node_weights[node_inputs[1]].shape[:2] 69 | dilations, group = node_attribute.get('dilations', 1), node_attribute.get('group', 1) 70 | pads = node_attribute['pads'] if "pads" in node_attribute else None 71 | kernel_shape, strides = node_attribute.get('kernel_shape', 1), node_attribute.get('strides', 1) 72 | 73 | weights = node_weights[node_inputs[1]] if node_inputs[1] in node_weights else tensor_grap[node_inputs[1]] 74 | out_channel, in_channel = weights.shape[:2] 75 | 76 | channel_sequence = [2+i for i in range(len(weights.shape)-2)] + [1, 0] 77 | weights = weights.transpose(*channel_sequence) 78 | 79 | bias = None 80 | if len(node_inputs) == 3: 81 | bias = node_weights[node_inputs[2]] if node_inputs[2] in node_weights else tensor_grap[node_inputs[2]] 82 | 83 | if group == 1: 84 | self.conv = TFConv(in_channel, out_channel, kernel_shape, strides, dilations, pads, weights, bias) 85 | elif group == out_channel: 86 | self.conv = TFDepthwiseConv(kernel_shape, strides, dilations, pads, weights, bias) 87 | else: 88 | if USE_NATIVE_GROUP_CONV: 89 | self.conv = TFConv(in_channel, out_channel, kernel_shape, strides, dilations, pads, weights, bias, group=group) 90 | LOG.warning(f"Group Convolution is detected, using native method, only supported tflite version >= 2.9, \ 91 | if compatibility error occurs and please make USE_NATIVE_GROUP_CONV=False!") 92 | else: 93 | self.conv = TFGroupConv(in_channel, out_channel, kernel_shape, strides, dilations, pads, weights, bias, group=group) 94 | 95 | for nop in node_outputs: 96 | layout_dict[nop] = Layout.Channel_Last 97 | 98 | self.need_trans = layout_dict[node_inputs[0]] != Layout.Channel_Last 99 | 100 | def __call__(self, inputs): 101 | if self.need_trans: 102 | inputs = NCD2NDC(inputs) 103 | return self.conv(inputs) 104 | 105 | class TFConv(): 106 | # Standard convolution 107 | def __init__(self, in_channel_num, out_channel_num, kernel_size=1, 108 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 109 | super().__init__() 110 | 111 | if len(weights.shape) == 3: 112 | self.conv1d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group) 113 | elif len(weights.shape) == 4: 114 | self.conv2d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group) 115 | elif len(weights.shape) == 5: 116 | self.conv3d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group) 117 | else: 118 | raise NotImplementedError(f"Conv{len(weights.shape)-2}d is not implemented") 119 | 120 | def conv1d_init(self, in_channel_num, out_channel_num, kernel_size=1, 121 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 122 | self.pad =None 123 | if pads is not None and max(pads) == 1 and max(strides) == 1: 124 | self.conv = keras.layers.Conv1D( 125 | out_channel_num, kernel_size, strides, "SAME", use_bias=False if bias is None else True, 126 | weights=[weights] if bias is None else [weights, bias], 127 | dilation_rate=dilations, groups=group) 128 | else: 129 | self.conv = keras.layers.Conv1D( 130 | out_channel_num, kernel_size, strides, "VALID", use_bias=False if bias is None else True, 131 | weights=[weights] if bias is None else [weights, bias], 132 | dilation_rate=dilations, groups=group) 133 | if pads is not None and max(pads) != 0: 134 | self.pad = keras.layers.ZeroPadding1D(padding=pads) 135 | 136 | def conv2d_init(self, in_channel_num, out_channel_num, kernel_size=1, 137 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 138 | if isinstance(dilations, int): 139 | dilations = (dilations, dilations) 140 | if isinstance(strides, int): 141 | strides = (strides, strides) 142 | if dilations[0] != 1 and strides[0] != 1: 143 | raise Exception("Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.") 144 | 145 | self.pad =None 146 | if pads is not None and max(pads) == 1 and max(strides) == 1: 147 | self.conv = keras.layers.Conv2D( 148 | out_channel_num, kernel_size, strides, "SAME", use_bias=False if bias is None else True, 149 | weights=[weights] if bias is None else [weights, bias], 150 | dilation_rate=dilations, groups=group) 151 | else: 152 | self.conv = keras.layers.Conv2D( 153 | out_channel_num, kernel_size, strides, "VALID", use_bias=False if bias is None else True, 154 | weights=[weights] if bias is None else [weights, bias], 155 | dilation_rate=dilations, groups=group) 156 | if pads is not None and max(pads) != 0: 157 | padding = None 158 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0): 159 | padding = (pads[0], pads[1]) 160 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0): 161 | padding = ((pads[0], pads[2]), (pads[1], pads[3])) 162 | self.pad = keras.layers.ZeroPadding2D(padding=padding) 163 | 164 | def conv3d_init(self, in_channel_num, out_channel_num, kernel_size=1, 165 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 166 | raise NotImplementedError("Conv3d is not implemented") 167 | 168 | def __call__(self, inputs): 169 | if self.pad: 170 | inputs = self.pad(inputs) 171 | return self.conv(inputs) 172 | 173 | class TFGroupConv(): 174 | ''' 175 | Group Convolution, using split method to implement, not native. 176 | ''' 177 | def __init__(self, in_channel_num, out_channel_num, kernel_size=1, 178 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 179 | super().__init__() 180 | 181 | if len(weights.shape) == 3: 182 | self.groupconv1d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group) 183 | elif len(weights.shape) == 4: 184 | self.groupconv2d_init(in_channel_num, out_channel_num, kernel_size, strides, dilations, pads, weights, bias, group) 185 | else: 186 | raise NotImplementedError(f"GroupConv{len(weights.shape)-2}d is not implemented") 187 | 188 | def groupconv1d_init(self, in_channel_num, out_channel_num, kernel_size=1, 189 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 190 | self.cin = in_channel_num 191 | self.groups = group 192 | out_channel_num = int(out_channel_num//group) 193 | self.convs = [] 194 | for i in range(group): 195 | if pads is not None and max(pads) == 1 and max(strides) == 1: 196 | self.convs.append(keras.layers.Conv1D( 197 | out_channel_num, kernel_size, strides, 'SAME', use_bias=False if bias is None else True, 198 | dilation_rate=dilations, 199 | weights=[weights[:, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]])) 200 | else: 201 | self.convs.append(keras.layers.Conv1D( 202 | out_channel_num, kernel_size, strides, 'VALID', use_bias=False if bias is None else True, 203 | dilation_rate=dilations, 204 | weights=[weights[:, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]])) 205 | self.pad =None 206 | if pads is not None and (max(pads) != 0 and not (max(pads) == 1 and max(strides) == 1)): 207 | self.pad = keras.layers.ZeroPadding1D(padding=pads) 208 | 209 | def groupconv2d_init(self, in_channel_num, out_channel_num, kernel_size=1, 210 | strides=1, dilations=1, pads=None, weights=None, bias=None, group=1): 211 | if isinstance(dilations, int): 212 | dilations = (dilations, dilations) 213 | if isinstance(strides, int): 214 | strides = (strides, strides) 215 | if dilations[0] != 1 and strides[0] != 1: 216 | raise Exception("Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.") 217 | self.cin = in_channel_num 218 | self.groups = group 219 | out_channel_num = int(out_channel_num//group) 220 | 221 | self.convs = [] 222 | for i in range(group): 223 | if pads is not None and max(pads) == 1 and max(strides) == 1: 224 | self.convs.append(keras.layers.Conv2D( 225 | out_channel_num, kernel_size, strides, 'SAME', use_bias=False if bias is None else True, 226 | dilation_rate=dilations, 227 | weights=[weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]])) 228 | else: 229 | self.convs.append(keras.layers.Conv2D( 230 | out_channel_num, kernel_size, strides, 'VALID', use_bias=False if bias is None else True, 231 | dilation_rate=dilations, 232 | weights=[weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num]] if bias is None else [weights[:, :, :, i*out_channel_num:(i+1)*out_channel_num], bias[i*out_channel_num:(i+1)*out_channel_num]])) 233 | self.pad =None 234 | if pads is not None and (max(pads) != 0 and not (max(pads) == 1 and max(strides) == 1)): 235 | padding = None 236 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0): 237 | padding = (pads[0], pads[1]) 238 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0): 239 | padding = ((pads[0], pads[2]), (pads[1], pads[3])) 240 | self.pad = keras.layers.ZeroPadding2D(padding=padding) 241 | 242 | def __call__(self, inputs): 243 | if self.pad is not None: 244 | inputs = self.pad(inputs) 245 | outs = [] 246 | in_s = tf.split(inputs, num_or_size_splits=self.groups, axis=-1) 247 | for i in range(self.groups): 248 | outs.append(self.convs[i](in_s[i])) 249 | outs = tf.concat(outs, axis=-1) 250 | return outs 251 | 252 | class TFDepthwiseConv(): 253 | # Depthwise Convolution, group = 1 254 | def __init__(self, kernel_size=1, strides=1, dilations=1, pads=None, weights=None, bias=None) -> None: 255 | super().__init__() 256 | if len(weights.shape) == 3: 257 | weights = weights.transpose(0, 2, 1) 258 | self.dwconv1d_init(kernel_size, strides, dilations, pads, weights, bias) 259 | elif len(weights.shape) == 4: 260 | weights = weights.transpose(0, 1, 3, 2) 261 | self.dwconv2d_init(kernel_size, strides, dilations, pads, weights, bias) 262 | else: 263 | raise NotImplementedError(f"DepthwiseConv{len(weights.shape)-2}d is not implemented") 264 | 265 | def dwconv1d_init(self, kernel_size=1, strides=1, dilations=1, pads=None, weights=None, bias=None): 266 | self.pad =None 267 | if pads is not None and max(pads) == 1 and max(strides) == 1: 268 | self.conv = keras.layers.DepthwiseConv1D( 269 | kernel_size, strides, "SAME", use_bias=False if bias is None else True, 270 | weights=[weights] if bias is None else [weights, bias], 271 | dilation_rate=dilations, 272 | activation=None, 273 | kernel_initializer='zeros', 274 | bias_initializer='zeros' 275 | ) 276 | else: 277 | self.conv = keras.layers.DepthwiseConv1D( 278 | kernel_size, strides, "VALID", use_bias=False if bias is None else True, 279 | weights=[weights] if bias is None else [weights, bias], 280 | dilation_rate=dilations, 281 | activation=None, 282 | kernel_initializer='zeros', 283 | bias_initializer='zeros' 284 | ) 285 | if pads is not None and max(pads) != 0: 286 | self.pad = keras.layers.ZeroPadding1D(padding=pads) 287 | 288 | def dwconv2d_init(self, kernel_size=1, strides=1, dilations=1, pads=None, weights=None, bias=None): 289 | if isinstance(dilations, int): 290 | dilations = (dilations, dilations) 291 | if isinstance(strides, int): 292 | strides = (strides, strides) 293 | 294 | self.pad =None 295 | if pads is not None and max(pads) == 1 and max(strides) == 1: 296 | self.conv = keras.layers.DepthwiseConv2D( 297 | kernel_size, strides, "SAME", use_bias=False if bias is None else True, 298 | weights=[weights] if bias is None else [weights, bias], 299 | dilation_rate=dilations, 300 | activation=None, 301 | kernel_initializer='zeros', 302 | bias_initializer='zeros' 303 | ) 304 | else: 305 | self.conv = keras.layers.DepthwiseConv2D( 306 | kernel_size, strides, "VALID", use_bias=False if bias is None else True, 307 | weights=[weights] if bias is None else [weights, bias], 308 | dilation_rate=dilations, 309 | activation=None, 310 | kernel_initializer='zeros', 311 | bias_initializer='zeros' 312 | ) 313 | if pads is not None and max(pads) != 0: 314 | padding = None 315 | if len(pads) == 2 and (pads[0] > 0 or pads[1] > 0): 316 | padding = (pads[0], pads[1]) 317 | elif len(pads) == 4 and (pads[0] > 0 or pads[1] > 0 or pads[2] > 0 or pads[3] > 0): 318 | padding = ((pads[0], pads[2]), (pads[1], pads[3])) 319 | self.pad = keras.layers.ZeroPadding2D(padding=padding) 320 | 321 | def __call__(self, inputs): 322 | if self.pad: 323 | inputs = self.pad(inputs) 324 | return self.conv(inputs) -------------------------------------------------------------------------------- /onnx2tflite/layers/deformation_layers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tensorflow as tf 3 | 4 | from onnx2tflite.utils.definitions import Layout 5 | from onnx2tflite.utils import OPERATOR, dimension_utils 6 | 7 | LOG = logging.getLogger("deformation_layers :") 8 | 9 | @OPERATOR.register_operator("Transpose") 10 | class TFTranspose(): 11 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 12 | super().__init__() 13 | for nop in node_outputs: 14 | layout_dict[nop] = Layout.Channel_First 15 | if kwargs.get("perm_list"): 16 | self.perm_list = kwargs.get("perm_list") 17 | return 18 | self.trans_in = None 19 | self.perm_list = [i for i in node_attribute['perm']] 20 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 21 | # LOG.info("Transpose will process tensor after change back to NCHW format.") 22 | shape_len = len(tensor_grap[node_inputs[0]].shape) 23 | self.trans_in = [0, shape_len-1] + [n for n in range(1, shape_len-1)] 24 | 25 | def __call__(self, inputs): 26 | if self.trans_in: 27 | inputs = tf.transpose(inputs, perm=self.trans_in) 28 | return tf.transpose(inputs, perm=self.perm_list) 29 | 30 | @OPERATOR.register_operator("Slice") 31 | class TFSlice(): 32 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 33 | super().__init__() 34 | if len(node_inputs) == 1: 35 | self.starts = node_attribute['starts'][0] 36 | self.ends = node_attribute['ends'][0] 37 | self.axis = node_attribute['axes'][0] 38 | self.steps = 1 39 | else: 40 | self.starts = node_weights[node_inputs[1]][0] if node_inputs[1] in node_weights else tensor_grap[node_inputs[1]][0] 41 | self.axis = node_weights[node_inputs[3]][0] if node_inputs[3] in node_weights else tensor_grap[node_inputs[3]][0] 42 | self.ends = node_weights[node_inputs[2]][0] if node_inputs[2] in node_weights else tensor_grap[node_inputs[2]][0] 43 | self.ends = min(self.ends, tensor_grap[node_inputs[0]].shape[self.axis]) 44 | if len(node_inputs) < 5: 45 | self.steps = 1 46 | else: 47 | self.steps = node_weights[node_inputs[4]][0] if node_inputs[4] in node_weights else tensor_grap[node_inputs[4]][0] 48 | 49 | shape = tensor_grap[node_inputs[0]].shape.as_list() 50 | if self.starts < 0: 51 | self.starts = shape[self.axis] + self.starts 52 | if self.ends < 0: 53 | self.ends = shape[self.axis] + self.ends 54 | 55 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 56 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 57 | 58 | def __call__(self, inputs): 59 | indices = tf.keras.backend.arange(self.starts, self.ends, step=self.steps) 60 | return tf.gather(inputs, indices, axis=self.axis) 61 | 62 | @OPERATOR.register_operator("Gather") 63 | class TFGather(): 64 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 65 | super().__init__() 66 | self.axis = node_attribute.get('axis', 0) 67 | self.indices = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]] 68 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 69 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 70 | 71 | def __call__(self, inputs): 72 | return tf.gather(inputs, self.indices, axis=self.axis) 73 | 74 | @OPERATOR.register_operator("Concat") 75 | class TFConcat(): 76 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 77 | super().__init__() 78 | #TODO can be optimzer by watch after node, if conv to be channel last. 79 | self._axis = node_attribute['axis'] 80 | # use `count` to count how much more for channel-last to channel-first 81 | count = 0 82 | for inp in node_inputs: 83 | if inp in node_weights: 84 | count -= 1 85 | elif layout_dict[inp] == Layout.Channel_Last: 86 | count += 1 87 | else: 88 | count -= 1 89 | 90 | self._gather = [] 91 | if count < 0: 92 | # align to Channel_First 93 | layout_dict[node_outputs[0]] = Layout.Channel_First 94 | for inp in node_inputs: 95 | if inp in tensor_grap: 96 | if layout_dict[inp] == Layout.Channel_Last: 97 | tensor_grap[inp] = dimension_utils.tensor_NDC_to_NCD_format(tensor_grap[inp]) 98 | self._gather.append(tensor_grap[inp]) 99 | else: 100 | self._gather.append(node_weights[inp]) 101 | else: 102 | # align to Channel_Last 103 | layout_dict[node_outputs[0]] = Layout.Channel_Last 104 | self._axis = dimension_utils.channel_to_last_dimension(self._axis) 105 | for inp in node_inputs: 106 | if inp in tensor_grap: 107 | if layout_dict[inp] != Layout.Channel_Last: 108 | tensor_grap[inp] = dimension_utils.tensor_NCD_to_NDC_format(tensor_grap[inp]) 109 | self._gather.append(tensor_grap[inp]) 110 | else: 111 | self._gather.append(dimension_utils.tensor_NCD_to_NDC_format(node_weights[inp])) 112 | 113 | def __call__(self, *args, **kwargs): 114 | return tf.concat(self._gather, axis=self._axis) 115 | 116 | @OPERATOR.register_operator("Reshape") 117 | class TFReshape(): 118 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 119 | super().__init__() 120 | self.out_shape = node_weights[node_inputs[1]] 121 | self.trans_in = None 122 | # LOG.info("Reshape will process tensor after change back to NCHW format.") 123 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 124 | shape_len = len(tensor_grap[node_inputs[0]].shape) 125 | self.trans_in = [0, shape_len-1] + [n for n in range(1, shape_len-1)] 126 | for nop in node_outputs: 127 | layout_dict[nop] = Layout.Channel_First 128 | 129 | def __call__(self, inputs): 130 | if self.trans_in: 131 | inputs = tf.transpose(inputs, perm=self.trans_in) 132 | inputs = tf.reshape(inputs, shape=self.out_shape) 133 | return inputs 134 | 135 | @OPERATOR.register_operator("Flatten") 136 | class TFFlatten(): 137 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 138 | super().__init__() 139 | num_elements = int(tensor_grap[node_inputs[0]].shape.num_elements()/tensor_grap[node_inputs[0]].shape[0]) 140 | input_shape = tensor_grap[node_inputs[0]].shape 141 | self.flat = tf.keras.layers.Flatten() 142 | ''' 143 | ensure memory order match, for example: 144 | onnx = (B, 2, 3, 4).reshape(B, -1) 145 | tflite = (B, 3, 4, 2).reshape(B, -1) 146 | we can observe that: 147 | onnx.shape == tflite.shape, but np.sum(onnx-tflite) != 0 148 | it's cause memory order of two vars is different, we must make tflite back to onnx by transpose. 149 | generally, this situation is general one, below is just special situation and most appear in cnn. 150 | onnx = (B, 512, 1, 1) 151 | tflite = (B, 1, 1, 512) 152 | or = (B, 1, 512, 1) 153 | these memory order are all same. 154 | ''' 155 | self.perm = None 156 | if layout_dict[node_inputs[0]] == Layout.Channel_Last and num_elements != max(input_shape[1:]): 157 | self.perm = [0, len(input_shape)-1] 158 | for i in range(len(input_shape)-2): 159 | self.perm.append(i+1) 160 | 161 | def __call__(self, inputs): 162 | if self.perm: 163 | inputs = tf.transpose(inputs, perm=self.perm) 164 | return self.flat(inputs) 165 | 166 | @OPERATOR.register_operator("Split") 167 | class TFSplit(): 168 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 169 | super().__init__() 170 | self.outputs_nums = len(node_outputs) 171 | self.axis = node_attribute.get("axis", 0) 172 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 173 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 174 | split_args = None 175 | if 'split' in node_attribute: 176 | split_args = node_attribute['split'] 177 | else: 178 | assert len(node_inputs) == 2 and node_inputs[1] in node_weights 179 | split_args = node_weights[node_inputs[1]] 180 | 181 | self.indices = [] 182 | start, end = 0, 0 183 | for i in range(self.outputs_nums): 184 | end = start + int(split_args[i]) 185 | self.indices.append(tf.keras.backend.arange(start, end, 1)) 186 | start = end 187 | 188 | def __call__(self, inputs): 189 | return [tf.gather(inputs, indices=indice, axis=self.axis) for indice in self.indices] 190 | 191 | @OPERATOR.register_operator("Expand") 192 | class TFExpand(): 193 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 194 | super().__init__() 195 | self.shape = node_weights[node_inputs[1]] 196 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 197 | self.shape = dimension_utils.shape_NCD_to_NDC_format(self.shape) 198 | def __call__(self, inputs): 199 | for i in range(len(self.shape)): 200 | if int(self.shape[i]//inputs.shape[i]) > 1: 201 | inputs = tf.repeat(inputs, repeats=int(self.shape[i]//inputs.shape[i]), axis=i) 202 | elif self.shape[i] < inputs.shape[i] and self.shape[i] != 1: 203 | inputs = tf.repeat(inputs, repeats=int(self.shape[i]), axis=i) 204 | return inputs 205 | 206 | @OPERATOR.register_operator("GatherElements") 207 | class TFGatherElements(): 208 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 209 | super().__init__() 210 | self.axis = node_attribute.get("axis", 1) 211 | self.indices = None 212 | if 'indices' in node_attribute: 213 | self.indices = node_attribute['indices'] 214 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices) 215 | elif node_inputs[1] in node_weights: 216 | self.indices = node_weights[node_inputs[1]] 217 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices) 218 | else: 219 | self.indices = tensor_grap[node_inputs[1]] 220 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 221 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 222 | if len(node_inputs) == 1 or layout_dict[node_inputs[1]] != Layout.Channel_Last: 223 | self.indices = dimension_utils.tensor_NCD_to_NDC_format(self.indices) 224 | 225 | def gather_elements(self, input_tensor, indices, axis): 226 | # Get the shape of the input tensor and the indices tensor 227 | input_shape = tf.shape(input_tensor) 228 | indices_shape = tf.shape(indices) 229 | 230 | # Create indices for all dimensions 231 | idx = tf.meshgrid(*[tf.range(s) for s in indices_shape], indexing='ij') 232 | idx = [tf.cast(i, tf.int64) for i in idx] 233 | 234 | # Replace the axis index with the provided indices 235 | idx[axis] = tf.cast(indices, tf.int64) 236 | 237 | # Stack indices to form the final gather indices 238 | gather_indices = tf.stack(idx, axis=-1) 239 | 240 | # Use tf.gather_nd to gather elements 241 | output_tensor = tf.gather_nd(input_tensor, gather_indices) 242 | 243 | return output_tensor 244 | 245 | def __call__(self, inputs): 246 | return self.gather_elements(inputs, self.indices, self.axis) 247 | 248 | @OPERATOR.register_operator("Tile") 249 | class TFTile(): 250 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 251 | super().__init__() 252 | self.repeats = node_attribute['repeats'] if 'repeats' in node_attribute else node_weights[node_inputs[1]] 253 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 254 | self.repeats = dimension_utils.shape_NCD_to_NDC_format(self.repeats) 255 | 256 | def __call__(self, inputs): 257 | for i in range(len(self.repeats)): 258 | if self.repeats[i] > 1: 259 | inputs = tf.repeat(inputs, self.repeats[i], axis=i) 260 | return inputs 261 | 262 | @OPERATOR.register_operator("Unsqueeze") 263 | class TFUnsqueeze(): 264 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 265 | super().__init__() 266 | self.axis = node_attribute['axes'] if 'axes' in node_attribute else node_weights[node_inputs[1]] 267 | if not isinstance(self.axis, int): 268 | self.axis = int(self.axis[0]) 269 | input_shape = tensor_grap[node_inputs[0]].shape 270 | if len(input_shape) == 1: 271 | layout_dict[node_outputs[0]] = Layout.Channel_None 272 | elif len(input_shape) == 2: 273 | layout_dict[node_outputs[0]] = Layout.Channel_First 274 | else: 275 | layout_dict[node_outputs[0]] = layout_dict[node_inputs[0]] 276 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 277 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 278 | 279 | def __call__(self, inputs): 280 | return tf.expand_dims(inputs, self.axis) 281 | 282 | @OPERATOR.register_operator("Squeeze") 283 | class TFSqueeze(): 284 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 285 | super().__init__() 286 | self.axis = node_attribute['axes'] if 'axes' in node_attribute else node_weights[node_inputs[1]] 287 | if not isinstance(self.axis, int): 288 | self.axis = int(self.axis[0]) 289 | input_shape = tensor_grap[node_inputs[0]].shape 290 | if len(input_shape) <= 3: 291 | layout_dict[node_outputs[0]] = Layout.Channel_None 292 | if len(input_shape) > 2 and layout_dict[node_inputs[0]] == Layout.Channel_Last: 293 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 294 | 295 | def __call__(self, inputs): 296 | return tf.squeeze(inputs, self.axis) 297 | 298 | @OPERATOR.register_operator("DepthToSpace") 299 | class TFDepthToSpace(): 300 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs)->None: 301 | super().__init__() 302 | self.block_size = node_attribute.get("blocksize", 2) 303 | self.mode = node_attribute.get("mode", "DCR") 304 | self.channel_last = layout_dict[node_inputs[0]] == Layout.Channel_Last 305 | 306 | def __call__(self, inputs): 307 | if not self.channel_last: 308 | inputs = dimension_utils.tensor_NDC_to_NCD_format(inputs) 309 | if self.mode == "DCR": 310 | return tf.nn.depth_to_space(inputs, self.block_size) 311 | elif self.mode == "CRD": 312 | # help want, native tensorflow is not support CRD mode, this way will generate 5 dims op. 313 | b, h, w, c = inputs.shape 314 | inputs = tf.reshape(inputs, [b, h, w, c//(self.block_size * self.block_size), self.block_size, self.block_size]) 315 | inputs = tf.transpose(inputs, perm=[0, 1, 4, 2, 5, 3]) 316 | inputs = tf.reshape(inputs, [b, h*self.block_size, w*self.block_size, c//(self.block_size * self.block_size)]) 317 | return inputs 318 | else: 319 | raise KeyError(f"For DepthToSpace, mode must be [DCR, CRD], not {self.mode}") -------------------------------------------------------------------------------- /onnx2tflite/layers/mathematics_layers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from onnx2tflite.utils.definitions import Layout 6 | from onnx2tflite.utils import OPERATOR, dimension_utils, np2tf_type 7 | 8 | LOG = logging.getLogger("calculations_layers :") 9 | 10 | def np2tf(x): 11 | if isinstance(x, np.ndarray): 12 | x = tf.convert_to_tensor(x, dtype=np2tf_type[x.dtype.name]) 13 | return x, False 14 | return x, True 15 | 16 | def match_tensor(x1:tf.Tensor or np.ndarray, x2:tf.Tensor or np.ndarray, x1_layout:Layout, x2_layout:Layout): 17 | 18 | x1, f1 = np2tf(x1) 19 | x2, f2 = np2tf(x2) 20 | 21 | # no need to transpose if all var are tensor, we assume tensor are computed by gragh. 22 | if f1 and f2: 23 | if x1_layout != x2_layout: 24 | if x1_layout == Layout.Channel_Last: 25 | x1 = dimension_utils.tensor_NDC_to_NCD_format(x1) 26 | elif x2_layout == Layout.Channel_Last: 27 | x2 = dimension_utils.tensor_NDC_to_NCD_format(x2) 28 | return x1, x2, Layout.Channel_First 29 | 30 | # ensure tensor is set to x1, const weights set to x2 31 | out_layout = x1_layout 32 | if f2: 33 | x1, x2 = x2, x1 34 | out_layout = x2_layout 35 | 36 | 37 | if out_layout == Layout.Channel_Last: 38 | if x1.shape.ndims != x2.shape.ndims: 39 | while x2.shape.ndims < x1.shape.ndims: 40 | x2 = tf.expand_dims(x2, axis=0) 41 | x2 = dimension_utils.tensor_NCD_to_NDC_format(x2) 42 | 43 | x2 = tf.cast(x2, x1.dtype) 44 | return (x2, x1, out_layout) if f2 else (x1, x2, out_layout) 45 | 46 | ''' 47 | tensor(NDC) + const 48 | tensor(NCD) + const 49 | tensor(NDC) + tensor(NDC) 50 | tensor(NCD) + tensor(NCD) 51 | ''' 52 | 53 | class BaseArithmetic: 54 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 55 | self.left_val, self.right_val = None, None 56 | left_layout, right_layout = Layout.Default, Layout.Default 57 | 58 | if node_inputs[0] in tensor_grap: 59 | self.left_val = tensor_grap[node_inputs[0]] 60 | left_layout = layout_dict[node_inputs[0]] 61 | else: 62 | self.left_val = node_weights[node_inputs[0]] 63 | 64 | if node_inputs[1] in tensor_grap: 65 | self.right_val = tensor_grap[node_inputs[1]] 66 | right_layout = layout_dict[node_inputs[1]] 67 | else: 68 | self.right_val = node_weights[node_inputs[1]] 69 | 70 | if left_layout == right_layout: 71 | return 72 | 73 | self.left_val, self.right_val, out_layout = match_tensor(self.left_val, self.right_val, left_layout, right_layout) 74 | layout_dict[node_outputs[0]] = out_layout 75 | 76 | @OPERATOR.register_operator("Add") 77 | class TFAdd(BaseArithmetic): 78 | def __init__(self, *args, **kwargs): 79 | super().__init__(*args, **kwargs) 80 | 81 | def __call__(self, *args, **kwargs): 82 | return self.left_val + self.right_val 83 | 84 | @OPERATOR.register_operator("Sub") 85 | class TFSub(BaseArithmetic): 86 | def __init__(self, *args, **kwargs): 87 | super().__init__(*args, **kwargs) 88 | 89 | def __call__(self, *args, **kwargs): 90 | return self.left_val - self.right_val 91 | 92 | @OPERATOR.register_operator("Mul") 93 | class TFMul(BaseArithmetic): 94 | def __init__(self,*args, **kwargs): 95 | super().__init__(*args, **kwargs) 96 | 97 | def __call__(self, *args, **kwargs): 98 | return self.left_val * self.right_val 99 | 100 | @OPERATOR.register_operator("Div") 101 | class TFDiv(BaseArithmetic): 102 | def __init__(self,*args, **kwargs): 103 | super().__init__(*args, **kwargs) 104 | 105 | def __call__(self, *args, **kwargs): 106 | return self.left_val / self.right_val 107 | 108 | @OPERATOR.register_operator("MatMul") 109 | class TFMatMul(): 110 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 111 | super().__init__() 112 | if node_inputs[0] in tensor_grap: 113 | self.A = tensor_grap[node_inputs[0]] 114 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 115 | self.A = dimension_utils.tensor_NDC_to_NCD_format(self.A) 116 | else: 117 | self.A = node_weights[node_inputs[0]] 118 | 119 | if node_inputs[1] in tensor_grap: 120 | self.B = tensor_grap[node_inputs[1]] 121 | if layout_dict[node_inputs[1]] == Layout.Channel_Last: 122 | self.B = dimension_utils.tensor_NDC_to_NCD_format(self.B) 123 | else: 124 | self.B = node_weights[node_inputs[1]] 125 | 126 | self.dense = tf.keras.layers.Dense(self.B.shape[-1], 127 | weights=[self.B], 128 | use_bias=False) 129 | 130 | layout_dict[node_outputs[0]] = Layout.Channel_First 131 | 132 | def __call__(self, *args, **kwargs): 133 | # out = tf.matmul(self.A, self.B) 134 | try: 135 | out = self.dense(self.A) 136 | except Exception: 137 | out = tf.matmul(self.A, self.B) 138 | return out 139 | 140 | @OPERATOR.register_operator("Mod") 141 | class TFMod(): 142 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs): 143 | super().__init__() 144 | self.fmod = bool(node_attribute.get("fmod", 0)) 145 | self.mod_value = None 146 | if node_inputs[1] in node_weights: 147 | self.mod_value = node_weights[node_inputs[1]] 148 | else: 149 | self.mod_value = tensor_grap[node_inputs[1]] 150 | 151 | def __call__(self, inputs): 152 | if self.fmod: 153 | return tf.math.floormod(inputs, tf.cast(self.mod_value, inputs.dtype)) 154 | else: 155 | return tf.math.mod(inputs, tf.cast(self.mod_value, inputs.dtype)) 156 | 157 | @OPERATOR.register_operator("Pow") 158 | class TFPow(): 159 | def __init__(self, tensor_grap, node_weights, node_inputs, *args, **kwargs): 160 | super().__init__() 161 | self.power_index = node_weights[node_inputs[1]] 162 | 163 | def __call__(self, inputs, *args, **kwargs): 164 | return tf.pow(inputs, self.power_index) 165 | 166 | @OPERATOR.register_operator("Reciprocal") 167 | class TFReciprocal(): 168 | def __init__(self, *args, **kwargs): 169 | super().__init__() 170 | 171 | def __call__(self, inputs, *args, **kwargs): 172 | return 1/inputs 173 | 174 | @OPERATOR.register_operator("Sqrt") 175 | class TFSqrt(): 176 | def __init__(self, *args, **kwargs): 177 | super().__init__() 178 | 179 | def __call__(self, inputs, *args, **kwargs): 180 | return tf.sqrt(inputs) 181 | 182 | @OPERATOR.register_operator("Exp") 183 | class TFSqrt(): 184 | def __init__(self, *args, **kwargs): 185 | super().__init__() 186 | 187 | def __call__(self, inputs, *args, **kwargs): 188 | return tf.exp(inputs) 189 | 190 | @OPERATOR.register_operator("Log") 191 | class TFLog(): 192 | def __init__(self, *args, **kwargs): 193 | super().__init__() 194 | 195 | def __call__(self, inputs, *args, **kwargs): 196 | return tf.log(inputs) 197 | 198 | class ReduceBase: 199 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 200 | self.keep_dims = node_attribute.get("keepdims", 1) == 1 201 | input_shape_len = len(tensor_grap[node_inputs[0]].shape) 202 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 203 | self.axes = [dimension_utils.channel_to_last_dimension(i) if i >=0 else dimension_utils.channel_to_last_dimension(input_shape_len + i) for i in node_attribute.get("axes", [-1])] 204 | else: 205 | self.axes = [i if i >=0 else input_shape_len + i for i in node_attribute.get("axes", [-1])] 206 | 207 | @OPERATOR.register_operator("ReduceSum") 208 | class TFReduceSum(ReduceBase): 209 | def __init__(self, *args, **kwargs): 210 | super().__init__(*args, **kwargs) 211 | 212 | def __call__(self, inputs, *args, **kwargs): 213 | return tf.math.reduce_sum(inputs, axis=self.axes, keepdims=self.keep_dims) 214 | 215 | @OPERATOR.register_operator("ReduceMean") 216 | class TFReduceMean(ReduceBase): 217 | def __init__(self, *args, **kwargs): 218 | super().__init__(*args, **kwargs) 219 | 220 | def __call__(self, inputs, *args, **kwargs): 221 | return tf.math.reduce_mean(inputs, axis=self.axes, keepdims=self.keep_dims) 222 | 223 | @OPERATOR.register_operator("ReduceMax") 224 | class TFReduceMax(ReduceBase): 225 | def __init__(self, *args, **kwargs): 226 | super().__init__(*args, **kwargs) 227 | 228 | def __call__(self, inputs, *args, **kwargs): 229 | return tf.math.reduce_max(inputs, axis=self.axes, keepdims=self.keep_dims) 230 | 231 | @OPERATOR.register_operator("ReduceMin") 232 | class TFReduceMin(ReduceBase): 233 | def __init__(self, *args, **kwargs): 234 | super().__init__(*args, **kwargs) 235 | 236 | def __call__(self, inputs, *args, **kwargs): 237 | return tf.math.reduce_min(inputs, axis=self.axes, keepdims=self.keep_dims) 238 | 239 | @OPERATOR.register_operator("ArgMax") 240 | class TFArgMax(): 241 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 242 | super().__init__() 243 | self.axis = node_attribute.get('axis', 0) 244 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 245 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 246 | self.keepdims = node_attribute.get("keepdims", 1) == 1 247 | 248 | def __call__(self, inputs, *args, **kwargs): 249 | _inputs = tf.argmax(inputs, axis=self.axis) 250 | if self.keepdims: 251 | _inputs = tf.expand_dims(_inputs, axis=self.axis) 252 | return _inputs 253 | 254 | @OPERATOR.register_operator("ArgMin") 255 | class TFArgMin(): 256 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs): 257 | super().__init__() 258 | self.axis = node_attribute.get('axis', 0) 259 | if layout_dict[node_inputs[0]] == Layout.Channel_Last: 260 | self.axis = dimension_utils.channel_to_last_dimension(self.axis) 261 | self.keepdims = node_attribute.get("keepdims", 1) == 1 262 | 263 | def __call__(self, inputs, *args, **kwargs): 264 | _inputs = tf.argmax(inputs, axis=self.axis) 265 | if self.keepdims: 266 | _inputs = tf.expand_dims(_inputs, axis=self.axis) 267 | return _inputs 268 | 269 | @OPERATOR.register_operator("Erf") 270 | class TFErf(): 271 | def __init__(self, *args, **kwargs) -> None: 272 | pass 273 | 274 | def __call__(self, inputs): 275 | inputs = tf.math.erf(inputs) 276 | return inputs -------------------------------------------------------------------------------- /onnx2tflite/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dimension_utils import * 2 | from .op_registry import OPERATOR 3 | from .definitions import * -------------------------------------------------------------------------------- /onnx2tflite/utils/definitions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from abc import ABC 3 | from enum import Enum, unique 4 | 5 | @unique 6 | class Layout(Enum): 7 | Default = 0 8 | Channel_First = 1 << 0# for onnx format 9 | Channel_Last = 1 << 1 # for tensorflow format 10 | Channel_None = 1 << 2 # no channel 11 | 12 | class Node_Layout: 13 | def __init__(self, name:str, pre:list=[], nxt:list=[]) -> None: 14 | self.name = name 15 | self.pre = pre 16 | self.nxt = nxt 17 | self.layout = Layout.Default 18 | 19 | class BaseOP(ABC): 20 | def __init__(self, tensor_graph, const_weights, node_attributes, node_inputs, node_outputs, layout_dict) -> None: 21 | pass 22 | 23 | onnx2tf_type = { 24 | 1: tf.float32, # ONNX_FLOAT 25 | 2: tf.uint8, # ONNX_UINT8 26 | 3: tf.int8, # ONNX_INT8 27 | 4: tf.uint16, # ONNX_UINT16 28 | 5: tf.int16, # ONNX_INT16 29 | 6: tf.int32, # ONNX_INT32 30 | 7: tf.int64, # ONNX_INT64 31 | 8: tf.string, # ONNX_STRING 32 | 9: tf.bool, # ONNX_BOOL 33 | 10: tf.float16, # ONNX_FLOAT16 34 | 11: tf.float64, # ONNX_DOUBLE 35 | 12: tf.uint32, # ONNX_UINT32 36 | 13: tf.uint64, # ONNX_UINT64 37 | 14: tf.complex64, # ONNX_COMPLEX64 38 | 15: tf.complex128 # ONNX_COMPLEX128 39 | } 40 | 41 | np2tf_type = { 42 | "int32": tf.int32, 43 | "int64": tf.int64, 44 | "float32": tf.float32, 45 | "float64": tf.float64, 46 | "bool": tf.bool, 47 | "uint8": tf.uint8, 48 | "int8": tf.int8, 49 | "int16": tf.int16, 50 | "uint16": tf.uint16, 51 | "uint32": tf.uint32, 52 | "uint64": tf.uint64, 53 | "complex64": tf.complex64, 54 | "complex128": tf.complex128 55 | } 56 | 57 | FORCE_CHANNEL_LAST_OP = ["Conv", "ConvTranspose", "DepthToSpace", "Pad", "AveragePool", "MaxPool", "Upsample", "Resize", "Gemm"] 58 | FORCE_CHANNEL_FIRST_OP = ["Reshape", "Transpose", "ScatterND", "MatMul"] 59 | 60 | -------------------------------------------------------------------------------- /onnx2tflite/utils/dimension_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | ''' 3 | shape and axis transform utils func. 4 | ''' 5 | def channel_to_last_dimension(axis): 6 | ''' 7 | make channel first to channel last 8 | ''' 9 | if axis == 0: 10 | axis = 0 11 | elif axis == 1: 12 | axis = -1 13 | else: 14 | axis -= 1 15 | return axis 16 | 17 | def shape_NCD_to_NDC_format(shape): 18 | ''' 19 | make shape format from channel first to channel last 20 | ''' 21 | if len(shape) <= 2: 22 | return tuple(shape) 23 | new_shape = [shape[0], *shape[2:], shape[1]] 24 | return tuple(new_shape) 25 | 26 | def shape_NDC_to_NCD_format(shape): 27 | ''' 28 | make shape format from channel last to channel first 29 | ''' 30 | if len(shape) <= 2: 31 | return tuple(shape) 32 | new_shape = [shape[0], shape[-1], *shape[1:-1]] 33 | return tuple(new_shape) 34 | 35 | def tensor_NCD_to_NDC_format(tensor): 36 | ''' 37 | make tensor format from channel first to channel last 38 | ''' 39 | if(len(tensor.shape) > 2): 40 | shape = [i for i in range(len(tensor.shape))] 41 | shape = shape_NCD_to_NDC_format(shape) 42 | tensor = tf.transpose(tensor, perm=shape) 43 | return tensor 44 | 45 | def tensor_NDC_to_NCD_format(tensor): 46 | ''' 47 | make tensor format from channel last to channel first 48 | ''' 49 | if(len(tensor.shape) > 2): 50 | shape = [i for i in range(len(tensor.shape))] 51 | shape = shape_NDC_to_NCD_format(shape) 52 | tensor = tf.transpose(tensor, perm=shape) 53 | return tensor 54 | 55 | def intfloat_to_list(x:int or float, lens:int): 56 | if isinstance(x, (int, float)): 57 | return [x]*lens 58 | else: 59 | return x -------------------------------------------------------------------------------- /onnx2tflite/utils/graph_tools.py: -------------------------------------------------------------------------------- 1 | from onnx import numpy_helper 2 | import tensorflow as tf 3 | from tensorflow import keras 4 | from .definitions import * 5 | 6 | # copy from https://github.com/gmalivenko/onnx2keras 7 | def decode_node_attribute(node)->dict: 8 | """ 9 | Parse ONNX attributes to Python dictionary 10 | :param args: ONNX attributes object 11 | :return: Python dictionary 12 | """ 13 | def onnx_attribute_to_dict(onnx_attr): 14 | """ 15 | Parse ONNX attribute 16 | :param onnx_attr: ONNX attribute 17 | :return: Python data type 18 | """ 19 | if onnx_attr.HasField('t'): 20 | return numpy_helper.to_array(getattr(onnx_attr, 't')) 21 | 22 | for attr_type in ['f', 'i']: 23 | if onnx_attr.HasField(attr_type): 24 | return getattr(onnx_attr, attr_type) 25 | 26 | # s need to be decode, bytes to string 27 | if onnx_attr.HasField('s'): 28 | return getattr(onnx_attr, 's').decode() 29 | 30 | for attr_type in ['floats', 'ints', 'strings']: 31 | if getattr(onnx_attr, attr_type): 32 | return list(getattr(onnx_attr, attr_type)) 33 | return {arg.name: onnx_attribute_to_dict(arg) for arg in node.attribute} 34 | 35 | def build_tf_inputs(model_graph, layout_dict:dict): 36 | inputs_name = [] 37 | for inp in model_graph.input: 38 | input_shape = [x.dim_value for x in inp.type.tensor_type.shape.dim] 39 | if input_shape == []: 40 | continue 41 | inputs_name.append(inp.name) 42 | layout_dict[inp.name] = Layout.Default 43 | if len(input_shape) < 3: 44 | layout_dict[inp.name] = Layout.Channel_None 45 | 46 | _inputs_name = inputs_name.copy() 47 | for node in model_graph.node: 48 | op_name, node_inputs = node.op_type, node.input 49 | # output_layout = Layout.Default 50 | for ninp in node_inputs: 51 | if ninp in _inputs_name and op_name in FORCE_CHANNEL_LAST_OP and layout_dict[ninp] == Layout.Default: 52 | layout_dict[ninp] = Layout.Channel_Last 53 | _inputs_name.remove(ninp) 54 | if ninp in _inputs_name and op_name in FORCE_CHANNEL_FIRST_OP and layout_dict[ninp] == Layout.Default: 55 | layout_dict[ninp] = Layout.Channel_First 56 | _inputs_name.remove(ninp) 57 | # output_layout = output_layout | node_dict[ninp] 58 | 59 | if len(_inputs_name) == 0: 60 | break 61 | 62 | input_nodes = {} 63 | for inp in model_graph.input: 64 | input_shape = [x.dim_value for x in inp.type.tensor_type.shape.dim] 65 | if input_shape == []: 66 | continue 67 | batch_size = 1 if input_shape[0] <= 0 else input_shape[0] 68 | input_shape = input_shape[1:] 69 | if layout_dict[inp.name] == Layout.Channel_Last: 70 | input_shape = input_shape[1:] + input_shape[0:1] 71 | 72 | input_nodes[inp.name] = keras.Input(shape=input_shape, batch_size=batch_size, dtype=onnx2tf_type.get(inp.type.tensor_type.elem_type)) 73 | 74 | return input_nodes 75 | -------------------------------------------------------------------------------- /onnx2tflite/utils/op_registry.py: -------------------------------------------------------------------------------- 1 | class Registry(object): 2 | def __init__(self, name) -> None: 3 | self._name = name 4 | self._operator_dict = dict() 5 | 6 | def __len__(self): 7 | return len(self._operator_dict) 8 | 9 | @property 10 | def name(self): 11 | return self._name 12 | 13 | @property 14 | def operator_dict(self): 15 | return self._operator_dict 16 | 17 | def get(self, key): 18 | return self._operator_dict.get(key, None) 19 | 20 | def _register_operator(self, op_class, op_name=None): 21 | if (not isinstance(op_name, str)) or op_name is None: 22 | op_name = op_class.__name__ 23 | 24 | if self._operator_dict.get(op_name, None): 25 | raise KeyError(f'{op_name} is already registered in {self._name}') 26 | 27 | self._operator_dict[op_name] = op_class 28 | 29 | def register_operator(self, name=None, op_class=None): 30 | if op_class is not None: 31 | self._register_operator(op_class, name) 32 | return op_class 33 | 34 | def _register(cls): 35 | self._register_operator(cls, name) 36 | return cls 37 | 38 | return _register 39 | 40 | OPERATOR = Registry("TensorflowOP") -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ONNX->Keras and ONNX->TFLite tools 2 | ## Welcome 3 | If you have some good ideas, welcome to discuss or give project PRs. 4 | 5 | ## Install 6 | ```cmd 7 | git clone https://github.com/MPolaris/onnx2tflite.git 8 | cd onnx2tflite 9 | python setup.py install 10 | ``` 11 | ```python 12 | from onnx2tflite import onnx_converter 13 | res = onnx_converter( 14 | onnx_model_path = "./model.onnx", 15 | need_simplify = True, 16 | output_path = "./models/", 17 | target_formats = ['tflite'], 18 | ) 19 | ``` 20 | --- 21 | ```cmd 22 | # base 23 | python -m onnx2tflite --weights "./your_model.onnx" 24 | 25 | # give save path 26 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" 27 | 28 | # save tflite model 29 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" 30 | 31 | # save keras and tflite model 32 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" "keras" 33 | 34 | # cutoff model, redefine inputs and outputs, support middle layers 35 | python -m onnx2tflite --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" --input-node-names "layer_inputname" --output-node-names "layer_outname1" "layer_outname2" 36 | 37 | # quantify model weight, only weight 38 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --weigthquant 39 | 40 | # quantify model weight, include input and output 41 | ## fp16 42 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --fp16 43 | ## recommend 44 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 255 255 255 45 | ## generate random data, instead of read from image file 46 | python -m onnx2tflite --weights "./your_model.onnx" --formats "tflite" --int8 47 | ``` 48 | --- 49 | ## Features 50 | - High Consistency. Compare to ONNX outputs, average error less than 1e-5 per elements. 51 | - More Faster. Output tensorflow-lite model 30% faster than [onnx_tf](https://github.com/onnx/onnx-tensorflow). 52 | - Auto Channel Align. Auto convert pytorch format(NCWH) to tensorflow format(NWHC). 53 | - Deployment Support. Support output quantitative model, include fp16 quantization and uint8 quantization. 54 | - Code Friendly. I've been trying to keep the code structure simple and clear. 55 | --- 56 | 57 | ## Pytorch -> ONNX -> Tensorflow-Keras -> Tensorflow-Lite 58 | 59 | - ### From torchvision to tensorflow-lite 60 | ```python 61 | import torch 62 | import torchvision 63 | _input = torch.randn(1, 3, 224, 224) 64 | model = torchvision.models.mobilenet_v2(True) 65 | # use default settings is ok 66 | torch.onnx.export(model, _input, './mobilenetV2.onnx', opset_version=11)# or opset_version=13 67 | 68 | from converter import onnx_converter 69 | onnx_converter( 70 | onnx_model_path = "./mobilenetV2.onnx", 71 | need_simplify = True, 72 | output_path = "./", 73 | target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite'] 74 | weight_quant = False, 75 | fp16_model=False, 76 | int8_model = False, 77 | int8_mean = None, 78 | int8_std = None, 79 | image_root = None 80 | ) 81 | ``` 82 | - ### From custom pytorch model to tensorflow-lite-int8 83 | ```python 84 | import torch 85 | import torch.nn as nn 86 | import torch.nn.functional as F 87 | 88 | class MyModel(nn.Module): 89 | def __init__(self): 90 | self.conv = nn.Sequential( 91 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 92 | nn.BatchNorm2d(64), 93 | nn.ReLU(inplace=True), 94 | ) 95 | 96 | def forward(self, x): 97 | return self.conv(x) 98 | 99 | model = MyModel() 100 | model.load_state_dict(torch.load("model_checkpoint.pth", map_location="cpu")) 101 | 102 | _input = torch.randn(1, 3, 224, 224) 103 | torch.onnx.export(model, _input, './mymodel.onnx', opset_version=11)# or opset_version=13 104 | 105 | from converter import onnx_converter 106 | onnx_converter( 107 | onnx_model_path = "./mymodel.onnx", 108 | need_simplify = True, 109 | output_path = "./", 110 | target_formats = ['tflite'], #or ['keras'], ['keras', 'tflite'] 111 | weight_quant = False, 112 | int8_model = True, # do quantification 113 | int8_mean = [123.675, 116.28, 103.53], # give mean of image preprocessing 114 | int8_std = [58.395, 57.12, 57.375], # give std of image preprocessing 115 | image_root = "./dataset/train" # give image folder of train 116 | ) 117 | ``` 118 | --- 119 | ## Validated models 120 | - [SSD](https://github.com/qfgaohao/pytorch-ssd) 121 | - [HRNet](HRNet-Facial-Landmark-Detection) 122 | - [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) 123 | - [YOLOV3](https://github.com/ultralytics/yolov3) 124 | - [YOLOV4](https://github.com/Tianxiaomo/pytorch-YOLOv4) 125 | - [YOLOV5](https://github.com/ultralytics/yolov5) 126 | - [YOLOV6](https://github.com/meituan/YOLOv6) 127 | - [YOLOV7](https://github.com/WongKinYiu/yolov7) 128 | - [YOLOV10](https://github.com/THU-MIG/yolov10) 129 | - [MoveNet](https://github.com/fire717/movenet.pytorch) 130 | - [UNet\FPN](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets) 131 | - ViT(torchvision) 132 | - [SwinTransformerV1](https://github.com/microsoft/Swin-Transformer) 133 | - MLP(custom) 134 | - DCGAN(custom) 135 | - [AutoEncoder/VAE](https://github.com/AntixK/PyTorch-VAE) 136 | - all torchvision classification models 137 | - some segmation models in torchvision 138 | - 1D or 2D CNN without special operators(custom) 139 | --- 140 | ## Add operator by yourself 141 | When you counter unspported operator, you can choose to add it by yourself or make an issue.
142 | It's very simple to implement a new operator parser by following these steps below.
143 | Step 0: Select a corresponding layer code file in [layers folder](./onnx2tflite/layers/), such as activations_layers.py for 'HardSigmoid'.
144 | Step 1: Open it, and edit it: 145 | ```python 146 | # all operators regist through OPERATOR register. 147 | # regist operator's name is onnx operator name. 148 | @OPERATOR.register_operator("HardSigmoid") 149 | class TFHardSigmoid(): 150 | def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, node_outputs, layout_dict, *args, **kwargs) -> None: 151 | ''' 152 | :param tensor_grap: dict, key is node name, value is tensorflow-keras node output tensor. 153 | :param node_weights: dict, key is node name, value is static data, such as weight/bias/constant, weight should be transfom by dimension_utils.tensor_NCD_to_NDC_format at most time. 154 | :param node_inputs: List[str], stored node input names, indicates which nodes the input comes from, tensor_grap and node_weights are possible. 155 | :param node_attribute: dict, key is attribute name, such as 'axis' or 'perm'. value type is indeterminate, such as List[int] or int or float. notice that type of 'axis' value should be adjusted form NCHW to NHWC by dimension_utils.channel_to_last_dimension or dimension_utils.shape_NCD_to_NDC_format. 156 | :param node_inputs: List[str], stored node output names. 157 | :param layout_dict: List[Layout], stored all before node's layout. 158 | ''' 159 | super().__init__() 160 | self.alpha = node_attribute.get("alpha", 0.2) 161 | self.beta = node_attribute.get("beta", 0.5) 162 | 163 | def __call__(self, inputs): 164 | return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1) 165 | ``` 166 | Step 2: Make it work without error.
167 | Step 3: Convert model to tflite without any quantification.
168 | 169 | --- 170 | 171 | # License 172 | This software is covered by Apache-2.0 license. 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | onnx 2 | onnxruntime 3 | onnx-simplifier 4 | numpy <= 1.24 5 | tensorflow>=2.5,<2.13 6 | opencv-python -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | abs_path = os.path.dirname(os.path.abspath(__file__)) 4 | 5 | setup( 6 | name="onnx2tflite", 7 | version="2.0", 8 | author="MPolaris", 9 | description="onnx to keras/tensorflow lite", 10 | long_description=open(os.path.join(abs_path, "readme.md")).read(), 11 | long_description_content_type='text/markdown', 12 | packages=find_packages(include=['onnx2tflite']), 13 | license="Apache-2.0", 14 | platforms=["Windows", "linux"], 15 | install_requires=open(os.path.join(abs_path, "requirements.txt")).read().splitlines() 16 | ) -------------------------------------------------------------------------------- /test/test_concat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import torch 5 | import torch.nn as nn 6 | from onnx2tflite import onnx_converter 7 | 8 | MODEL_ROOT = "./unit_test" 9 | os.makedirs(MODEL_ROOT, exist_ok=True) 10 | 11 | @pytest.mark.filterwarnings('ignore::UserWarning') 12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 13 | def test_concat(): 14 | class Concat(nn.Module): 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) 17 | self.conv1 = nn.Conv2d(3, 3, 3, 2, 1) 18 | # self.conv2 = nn.Conv2d(3, 3, 3, 2, 1) 19 | self._const = torch.randn(1,2,16,8) 20 | 21 | def forward(self, x1, x2, x3): 22 | x1 = torch.reshape(x1, (1, 3, 16, 8)) 23 | # x = torch.transpose(x, (0, 1, 3, 2)) 24 | x2 = torch.transpose(x2, 3, 2) 25 | x3 = self.conv1(x3) 26 | x = torch.concat([x1,x2,x3,self._const], dim=1) 27 | return x 28 | 29 | model = Concat() 30 | x1 = torch.randn(1,3*16*8) 31 | x2 = torch.randn(1,3,8,16) 32 | x3 = torch.randn(1,3,32,16) 33 | 34 | onnx_model_path = os.path.join(MODEL_ROOT, "test_concat.onnx") 35 | torch.onnx.export(model, (x1,x2,x3), onnx_model_path, opset_version=11) 36 | 37 | res = onnx_converter( 38 | onnx_model_path = onnx_model_path, 39 | need_simplify = True, 40 | output_path = MODEL_ROOT, 41 | target_formats = ['tflite'], 42 | native_groupconv=False, 43 | fp16_model=False, 44 | int8_model=False, 45 | ) 46 | 47 | assert res['tflite_error'] < 1e-3 -------------------------------------------------------------------------------- /test/test_reshape_transpose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import torch 5 | import torch.nn as nn 6 | from onnx2tflite import onnx_converter 7 | 8 | MODEL_ROOT = "./unit_test" 9 | os.makedirs(MODEL_ROOT, exist_ok=True) 10 | 11 | @pytest.mark.filterwarnings('ignore::UserWarning') 12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 13 | def test_reshape_trans(): 14 | class test1(nn.Module): 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) 17 | self.conv1 = nn.Conv2d(3, 3, 3, 2, 1) 18 | self.conv2 = nn.Conv2d(3, 3, 3, 2, 1) 19 | 20 | def forward(self, x): 21 | x = torch.reshape(x, (1, 3, 32, 16)) 22 | # x = torch.transpose(x, (0, 1, 3, 2)) 23 | x = torch.transpose(x, 3, 2) 24 | x = self.conv1(x) 25 | x = self.conv2(x) 26 | return x 27 | 28 | model = test1() 29 | x = torch.randn(1, 3*32*16) 30 | 31 | onnx_model_path = os.path.join(MODEL_ROOT, "test_reshape_trans.onnx") 32 | torch.onnx.export(model, x, onnx_model_path, opset_version=11) 33 | 34 | res = onnx_converter( 35 | onnx_model_path = onnx_model_path, 36 | need_simplify = True, 37 | output_path = MODEL_ROOT, 38 | target_formats = ['tflite'], 39 | native_groupconv=False, 40 | fp16_model=False, 41 | int8_model = False, 42 | ) 43 | 44 | assert res['tflite_error'] < 1e-3 -------------------------------------------------------------------------------- /test/test_squeeze.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import torch 5 | import torch.nn as nn 6 | from onnx2tflite import onnx_converter 7 | 8 | MODEL_ROOT = "./unit_test" 9 | os.makedirs(MODEL_ROOT, exist_ok=True) 10 | 11 | @pytest.mark.filterwarnings('ignore::UserWarning') 12 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 13 | def test_squeeze(): 14 | class Squeeze(nn.Module): 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) 17 | 18 | def forward(self, x): 19 | x = torch.unsqueeze(x, dim=1) 20 | # x = torch.tile(x, dims=(2,1,1)) 21 | x = torch.squeeze(x, dim=1) 22 | 23 | return x 24 | 25 | model = Squeeze() 26 | x = torch.randn(1,1,1,2) 27 | 28 | onnx_model_path = os.path.join(MODEL_ROOT, "test_squeeze.onnx") 29 | torch.onnx.export(model, x, onnx_model_path, opset_version=11) 30 | 31 | res = onnx_converter( 32 | onnx_model_path = onnx_model_path, 33 | need_simplify = True, 34 | output_path = MODEL_ROOT, 35 | target_formats = ['tflite'], 36 | native_groupconv=False, 37 | fp16_model=False, 38 | int8_model=False, 39 | ) 40 | 41 | assert res['tflite_error'] < 1e-3 -------------------------------------------------------------------------------- /test/test_torchvison.py: -------------------------------------------------------------------------------- 1 | ''' 2 | unit test for torchvision models 3 | ''' 4 | import os 5 | import pytest 6 | 7 | import torch 8 | import torchvision 9 | from onnx2tflite import onnx_converter 10 | 11 | MODEL_ROOT = "./unit_test" 12 | os.makedirs(MODEL_ROOT, exist_ok=True) 13 | 14 | @pytest.mark.filterwarnings('ignore::UserWarning') 15 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 16 | def test_resnet(): 17 | model = torchvision.models.resnet18(False) 18 | onnx_model_path = os.path.join(MODEL_ROOT, "resnet18.onnx") 19 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13) 20 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error'] 21 | assert error < 1e-3 22 | 23 | @pytest.mark.filterwarnings('ignore::UserWarning') 24 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 25 | def test_mobilenet(): 26 | model = torchvision.models.mobilenet_v2(False) 27 | onnx_model_path = os.path.join(MODEL_ROOT, "mobilenet_v2.onnx") 28 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13) 29 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error'] 30 | assert error < 1e-3 31 | 32 | @pytest.mark.filterwarnings('ignore::UserWarning') 33 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 34 | def test_deeplabv3(): 35 | model = torchvision.models.segmentation.deeplabv3_resnet50(False) 36 | onnx_model_path = os.path.join(MODEL_ROOT, "deeplabv3_resnet50.onnx") 37 | torch.onnx.export(model, torch.randn(1, 3, 512, 1024), onnx_model_path, opset_version=13) 38 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error'] 39 | assert error < 1e-3 40 | 41 | @pytest.mark.filterwarnings('ignore::UserWarning') 42 | @pytest.mark.filterwarnings('ignore::DeprecationWarning') 43 | def test_vit(): 44 | model = torchvision.models.vit_b_16(False) 45 | onnx_model_path = os.path.join(MODEL_ROOT, "vit_b_16.onnx") 46 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_model_path, opset_version=13) 47 | error = onnx_converter(onnx_model_path, need_simplify = True, output_path = MODEL_ROOT, target_formats = ['tflite'])['tflite_error'] 48 | assert error < 1e-3 --------------------------------------------------------------------------------