├── .gitignore ├── README.md ├── convert_models.md ├── onnx2pytorch.py ├── onnx2pytorch ├── __init__.py ├── convert │ ├── __init__.py │ ├── attribute.py │ ├── debug.py │ ├── layer.py │ ├── model.py │ └── operations.py ├── operations │ ├── __init__.py │ ├── add.py │ ├── cast.py │ ├── clamp.py │ ├── concat.py │ ├── constant.py │ ├── flatten.py │ ├── gather.py │ ├── matmul.py │ ├── mul.py │ ├── pad.py │ ├── pooling.py │ ├── reshape.py │ ├── resize.py │ ├── shape.py │ ├── slice.py │ ├── split.py │ ├── squeeze.py │ └── where.py └── utils.py ├── onnxsim ├── __init__.py ├── __main__.py └── onnx_simplifier.py └── tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.onnx 3 | *.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX2X 2 | ONNX作为微软的网络模型中间表示被各个框架广泛应用,包括Pytroch,TensorFlow,OneFlow,Keras,Paddle等多种深度学习训练框架。因此,一直在思考一个问题,一个TensorFlow导出来的ONNX模型是否可以借助ONNX被Pytorch框架使用呢?ONNX的理想是作为所有框架的模型的中间交换,那么我们只需要再实现ONNX到各个框架的逆转就可以完成这件事情了。本工程的目的即是尝试支持ONNX转换到各种训练框架,主要为了锻炼算子对齐和更深入的了解ONNX。 3 | 4 | # 代码结构 5 | 6 | ```markdown 7 | - onnx2pytorch onnx转pytorch代码实现 8 | - onnx2pytorch.py onnx转pytorch测试代码 9 | - convert_models.md 转换ONNX Model Zoo里面的模型对应的命令和结果记录 10 | - README.md 11 | ``` 12 | 13 | # 运行环境 14 | 15 | - pytorch >= 1.1.0 16 | - onnx>=1.8.1 17 | - >=1.6.0 18 | - onnxoptimizer>=0.2.3 19 | 20 | # 使用方法 21 | 22 | 使用下面的命令将各个训练框架导出的ONNX模型转换成Pytorch模型 23 | 24 | ```sh 25 | python .\onnx2pytorch.py ... 26 | ``` 27 | 28 | 参数列表如下: 29 | 30 | - `--onnx_path` 字符串,必选参数,代表onnx模型的路径 31 | - `--pytorch_path` 字符串,必选参数,代表转换出的Pytorch模型保存路径 32 | - `--simplify_path` 字符串,可选参数,代表ONNX模型简化(例如删除Dropout和常量OP)后保存的ONNX模型路径 33 | - `--input_shape` 字符串,必选参数,代表ONNX模型的输入数据层的名字和维度信息 34 | 35 | # 使用示例 36 | 37 | ```sh 38 | python .\onnx2pytorch.py --onnx_path .\models\mobilenetv2-7.onnx --simplify_path .\models\mobilenetv2-7-simplify.onnx --pytorch_path .\models\mobilenetv2-7.pth --input_shape input:1,3,224,224 39 | ``` 40 | 41 | # 模型转换失败处理方法 42 | 43 | - 将`onnx2pytorch.py`里面的`model = convert.ConvertModel(onnx_model, debug=False)`这行代码里面的`debug`设置False重新运行模型即可定位到转换失败的OP,然后你可以在工程提出issue或者自己解决然后给本工程PR。 44 | 45 | # ONNX2Pytorch 46 | 47 | ## 已支持的ONNX OP 48 | 49 | - [x] Conv 50 | - [x] BatchNormalization 51 | - [x] GlobalAvgragePool 52 | - [x] AvgPool 53 | - [x] MaxPool 54 | - [x] BatchNorm 55 | - [x] Flatten 56 | - [x] Reshape 57 | - [x] Relu 58 | - [x] Add 59 | - [x] Gemm 60 | - [x] Sigmoid 61 | - [x] Mul 62 | - [x] Concat 63 | - [x] Resize (还有一些问题需要解决,当前版本支持固定倍数方法) 64 | - [x] Transpose 65 | - [x] LRN 66 | - [x] Clip 67 | - [x] Pad2d 68 | - [x] Split 69 | - [x] ReduceMean 70 | - [x] LeakyRelu 71 | 72 | ## 已验证支持的模型 73 | 74 | 基于ONNXRuntime和Pytorch推理之后特征值mse小于1e-7,视为转换成功 75 | 76 | ### 分类模型 77 | - [x] zfnet512-9.onnx 78 | - [x] resnet50-v2-7.onnx 79 | - [x] mobilenetv2-7.onnx 80 | - [x] mobilenetv2-1.0.onnx 81 | - [x] bvlcalexnet-9.onnx 82 | - [x] googlenet-9.onnx 83 | - [x] squeezenet1.1-7.onnx 84 | - [x] shufflenet-v2-10.onnx 85 | - [x] inception-v1-9.onnx 86 | - [x] inception-v2-9.onnx 87 | - [x] vgg19-caffe2-9.onnx 88 | - [x] rcnn-ilsvrc13-9.onnx 89 | 90 | ### 检测模型 91 | - [x] yolov5s-simple.onnx 92 | 93 | ### 分割模型 94 | 95 | # TODO 96 | 97 | - [ ] 支持更多模型 98 | - [ ] 重构工程,并解决某些模型转为Pytorch模型之后Netron可视化看不到某些OP的问题 99 | - [ ] 一些部署工作,比如Keras导出的ONNX转为Pytorch模型后,二次导出ONNX递交给NCNN推理 100 | 101 | # 相关链接 102 | 103 | - https://github.com/ToriML/onnx2pytorch 104 | - https://github.com/daquexian/onnx-simplifier 105 | -------------------------------------------------------------------------------- /convert_models.md: -------------------------------------------------------------------------------- 1 | ## 模型转换命令 2 | 3 | ### 1. resnet50-v2-7.onnx 4 | 5 | ```sh 6 | python .\onnx2pytorch.py --onnx_path .\models\resnet50-v2-7.onnx --simplify_path .\models\resnet50-v2-7-simplify.onnx --pytorch_path .\models\resnet50-v2-7.pth --input_shape data:1,3,224,224 7 | 8 | {'data': [1, 3, 224, 224]} 9 | mse 4.643172069052071e-12 10 | ``` 11 | 12 | 13 | ### 2. mobilenetv2-7.onnx 14 | 15 | ```sh 16 | python .\onnx2pytorch.py --onnx_path .\models\mobilenetv2-7.onnx --simplify_path .\models\mobilenetv2-7-simplify.onnx --pytorch_path .\models\mobilenetv2-7.pth --input_shape input:1,3,224,224 17 | 18 | 19 | {'input': [1, 3, 224, 224]} 20 | mse 3.6263321234741854e-12 21 | ``` 22 | 23 | ## 3. bvlcalexnet-9.onnx 24 | 25 | 26 | ```sh 27 | python .\onnx2pytorch.py --onnx_path .\models\bvlcalexnet-9.onnx --simplify_path .\models\bvlcalexnet-9-simplify.onnx --pytorch_path .\models\bvlcalexnet-9.pth --input_shape data_0:1,3,224,224 28 | 29 | 30 | {'data_0': [1, 3, 224, 224]} 31 | mse 4.7017316594428514e-20 32 | ``` 33 | 34 | ### 4. googlenet-9.onnx 35 | 36 | ```sh 37 | python .\onnx2pytorch.py --onnx_path .\models\googlenet-9.onnx --simplify_path .\models\googlenet-9-simplify.onnx --pytorch_path .\models\googlenet-9.pth --input_shape data_0:1,3,224,224 38 | 39 | 40 | {'data_0': [1, 3, 224, 224]} 41 | mse 4.14498926989363e-17 42 | ``` 43 | 44 | ## 5. squeezenet1.1-7.onnx 45 | 46 | ```sh 47 | python .\onnx2pytorch.py --onnx_path .\models\squeezenet1.1-7.onnx --simplify_path .\models\squeezenet1.1-7-simplify.onnx --pytorch_path .\models\squeezenet1.1-7.pth --input_shape data:1,3,224,224 48 | 49 | 50 | {'data': [1, 3, 224, 224]} 51 | mse 1.0111956827429935e-12 52 | ``` 53 | 54 | ## 6. shufflenet-v2-10.onnx 55 | 56 | ```sh 57 | python .\onnx2pytorch.py --onnx_path .\models\shufflenet-v2-10.onnx --simplify_path .\models\shufflenet-v2-10-simplify.onnx --pytorch_path .\models\shufflenet-v2-10.pth --input_shape input:1,3,224,224 58 | 59 | 60 | {'input': [1, 3, 224, 224]} 61 | mse 5.285994753023715e-12 62 | ``` 63 | 64 | ## 7. inception-v1-9.onnx 65 | 66 | ```sh 67 | python .\onnx2pytorch.py --onnx_path .\models\inception-v1-9.onnx --simplify_path .\models\inception-v1-9-simplify.onnx --pytorch_path .\models\inception-v1-9.pth --input_shape data_0:1,3,224,224 68 | 69 | 70 | {'data_0': [1, 3, 224, 224]} 71 | mse 1.6917238484094424e-17 72 | ``` 73 | 74 | ## 8. inception-v2-9.onnx 75 | 76 | ```sh 77 | python .\onnx2pytorch.py --onnx_path .\models\inception-v2-9.onnx --simplify_path .\models\inception-v2-9-simplify.onnx --pytorch_path .\models\inception-v2-9.pth --input_shape data_0:1,3,224,224 78 | 79 | {'data_0': [1, 3, 224, 224]} 80 | mse 1.363866701867278e-15 81 | ``` 82 | 83 | ## 9. mobilenetv2-1.0.onnx 84 | 85 | ```sh 86 | python .\onnx2pytorch.py --onnx_path .\models\mobilenetv2-1.0.onnx --simplify_path .\models\mobilenetv2-1.0-simplify.onnx --pytorch_path .\models\mobilenetv2-1.0.pth --input_shape data:1,3,224,224 87 | 88 | {'data': [1, 3, 224, 224]} 89 | mse 5.929283286576492e-12 90 | ``` 91 | 92 | ## 10. vgg19-caffe2-9.onnx 93 | 94 | - 可以看到模型转换成功了,但是Pytorch模型太大了,无法torch.save。 95 | 96 | ```sh 97 | python .\onnx2pytorch.py --onnx_path .\models\vgg19-caffe2-9.onnx --simplify_path .\models\vgg19-caffe2-9-simplify.onnx --pytorch_path .\models\vgg19-caffe2-9.pth --input_shape data_0:1,3,224,224 98 | 99 | {'data_0': [1, 3, 224, 224]} 100 | mse 8.932564880980331e-19 101 | Traceback (most recent call last): 102 | File ".\onnx2pytorch.py", line 92, in 103 | convert_onnx_pytorch(model_slim, pytorch_model, output, input) 104 | File ".\onnx2pytorch.py", line 28, in convert_onnx_pytorch 105 | torch.save(model, pytorch_model) 106 | File "D:\Anaconda3\lib\site-packages\torch\serialization.py", line 372, in save 107 | _save(obj, opened_zipfile, pickle_module, pickle_protocol) 108 | File "D:\Anaconda3\lib\site-packages\torch\serialization.py", line 476, in _save 109 | pickler.dump(obj) 110 | MemoryError 111 | ``` 112 | 113 | ## 11. zfnet512-9.onnx 114 | 115 | ```sh 116 | python .\onnx2pytorch.py --onnx_path .\models\zfnet512-9.onnx --simplify_path .\models\zfnet512-9-simplify.onnx --pytorch_path .\models\zfnet512-9.pth --input_shape gpu_0/data_0:1,3,224,224 117 | 118 | {'gpu_0/data_0': [1, 3, 224, 224]} 119 | mse 5.180448423209617e-18 120 | ``` 121 | 122 | ## 12. rcnn-ilsvrc13-9.onnx 123 | 124 | ```sh 125 | python .\onnx2pytorch.py --onnx_path .\models\rcnn-ilsvrc13-9.onnx --simplify_path .\models\rcnn-ilsvrc13-9-simplify.onnx --pytorch_path .\models\rcnn-ilsvrc13-9.pth --input_shape data_0:1,3,224,224 126 | 127 | {'data_0': [1, 3, 224, 224]} 128 | mse 2.1032064978498965e-14 129 | ``` -------------------------------------------------------------------------------- /onnx2pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import re 4 | import onnx 5 | import numpy as np 6 | import torch 7 | import onnxruntime as ort 8 | import argparse 9 | from tool import * 10 | 11 | from onnx2pytorch import convert 12 | from onnxsim import simplify 13 | 14 | def convert_onnx_pytorch(onnx_model, pytorch_model, onnx_model_outputs, onnx_inputs): 15 | model = convert.ConvertModel(onnx_model, debug=False) 16 | model.eval() 17 | model.cpu() 18 | with torch.no_grad(): 19 | outputs = model(onnx_inputs) 20 | if not isinstance(outputs, list): 21 | outputs = [outputs] 22 | outputs = [x.cpu().numpy() for x in outputs] 23 | # print(outputs[0][0][0:10]) 24 | for output, onnx_model_output in zip(outputs, onnx_model_outputs): 25 | print("mse", ((onnx_model_output - output) ** 2).sum() / onnx_model_output.size) 26 | np.testing.assert_allclose(onnx_model_output, output, atol=1e-5, rtol=1e-3) 27 | 28 | torch.save(model, pytorch_model) 29 | 30 | def get_onnx_output(onnx_model, onnx_inputs): 31 | sess = ort.InferenceSession(onnx_model.SerializeToString()) 32 | sess.set_providers(['CPUExecutionProvider']) 33 | input_name = sess.get_inputs()[0].name 34 | output_name = sess.get_outputs()[0].name 35 | output = sess.run([output_name], {input_name : onnx_inputs.numpy()}) 36 | # print(output[0][0][0:10]) 37 | return output 38 | 39 | 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser(description="onnx2pytorch test") 44 | parser.add_argument("--onnx_path", default="", type=str, required=True) 45 | parser.add_argument("--simplify_path", default="", type=str, required=True) 46 | parser.add_argument("--pytorch_path", default="", type=str, required=True) 47 | parser.add_argument("--input_shape", default="input:1,3,224,224", type=str, required=True) 48 | args = parser.parse_args() 49 | 50 | input_shape_backup = args.input_shape 51 | input_shape = re.split(':', input_shape_backup)[-1] 52 | input_shape = re.split(',', input_shape) 53 | input = torch.randn(list(map(int, input_shape))) 54 | 55 | if(args.onnx_path.endswith('.onnx') == False): 56 | print('Please Check Your ONNX Model Path Format') 57 | if(args.simplify_path.endswith('.onnx') == False): 58 | print('Please Check Your ONNX Simplify Model Path Format') 59 | if(args.pytorch_path.endswith('.pth') == False): 60 | print('Please Check Your Pytorch Model Path Format') 61 | 62 | tool = Tool(args.onnx_path) 63 | for i, node in enumerate(tool.model.graph.node): 64 | if(node.op_type == "Dropout"): 65 | tool.remove_node(node) 66 | 67 | input_shape_backup = [input_shape_backup] 68 | 69 | input_shapes = {} 70 | 71 | if input_shape_backup is not None: 72 | for x in input_shape_backup: 73 | if ':' not in x: 74 | input_shapes[None] = list(map(int, x.split(','))) 75 | else: 76 | pieces = x.split(':') 77 | # for the input name like input:0 78 | name, shape = ':'.join( 79 | pieces[:-1]), list(map(int, pieces[-1].split(','))) 80 | input_shapes[name] = shape 81 | 82 | print(input_shapes) 83 | 84 | model_slim, check = simplify(tool.model, input_shapes=input_shapes) 85 | 86 | assert check, "Simplified ONNX model could not be validated" 87 | 88 | if args.simplify_path.endswith('.onnx'): 89 | onnx.save(model_slim, args.simplify_path) 90 | 91 | pytorch_model = args.pytorch_path 92 | output = get_onnx_output(model_slim, input) 93 | 94 | convert_onnx_pytorch(model_slim, pytorch_model, output, input) 95 | -------------------------------------------------------------------------------- /onnx2pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import ConvertModel 2 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ConvertModel 2 | from .layer import * 3 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/attribute.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import onnx 4 | from onnx import numpy_helper 5 | 6 | from onnx2pytorch.utils import ( 7 | extract_padding_params_for_conv_layer, 8 | extract_padding_params, 9 | ) 10 | 11 | TENSOR_PROTO_MAPPING = dict([i[::-1] for i in onnx.TensorProto.DataType.items()]) 12 | 13 | AttributeType = dict( 14 | UNDEFINED=0, 15 | FLOAT=1, 16 | INT=2, 17 | STRING=3, 18 | TENSOR=4, 19 | GRAPH=5, 20 | SPARSE_TENSOR=11, 21 | FLOATS=6, 22 | INTS=7, 23 | STRINGS=8, 24 | TENSORS=9, 25 | GRAPHS=10, 26 | SPARSE_TENSORS=12, 27 | ) 28 | 29 | # 获取ONNX节点属性的具体值 30 | def extract_attr_values(attr): 31 | """Extract onnx attribute values.""" 32 | if attr.type == AttributeType["INT"]: 33 | value = attr.i 34 | elif attr.type == AttributeType["FLOAT"]: 35 | value = attr.f 36 | elif attr.type == AttributeType["INTS"]: 37 | value = tuple(attr.ints) 38 | elif attr.type == AttributeType["FLOATS"]: 39 | value = tuple(attr.floats) 40 | elif attr.type == AttributeType["TENSOR"]: 41 | value = numpy_helper.to_array(attr.t) 42 | elif attr.type == AttributeType["STRING"]: 43 | value = attr.s.decode() 44 | else: 45 | raise NotImplementedError( 46 | "Extraction of attribute type {} not implemented.".format(attr.type) 47 | ) 48 | return value 49 | 50 | # 提取ONNX节点的各个属性 51 | def extract_attributes(node): 52 | """Extract onnx attributes. Map onnx feature naming to pytorch.""" 53 | kwargs = {} 54 | for attr in node.attribute: 55 | if attr.name == "dilations": 56 | kwargs["dilation"] = extract_attr_values(attr) 57 | elif attr.name == "group": 58 | kwargs["groups"] = extract_attr_values(attr) 59 | elif attr.name == "kernel_shape": 60 | kwargs["kernel_size"] = extract_attr_values(attr) 61 | elif attr.name == "pads": 62 | params = extract_attr_values(attr) 63 | if node.op_type == "Pad": 64 | kwargs["padding"] = extract_padding_params(params) 65 | else: 66 | # Works for Conv, MaxPooling and other layers from convert_layer func 67 | kwargs["padding"] = extract_padding_params_for_conv_layer(params) 68 | elif attr.name == "strides": 69 | kwargs["stride"] = extract_attr_values(attr) 70 | elif attr.name == "axis" and node.op_type == "Flatten": 71 | kwargs["start_dim"] = extract_attr_values(attr) 72 | elif attr.name == "axis" or attr.name == "axes": 73 | v = extract_attr_values(attr) 74 | if isinstance(v, (tuple, list)) and len(v) == 1: 75 | kwargs["dim"] = v[0] 76 | else: 77 | kwargs["dim"] = v 78 | elif attr.name == "keepdims": 79 | kwargs["keepdim"] = bool(extract_attr_values(attr)) 80 | elif attr.name == "epsilon": 81 | kwargs["eps"] = extract_attr_values(attr) 82 | elif attr.name == "momentum": 83 | kwargs["momentum"] = extract_attr_values(attr) 84 | elif attr.name == "ceil_mode": 85 | kwargs["ceil_mode"] = bool(extract_attr_values(attr)) 86 | elif attr.name == "value": 87 | kwargs["constant"] = extract_attr_values(attr) 88 | elif attr.name == "perm": 89 | kwargs["dims"] = extract_attr_values(attr) 90 | elif attr.name == "split": 91 | kwargs["split_size_or_sections"] = extract_attr_values(attr) 92 | elif attr.name == "spatial": 93 | kwargs["spatial"] = extract_attr_values(attr) # Batch norm parameter 94 | elif attr.name == "to": 95 | kwargs["dtype"] = TENSOR_PROTO_MAPPING[extract_attr_values(attr)].lower() 96 | elif attr.name == "mode": 97 | kwargs["mode"] = extract_attr_values(attr) 98 | elif attr.name == "transB": 99 | kwargs["transpose_weight"] = not extract_attr_values(attr) 100 | elif attr.name == "transA": 101 | kwargs["transpose_activation"] = bool(extract_attr_values(attr)) 102 | elif attr.name == "alpha" and node.op_type == "LeakyRelu": 103 | kwargs["negative_slope"] = extract_attr_values(attr) 104 | elif attr.name == "alpha" and node.op_type != "LRN": 105 | kwargs["weight_multiplier"] = extract_attr_values(attr) 106 | elif attr.name == "beta" and node.op_type != "LRN": 107 | kwargs["bias_multiplier"] = extract_attr_values(attr) 108 | elif attr.name == "starts": 109 | kwargs["starts"] = extract_attr_values(attr) 110 | elif attr.name == "ends": 111 | kwargs["ends"] = extract_attr_values(attr) 112 | elif attr.name == "coordinate_transformation_mode": 113 | arg = extract_attr_values(attr) 114 | if arg == "align_corners": 115 | kwargs["align_corners"] = True 116 | else: 117 | warnings.warn( 118 | "Pytorch's interpolate uses no coordinate_transformation_mode={}. " 119 | "Result might differ.".format(arg) 120 | ) 121 | elif node.op_type == "Resize": 122 | # These parameters are not used, warn in Resize operator 123 | kwargs[attr.name] = extract_attr_values(attr) 124 | elif attr.name == "auto_pad": 125 | value = extract_attr_values(attr) 126 | if value == "NOTSET": 127 | pass 128 | else: 129 | raise NotImplementedError( 130 | "auto_pad={} functionality not implemented.".format(value) 131 | ) 132 | elif attr.name == "alpha" and node.op_type == "LRN": 133 | kwargs["alpha"] = extract_attr_values(attr) 134 | elif attr.name == "beta" and node.op_type == "LRN": 135 | kwargs["beta"] = extract_attr_values(attr) 136 | elif attr.name == "bias" and node.op_type == "LRN": 137 | kwargs["k"] = extract_attr_values(attr) 138 | elif attr.name == "size" and node.op_type == "LRN": 139 | kwargs["size"] = extract_attr_values(attr) 140 | elif attr.name == "min": 141 | kwargs["min"] = extract_attr_values(attr) 142 | elif attr.name == "max": 143 | kwargs["max"] = extract_attr_values(attr) 144 | else: 145 | raise NotImplementedError( 146 | "Extraction of attribute {} not implemented.".format(attr.name) 147 | ) 148 | return kwargs 149 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from onnx2pytorch.utils import get_activation_value 5 | 6 | 7 | def debug_model_conversion(onnx_model, inputs, pred_act, node, rtol=1e-3, atol=1e-4): 8 | """Compare if the activations of pytorch are the same as from onnxruntime.""" 9 | if not isinstance(inputs, list): 10 | raise TypeError("inputs should be in a list.") 11 | 12 | if not all(isinstance(x, np.ndarray) for x in inputs): 13 | inputs = [x.detach().numpy() for x in inputs] 14 | 15 | exp_act = get_activation_value(onnx_model, inputs, list(node.output)) 16 | if isinstance(pred_act, list): 17 | for a, b in zip(exp_act, pred_act): 18 | assert torch.allclose(torch.from_numpy(a), b, rtol=rtol, atol=atol) 19 | else: 20 | a = torch.from_numpy(exp_act[0]) 21 | b = pred_act 22 | if torch.allclose(a, b, rtol=rtol, atol=atol) == False: 23 | print(node.input[0]) 24 | print(node.op_type) 25 | assert torch.allclose(a, b, rtol=rtol, atol=atol) 26 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from onnx import numpy_helper 4 | 5 | from onnx2pytorch.convert.attribute import extract_attributes, extract_attr_values 6 | 7 | 8 | def extract_params(params): 9 | """Extract weights and biases.""" 10 | param_length = len(params) 11 | if param_length == 1: 12 | weight = params[0] 13 | bias = None 14 | elif param_length == 2: 15 | weight = params[0] 16 | bias = params[1] 17 | else: 18 | raise ValueError("Unexpected number of parameters: {}".format(param_length)) 19 | return weight, bias 20 | 21 | 22 | def load_params(layer, weight, bias): 23 | """Load weight and bias to a given layer from onnx format.""" 24 | layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight)) 25 | if bias is not None: 26 | layer.bias.data = torch.from_numpy(numpy_helper.to_array(bias)) 27 | 28 | # 卷积/反卷积/池化层的转换 29 | def convert_layer(node, layer_type, params=None): 30 | """Use to convert Conv, MaxPool, AvgPool layers.""" 31 | assert layer_type in [ 32 | "Conv", 33 | "ConvTranspose", 34 | "MaxPool", 35 | "AvgPool", 36 | ], "Incorrect layer type: {}".format(layer_type) 37 | kwargs = extract_attributes(node) 38 | kernel_size_length = len(kwargs["kernel_size"]) 39 | try: 40 | layer = getattr(nn, "{}{}d".format(layer_type, kernel_size_length)) 41 | except AttributeError: 42 | raise ValueError( 43 | "Unexpected length of kernel_size dimension: {}".format(kernel_size_length) 44 | ) 45 | 46 | if params: 47 | pad_layer = None 48 | weight, bias = extract_params(params) 49 | kwargs["bias"] = bias is not None 50 | kwargs["in_channels"] = weight.dims[1] * kwargs.get("groups", 1) 51 | kwargs["out_channels"] = weight.dims[0] 52 | 53 | if layer_type == "ConvTranspose": 54 | kwargs["in_channels"], kwargs["out_channels"] = ( 55 | kwargs["out_channels"], 56 | kwargs["in_channels"], 57 | ) 58 | 59 | # if padding is a layer, remove from kwargs and prepend later 60 | if isinstance(kwargs["padding"], nn.Module): 61 | pad_layer = kwargs.pop("padding") 62 | 63 | # initialize layer and load weights 64 | layer = layer(**kwargs) 65 | load_params(layer, weight, bias) 66 | if pad_layer is not None: 67 | layer = nn.Sequential(pad_layer, layer) 68 | else: 69 | # initialize operations without parameters (MaxPool, AvgPool, etc.) 70 | if(len(kwargs["padding"]) == 2): 71 | if node.op_type == "AveragePool": 72 | kwargs["count_include_pad"] = False 73 | layer = layer(**kwargs) 74 | else: 75 | if node.op_type == "AveragePool": 76 | kernel_size_x = kwargs["kernel_size"][0] 77 | kernel_size_y = kwargs["kernel_size"][1] 78 | pad_x = kwargs["padding"][1] 79 | pad_y = kwargs["padding"][3] 80 | kernel_size_x -= pad_x 81 | kernel_size_y -= pad_y 82 | kwargs["padding"] = (0, 0) 83 | kwargs["kernel_size"] = (kernel_size_x, kernel_size_y) 84 | layer = layer(**kwargs) 85 | else: 86 | pad_layer = nn.ConstantPad2d(kwargs["padding"], 0.0) 87 | kwargs["padding"] = (0, 0) 88 | layer = layer(**kwargs) 89 | layer = nn.Sequential(pad_layer, layer) 90 | return layer 91 | 92 | # BN层的转换 93 | def convert_batch_norm_layer(node, params): 94 | kwargs = extract_attributes(node) 95 | layer = nn.BatchNorm2d() 96 | 97 | kwargs["num_features"] = params[0].dims[0] 98 | # initialize layer and load weights 99 | layer = layer(**kwargs) 100 | key = ["weight", "bias", "running_mean", "running_var"] 101 | for key, value in zip(key, params): 102 | getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value)) 103 | 104 | return layer 105 | 106 | # InstanceNorm层的转换 107 | def convert_instance_norm_layer(node, params): 108 | kwargs = extract_attributes(node) 109 | # Skips input dimension check, not possible before forward pass 110 | layer = nn.InstanceNorm2d() 111 | 112 | kwargs["num_features"] = params[0].dims[0] 113 | # initialize layer and load weights 114 | layer = layer(**kwargs) 115 | key = ["weight", "bias"] 116 | for key, value in zip(key, params): 117 | getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value)) 118 | 119 | return layer 120 | 121 | # 全连接层的转换 122 | def convert_linear_layer(node, params): 123 | """Convert linear layer from onnx node and params.""" 124 | # Default Gemm attributes 125 | dc = dict( 126 | transpose_weight=True, 127 | transpose_activation=False, 128 | weight_multiplier=1, 129 | bias_multiplier=1, 130 | ) 131 | dc.update(extract_attributes(node)) 132 | for attr in node.attribute: 133 | if attr.name in ["transA"] and extract_attr_values(attr) != 0: 134 | raise NotImplementedError( 135 | "Not implemented for attr.name={} and value!=0.".format(attr.name) 136 | ) 137 | 138 | kwargs = {} 139 | weight, bias = extract_params(params) 140 | kwargs["bias"] = bias is not None 141 | kwargs["in_features"] = weight.dims[1] 142 | kwargs["out_features"] = weight.dims[0] 143 | 144 | # initialize layer and load weights 145 | layer = nn.Linear(**kwargs) 146 | load_params(layer, weight, bias) 147 | 148 | # apply onnx gemm attributes 149 | if dc.get("transpose_weight"): 150 | layer.weight.data = layer.weight.data.t() 151 | 152 | layer.weight.data *= dc.get("weight_multiplier") 153 | if layer.bias is not None: 154 | layer.bias.data *= dc.get("bias_multiplier") 155 | 156 | return layer 157 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import warnings 3 | 4 | import onnx 5 | import torch 6 | from onnx import numpy_helper 7 | from torch import nn 8 | from torch.jit import TracerWarning 9 | from torch.nn.modules.conv import _ConvNd 10 | from torch.nn.modules.batchnorm import _BatchNorm 11 | from torch.nn.modules.instancenorm import _InstanceNorm 12 | from torch.nn.modules.linear import Identity 13 | 14 | from onnx2pytorch.operations import Split 15 | from onnx2pytorch.convert.debug import debug_model_conversion 16 | from onnx2pytorch.convert.operations import convert_operations 17 | from onnx2pytorch.utils import get_inputs_names 18 | 19 | 20 | # 初始化参数 21 | class InitParameters(dict): 22 | """Use for parameters that are hidden.""" 23 | 24 | def __getitem__(self, item): 25 | with warnings.catch_warnings(): 26 | warnings.simplefilter("ignore", TracerWarning) 27 | return torch.from_numpy(numpy_helper.to_array(super().__getitem__(item))) 28 | 29 | def get(self, item, default): 30 | if item in self: 31 | return self[item] 32 | else: 33 | return default 34 | 35 | 36 | class ConvertModel(nn.Module): 37 | def __init__( 38 | self, onnx_model: onnx.ModelProto, batch_dim=0, debug=False 39 | ): 40 | """ 41 | Convert onnx model to pytorch. 42 | 43 | Parameters 44 | ---------- 45 | onnx_model: onnx.ModelProto 46 | Loaded onnx model. 47 | batch_dim: int 48 | Dimension of the batch. 49 | 50 | Returns 51 | ------- 52 | model: torch.nn.Module 53 | A converted pytorch model. 54 | """ 55 | super().__init__() 56 | self.onnx_model = onnx_model 57 | self.batch_dim = batch_dim 58 | self.debug = debug 59 | self.mapping = {} 60 | for op_id, op_name, op in convert_operations(onnx_model, batch_dim): 61 | # 设置属性值,该属性不一定是存在的 62 | setattr(self, op_name, op) 63 | self.mapping[op_id] = op_name 64 | 65 | self.init_parameters = InitParameters( 66 | {tensor.name: tensor for tensor in self.onnx_model.graph.initializer} 67 | ) 68 | 69 | self.input_names = get_inputs_names(onnx_model) 70 | 71 | def forward(self, *input): 72 | if input[0].shape[self.batch_dim] > 1: 73 | raise NotImplementedError( 74 | "Input with larger batch size than 1 not supported yet." 75 | ) 76 | 77 | activations = dict(zip(self.input_names, input)) 78 | 79 | for node in self.onnx_model.graph.node: 80 | # 指明节点的id和名字 81 | out_op_id = node.output[0] 82 | out_op_name = self.mapping[out_op_id] 83 | 84 | # 获取当前ONNX节点对应的Pytorch OP 85 | op = getattr(self, out_op_name) 86 | 87 | layer_types = (nn.Linear, _ConvNd, _BatchNorm, _InstanceNorm) 88 | if isinstance(op, layer_types) or ( 89 | isinstance(op, nn.Sequential) 90 | and any(isinstance(x, layer_types) for x in op.modules()) 91 | ): 92 | in_activations = [ 93 | activations[in_op_id] 94 | for in_op_id in node.input 95 | if in_op_id in activations 96 | ] 97 | else: 98 | in_activations = [ 99 | activations[in_op_id] if in_op_id in activations 100 | # 如果输入节点(in_op_id)不在activations中,那么一定在initializer里面 101 | else self.init_parameters.get(in_op_id, input[0]) 102 | for in_op_id in node.input 103 | ] 104 | 105 | # store activations for next layer 106 | if isinstance(op, partial) and op.func == torch.cat: 107 | activations[out_op_id] = op(in_activations) 108 | elif isinstance(op, Split): 109 | for out_op_id, output in zip(node.output, op(*in_activations)): 110 | activations[out_op_id] = output 111 | elif isinstance(op, Identity): 112 | # After batch norm fusion the batch norm parameters 113 | # were all passed to identity instead of first one only 114 | activations[out_op_id] = op(in_activations[0]) 115 | else: 116 | activations[out_op_id] = op(*in_activations) 117 | 118 | if self.debug: 119 | # 如果启用debug模式,会比较每一个OP的特征值通过Pytorch和ONNXRuntime推理之后是否完全一样 120 | debug_model_conversion( 121 | self.onnx_model, 122 | [activations[x] for x in self.input_names], 123 | activations[out_op_id], 124 | node, 125 | ) 126 | 127 | outputs = [activations[x.name] for x in self.onnx_model.graph.output] 128 | if len(outputs) == 1: 129 | outputs = outputs[0] 130 | return outputs 131 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/operations.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from onnx import numpy_helper 7 | 8 | from onnx2pytorch.convert.attribute import extract_attributes 9 | from onnx2pytorch.convert.layer import ( 10 | convert_layer, 11 | convert_linear_layer, 12 | convert_batch_norm_layer, 13 | convert_instance_norm_layer, 14 | ) 15 | from onnx2pytorch.operations import * 16 | from onnx2pytorch.operations import Resize, Upsample 17 | from onnx2pytorch.utils import value_wrapper 18 | 19 | 20 | def convert_operations(onnx_model, batch_dim=0): 21 | """ 22 | Convert onnx model operations. Yields onnx's operator_id, opeartor_name and 23 | converted pytorch operator. 24 | 25 | Parameters 26 | ---------- 27 | onnx_model: onnx.ModelProto 28 | Loaded onnx model. 29 | batch_dim: int 30 | Usually 0 for computer vision models and 1 for NLP models. 31 | 32 | Returns 33 | ------- 34 | iterator: (op_id, op_name, op) 35 | """ 36 | weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer} 37 | 38 | for i, node in enumerate(onnx_model.graph.node): 39 | params = [weights[par_name] for par_name in node.input if par_name in weights] 40 | 41 | if node.op_type == "Conv": 42 | op = convert_layer(node, "Conv", params) 43 | elif node.op_type == "Relu": 44 | op = nn.ReLU() 45 | elif node.op_type == "LeakyRelu": 46 | op = nn.LeakyReLU(**extract_attributes(node)) 47 | elif node.op_type == "Sigmoid": 48 | op = nn.Sigmoid() 49 | elif node.op_type == "MaxPool": 50 | op = convert_layer(node, "MaxPool") 51 | elif node.op_type == "AveragePool": 52 | op = convert_layer(node, "AvgPool") 53 | elif node.op_type == "Flatten": 54 | op = Flatten(**extract_attributes(node)) 55 | elif node.op_type == "Gemm": 56 | op = convert_linear_layer(node, params) 57 | op.feature_dim = batch_dim + 1 # Necessary for transformers 58 | elif node.op_type == "BatchNormalization": 59 | op = convert_batch_norm_layer(node, params=params) 60 | elif node.op_type == "InstanceNormalization": 61 | op = convert_instance_norm_layer(node, params=params) 62 | elif node.op_type == "Concat": 63 | op = Concat(**extract_attributes(node)) 64 | elif node.op_type == "Constant": 65 | # 常量OP如何解决的问题,先过一遍ONNX-Simplifier 66 | op = value_wrapper(torch.from_numpy(extract_attributes(node)["constant"])) 67 | elif node.op_type == "Reshape": 68 | shape = list( 69 | filter(lambda x: x.name == node.input[1], onnx_model.graph.initializer) 70 | ) 71 | shape = numpy_helper.to_array(shape[0]) if shape else None 72 | op = Reshape(tuple(shape)) 73 | elif node.op_type == "Shape": 74 | op = Shape() 75 | elif node.op_type == "Gather": 76 | op = Gather(**extract_attributes(node)) 77 | elif node.op_type == "Squeeze": 78 | op = Squeeze(**extract_attributes(node)) 79 | elif node.op_type == "Unsqueeze": 80 | op = partial(torch.unsqueeze, **extract_attributes(node)) 81 | elif node.op_type == "ConstantOfShape": 82 | op = ConstantOfShape(**extract_attributes(node)) 83 | elif node.op_type == "Slice": 84 | op = Slice(**extract_attributes(node)) 85 | elif node.op_type == "Cast": 86 | op = Cast(**extract_attributes(node)) 87 | elif node.op_type == "Where": 88 | op = Where() 89 | elif node.op_type == "Equal": 90 | op = torch.eq 91 | elif node.op_type == "Mul": 92 | op = Mul(**extract_attributes(node)) 93 | elif node.op_type == "Div": 94 | op = torch.true_divide 95 | elif node.op_type == "MatMul": 96 | if params: 97 | weight = torch.from_numpy(numpy_helper.to_array(params[0])) 98 | op = nn.Linear(weight.shape[0], weight.shape[1], bias=False) 99 | op.weight.data = weight.t() 100 | 101 | # check if next node Add to add bias 102 | next_node = onnx_model.graph.node[i + 1] 103 | next_params = [ 104 | weights[par_name] 105 | for par_name in next_node.input 106 | if par_name in weights 107 | ] 108 | if next_params and next_node.op_type == "Add": 109 | bias = torch.from_numpy(numpy_helper.to_array(next_params[0])) 110 | op.bias = nn.Parameter(bias) 111 | node.output.pop() 112 | node.output.extend(next_node.output) 113 | onnx_model.graph.node.pop(i + 1) # remove next node 114 | else: 115 | op = Matmul() 116 | elif node.op_type == "Sub": 117 | op = torch.sub 118 | elif node.op_type == "Pow": 119 | op = torch.pow 120 | elif node.op_type == "Sqrt": 121 | op = torch.sqrt 122 | elif node.op_type == "Softmax": 123 | op = nn.Softmax(dim = 1) 124 | elif node.op_type == "Transpose": 125 | op = partial(torch.Tensor.permute, **extract_attributes(node)) 126 | elif node.op_type == "Split": 127 | kwargs = extract_attributes(node) 128 | # if the split_size_or_sections is not in node attributes, 129 | # the number_of_splits becomes the number of node outputs 130 | if "split_size_or_sections" not in kwargs: 131 | kwargs["number_of_splits"] = len(node.output) 132 | op = Split(**kwargs) 133 | elif node.op_type == "ReduceMean": 134 | kwargs = dict(keepdim=True) 135 | kwargs.update(extract_attributes(node)) 136 | op = partial(torch.mean, **kwargs) 137 | elif node.op_type == "Add": 138 | op = Add() 139 | elif node.op_type == "GlobalAveragePool": 140 | op = GlobalAveragePool() 141 | elif node.op_type == "ConvTranspose": 142 | op = convert_layer(node, "ConvTranspose", params) 143 | elif node.op_type == "Identity": 144 | op = nn.Identity() 145 | elif node.op_type == "Resize": 146 | op = Resize(**extract_attributes(node)) 147 | elif node.op_type == "Upsample": 148 | op = Upsample(**extract_attributes(node)) 149 | elif node.op_type == "OneHot": 150 | op = OneHot(**extract_attributes(node)) 151 | elif node.op_type == "Pad": 152 | op = Pad(**extract_attributes(node)) 153 | elif node.op_type == "Clip": 154 | op = Clamp(**extract_attributes(node)) 155 | elif node.op_type == "Tanh": 156 | op = torch.tanh 157 | elif node.op_type == "Erf": 158 | op = torch.erf 159 | elif node.op_type == "Log": 160 | op = torch.log 161 | elif node.op_type == "Exp": 162 | op = torch.exp 163 | elif node.op_type == "LRN": 164 | op = nn.LocalResponseNorm(**extract_attributes(node)) 165 | elif node.op_type == "Dropout": 166 | op = nn.Dropout(p=1.0) 167 | else: 168 | op = getattr(torch, node.op_type.lower(), None) 169 | if op is None: 170 | raise NotImplementedError( 171 | "Conversion not implemented for op_type={}.".format(node.op_type) 172 | ) 173 | else: 174 | print( 175 | "Automatic inference of operator: {}".format(node.op_type.lower()) 176 | ) 177 | 178 | op_name = "{}_{}".format(node.op_type, node.output[0]) 179 | op_id = node.output[0] 180 | yield op_id, op_name, op 181 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/__init__.py: -------------------------------------------------------------------------------- 1 | from .add import Add 2 | from .cast import Cast 3 | from .constant import ConstantOfShape 4 | from .flatten import Flatten 5 | from .gather import Gather 6 | from .pad import Pad 7 | from .pooling import GlobalAveragePool 8 | from .reshape import Reshape 9 | from .shape import Shape 10 | from .slice import Slice 11 | from .split import Split 12 | from .squeeze import Squeeze 13 | from .resize import Resize, Upsample 14 | from .mul import Mul 15 | from .concat import Concat 16 | from .where import Where 17 | from .matmul import Matmul 18 | from .clamp import Clamp 19 | 20 | __all__ = [ 21 | "Add", 22 | "Cast", 23 | "ConstantOfShape", 24 | "Flatten", 25 | "Gather", 26 | "Pad", 27 | "GlobalAveragePool", 28 | "Reshape", 29 | "Shape", 30 | "Slice", 31 | "Split", 32 | "Squeeze", 33 | "Resize", 34 | "Upsample", 35 | "Mul", 36 | "Concat", 37 | "Where", 38 | "Matmul", 39 | 'Clamp' 40 | ] 41 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/add.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Add(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, input1, input2): 10 | return torch.add(input1, input2) 11 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/cast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Cast(nn.Module): 6 | def __init__(self, dtype): 7 | if isinstance(dtype, str): 8 | dtype = getattr(torch, dtype.lower()) 9 | self.dtype = dtype 10 | super().__init__() 11 | 12 | def forward(self, input: torch.Tensor): 13 | return input.to(self.dtype) 14 | 15 | def extra_repr(self) -> str: 16 | return "dtype={}".format(self.dtype) 17 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/clamp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Clamp(nn.Module): 6 | def __init__(self, min, max): 7 | self.min = min 8 | self.max = max 9 | super().__init__() 10 | 11 | def forward(self, input): 12 | return torch.clamp(input, self.min, self.max) 13 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/concat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Concat(nn.Module): 6 | def __init__(self, dim=1): 7 | super().__init__() 8 | self.axis = dim 9 | 10 | def forward(self, *input): 11 | return torch.cat(input, axis=self.axis) 12 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/constant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ConstantOfShape(nn.Module): 6 | def __init__(self, constant): 7 | super().__init__() 8 | self.constant = torch.from_numpy(constant) 9 | 10 | def forward(self, shape: torch.Tensor): 11 | return self.constant * torch.ones(*shape) 12 | 13 | def extra_repr(self) -> str: 14 | return "constant={}".format(self.constant) 15 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/flatten.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Flatten(nn.Module): 6 | def __init__(self, start_dim=1, end_dim=-1): 7 | super().__init__() 8 | self.start_dim = start_dim 9 | self.end_dim = end_dim 10 | 11 | def forward(self, input: torch.Tensor): 12 | return torch.flatten(input, start_dim=self.start_dim, end_dim=self.end_dim) 13 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Gather(nn.Module): 6 | def __init__(self, dim=0): 7 | self.dim = dim 8 | super().__init__() 9 | 10 | def forward(self, input: torch.Tensor, index: torch.Tensor): 11 | return torch.gather(input, self.dim, index) 12 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Matmul(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, input1: torch.Tensor, input2: torch.Tensor): 9 | return torch.matmul(input1, input2) 10 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/mul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Mul(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, input1: torch.Tensor, input2: torch.Tensor): 10 | return torch.mul(input1, input2) 11 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/pad.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | class Pad(nn.Module): 5 | def __init__(self, mode="constant", padding=None): 6 | self.mode = mode 7 | self.padding = padding 8 | super().__init__() 9 | 10 | def forward(self, input, pads=None, value=0): 11 | if self.padding is not None: 12 | pads = self.padding 13 | elif pads is None: 14 | raise TypeError("pad forward() missing 1 required positional argument: 'pads'") 15 | return F.pad(input, list(pads), mode=self.mode, value=value) 16 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class GlobalAveragePool(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 9 | def forward(self, input: torch.Tensor): 10 | return self.avgpool(input) 11 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Reshape(nn.Module): 6 | def __init__(self, shape=None): 7 | super().__init__() 8 | self.shape = shape 9 | self.initial_input_shape = None 10 | def forward(self, input: torch.Tensor, shape=None): 11 | shape = shape if shape is not None else self.shape 12 | shape = [x if x != 0 else input.size(i) for i, x in enumerate(shape)] 13 | inp_shape = torch.tensor(input.shape) 14 | if self.initial_input_shape is None: 15 | self.initial_input_shape = inp_shape 16 | elif len(shape) == 2 and shape[-1] == -1: 17 | pass 18 | elif torch.equal(self.initial_input_shape, inp_shape): 19 | pass 20 | 21 | return torch.reshape(input, tuple(shape)) 22 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/resize.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | empty_tensor = torch.Tensor([]) 7 | 8 | 9 | class Resize(nn.Module): 10 | def __init__(self, mode="nearest", align_corners=None, **kwargs): 11 | self.mode = mode 12 | self.align_corners = align_corners 13 | for key in kwargs.keys(): 14 | warnings.warn( 15 | "Pytorch's interpolate uses no {}. " "Result might differ.".format(key) 16 | ) 17 | super().__init__() 18 | 19 | def forward(self, inp, roi=empty_tensor, scales=empty_tensor, sizes=empty_tensor): 20 | if roi.nelement() > 0: 21 | warnings.warn("Pytorch's interpolate uses no roi. Result might differ.") 22 | 23 | scales = list(scales) 24 | sizes = list(sizes) 25 | shape = list(inp.shape) 26 | if shape[:2] == sizes[:2]: 27 | sizes = sizes[2:] # Pytorch's interpolate takes only H and W params 28 | elif scales[:2] == [1, 1]: 29 | scales = scales[2:] 30 | elif len(scales) == 0 and len(sizes) == 0: 31 | raise ValueError("One of the two, scales or sizes, needs to be defined.") 32 | else: 33 | raise NotImplementedError( 34 | "Pytorch's interpolate does not scale batch and channel dimensions." 35 | ) 36 | 37 | if len(scales) == 0: 38 | scales = None 39 | elif len(sizes) == 0: 40 | sizes = None 41 | else: 42 | raise ValueError( 43 | "Only one of the two, scales or sizes, needs to be defined." 44 | ) 45 | 46 | return F.interpolate( 47 | inp, 48 | scale_factor=2, 49 | size=sizes, 50 | mode=self.mode, 51 | align_corners=self.align_corners, 52 | ) 53 | 54 | 55 | class Upsample(Resize): 56 | """Deprecated onnx operator.""" 57 | 58 | def forward(self, inp, scales): 59 | return super().forward(inp, torch.tensor([]), scales, torch.tensor([])) 60 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Shape(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, input: torch.Tensor): 9 | return torch.tensor(input.shape) 10 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/slice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Slice(nn.Module): 6 | def __init__(self, dim=None, starts=None, ends=None, steps=None): 7 | self.dim = [dim] if isinstance(dim, int) else dim 8 | self.starts = starts 9 | self.ends = ends 10 | self.steps = steps 11 | super().__init__() 12 | 13 | def forward( 14 | self, input: torch.Tensor, starts=None, ends=None, axes=None, steps=None 15 | ): 16 | if axes is None: 17 | axes = self.dim 18 | if starts is None: 19 | starts = self.starts 20 | if ends is None: 21 | ends = self.ends 22 | if steps is None: 23 | steps = self.steps 24 | 25 | # If axes=None set them to (0, 1, 2, ...) 26 | if axes is None: 27 | axes = tuple(range(len(starts))) 28 | if steps is None: 29 | steps = tuple(1 for _ in axes) 30 | 31 | selection = [slice(None) for _ in range(max(axes) + 1)] 32 | for i, axis in enumerate(axes): 33 | selection[axis] = slice(starts[i], ends[i], steps[i]) 34 | return input.__getitem__(selection) 35 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/split.py: -------------------------------------------------------------------------------- 1 | from itertools import accumulate 2 | import torch 3 | from torch import nn 4 | from onnx2pytorch.utils import assign_values_to_dim 5 | 6 | 7 | class Split(nn.Module): 8 | def __init__( 9 | self, split_size_or_sections=None, number_of_splits=None, dim=0, keep_size=True 10 | ): 11 | """ 12 | Parameters 13 | ---------- 14 | split_size_or_sections: tuple[int] 15 | number_of_splits: int 16 | The number of equal splits along dim. 17 | dim: int 18 | Split dimension. Tensor is split over this axis. 19 | keep_size: bool 20 | If True it keeps the size of the split the same as in initial pass. 21 | Else it splits it accordingly to the pruned input. 22 | """ 23 | assert ( 24 | split_size_or_sections is not None or number_of_splits is not None 25 | ), "One of the parameters needs to be set." 26 | self.dim = dim 27 | self.split_size_or_sections = split_size_or_sections 28 | self.number_of_splits = number_of_splits 29 | self.keep_size = keep_size 30 | self.input_indices = None 31 | self.placeholder = None 32 | super().__init__() 33 | 34 | def _get_sections(self, input): 35 | """Calculate sections from number of splits.""" 36 | dim_size = input[0].shape[self.dim] 37 | assert ( 38 | dim_size % self.number_of_splits == 0 39 | ), "Dimension size {} not equally divisible by {}.".format( 40 | dim_size, self.number_of_splits 41 | ) 42 | s = dim_size // self.number_of_splits 43 | sections = tuple(s for _ in range(self.number_of_splits)) 44 | return sections 45 | 46 | def forward(self, *input): 47 | if self.split_size_or_sections is None: 48 | self.split_size_or_sections = self._get_sections(input) 49 | 50 | if self.input_indices is not None: 51 | self.placeholder *= 0 52 | assign_values_to_dim( 53 | self.placeholder, input[0], self.input_indices, self.dim 54 | ) 55 | split = torch.split(self.placeholder, self.split_size_or_sections, self.dim) 56 | else: 57 | split = torch.split(*input, self.split_size_or_sections, dim=self.dim) 58 | return split 59 | 60 | def set_input_indices(self, input: tuple): 61 | assert isinstance(input, (tuple, list)) 62 | 63 | inp = input[0] 64 | # We assume that aggregation dimensions correspond to split dimension 65 | axis = self.get_axis(inp.shape, self.dim) 66 | 67 | # Mask shows where features are non zero in the whole axis. 68 | mask = inp != 0 69 | if len(inp.shape) > 1: 70 | mask = mask.sum(axis=tuple(axis)) != 0 71 | 72 | if not self.keep_size: 73 | # Read docstrings 74 | if isinstance(self.split_size_or_sections, tuple): 75 | indices = list(accumulate(self.split_size_or_sections)) 76 | indices = torch.tensor(indices) - 1 77 | else: 78 | raise NotImplementedError("Not implemented for split size.") 79 | cs = torch.cumsum(mask, 0) 80 | ind = [0] + cs[indices].tolist() 81 | sec = [ind[i + 1] - ind[i] for i in range(len(ind) - 1)] 82 | self.split_size_or_sections = sec 83 | else: 84 | (self.input_indices,) = torch.where(mask) 85 | self.placeholder = torch.zeros(inp.shape) 86 | 87 | def __str__(self): 88 | return "Split(dim={})".format(self.dim) 89 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from onnx2pytorch.utils import get_selection 4 | 5 | 6 | class Squeeze(nn.Module): 7 | def __init__(self, dim=None): 8 | self.dim = dim 9 | super().__init__() 10 | 11 | def forward(self, input): 12 | if self.dim is None: 13 | return torch.squeeze(input) 14 | elif isinstance(self.dim, int): 15 | return torch.squeeze(input, dim=self.dim) 16 | else: 17 | for dim in sorted(self.dim, reverse=True): 18 | input = torch.squeeze(input, dim=dim) 19 | return input 20 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/where.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Where(nn.Module): 6 | def forward(self, *input): 7 | return torch.where(input[0], input[1], input[2]) 8 | -------------------------------------------------------------------------------- /onnx2pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import torch 4 | import numpy as np 5 | import onnx 6 | 7 | try: 8 | import onnxruntime as ort 9 | except ImportError: 10 | ort = None 11 | 12 | 13 | def value_wrapper(value): 14 | def callback(*args, **kwargs): 15 | return value 16 | 17 | return callback 18 | 19 | def is_constant(value): 20 | return value.ndim == 0 or value.shape == torch.Size([1]) 21 | 22 | 23 | def is_symmetric(params): 24 | """ 25 | Check if parameters are symmetric, all values [2,2,2,2]. 26 | Then we can use only [2,2]. 27 | """ 28 | assert len(params) // 2 == len(params) / 2, "Non even number of parameters." 29 | idx = len(params) // 2 30 | for i in range(0, idx): 31 | if params[i] != params[idx + i]: 32 | return False 33 | return True 34 | 35 | # 为ConstantPad2D这种OP提取padding参数 36 | def extract_padding_params(params): 37 | """Extract padding parameters fod Pad layers.""" 38 | pad_dim = len(params) // 2 39 | pads = np.array(params).reshape(-1, pad_dim).T.flatten() # .tolist() 40 | 41 | # Some padding modes do not support padding in batch and channel dimension. 42 | # If batch and channel dimension have no padding, discard. 43 | if (pads[:4] == 0).all(): 44 | pads = pads[4:] 45 | pads = pads.tolist() 46 | # Reverse, because for pytorch first two numbers correspond to last dimension, etc. 47 | pads.reverse() 48 | return pads 49 | 50 | # 为卷积OP提取padding参数 51 | # >>> import torch 52 | # >>> x = [1, 2,3, 4] 53 | # >>> import numpy as np 54 | # >>> y = np.array(x).reshape(-1,2).T.flatten() 55 | # >>> print(y) 56 | # [1 3 2 4] 57 | # >>> print(y[:4]) 58 | # [1 3 2 4] 59 | # >>> print(y[4:]) 60 | # [] 61 | 62 | def extract_padding_params_for_conv_layer(params): 63 | """ 64 | Padding params in onnx are different than in pytorch. That is why we need to 65 | check if they are symmetric and cut half or return a padding layer. 66 | """ 67 | # 参数是否堆成 68 | if is_symmetric(params): 69 | return params[: len(params) // 2] 70 | else: 71 | pads = extract_padding_params(params)[::-1] 72 | return pads 73 | 74 | 75 | def get_selection(indices, dim): 76 | """ 77 | Give selection to assign values to specific indices at given dimension. 78 | Enables dimension to be dynamic: 79 | tensor[get_selection(indices, dim=2)] = values 80 | Alternatively the dimension is fixed in code syntax: 81 | tensor[:, :, indices] = values 82 | """ 83 | assert dim >= 0, "Negative dimension not supported." 84 | # Behaviour with python lists is unfortunately not working the same. 85 | if isinstance(indices, list): 86 | indices = torch.tensor(indices) 87 | assert isinstance(indices, (torch.Tensor, np.ndarray)) 88 | selection = [slice(None) for _ in range(dim + 1)] 89 | selection[dim] = indices 90 | return selection 91 | 92 | 93 | def assign_values_to_dim(tensor, values, indices, dim, inplace=True): 94 | """ 95 | Inplace tensor operation that assigns values to corresponding indices 96 | at given dimension. 97 | """ 98 | if dim < 0: 99 | dim = dim + len(tensor.shape) 100 | selection = get_selection(indices, dim) 101 | if not inplace: 102 | tensor = tensor.clone() 103 | tensor[selection] = values 104 | return tensor 105 | 106 | 107 | def get_type(x): 108 | """ 109 | Extract type from onnxruntime input. 110 | 111 | Parameters 112 | ---------- 113 | x: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg 114 | """ 115 | if x.type.startswith("tensor"): 116 | typ = x.type[7:-1] 117 | else: 118 | raise NotImplementedError("For type: {}".format(x.type)) 119 | 120 | if typ == "float": 121 | typ = "float32" 122 | elif typ == "double": 123 | typ = "float64" 124 | return typ 125 | 126 | 127 | def get_shape(x, unknown_dim_size=1): 128 | """ 129 | Extract shape from onnxruntime input. 130 | Replace unknown dimension by default with 1. 131 | 132 | Parameters 133 | ---------- 134 | x: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg 135 | unknown_dim_size: int 136 | Default: 1 137 | """ 138 | shape = x.shape 139 | # replace unknown dimensions by default with 1 140 | shape = [i if isinstance(i, int) else unknown_dim_size for i in shape] 141 | return shape 142 | 143 | 144 | def get_activation_value(onnx_model, inputs, activation_names): 145 | """ 146 | Get activation value from an onnx model. 147 | 148 | Parameters 149 | ---------- 150 | onnx_model: onnx.ModelProto 151 | inputs: list[np.ndarray] 152 | activation_names: list[str] 153 | Can be retrieved from onnx node: list(node.output) 154 | 155 | Returns 156 | ------- 157 | value: list[np.ndarray] 158 | Value of the activation with activation_name. 159 | """ 160 | assert ort is not None, "onnxruntime needed. pip install onnxruntime" 161 | assert all(isinstance(x, np.ndarray) for x in inputs) 162 | 163 | if not isinstance(activation_names, (list, tuple)): 164 | activation_names = [activation_names] 165 | 166 | # clear output 167 | while len(onnx_model.graph.output): 168 | onnx_model.graph.output.pop() 169 | 170 | for activation_name in activation_names: 171 | activation_value = onnx.helper.ValueInfoProto() 172 | activation_value.name = activation_name 173 | onnx_model.graph.output.append(activation_value) 174 | 175 | buffer = io.BytesIO() 176 | onnx.save(onnx_model, buffer) 177 | buffer.seek(0) 178 | onnx_model_new = onnx.load(buffer) 179 | sess = ort.InferenceSession(onnx_model_new.SerializeToString()) 180 | 181 | input_names = [x.name for x in sess.get_inputs()] 182 | if not isinstance(inputs, list): 183 | inputs = [inputs] 184 | inputs = dict(zip(input_names, inputs)) 185 | 186 | return sess.run(None, inputs) 187 | 188 | # 获取网络所有的输入节点名字 189 | def get_inputs_names(onnx_model): 190 | param_names = set([x.name for x in onnx_model.graph.initializer]) 191 | input_names = [x.name for x in onnx_model.graph.input] 192 | input_names = [x for x in input_names if x not in param_names] 193 | return input_names 194 | 195 | 196 | def get_inputs_sample(onnx_model, to_torch=False): 197 | """Get inputs sample from onnx model.""" 198 | assert ort is not None, "onnxruntime needed. pip install onnxruntime" 199 | 200 | sess = ort.InferenceSession(onnx_model.SerializeToString()) 201 | inputs = sess.get_inputs() 202 | input_names = get_inputs_names(onnx_model) 203 | input_tensors = [ 204 | np.abs(np.random.rand(*get_shape(x)).astype(get_type(x))) for x in inputs 205 | ] 206 | if to_torch: 207 | input_tensors = [torch.from_numpy(x) for x in input_tensors] 208 | return dict(zip(input_names, input_tensors)) 209 | -------------------------------------------------------------------------------- /onnxsim/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxsim.onnx_simplifier import simplify 2 | 3 | __version__ = '0.0.0' 4 | -------------------------------------------------------------------------------- /onnxsim/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import onnx # type: ignore 5 | import onnxsim 6 | import numpy as np 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('input_model', help='Input ONNX model') 12 | parser.add_argument('output_model', help='Output ONNX model') 13 | parser.add_argument('check_n', help='Check whether the output is correct with n random inputs', 14 | nargs='?', type=int, default=3) 15 | parser.add_argument('--enable-fuse-bn', help='This option is deprecated. Fusing bn into conv is enabled by default.', 16 | action='store_true') 17 | parser.add_argument('--skip-fuse-bn', help='Skip fusing batchnorm into conv.', 18 | action='store_true') 19 | parser.add_argument('--skip-optimization', help='Skip optimization of ONNX optimizers.', 20 | action='store_true') 21 | parser.add_argument( 22 | '--input-shape', help='The manually-set static input shape, useful when the input shape is dynamic. The value should be "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.', type=str, nargs='+') 23 | parser.add_argument( 24 | '--skip-optimizer', help='Skip a certain ONNX optimizer', type=str, nargs='+') 25 | parser.add_argument('--skip-shape-inference', 26 | help='Skip shape inference. Shape inference causes segfault on some large models', action='store_true') 27 | parser.add_argument('--dynamic-input-shape', help='This option enables dynamic input shape support. "Shape" ops will not be eliminated in this case. Note that "--input-shape" is also needed for generating random inputs and checking equality. If "dynamic_input_shape" is False, the input shape in simplified model will be overwritten by the value of "input_shapes" param.', action='store_true') 28 | parser.add_argument( 29 | '--input-data-path', help='input data, The value should be "input_name1:xxx1.bin" "input_name2:xxx2.bin ...", input data should be a binary data file.', type=str, nargs='+') 30 | parser.add_argument( 31 | '--custom-lib', help="custom lib path which should be absolute path, if you have custom onnxruntime backend you should use this to register you custom op", type=str) 32 | 33 | args = parser.parse_args() 34 | 35 | print("Simplifying...") 36 | 37 | if args.dynamic_input_shape and args.input_shape is None: 38 | raise RuntimeError( 39 | 'Please pass "--input-shape" argument for generating random input and checking equality. Run "python3 -m onnxsim -h" for details.') 40 | if args.input_shape is not None and not args.dynamic_input_shape: 41 | print("Note: The input shape of the simplified model will be overwritten by the value of '--input--shape' argument. Pass '--dynamic-input-shape' if it is not what you want. Run 'python3 -m onnxsim -h' for details.") 42 | input_shapes = dict() 43 | if args.input_shape is not None: 44 | for x in args.input_shape: 45 | if ':' not in x: 46 | input_shapes[None] = list(map(int, x.split(','))) 47 | else: 48 | pieces = x.split(':') 49 | # for the input name like input:0 50 | name, shape = ':'.join( 51 | pieces[:-1]), list(map(int, pieces[-1].split(','))) 52 | input_shapes.update({name: shape}) 53 | 54 | input_data_paths = dict() 55 | if args.input_data_path is not None: 56 | for x in args.input_data_path: 57 | pieces = x.split(':') 58 | name, data = ':'.join(pieces[:-1]), pieces[-1] 59 | input_data_paths.update({name: data}) 60 | 61 | input_tensors = dict() 62 | if len(input_data_paths) > 0 and args.input_shape is not None: 63 | for name in input_shapes.keys(): 64 | input_data = np.fromfile(input_data_paths[name], dtype=np.float32) 65 | input_data = input_data.reshape(input_shapes[name]) 66 | input_tensors.update({name: input_data}) 67 | 68 | model_opt, check_ok = onnxsim.simplify( 69 | args.input_model, 70 | check_n=args.check_n, 71 | perform_optimization=not args.skip_optimization, 72 | skip_fuse_bn=args.skip_fuse_bn, 73 | input_shapes=input_shapes, 74 | skipped_optimizers=args.skip_optimizer, 75 | skip_shape_inference=args.skip_shape_inference, 76 | input_data=input_tensors, 77 | dynamic_input_shape=args.dynamic_input_shape, 78 | custom_lib=args.custom_lib) 79 | 80 | onnx.save(model_opt, args.output_model) 81 | 82 | if check_ok: 83 | print("Ok!") 84 | else: 85 | print("Check failed, please be careful to use the simplified model, or try specifying \"--skip-fuse-bn\" or \"--skip-optimization\" (run \"python3 -m onnxsim -h\" for details)") 86 | sys.exit(1) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /onnxsim/onnx_simplifier.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import reduce 3 | 4 | from typing import Callable, List, Dict, Union, Optional, Tuple, Sequence, TypeVar 5 | import copy 6 | 7 | import onnx # type: ignore 8 | import onnx.helper # type: ignore 9 | import onnx.shape_inference # type: ignore 10 | import onnx.numpy_helper # type: ignore 11 | import onnxruntime as rt # type: ignore 12 | import onnxoptimizer # type: ignore 13 | 14 | import numpy as np # type: ignore 15 | import os 16 | import sys 17 | 18 | Tensors = Dict[str, np.ndarray] 19 | TensorShape = List[int] 20 | TensorShapes = Dict[str, TensorShape] 21 | TensorShapesWithOptionalKey = Dict[Optional[str], TensorShape] 22 | 23 | 24 | def add_features_to_output(m: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> None: 25 | """ 26 | Add features to output in pb, so that ONNX Runtime will output them. 27 | :param m: the model that will be run in ONNX Runtime 28 | :param nodes: nodes whose outputs will be added into the graph outputs 29 | """ 30 | for node in nodes: 31 | for output in node.output: 32 | # ONNX模型的graph扩展输出节点,获取所有静态OP的输出和原始输出节点的输出 33 | m.graph.output.extend([onnx.ValueInfoProto(name=output)]) 34 | 35 | 36 | def get_shape_from_value_info_proto(v: onnx.ValueInfoProto) -> List[int]: 37 | return [dim.dim_value for dim in v.type.tensor_type.shape.dim] 38 | 39 | 40 | def get_value_info_all(m: onnx.ModelProto, name: str) -> Optional[onnx.ValueInfoProto]: 41 | for v in m.graph.value_info: 42 | if v.name == name: 43 | return v 44 | 45 | for v in m.graph.input: 46 | if v.name == name: 47 | return v 48 | 49 | for v in m.graph.output: 50 | if v.name == name: 51 | return v 52 | 53 | return None 54 | 55 | 56 | def get_shape(m: onnx.ModelProto, name: str) -> TensorShape: 57 | """ 58 | Note: This method relies on onnx shape inference, which is not reliable. So only use it on input or output tensors 59 | """ 60 | v = get_value_info_all(m, name) 61 | if v is not None: 62 | return get_shape_from_value_info_proto(v) 63 | raise RuntimeError('Cannot get shape of "{}"'.format(name)) 64 | 65 | 66 | def get_elem_type(m: onnx.ModelProto, name: str) -> int: 67 | v = get_value_info_all(m, name) 68 | if v is not None: 69 | return v.type.tensor_type.elem_type 70 | raise RuntimeError('Cannot get shape dtype "{}"'.format(name)) 71 | 72 | 73 | def get_np_type_from_elem_type(elem_type: int) -> np.dtype: 74 | sizes = (None, np.float32, np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, str, np.bool, 75 | np.float16, np.double, np.uint32, np.uint64, np.complex64, np.complex128, np.float16) 76 | assert len(sizes) == 17 77 | size = sizes[elem_type] 78 | assert size is not None 79 | return size 80 | 81 | 82 | def get_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]: 83 | initializer_names = [x.name for x in model.graph.initializer] 84 | return [ipt for ipt in model.graph.input if ipt.name not in initializer_names] 85 | 86 | 87 | def get_input_names(model: onnx.ModelProto) -> List[str]: 88 | input_names = [ipt.name for ipt in get_inputs(model)] 89 | return input_names 90 | 91 | 92 | def generate_specific_rand_input(model, input_shapes: TensorShapes): 93 | """ 94 | Only generate rand inputs whose shape in `input_shapes` 95 | """ 96 | 97 | for key, shape in input_shapes.items(): 98 | if not np.all(np.array(shape) > 0): 99 | raise RuntimeError( 100 | 'The shape of input "{}" has dynamic size "{}", ' 101 | 'please determine the input size manually by ' 102 | '"--dynamic-input-shape --input-shape xxx" or "--input-shape xxx". ' 103 | 'Run "python3 -m onnxsim -h" for details'.format(key, shape)) 104 | 105 | inputs = {ipt: np.array(np.random.rand(*input_shapes[ipt]), 106 | dtype=get_np_type_from_elem_type(get_elem_type(model, ipt))) for ipt in 107 | input_shapes} 108 | return inputs 109 | 110 | 111 | def generate_all_rand_input(model, input_shapes: Optional[TensorShapes] = None): 112 | """ 113 | Generate random array for all inputs of a model 114 | """ 115 | if input_shapes is None: 116 | input_shapes = {} 117 | input_names = get_input_names(model) 118 | full_input_shapes = {ipt: get_shape(model, ipt) for ipt in input_names} 119 | assert None not in input_shapes 120 | full_input_shapes.update(input_shapes) # type: ignore 121 | return generate_specific_rand_input(model, full_input_shapes) 122 | 123 | 124 | def is_non_deterministic_node(node: onnx.NodeProto) -> bool: 125 | # TODO: handle node with subgraph 126 | return node.op_type in ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike'] 127 | 128 | 129 | def is_non_deterministic_model(model: onnx.ModelProto) -> bool: 130 | return any([is_non_deterministic_node(node) for node in model.graph.node]) 131 | 132 | 133 | def get_constant_nodes(m: onnx.ModelProto, dynamic_input_shape: bool = False) -> List[onnx.NodeProto]: 134 | const_nodes = [] 135 | # 如果节点的name在ONNX的GraphProto的initizlizer数组里面,它就是静态的tensor 136 | const_tensors = [x.name for x in m.graph.initializer] 137 | # 显示的常量OP也加进来 138 | const_tensors.extend([node.output[0] 139 | for node in m.graph.node if node.op_type == 'Constant']) 140 | 141 | # 一些节点的输出shape是由输入节点决定的,我们认为这个节点的输出shape并不是常量, 142 | # 所以我们不需要简化这种节点 143 | dynamic_tensors = [] 144 | if dynamic_input_shape: 145 | dynamic_tensors.extend(get_input_names(m)) 146 | # 判断是否为动态OP 147 | def is_dynamic(node): 148 | if node.op_type in ['NonMaxSuppression', 'NonZero', 'Unique'] and node.input[0] not in const_tensors: 149 | return True 150 | if node.op_type in ['Reshape', 'Expand', 'Upsample', 'ConstantOfShape'] and len(node.input) > 1 and node.input[1] not in const_tensors: 151 | return True 152 | if node.op_type in ['Resize'] and ((len(node.input) > 2 and node.input[2] not in const_tensors) or (len(node.input) > 3 and node.input[3] not in const_tensors)): 153 | return True 154 | return False 155 | 156 | def has_subgraph_in_node(node: onnx.NodeProto): 157 | for attr in node.attribute: 158 | if attr.type in [onnx.AttributeProto.GRAPH, onnx.AttributeProto.GRAPHS]: 159 | return True 160 | return False 161 | 162 | for node in m.graph.node: 163 | if any(x in dynamic_tensors for x in node.input): 164 | dynamic_tensors.extend(node.output) 165 | # Note "elif" here, only Shape op with non-dynamic input will be seen as const node 166 | elif node.op_type == 'Shape': 167 | const_nodes.append(node) 168 | const_tensors.extend(node.output) 169 | elif is_dynamic(node): 170 | dynamic_tensors.extend(node.output) 171 | elif has_subgraph_in_node(node): 172 | # Skip this node if this node has subgraph in it 173 | # "If" node with const cond will be eliminated by onnxoptimizer 174 | pass 175 | elif all([x in const_tensors for x in node.input]) and not is_non_deterministic_node(node): 176 | const_nodes.append(node) 177 | const_tensors.extend(node.output) 178 | # 深拷贝 179 | return copy.deepcopy(const_nodes) 180 | 181 | 182 | def forward(model, 183 | input_data: Optional[Tensors] = None, 184 | input_shapes: Optional[TensorShapes] = None, 185 | custom_lib: Optional[str] = None) -> Tensors: 186 | if input_shapes is None: 187 | input_shapes = {} 188 | sess_options = rt.SessionOptions() 189 | if custom_lib is not None: 190 | if os.path.exists(custom_lib): 191 | sess_options.register_custom_ops_library(custom_lib) 192 | else: 193 | print("No such file '{}'".format(custom_lib), file=sys.stderr) 194 | exit(1) 195 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0) 196 | sess_options.log_severity_level = 3 197 | sess = rt.InferenceSession(model.SerializeToString( 198 | ), sess_options=sess_options, providers=['CPUExecutionProvider']) 199 | 200 | input_names = get_input_names(model) 201 | inputs = {} 202 | for name in input_names: 203 | if input_data is not None and input_data.get(name, None) is not None: 204 | inputs[name] = input_data[name] 205 | else: 206 | if input_shapes is not None and input_shapes.get(name, None) is not None: 207 | shape = input_shapes[name] 208 | else: 209 | shape = get_shape(model, name) 210 | inputs.update(generate_specific_rand_input(model, {name: shape})) 211 | 212 | outputs = [x.name for x in sess.get_outputs()] 213 | run_options = rt.RunOptions() 214 | run_options.log_severity_level = 3 215 | res = OrderedDict(zip(outputs, sess.run( 216 | outputs, inputs, run_options=run_options))) 217 | return res 218 | 219 | 220 | def forward_for_node_outputs(model: onnx.ModelProto, 221 | nodes: List[onnx.NodeProto], 222 | input_shapes: Optional[TensorShapes] = None, 223 | input_data: Optional[Tensors] = None, 224 | custom_lib: Optional[str] = None) -> Tensors: 225 | if input_shapes is None: 226 | input_shapes = {} 227 | model = copy.deepcopy(model) 228 | # nodes 是Graph中所有的静态OP 229 | add_features_to_output(model, nodes) 230 | res = forward(model, 231 | input_data=input_data, 232 | input_shapes=input_shapes, 233 | custom_lib=custom_lib) 234 | return res 235 | 236 | 237 | def insert_elem(repeated_container, index: int, element): 238 | repeated_container.extend([repeated_container[-1]]) 239 | for i in reversed(range(index + 1, len(repeated_container) - 1)): 240 | repeated_container[i].CopyFrom(repeated_container[i - 1]) 241 | repeated_container[index].CopyFrom(element) 242 | 243 | 244 | def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: List[onnx.NodeProto], 245 | res: Tensors) -> onnx.ModelProto: 246 | """ 247 | :model参数: 原始ONNX模型 248 | :const_nodes参数: 使用`get_constant_nodes`获得的静态OP 249 | :res参数: 包含所有输出Tensor的字典 250 | :return: 简化后的模型. 所有冗余操作都已删除. 251 | """ 252 | for i, node in enumerate(model.graph.node): 253 | if node in const_nodes: 254 | for output in node.output: 255 | new_node = copy.deepcopy(node) 256 | new_node.name = "node_" + output 257 | new_node.op_type = 'Constant' 258 | new_attr = onnx.helper.make_attribute( 259 | 'value', 260 | onnx.numpy_helper.from_array(res[output], name=output) 261 | ) 262 | del new_node.input[:] 263 | del new_node.attribute[:] 264 | del new_node.output[:] 265 | new_node.output.extend([output]) 266 | new_node.attribute.extend([new_attr]) 267 | insert_elem(model.graph.node, i + 1, new_node) 268 | del model.graph.node[i] 269 | 270 | return model 271 | 272 | 273 | def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, skipped_optimizers: Optional[Sequence[str]]) -> onnx.ModelProto: 274 | """ 275 | :model参数: 待优化的ONXX模型. 276 | :return: 优化之后的ONNX模型. 277 | 简化之前, 使用这个方法产生会在'forward_all'用到的ValueInfo 278 | 简化之后,使用这个方法去折叠前一步产生的常量到initializer中并且消除没被使用的常量 279 | """ 280 | 281 | onnx.checker.check_model(model) 282 | onnx.helper.strip_doc_string(model) 283 | optimizers_list = onnxoptimizer.get_fuse_and_elimination_passes() 284 | if not skip_fuse_bn: 285 | optimizers_list.append('fuse_bn_into_conv') 286 | if skipped_optimizers is not None: 287 | for opt in skipped_optimizers: 288 | try: 289 | optimizers_list.remove(opt) 290 | except ValueError: 291 | pass 292 | 293 | model = onnxoptimizer.optimize(model, optimizers_list, 294 | fixed_point=True) 295 | onnx.checker.check_model(model) 296 | return model 297 | 298 | 299 | def check(model_opt: onnx.ModelProto, model_ori: onnx.ModelProto, n_times: int = 5, 300 | input_shapes: Optional[TensorShapes] = None) -> bool: 301 | """ 302 | Warning: Some models (e.g., MobileNet) may fail this check by a small magnitude. 303 | Just ignore if it happens. 304 | :param input_shapes: Shapes of generated random inputs 305 | :param model_opt: The simplified ONNX model 306 | :param model_ori: The original ONNX model 307 | :param n_times: Generate n random inputs 308 | """ 309 | if input_shapes is None: 310 | input_shapes = {} 311 | onnx.checker.check_model(model_opt) 312 | 313 | if is_non_deterministic_model(model_ori) and n_times > 0: 314 | print("The model has random ops like RandomNormal. Skip checking..") 315 | n_times = 0 316 | 317 | for i in range(n_times): 318 | print("Checking {}/{}...".format(i, n_times)) 319 | rand_input = generate_all_rand_input( 320 | model_opt, input_shapes=input_shapes) 321 | res_opt = forward(model_opt, input_data=rand_input) 322 | res_ori = forward(model_ori, input_data=rand_input) 323 | 324 | for name in res_opt.keys(): 325 | if not np.allclose(res_opt[name], res_ori[name], rtol=1e-4, atol=1e-5): 326 | print("Tensor {} changes after simplifying. The max diff is {}.".format( 327 | name, np.max(np.abs(res_opt[name] - res_ori[name])))) 328 | print("Note that the checking is not always correct.") 329 | print("After simplifying:") 330 | print(res_opt[name]) 331 | print("Before simplifying:") 332 | print(res_ori[name]) 333 | print("----------------") 334 | return False 335 | return True 336 | 337 | 338 | def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Tensors): 339 | """ 340 | It seems not needed since commit 6f2a72, but maybe it still prevents some unknown bug 341 | :param const_nodes: const nodes detected by `get_constant_nodes` 342 | :param res: The dict containing all tensors, got by `forward_all` 343 | :return: The constant nodes which have an output in res 344 | """ 345 | return [node for node in const_nodes if node.output[0] in res] 346 | 347 | 348 | def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorShapesWithOptionalKey) -> TensorShapes: 349 | input_names = get_input_names(model) 350 | if None in input_shapes: 351 | if len(input_names) == 1: 352 | input_shapes[input_names[0]] = input_shapes[None] 353 | del input_shapes[None] 354 | else: 355 | raise RuntimeError( 356 | 'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape') 357 | for x in input_shapes: 358 | if x not in input_names: 359 | raise RuntimeError( 360 | 'The model doesn\'t have input named "{}"'.format(x)) 361 | return input_shapes # type: ignore 362 | 363 | 364 | def infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto: 365 | try: 366 | model = onnx.shape_inference.infer_shapes(model) 367 | except: 368 | pass 369 | return model 370 | 371 | 372 | T = TypeVar('T') 373 | 374 | # 递归执行func_a和func_b直到模型稳定 375 | def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T: 376 | """ 377 | Run `func_a` and `func_b` on `x` until func_b(func_a(x)) == x 378 | :param x: 379 | :param func_a: A function satisfying func_a(func_a(x)) == func_a(x) 380 | :param func_b: A function satisfying func_b(func_b(x)) == func_b(x) 381 | :return: the x that satisfies func_b(func_a(x)) == x 382 | """ 383 | x = func_a(x) 384 | x = func_b(x) 385 | while True: 386 | y = func_a(x) 387 | if y == x: 388 | # Since func_b(func_b(x)) == func_b(x), 389 | # we are already at the fixed point if 390 | # `y == x` 391 | return x 392 | x = y 393 | y = func_b(x) 394 | if y == x: 395 | return x 396 | x = y 397 | 398 | 399 | def simplify(model: Union[str, onnx.ModelProto], 400 | check_n: int = 0, 401 | perform_optimization: bool = True, 402 | skip_fuse_bn: bool = False, 403 | input_shapes: Optional[TensorShapesWithOptionalKey] = None, 404 | skipped_optimizers: Optional[Sequence[str]] = None, 405 | skip_shape_inference=False, 406 | input_data: Optional[Tensors] = None, 407 | dynamic_input_shape: bool = False, 408 | custom_lib: Optional[str] = None) -> Tuple[onnx.ModelProto, bool]: 409 | """ 410 | :param model: onnx ModelProto object or file path 411 | :param check_n: The simplified model will be checked for `check_n` times by random inputs 412 | :param perform_optimization: Whether to run onnx optimizer on the model 413 | :param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer 414 | :param input_shapes: If the model has dynamic input shape, user must pass a fixed input shape 415 | for generating random inputs and checking equality. (Also see "dynamic_input_shape" param) 416 | :param skipped_optimizers: Skip some specific onnx optimizers 417 | :param skip_shape_inference: Skip shape inference (sometimes shape inference will crash) 418 | :param input_data: Feed custom input data for checking if needed 419 | :param dynamic_input_shape: Indicates whether the input shape should be dynamic. Note that 420 | input_shapes is also needed even if dynamic_input_shape is True, 421 | the value of input_shapes will be used when generating random inputs for checking equality. 422 | If 'dynamic_input_shape' is False, the input shape in simplified model will be overwritten 423 | by the value of 'input_shapes' param. 424 | :param custom_lib: onnxruntime custom ops's shared library 425 | :return: A tuple (simplified model, success(True) or failed(False)) 426 | """ 427 | if input_shapes is None: 428 | input_shapes = {} 429 | if input_data is None: 430 | input_data = {} 431 | 432 | if type(model) == str: 433 | # 加载ONNX模型 434 | model = onnx.load(model) 435 | assert(isinstance(model, onnx.ModelProto)) 436 | # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等 437 | onnx.checker.check_model(model) 438 | # 深拷贝一份原始ONNX模型 439 | model_ori = copy.deepcopy(model) 440 | 441 | 442 | input_names = get_input_names(model) 443 | for input_name, data in input_data.items(): 444 | if input_name not in input_names: 445 | raise RuntimeError( 446 | 'The model doesn\'t have input named "{}"'.format(input_name)) 447 | 448 | shape = list(input_data[input_name].shape) 449 | 450 | # special case for single constant variables (with shape []) 451 | if len(shape) == 0: 452 | shape = [input_data[input_name].size] 453 | if input_name in input_shapes and shape != input_shapes[input_name]: 454 | raise RuntimeError('The shape of input_data[{}] is not the same with input_shape[{}]'.format( 455 | input_name, input_name)) 456 | elif input_name not in input_shapes: 457 | input_shapes[input_name] = shape 458 | 459 | # 检查核对输入节点 460 | updated_input_shapes = check_and_update_input_shapes(model, input_shapes) 461 | 462 | 463 | def infer_shapes_and_optimize(model: onnx.ModelProto) -> onnx.ModelProto: 464 | # 做ONNX模型节点形状推断 465 | def infer_shapes_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto: 466 | if not skip_shape_inference: 467 | model = infer_shapes(model) 468 | return model 469 | # 对ONNX模型进行optimizer 470 | def optimize_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto: 471 | if perform_optimization: 472 | model = optimize(model, skip_fuse_bn, skipped_optimizers) 473 | return model 474 | # 递归执行infer_shapes_if_applicable和optimize_if_applicable直到模型稳定 475 | return fixed_point(model, infer_shapes_if_applicable, optimize_if_applicable) 476 | 477 | def constant_folding(model: onnx.ModelProto) -> onnx.ModelProto: 478 | # 获取模型的常量OP 479 | const_nodes = get_constant_nodes( 480 | model, dynamic_input_shape=dynamic_input_shape) 481 | # 获取所有的常量OP以及原始输出OP的特征值 482 | res = forward_for_node_outputs(model, 483 | const_nodes, 484 | input_shapes=updated_input_shapes, 485 | input_data=input_data, 486 | custom_lib=custom_lib) 487 | # 清洗那些没有被onnxruntime推理的静态节点 488 | const_nodes = clean_constant_nodes(const_nodes, res) 489 | # 移除常量OP,获得简化后的ONNX模型 490 | model = eliminate_const_nodes(model, const_nodes, res) 491 | # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等 492 | onnx.checker.check_model(model) 493 | return model 494 | 495 | # 递归执行infer_shapes_and_optimize和constant_folding直到模型稳定 496 | model = fixed_point(model, infer_shapes_and_optimize, constant_folding) 497 | 498 | # 重写模型的输入shape 499 | if not dynamic_input_shape: 500 | for name, input_shape in updated_input_shapes.items(): 501 | for ipt in model.graph.input: 502 | if ipt.name == name: 503 | for i, dim in enumerate(ipt.type.tensor_type.shape.dim): 504 | dim.dim_value = input_shape[i] 505 | # 检查核对输入节点 506 | check_ok = check(model_ori, model, check_n, 507 | input_shapes=updated_input_shapes) 508 | 509 | return model, check_ok -------------------------------------------------------------------------------- /tool.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import onnx 3 | import numpy as np 4 | from onnx import helper 5 | from onnx import numpy_helper 6 | 7 | class Tool(object): 8 | # 初始化onnx模型 9 | def __init__(self, onnx_model_path): 10 | self.model = onnx.load(onnx_model_path) 11 | self.model = onnx.shape_inference.infer_shapes(self.model) 12 | self.inputs = [] 13 | self.outputs = [] 14 | 15 | # 保存onnx模型 16 | def export(self, save_path): 17 | onnx.checker.check_model(self.model) 18 | self.model = onnx.shape_inference.infer_shapes(self.model) 19 | onnx.save(self.model, save_path) 20 | 21 | # 获取onnx模型的输入,返回一个列表 22 | def get_input_names(self): 23 | set_input = set() 24 | set_initializer = set() 25 | for ipt in self.model.graph.input: 26 | set_input.add(ipt.name) 27 | for x in model.graph.initializer: 28 | set_initializer.add(x.name) 29 | return list(set_input - set_initializer) 30 | 31 | # 为onnx模型增加batch维度 32 | def set_model_input_batch(self, index=0, name=None, batch_size=4): 33 | model_input = None 34 | if name is not None: 35 | for ipt in self.model.graph.input: 36 | if ipt.name == name: 37 | model_input = ipt 38 | else: 39 | model_input = self.model.graph.input[index] 40 | if model_input: 41 | tensor_dim = model_input.type.tensor_type.shape.dim 42 | tensor_dim[0].ClearField("dim_param") 43 | tensor_dim[0].dim_value = batch_size 44 | else: 45 | print('get model input failed, check index or name') 46 | 47 | # 为onnx模型的输入设置形状 48 | def set_model_input_shape(self, index=0, name=None, shape=None): 49 | model_input = None 50 | if name is not None: 51 | for ipt in self.model.graph.input: 52 | if ipt.name == name: 53 | model_input = ipt 54 | else: 55 | model_input = self.model.graph.input[index] 56 | if model_input: 57 | if shape is not None: 58 | tensor_shape_proto = model_input.type.tensor_type.shape 59 | tensor_shape_proto.ClearField("dim") 60 | tensor_shape_proto.dim.extend([]) 61 | for d in shape: 62 | dim = tensor_shape_proto.dim.add() 63 | dim.dim_value = d 64 | else: 65 | print('get input shape failed, check input') 66 | else: 67 | print('get model input failed, check index or name') 68 | 69 | # 通过名字获取onnx模型中的计算节点 70 | def get_node_by_name(self, name): 71 | for node in self.model.graph.node: 72 | if node.name == name: 73 | return node 74 | 75 | # 通过op的类型获取onnx模型的计算节点 76 | def get_nodes_by_optype(self, typename): 77 | nodes = [] 78 | for node in self.model.graph.node: 79 | if node.op_type == typename: 80 | nodes.append(node) 81 | return nodes 82 | 83 | # 通过名字获取onnx模型计算节点的权重 84 | def get_weight_by_name(self, name): 85 | for weight in self.model.graph.initializer: 86 | if weight.name == name: 87 | return weight 88 | 89 | # 设置权重,注意这个weight是TensorProto类型,`https://github.com/onnx/onnx/blob/b1e0bc9a31eaefc2a9946182fbad939843534984/onnx/onnx.proto#L461` 90 | def set_weight(self, weight, data_numpy=None, all_ones=False, all_zeros=False): 91 | if data_numpy is not None: 92 | raw_shape = tuple([i for i in weight.dims]) 93 | new_shape = np.shape(data_numpy) 94 | if weight.data_type == 8: 95 | print("Can NOT handle string data type right now...") 96 | exit() 97 | if new_shape != raw_shape: 98 | print("Warning: the new weight shape is not consistent with original shape!") 99 | weight.dims[:] = list(new_shape) 100 | for model_input in self.model.graph.input: 101 | if model_input.name == weight.name: 102 | # copy from onnx.helper... 103 | tensor_shape_proto = model_input.type.tensor_type.shape 104 | tensor_shape_proto.ClearField("dim") 105 | tensor_shape_proto.dim.extend([]) 106 | for d in new_shape: 107 | dim = tensor_shape_proto.dim.add() 108 | dim.dim_value = d 109 | 110 | weight.ClearField("float_data") 111 | weight.ClearField("int32_data") 112 | weight.ClearField("int64_data") 113 | weight.raw_data = data_numpy.tobytes() 114 | else: 115 | if all_ones: 116 | wr = numpy_helper.to_array(weight) 117 | wn = np.ones_like(wr) 118 | elif all_zeros: 119 | wr = numpy_helper.to_array(weight) 120 | wn = np.zeros_like(wr) 121 | else: 122 | print("You must give a data_numpy to set the weight, or set the all_ones/all_zeros flag.") 123 | exit() 124 | weight.ClearField("float_data") 125 | weight.ClearField("int32_data") 126 | weight.ClearField("int64_data") 127 | weight.raw_data = wn.tobytes() 128 | 129 | # 通过名字设置ONNX节点的权重 130 | def set_weight_by_name(self, name, data_numpy=None, all_ones=False, all_zeros=False): 131 | weight = self.get_weight_by_name(name) 132 | self.set_weight(weight, data_numpy, all_ones, all_zeros) 133 | 134 | # 移除ONNX模型中的目标节点 135 | def remove_node(self, target_node): 136 | ''' 137 | 删除只有一个输入和输出的节点 138 | ''' 139 | node_input = target_node.input[0] 140 | node_output = target_node.output[0] 141 | # 将后继节点的输入设置为目标节点的前置节点 142 | for node in self.model.graph.node: 143 | for i, n in enumerate(node.input): 144 | if n == node_output: 145 | node.input[i] = node_input 146 | 147 | target_names = set(target_node.input) & set([weight.name for weight in self.model.graph.initializer]) 148 | self.remove_weights(target_names) 149 | target_names.add(node_output) 150 | self.remove_inputs(target_names) 151 | self.remove_value_infos(target_names) 152 | self.model.graph.node.remove(target_node) 153 | 154 | # 移除ONNX模型中指定节点的权重 155 | def remove_weights(self, name_list): 156 | rm_list = [] 157 | for weight in self.model.graph.initializer: 158 | if weight.name in name_list: 159 | rm_list.append(weight) 160 | for weight in rm_list: 161 | self.model.graph.initializer.remove(weight) 162 | 163 | # 移除ONNX模型中指定的输入节点 164 | def remove_inputs(self, name_list): 165 | rm_list = [] 166 | for input_t in self.model.graph.input: 167 | if input_t.name in name_list: 168 | rm_list.append(input_t) 169 | for input_t in rm_list: 170 | self.model.graph.input.remove(input_t) 171 | 172 | # 移除ONNX模型中指定的输入输出节点 173 | def remove_value_infos(self, name_list): 174 | rm_list = [] 175 | for value_info in self.model.graph.value_info: 176 | if value_info.name in name_list: 177 | rm_list.append(value_info) 178 | for value_info in rm_list: 179 | self.model.graph.value_info.remove(value_info) 180 | 181 | # 给ONNX模型中的目标节点设置指定属性 182 | def set_node_attribute(self, target_node, attr_name, attr_value): 183 | flag = False 184 | for attr in target_node.attribute: 185 | if (attr.name == attr_name): 186 | if attr.type == 1: 187 | attr.f = attr_value 188 | elif attr.type == 2: 189 | attr.i = attr_value 190 | elif attr.type == 3: 191 | attr.s = attr_value 192 | elif attr.type == 4: 193 | attr.t = attr_value 194 | elif attr.type == 5: 195 | attr.g = attr_value 196 | # NOTE: For repeated composite types, we should use something like 197 | # del attr.xxx[:] 198 | # attr.xxx.extend([n1, n2, n3]) 199 | elif attr.type == 6: 200 | attr.floats[:] = attr_value 201 | elif attr.type == 7: 202 | attr.ints[:] = attr_value 203 | elif attr.type == 8: 204 | attr.strings[:] = attr_value 205 | else: 206 | print("unsupported attribute data type with attribute name") 207 | return False 208 | flag = True 209 | 210 | if not flag: 211 | # attribute not in original node 212 | print("Warning: you are appending a new attribute to the node!") 213 | target_node.attribute.append(helper.make_attribute(attr_name, attr_value)) 214 | flag = True 215 | return flag 216 | 217 | def chunk_at(self, target_node): 218 | r_nodes = [target_node] 219 | r_input_names = [input_n for input_n in target_node.input] 220 | r_count = len(r_nodes) + len(r_input_names) 221 | 222 | while True: 223 | for node in self.model.graph.node: 224 | # print("nn", node.output) 225 | if node in r_nodes: 226 | continue 227 | for o in node.output: 228 | if o in r_input_names: 229 | r_nodes.append(node) 230 | r_input_names.extend([input_n for input_n in node.input]) 231 | continue 232 | n_count = len(r_nodes) + len(r_input_names) 233 | if n_count == r_count: 234 | break 235 | r_count = n_count 236 | 237 | print("debug r count", r_count) 238 | 239 | d_nodes = [] 240 | d_inputs = [] 241 | d_weights = [] 242 | d_value_infos = [] 243 | for node in self.model.graph.node: 244 | if node not in r_nodes: 245 | d_nodes.append(node) 246 | for model_input in self.model.graph.input: 247 | if model_input.name not in r_input_names: 248 | d_inputs.append(model_input) 249 | for weight in self.model.graph.initializer: 250 | if weight.name not in r_input_names: 251 | d_weights.append(weight) 252 | for value_info in self.model.graph.value_info: 253 | if value_info.name not in r_input_names: 254 | d_values.append(value_info) 255 | for node in d_nodes: 256 | self.model.graph.node.remove(node) 257 | for model_input in d_inputs: 258 | self.model.graph.input.remove(model_input) 259 | for weight in d_weights: 260 | self.model.graph.initializer.remove(weight) 261 | for value_info in d_value_infos: 262 | self.model.graph.value_info.remove(value_info) 263 | 264 | target_node.output[0] = self.model.graph.output[0].name 265 | # remove other outputs if model has multi-output 266 | d_outputs = [] 267 | for i, output in enumerate(self.model.graph.output): 268 | if i != 0 : 269 | d_outputs.append(output) 270 | for output in d_outputs: 271 | self.model.graph.output.remove(output) 272 | 273 | # 在指定节点前插入flatten node 274 | def insert_flatten_before(self, target_node): 275 | # get target_node inputs 276 | node_input = target_node.input[0] 277 | # create new node 278 | node_name = "flatten_test" 279 | flatten_node = helper.make_node('Flatten', inputs=[node_input], outputs=[node_name], name=node_name) 280 | # set target_node inputs to new node outputs 281 | target_node.input[0] = node_name 282 | for target_node_index, _target_node in enumerate(self.model.graph.node): 283 | if _target_node == target_node: 284 | self.model.graph.node.insert(target_node_index, flatten_node) 285 | break 286 | 287 | # 在指定节点target_node前插入一个新的OP 288 | def insert_op_before(self, node_name, target_node, input_idx=0, *args, **kwargs): 289 | ''' 290 | op_name 291 | weight_dict 292 | attr_dict 293 | ...... 294 | NOTE: 295 | you must ensure the output shape match the input shape of target_node 296 | ''' 297 | # get target_node inputs 298 | node_input = target_node.input[input_idx] 299 | weight_input = [] 300 | weight_input_vi = [] 301 | weight_initializer = [] 302 | if "weight_dict" in kwargs: 303 | for weight_name, weight_numpy in kwargs["weight_dict"].items(): 304 | weight_input.append(weight_name) 305 | weight_input_vi.append( 306 | helper.make_tensor_value_info( 307 | name=weight_name, 308 | elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_numpy.dtype], 309 | shape=weight_numpy.shape 310 | ) 311 | ) 312 | weight_initializer.append( 313 | helper.make_tensor( 314 | name=weight_name, 315 | data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_numpy.dtype], 316 | dims=weight_numpy.shape, 317 | vals=weight_numpy.tobytes(), 318 | raw=True 319 | ) 320 | ) 321 | # create new node 322 | new_op_node = helper.make_node( 323 | kwargs["op_name"], 324 | inputs=[node_input, *weight_input], 325 | outputs=[node_name], 326 | name=node_name, 327 | **kwargs["attr_dict"] 328 | ) 329 | # set target_node input to new node outputs 330 | target_node.input[input_idx] = node_name 331 | # TODO: change other nodes input into the new node? 332 | # iterator all the nodes in the graph and find 333 | # which node's input equals the original target_node input 334 | # ... 335 | # add new node and weight input into the graph 336 | for target_node_index, _target_node in enumerate(self.model.graph.node): 337 | if _target_node == target_node: 338 | self.model.graph.node.insert(target_node_index, new_op_node) 339 | break 340 | self.model.graph.input.extend(weight_input_vi) 341 | self.model.graph.initializer.extend(weight_initializer) 342 | 343 | # 将target_node添加到ONNX模型中作为输出节点 344 | def add_extra_output(self, target_node, output_name): 345 | target_output = target_node.output[0] 346 | extra_shape = [] 347 | for vi in self.model.graph.value_info: 348 | if vi.name == target_output: 349 | extra_elem_type = vi.type.tensor_type.elem_type 350 | for s in vi.type.tensor_type.shape.dim: 351 | extra_shape.append(s.dim_value) 352 | extra_output = helper.make_tensor_value_info( 353 | output_name, 354 | extra_elem_type, 355 | extra_shape 356 | ) 357 | identity_node = helper.make_node('Identity', inputs=[target_output], outputs=[output_name], name=output_name) 358 | self.model.graph.node.append(identity_node) 359 | self.model.graph.output.append(extra_output) --------------------------------------------------------------------------------