├── .gitignore ├── LICENSE ├── MyCaffe.py ├── README.md ├── convertCaffe.py ├── model └── MobileNetV2.onnx ├── model_generator ├── MobileNetV2.py ├── __init__.py ├── alexnet.py ├── broadcast_add.py ├── broadcast_mul.py ├── googlenet.py ├── resnet.py └── resnet50.py ├── onnx2caffe ├── __init__.py ├── _error_utils.py ├── _graph.py ├── _operators.py ├── _transformers.py └── _weightloader.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 MTlab, Meitu Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MyCaffe.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, Counter 2 | 3 | from caffe.proto import caffe_pb2 4 | from google import protobuf 5 | import six 6 | 7 | def param_name_dict(): 8 | """Find out the correspondence between layer names and parameter names.""" 9 | 10 | layer = caffe_pb2.LayerParameter() 11 | # get all parameter names (typically underscore case) and corresponding 12 | # type names (typically camel case), which contain the layer names 13 | # (note that not all parameters correspond to layers, but we'll ignore that) 14 | param_names = [f.name for f in layer.DESCRIPTOR.fields if f.name.endswith('_param')] 15 | param_type_names = [type(getattr(layer, s)).__name__ for s in param_names] 16 | # strip the final '_param' or 'Parameter' 17 | param_names = [s[:-len('_param')] for s in param_names] 18 | param_type_names = [s[:-len('Parameter')] for s in param_type_names] 19 | return dict(zip(param_type_names, param_names)) 20 | 21 | def assign_proto(proto, name, val): 22 | """Assign a Python object to a protobuf message, based on the Python 23 | type (in recursive fashion). Lists become repeated fields/messages, dicts 24 | become messages, and other types are assigned directly. For convenience, 25 | repeated fields whose values are not lists are converted to single-element 26 | lists; e.g., `my_repeated_int_field=3` is converted to 27 | `my_repeated_int_field=[3]`.""" 28 | 29 | is_repeated_field = hasattr(getattr(proto, name), 'extend') 30 | if is_repeated_field and not isinstance(val, list): 31 | val = [val] 32 | if isinstance(val, list): 33 | if isinstance(val[0], dict): 34 | for item in val: 35 | proto_item = getattr(proto, name).add() 36 | for k, v in six.iteritems(item): 37 | assign_proto(proto_item, k, v) 38 | else: 39 | getattr(proto, name).extend(val) 40 | elif isinstance(val, dict): 41 | for k, v in six.iteritems(val): 42 | assign_proto(getattr(proto, name), k, v) 43 | else: 44 | setattr(proto, name, val) 45 | 46 | class Function(object): 47 | """A Function specifies a layer, its parameters, and its inputs (which 48 | are Tops from other layers).""" 49 | 50 | def __init__(self, type_name, layer_name, inputs,outputs, **params): 51 | self.type_name = type_name 52 | self.inputs = inputs 53 | self.outputs = outputs 54 | self.params = params 55 | self.layer_name = layer_name 56 | self.ntop = self.params.get('ntop', 1) 57 | # use del to make sure kwargs are not double-processed as layer params 58 | if 'ntop' in self.params: 59 | del self.params['ntop'] 60 | self.in_place = self.params.get('in_place', False) 61 | if 'in_place' in self.params: 62 | del self.params['in_place'] 63 | # self.tops = tuple(Top(self, n) for n in range(self.ntop))l 64 | 65 | def _get_name(self, names, autonames): 66 | if self not in names and self.ntop > 0: 67 | names[self] = self._get_top_name(self.tops[0], names, autonames) 68 | elif self not in names: 69 | autonames[self.type_name] += 1 70 | names[self] = self.type_name + str(autonames[self.type_name]) 71 | return names[self] 72 | 73 | def _get_top_name(self, top, names, autonames): 74 | if top not in names: 75 | autonames[top.fn.type_name] += 1 76 | names[top] = top.fn.type_name + str(autonames[top.fn.type_name]) 77 | return names[top] 78 | 79 | def _to_proto(self): 80 | bottom_names = [] 81 | for inp in self.inputs: 82 | # inp._to_proto(layers, names, autonames) 83 | bottom_names.append(inp) 84 | layer = caffe_pb2.LayerParameter() 85 | layer.type = self.type_name 86 | layer.bottom.extend(bottom_names) 87 | 88 | if self.in_place: 89 | layer.top.extend(layer.bottom) 90 | else: 91 | for top in self.outputs: 92 | layer.top.append(top) 93 | layer.name = self.layer_name 94 | # print(self.type_name + "...") 95 | for k, v in six.iteritems(self.params): 96 | # special case to handle generic *params 97 | # print("generating "+k+"...") 98 | 99 | if k.endswith('param'): 100 | assign_proto(layer, k, v) 101 | else: 102 | try: 103 | assign_proto(getattr(layer, 104 | _param_names[self.type_name] + '_param'), k, v) 105 | except (AttributeError, KeyError): 106 | assign_proto(layer, k, v) 107 | 108 | return layer 109 | 110 | class Layers(object): 111 | """A Layers object is a pseudo-module which generates functions that specify 112 | layers; e.g., Layers().Convolution(bottom, kernel_size=3) will produce a Top 113 | specifying a 3x3 convolution applied to bottom.""" 114 | 115 | def __getattr__(self, name): 116 | def layer_fn(*args, **kwargs): 117 | fn = Function(name, args, kwargs) 118 | return fn 119 | return layer_fn 120 | 121 | 122 | 123 | 124 | _param_names = param_name_dict() 125 | 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert pytorch to Caffe by ONNX 2 | This tool converts [pytorch](https://github.com/pytorch/pytorch) model to [Caffe](https://github.com/BVLC/caffe) model by [ONNX](https://github.com/onnx/onnx) 3 | only use for inference 4 | 5 | ### Dependencies 6 | * caffe (with python support) 7 | * pytorch 0.4 (optional if you only want to convert onnx) 8 | * onnx 9 | 10 | we recomand using protobuf 2.6.1 and install onnx from source 11 | ``` 12 | git clone --recursive https://github.com/onnx/onnx.git 13 | cd onnx 14 | python setup.py install 15 | ``` 16 | 17 | ### How to use 18 | run test.py to make sure it has been installed correctly 19 | To convert onnx model to caffe: 20 | ``` 21 | python convertCaffe.py ./model/MobileNetV2.onnx ./model/MobileNetV2.prototxt ./model/MobileNetV2.caffemodel 22 | ``` 23 | ### Current support operation 24 | * Conv 25 | * ConvTranspose 26 | * BatchNormalization 27 | * MaxPool 28 | * AveragePool 29 | * Relu 30 | * Sigmoid 31 | * Dropout 32 | * Gemm (InnerProduct only) 33 | * Add 34 | * Mul 35 | * Reshape 36 | * Upsample 37 | * Concat 38 | * Flatten 39 | 40 | ### TODO List 41 | - [ ] support all onnx operations (which is impossible) 42 | - [ ] merge batchnormization to convolution 43 | - [ ] merge scale to convolution 44 | -------------------------------------------------------------------------------- /convertCaffe.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import caffe 4 | import onnx 5 | import numpy as np 6 | from caffe.proto import caffe_pb2 7 | caffe.set_mode_cpu() 8 | from onnx2caffe._transformers import ConvAddFuser,ConstantsToInitializers 9 | from onnx2caffe._graph import Graph 10 | 11 | import onnx2caffe._operators as cvt 12 | import onnx2caffe._weightloader as wlr 13 | from onnx2caffe._error_utils import ErrorHandling 14 | from collections import OrderedDict 15 | from onnx import shape_inference 16 | import importlib 17 | 18 | transformers = [ 19 | ConstantsToInitializers(), 20 | ConvAddFuser(), 21 | ] 22 | 23 | def convertToCaffe(graph, prototxt_save_path, caffe_model_save_path): 24 | 25 | exist_edges = [] 26 | layers = [] 27 | exist_nodes = [] 28 | err = ErrorHandling() 29 | for i in graph.inputs: 30 | edge_name = i[0] 31 | input_layer = cvt.make_input(i) 32 | layers.append(input_layer) 33 | exist_edges.append(i[0]) 34 | graph.channel_dims[edge_name] = graph.shape_dict[edge_name][1] 35 | 36 | 37 | for id, node in enumerate(graph.nodes): 38 | node_name = node.name 39 | op_type = node.op_type 40 | inputs = node.inputs 41 | inputs_tensor = node.input_tensors 42 | input_non_exist_flag = False 43 | 44 | for inp in inputs: 45 | if inp not in exist_edges and inp not in inputs_tensor: 46 | input_non_exist_flag = True 47 | break 48 | if input_non_exist_flag: 49 | continue 50 | 51 | if op_type not in cvt._ONNX_NODE_REGISTRY: 52 | err.unsupported_op(node) 53 | continue 54 | converter_fn = cvt._ONNX_NODE_REGISTRY[op_type] 55 | layer = converter_fn(node,graph,err) 56 | if type(layer)==tuple: 57 | for l in layer: 58 | layers.append(l) 59 | else: 60 | layers.append(layer) 61 | outs = node.outputs 62 | for out in outs: 63 | exist_edges.append(out) 64 | 65 | net = caffe_pb2.NetParameter() 66 | for id,layer in enumerate(layers): 67 | layers[id] = layer._to_proto() 68 | net.layer.extend(layers) 69 | 70 | with open(prototxt_save_path, 'w') as f: 71 | print(net,file=f) 72 | 73 | caffe.set_mode_cpu() 74 | deploy = prototxt_save_path 75 | net = caffe.Net(deploy, 76 | caffe.TEST) 77 | 78 | for id, node in enumerate(graph.nodes): 79 | node_name = node.name 80 | op_type = node.op_type 81 | inputs = node.inputs 82 | inputs_tensor = node.input_tensors 83 | input_non_exist_flag = False 84 | if op_type not in wlr._ONNX_NODE_REGISTRY: 85 | err.unsupported_op(node) 86 | continue 87 | converter_fn = wlr._ONNX_NODE_REGISTRY[op_type] 88 | converter_fn(net, node, graph, err) 89 | 90 | net.save(caffe_model_save_path) 91 | return net 92 | 93 | def getGraph(onnx_path): 94 | model = onnx.load(onnx_path) 95 | model = shape_inference.infer_shapes(model) 96 | model_graph = model.graph 97 | graph = Graph.from_onnx(model_graph) 98 | graph = graph.transformed(transformers) 99 | graph.channel_dims = {} 100 | 101 | return graph 102 | 103 | if __name__ == "__main__": 104 | onnx_path = sys.argv[1] 105 | prototxt_path = sys.argv[2] 106 | caffemodel_path = sys.argv[3] 107 | graph = getGraph(onnx_path) 108 | convertToCaffe(graph, prototxt_path, caffemodel_path) 109 | 110 | -------------------------------------------------------------------------------- /model/MobileNetV2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MTLab/onnx2caffe/46ae6b8b7838361e80cb441a4ca3d082be21bf44/model/MobileNetV2.onnx -------------------------------------------------------------------------------- /model_generator/MobileNetV2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from torch.autograd import Variable 4 | import torch 5 | import torch.onnx as onnx 6 | import os 7 | 8 | def conv_bn(inp, oup, stride): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=True), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU(inplace=True) 13 | ) 14 | 15 | 16 | def conv_1x1_bn(inp, oup): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | 24 | class InvertedResidual(nn.Module): 25 | def __init__(self, inp, oup, stride, expand_ratio): 26 | super(InvertedResidual, self).__init__() 27 | self.stride = stride 28 | assert stride in [1, 2] 29 | 30 | self.use_res_connect = self.stride == 1 and inp == oup 31 | 32 | self.conv = nn.Sequential( 33 | # pw 34 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 35 | nn.BatchNorm2d(inp * expand_ratio), 36 | nn.ReLU(inplace=True), 37 | # dw 38 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=True), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | ) 45 | 46 | def forward(self, x): 47 | if self.use_res_connect: 48 | return x + self.conv(x) 49 | else: 50 | return self.conv(x) 51 | 52 | 53 | class MobileNetV2(nn.Module): 54 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 55 | super(MobileNetV2, self).__init__() 56 | # setting of inverted residual blocks 57 | self.interverted_residual_setting = [ 58 | # t, c, n, s 59 | [1, 16, 1, 1], 60 | [6, 24, 2, 2], 61 | [6, 32, 3, 2], 62 | [6, 64, 4, 2], 63 | [6, 96, 3, 1], 64 | [6, 160, 3, 2], 65 | [6, 320, 1, 1], 66 | ] 67 | 68 | # building first layer 69 | assert input_size % 32 == 0 70 | input_channel = int(32 * width_mult) 71 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 72 | self.features = [conv_bn(3, input_channel, 2)] 73 | # building inverted residual blocks 74 | for t, c, n, s in self.interverted_residual_setting: 75 | output_channel = int(c * width_mult) 76 | for i in range(n): 77 | if i == 0: 78 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 79 | else: 80 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 81 | input_channel = output_channel 82 | # building last several layers 83 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 84 | self.features.append(nn.AvgPool2d(int(input_size/32))) 85 | # make it nn.Sequential 86 | self.features = nn.Sequential(*self.features) 87 | 88 | # building classifier 89 | self.classifier = nn.Sequential( 90 | nn.Dropout(), 91 | nn.Linear(self.last_channel, n_class), 92 | ) 93 | 94 | self._initialize_weights() 95 | 96 | def forward(self, x): 97 | x = self.features(x) 98 | x = x.view(-1, self.last_channel) 99 | x = self.classifier(x) 100 | return x 101 | 102 | def _initialize_weights(self): 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | if m.bias is not None: 108 | m.bias.data.zero_() 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | elif isinstance(m, nn.Linear): 113 | n = m.weight.size(1) 114 | m.weight.data.normal_(0, 0.01) 115 | m.bias.data.zero_() 116 | 117 | def export(dir): 118 | dummy_input = Variable(torch.randn(1, 3, 224, 224)) 119 | model = MobileNetV2() 120 | model.eval() 121 | torch.save(model.state_dict(),os.path.join(dir,"MobileNetV2.pth")) 122 | onnx.export(model, dummy_input,os.path.join(dir,"MobileNetV2.onnx"), verbose=True) 123 | 124 | 125 | 126 | def get_model_and_input(model_save_dir): 127 | model = MobileNetV2() 128 | model.cpu() 129 | model_path = os.path.join(model_save_dir,'MobileNetV2.pth') 130 | model.load_state_dict(torch.load(model_path)) 131 | model.cpu() 132 | model.eval() 133 | batch_size = 1 134 | channels = 3 135 | height = 224 136 | width = 224 137 | images = Variable(torch.ones(batch_size, channels, height, width)) 138 | return images,model -------------------------------------------------------------------------------- /model_generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MTLab/onnx2caffe/46ae6b8b7838361e80cb441a4ca3d082be21bf44/model_generator/__init__.py -------------------------------------------------------------------------------- /model_generator/alexnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.autograd import Variable 7 | import torch.onnx as onnx 8 | 9 | 10 | __all__ = ['AlexNet', 'alexnet'] 11 | 12 | 13 | model_urls = { 14 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 15 | } 16 | 17 | 18 | class AlexNet(nn.Module): 19 | 20 | def __init__(self, num_classes=1000): 21 | super(AlexNet, self).__init__() 22 | self.featuresxxx = nn.Sequential( 23 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=3, stride=2), 26 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=3, stride=2), 29 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(kernel_size=3, stride=2), 36 | ) 37 | 38 | 39 | self.featuresa = nn.Sequential( 40 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2, bias=False), 41 | nn.ReLU(inplace=True), 42 | nn.MaxPool2d(kernel_size=3, stride=2), 43 | nn.Conv2d(64, 192, kernel_size=5, padding=2, bias=False), 44 | nn.ReLU(inplace=True), 45 | nn.MaxPool2d(kernel_size=3, stride=2), 46 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(384, 256, kernel_size=3, padding=1, bias=False), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 51 | nn.ReLU(inplace=True), 52 | nn.MaxPool2d(kernel_size=3, stride=2), 53 | ) 54 | 55 | self.features = nn.Sequential( 56 | nn.Conv2d(3, 64, kernel_size=2, stride=2, padding=4, bias=True), 57 | # nn.Conv2d(3, 2, kernel_size=11, stride=4, padding=2), 58 | # nn.ReLU(inplace=True), 59 | #nn.MaxPool2d(kernel_size=3, stride=2), 60 | ) 61 | self.classifier = nn.Sequential( 62 | nn.Dropout(), 63 | nn.Linear(256 * 6 * 6, 4096), 64 | nn.ReLU(inplace=True), 65 | nn.Dropout(), 66 | nn.Linear(4096, 4096), 67 | nn.ReLU(inplace=True), 68 | nn.Linear(4096, num_classes), 69 | ) 70 | 71 | def forward(self, x): 72 | x = self.features(x) 73 | return x 74 | x = x.view(x.size(0), 256 * 6 * 6) 75 | x = self.classifier(x) 76 | return x 77 | #return nn.functional.log_softmax(x) 78 | 79 | 80 | def alexnet(pretrained=False, **kwargs): 81 | r"""AlexNet model architecture from the 82 | `"One weird trick..." `_ paper. 83 | 84 | Args: 85 | pretrained (bool): If True, returns a model pre-trained on ImageNet 86 | """ 87 | model = AlexNet(**kwargs) 88 | if pretrained: 89 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 90 | return model 91 | 92 | 93 | def export(dir): 94 | file_path = os.path.realpath(__file__) 95 | file_dir = os.path.dirname(file_path) 96 | dummy_input = Variable(torch.randn(1, 3, 224, 224)) 97 | model = AlexNet() 98 | # model = load_network(model,os.path.join(file_dir,'..','model','pose_v02.pth')) 99 | model.eval() 100 | torch.save(model.state_dict(),os.path.join(dir,"alexnet.pth")) 101 | onnx.export(model, dummy_input,os.path.join(dir,"alexnet.onnx"), verbose=True) 102 | 103 | def get_model_and_input(model_save_dir): 104 | model = AlexNet() 105 | model.cpu() 106 | model_path = os.path.join(model_save_dir,'alexnet.pth') 107 | model.load_state_dict(torch.load(model_path)) 108 | model.cpu() 109 | model.eval() 110 | batch_size = 1 111 | channels = 3 112 | height = 224 113 | width = 224 114 | images = Variable(torch.ones(batch_size, channels, height, width)) 115 | return images,model 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /model_generator/broadcast_add.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.onnx as onnx 5 | import os 6 | import numpy as np 7 | def conv_bn(inp, oup, stride): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=True), 10 | nn.BatchNorm2d(oup), 11 | nn.ReLU(inplace=True) 12 | ) 13 | class broadcast_add(nn.Module): 14 | def __init__(self): 15 | super(broadcast_add, self).__init__() 16 | self.conv1 = conv_bn(3,128,1) 17 | self.poo1 = nn.AvgPool2d(kernel_size=4) 18 | 19 | def forward(self, x): 20 | x1 = self.conv1(x) 21 | x2 = self.poo1(x1) 22 | out = x1+x2 23 | return out 24 | 25 | def export(dir): 26 | dummy_input = Variable(torch.randn(1, 3, 4, 4)) 27 | model = broadcast_add() 28 | model.eval() 29 | torch.save(model.state_dict(),os.path.join(dir,"broadcast_add.pth")) 30 | onnx.export(model, dummy_input,os.path.join(dir,"broadcast_add.onnx"), verbose=True) 31 | 32 | def get_model_and_input(model_save_dir): 33 | model = broadcast_add() 34 | model.cpu() 35 | model_path = os.path.join(model_save_dir,'broadcast_add.pth') 36 | model.load_state_dict(torch.load(model_path)) 37 | model.cpu() 38 | model.eval() 39 | batch_size = 1 40 | channels = 3 41 | height = 4 42 | width = 4 43 | images = Variable(torch.ones(batch_size, channels, height, width)) 44 | return images,model -------------------------------------------------------------------------------- /model_generator/broadcast_mul.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.onnx as onnx 5 | import os 6 | import numpy as np 7 | def conv_bn(inp, oup, stride): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=True), 10 | nn.BatchNorm2d(oup), 11 | nn.ReLU(inplace=True) 12 | ) 13 | class broadcast_mul(nn.Module): 14 | def __init__(self): 15 | super(broadcast_mul, self).__init__() 16 | self.conv1 = conv_bn(3,128,1) 17 | self.poo1 = nn.AvgPool2d(kernel_size=4) 18 | 19 | def forward(self, x): 20 | x1 = self.conv1(x) 21 | x2 = self.poo1(x1) 22 | # x2 = x2.view(x2.size(0), x2.size(1)) 23 | out = x1*x2 24 | return out 25 | 26 | def export(dir): 27 | dummy_input = Variable(torch.randn(1, 3, 4, 4)) 28 | model = broadcast_mul() 29 | model.eval() 30 | torch.save(model.state_dict(),os.path.join(dir,"broadcast_mul.pth")) 31 | onnx.export(model, dummy_input,os.path.join(dir,"broadcast_mul.onnx"), verbose=True) 32 | 33 | 34 | def get_model_and_input(model_save_dir): 35 | model = broadcast_mul() 36 | model.cpu() 37 | model_path = os.path.join(model_save_dir,'broadcast_mul.pth') 38 | model.load_state_dict(torch.load(model_path)) 39 | model.cpu() 40 | model.eval() 41 | batch_size = 1 42 | channels = 3 43 | height = 4 44 | width = 4 45 | images = Variable(torch.ones(batch_size, channels, height, width)) 46 | return images,model 47 | -------------------------------------------------------------------------------- /model_generator/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import os 9 | import torch.onnx as onnx 10 | 11 | class Inception(nn.Module): 12 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 13 | super(Inception, self).__init__() 14 | # 1x1 conv branch 15 | self.b1 = nn.Sequential( 16 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 17 | nn.BatchNorm2d(n1x1), 18 | nn.ReLU(True), 19 | ) 20 | 21 | # 1x1 conv -> 3x3 conv branch 22 | self.b2 = nn.Sequential( 23 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 24 | nn.BatchNorm2d(n3x3red), 25 | nn.ReLU(True), 26 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 27 | nn.BatchNorm2d(n3x3), 28 | nn.ReLU(True), 29 | ) 30 | 31 | # 1x1 conv -> 5x5 conv branch 32 | self.b3 = nn.Sequential( 33 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 34 | nn.BatchNorm2d(n5x5red), 35 | nn.ReLU(True), 36 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(n5x5), 38 | nn.ReLU(True), 39 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 40 | nn.BatchNorm2d(n5x5), 41 | nn.ReLU(True), 42 | ) 43 | 44 | # 3x3 pool -> 1x1 conv branch 45 | self.b4 = nn.Sequential( 46 | nn.MaxPool2d(3, stride=1, padding=1), 47 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 48 | nn.BatchNorm2d(pool_planes), 49 | nn.ReLU(True), 50 | ) 51 | 52 | def forward(self, x): 53 | y1 = self.b1(x) 54 | y2 = self.b2(x) 55 | y3 = self.b3(x) 56 | y4 = self.b4(x) 57 | return torch.cat([y1,y2,y3,y4], 1) 58 | 59 | 60 | class GoogLeNet(nn.Module): 61 | def __init__(self): 62 | super(GoogLeNet, self).__init__() 63 | self.pre_layers = nn.Sequential( 64 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 65 | nn.BatchNorm2d(192), 66 | nn.ReLU(True), 67 | ) 68 | 69 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 70 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 71 | 72 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) 73 | 74 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 75 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 76 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 77 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 78 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 79 | 80 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 81 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 82 | 83 | self.avgpool = nn.AvgPool2d(8, stride=1) 84 | self.linear = nn.Linear(1024, 10) 85 | 86 | def forward(self, x): 87 | out = self.pre_layers(x) 88 | out = self.a3(out) 89 | out = self.b3(out) 90 | out = self.maxpool(out) 91 | out = self.a4(out) 92 | out = self.b4(out) 93 | out = self.c4(out) 94 | out = self.d4(out) 95 | out = self.e4(out) 96 | out = self.maxpool(out) 97 | out = self.a5(out) 98 | out = self.b5(out) 99 | out = self.avgpool(out) 100 | out = out.view(out.size(0), -1) 101 | out = self.linear(out) 102 | return out 103 | 104 | def export(dir): 105 | file_path = os.path.realpath(__file__) 106 | file_dir = os.path.dirname(file_path) 107 | dummy_input = Variable(torch.randn(1, 3, 32, 32)) 108 | model = GoogLeNet() 109 | # model = load_network(model,os.path.join(file_dir,'..','model','pose_v02.pth')) 110 | model.eval() 111 | torch.save(model.state_dict(),os.path.join(dir,"googlenet.pth")) 112 | onnx.export(model, dummy_input,os.path.join(dir,"googlenet.onnx"), verbose=True) 113 | 114 | def get_model_and_input(model_save_dir): 115 | model = GoogLeNet() 116 | model.cpu() 117 | model_path = os.path.join(model_save_dir,'googlenet.pth') 118 | model.load_state_dict(torch.load(model_path)) 119 | model.cpu() 120 | model.eval() 121 | batch_size = 1 122 | channels = 3 123 | height = 32 124 | width = 32 125 | images = Variable(torch.ones(batch_size, channels, height, width)) 126 | return images,model 127 | -------------------------------------------------------------------------------- /model_generator/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet18/34/50/101/152 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import os 9 | import torch.onnx as onnx 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(in_planes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion*planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion*planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=10): 70 | super(ResNet, self).__init__() 71 | self.in_planes = 64 72 | 73 | self.conv1 = conv3x3(3,64) 74 | self.bn1 = nn.BatchNorm2d(64) 75 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 78 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 79 | self.linear = nn.Linear(512*block.expansion, num_classes) 80 | 81 | def _make_layer(self, block, planes, num_blocks, stride): 82 | strides = [stride] + [1]*(num_blocks-1) 83 | layers = [] 84 | for stride in strides: 85 | layers.append(block(self.in_planes, planes, stride)) 86 | self.in_planes = planes * block.expansion 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = self.layer4(out) 95 | out = F.avg_pool2d(out, 4) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def ResNet18(): 102 | return ResNet(BasicBlock, [2,2,2,2]) 103 | 104 | def ResNet34(): 105 | return ResNet(BasicBlock, [3,4,6,3]) 106 | 107 | def ResNet50(): 108 | return ResNet(Bottleneck, [3,4,6,3]) 109 | 110 | def ResNet101(): 111 | return ResNet(Bottleneck, [3,4,23,3]) 112 | 113 | def ResNet152(): 114 | return ResNet(Bottleneck, [3,8,36,3]) 115 | 116 | def test_resnet(): 117 | net = ResNet34() 118 | y = net(Variable(torch.randn(1,3,32,32))) 119 | print("======================") 120 | print(y.size()) 121 | 122 | 123 | def export(dir): 124 | file_path = os.path.realpath(__file__) 125 | file_dir = os.path.dirname(file_path) 126 | dummy_input = Variable(torch.randn(1, 3, 32, 32)) 127 | model = ResNet34() 128 | # model = load_network(model,os.path.join(file_dir,'..','model','pose_v02.pth')) 129 | model.eval() 130 | torch.save(model.state_dict(),os.path.join(dir,"resnet.pth")) 131 | onnx.export(model, dummy_input,os.path.join(dir,"resnet.onnx"), verbose=True) 132 | 133 | def get_model_and_input(model_save_dir): 134 | model = ResNet34() 135 | model.cpu() 136 | model_path = os.path.join(model_save_dir,'resnet.pth') 137 | model.load_state_dict(torch.load(model_path)) 138 | model.cpu() 139 | model.eval() 140 | batch_size = 1 141 | channels = 3 142 | height = 32 143 | width = 32 144 | images = Variable(torch.ones(batch_size, channels, height, width)) 145 | return images,model -------------------------------------------------------------------------------- /model_generator/resnet50.py: -------------------------------------------------------------------------------- 1 | # import torchvision.models as models 2 | from torch import onnx 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import math 7 | 8 | 9 | # model = models.resnet50(False) 10 | # dummy_input = torch.randn(1, 3, 224, 224) 11 | # onnx.export(model, dummy_input, "resnet50.onnx", verbose=True) 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | 92 | def __init__(self, block, layers, num_classes=1000): 93 | self.inplanes = 64 94 | super(ResNet, self).__init__() 95 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 96 | bias=False) 97 | self.bn1 = nn.BatchNorm2d(64) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 100 | self.layer1 = self._make_layer(block, 64, layers[0]) 101 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 102 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 103 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 104 | self.avgpool = nn.AvgPool2d(7, stride=1) 105 | self.fc = nn.Linear(512 * block.expansion, num_classes) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | 147 | return x 148 | 149 | 150 | def resnet50(): 151 | """Constructs a ResNet-50 model.""" 152 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 153 | return model 154 | 155 | 156 | def export(dir): 157 | file_path = os.path.realpath(__file__) 158 | file_dir = os.path.dirname(file_path) 159 | dummy_input = torch.randn(1, 3, 224, 224) 160 | model = resnet50() 161 | # model = load_network(model,os.path.join(file_dir,'..','model','pose_v02.pth')) 162 | model.eval() 163 | torch.save(model.state_dict(),os.path.join(dir,"resnet50.pth")) 164 | onnx.export(model, dummy_input,os.path.join(dir,"resnet50.onnx"), verbose=True) 165 | 166 | 167 | def get_model_and_input(model_save_dir=None): 168 | model = resnet50() 169 | model.cpu() 170 | if model_save_dir is not None: 171 | model_path = os.path.join(model_save_dir, 'resnet50.pth') 172 | model.load_state_dict(torch.load(model_path)) 173 | model.cpu() 174 | model.eval() 175 | batch_size = 1 176 | channels = 3 177 | height = 224 178 | width = 224 179 | images = torch.ones(batch_size, channels, height, width) 180 | return images, model -------------------------------------------------------------------------------- /onnx2caffe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MTLab/onnx2caffe/46ae6b8b7838361e80cb441a4ca3d082be21bf44/onnx2caffe/__init__.py -------------------------------------------------------------------------------- /onnx2caffe/_error_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from typing import Dict, Text, Any, Callable 6 | from ._graph import Node, Graph 7 | 8 | class ErrorHandling(object): 9 | ''' 10 | To handle errors and addition of custom layers 11 | ''' 12 | 13 | def __init__(self, 14 | add_custom_layers = False, # type: bool 15 | custom_conversion_functions = dict(), # type: Dict[Text, Any] 16 | custom_layer_nodes = [], # type : List[Node] 17 | ): 18 | # type: (...) -> None 19 | self.add_custom_layers = add_custom_layers 20 | self.custom_conversion_functions = custom_conversion_functions 21 | self.custom_layer_nodes = custom_layer_nodes 22 | 23 | 24 | def unsupported_op(self, 25 | node, # type: Node 26 | ): 27 | # type: (...) -> Callable[[Any, Node, Graph, ErrorHandling], None] 28 | ''' 29 | Either raise an error for an unsupported op type or return custom layer add function 30 | ''' 31 | if self.add_custom_layers: 32 | from ._operators import _convert_custom 33 | return _convert_custom 34 | else: 35 | raise TypeError( 36 | "ONNX node of type {} is not supported.\n".format(node.op_type,) 37 | ) 38 | 39 | 40 | def unsupported_op_configuration(self, 41 | node, # type: Node 42 | err_message, # type: Text 43 | ): 44 | raise TypeError( 45 | "Error while converting op of type: {}. Error message: {}\n".format(node.op_type, err_message, ) 46 | ) 47 | 48 | 49 | def missing_initializer(self, 50 | node, # type: Node 51 | err_message, # type: Text 52 | ): 53 | # type: (...) -> None 54 | ''' 55 | Missing initializer error 56 | ''' 57 | raise ValueError( 58 | "Missing initializer error in op of type {}, with input name = {}, " 59 | "output name = {}. Error message: {}\n". 60 | format(node.op_type, node.inputs[0], node.outputs[0], err_message) 61 | ) 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /onnx2caffe/_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from onnx import numpy_helper, ValueInfoProto, AttributeProto, GraphProto, NodeProto, TensorProto, TensorShapeProto 7 | from typing import Any, Text, Iterable, List, Dict, Sequence, Optional, Tuple, Union 8 | from typing_extensions import Protocol 9 | import numpy as np 10 | 11 | 12 | class Transformer(Protocol): 13 | def __call__(self, graph): # type: (Graph) -> Graph 14 | pass 15 | 16 | 17 | EdgeInfo = Tuple[Text, Any, TensorShapeProto] 18 | AttributeValue = Any # TODO Union[Sequence[float], Sequence[int], Sequence[Text], Sequence[TensorProto], Sequence[GraphProto]] 19 | 20 | def _input_from_onnx_input(input): # type: (ValueInfoProto) -> EdgeInfo 21 | name = input.name 22 | type = input.type.tensor_type.elem_type 23 | shape = tuple([d.dim_value for d in input.type.tensor_type.shape.dim]) 24 | return (name, type, shape) 25 | 26 | 27 | def _convertAttributeProto(onnx_arg): # type: (AttributeProto) -> AttributeValue 28 | """ 29 | Convert an ONNX AttributeProto into an appropriate Python object 30 | for the type. 31 | NB: Tensor attribute gets returned as numpy array 32 | """ 33 | if onnx_arg.HasField('f'): 34 | return onnx_arg.f 35 | elif onnx_arg.HasField('i'): 36 | return onnx_arg.i 37 | elif onnx_arg.HasField('s'): 38 | return onnx_arg.s 39 | elif onnx_arg.HasField('t'): 40 | return numpy_helper.to_array(onnx_arg.t) 41 | elif len(onnx_arg.floats): 42 | return list(onnx_arg.floats) 43 | elif len(onnx_arg.ints): 44 | return list(onnx_arg.ints) 45 | elif len(onnx_arg.strings): 46 | return list(onnx_arg.strings) 47 | else: 48 | raise ValueError("Unsupported ONNX attribute: {}".format(onnx_arg)) 49 | 50 | 51 | class Attributes(Dict[Text, Any]): 52 | @staticmethod 53 | def from_onnx(args): # type: (Iterable[AttributeProto]) -> Attributes 54 | d = Attributes() 55 | for arg in args: 56 | d[arg.name] = _convertAttributeProto(arg) 57 | return d 58 | 59 | 60 | class Node(object): 61 | def __init__(self, 62 | name, # type: Optional[Text] 63 | op_type, # type: Text 64 | attrs, # type: Dict[Text, AttributeValue] 65 | inputs, # type: List[Text] 66 | outputs, # type: List[Text] 67 | ): 68 | # type: (...) -> None 69 | self.name = name 70 | self.op_type = op_type 71 | self.attrs = attrs 72 | self.inputs = inputs 73 | self.outputs = outputs 74 | self.input_tensors = {} # type: Dict[Text, np._ArrayLike[Any]] 75 | self.parents = [] # type: List[Node] 76 | self.children = [] # type: List[Node] 77 | self.metadata = {} # type: Dict[Any, Any] 78 | 79 | def add_parent(self, parent_node): # type: (Node) -> None 80 | assert parent_node not in self.parents 81 | self.parents.append(parent_node) 82 | if self not in parent_node.children: 83 | parent_node.children.append(self) 84 | 85 | def add_child(self, child_node): # type: (Node) -> None 86 | assert child_node not in self.children 87 | self.children.append(child_node) 88 | if self not in child_node.parents: 89 | child_node.parents.append(self) 90 | 91 | def get_only_parent(self): # type: () -> Node 92 | if len(self.parents) != 1: 93 | raise ValueError('Node ({}) expected to have 1 parent. Found {}.' 94 | .format(self, len(self.parents))) 95 | return self.parents[0] 96 | 97 | @staticmethod 98 | def from_onnx(node): # type: (NodeProto) -> Node 99 | attrs = Attributes.from_onnx(node.attribute) 100 | name = Text(node.name) 101 | if len(name) == 0: 102 | name = "_".join(node.output) 103 | return Node( 104 | name, node.op_type, attrs, list(node.input), list(node.output) 105 | ) 106 | 107 | 108 | class Graph(object): 109 | def __init__(self, 110 | nodes, # type: List[Node] 111 | inputs, # type: List[EdgeInfo] 112 | outputs, # type: List[EdgeInfo] 113 | shape_dict, # type: Dict[Text,Tuple[int,...]] 114 | ): 115 | # type: (...) -> None 116 | self.nodes = nodes 117 | self.inputs = inputs 118 | self.outputs = outputs 119 | self.shape_dict = shape_dict # data blob name to its shape 120 | 121 | # data blob name to the list of op types it feeds into 122 | self.blob_to_op_type = {} # type: Dict[Text, List[Text]] 123 | # data blob name to the op_type that generates it 124 | self.blob_from_op_type = {} # type: Dict[Text, Text] 125 | 126 | for node_ in nodes: 127 | for input_ in node_.inputs: 128 | if input_ in self.blob_to_op_type: 129 | self.blob_to_op_type[input_].append(node_.op_type) 130 | else: 131 | self.blob_to_op_type[input_] = [node_.op_type] 132 | for output_ in node_.outputs: 133 | if output_ in self.blob_from_op_type: 134 | raise ValueError("Data blob: %s, is generated by more than 1 op" %(output_)) 135 | self.blob_from_op_type[output_] = node_.op_type 136 | 137 | 138 | def transformed(self, transformers): # type: (Iterable[Transformer]) -> Graph 139 | graph = self 140 | for transformer in transformers: 141 | graph = transformer(graph) 142 | return graph 143 | 144 | def has_edge_name(self, name): # type: (Text) -> bool 145 | ''' 146 | Check if name is already used for graph inputs/outputs or for nodes 147 | inputs/outputs 148 | ''' 149 | names = set() 150 | for input in self.inputs: 151 | names.add(input[0]) 152 | for output in self.outputs: 153 | names.add(output[0]) 154 | for node in self.nodes: 155 | names.update(node.inputs) 156 | names.update(node.outputs) 157 | return name in names 158 | 159 | def get_unique_edge_name(self, name): # type: (Text) -> Text 160 | n_ = name 161 | i = 0 162 | while self.has_edge_name(n_): 163 | n_ = "{}_{}".format(name, i) 164 | i += 1 165 | return n_ 166 | 167 | @staticmethod 168 | def from_onnx(graph): # type: (GraphProto) -> Graph 169 | input_tensors = { 170 | t.name: numpy_helper.to_array(t) for t in graph.initializer 171 | } 172 | nodes_ = [] 173 | nodes_by_input = {} # type: Dict[Text, List[Node]] 174 | nodes_by_output = {} 175 | for node in graph.node: 176 | node_ = Node.from_onnx(node) 177 | for input_ in node_.inputs: 178 | if input_ in input_tensors: 179 | node_.input_tensors[input_] = input_tensors[input_] 180 | else: 181 | if input_ in nodes_by_input: 182 | input_nodes = nodes_by_input[input_] 183 | else: 184 | input_nodes = [] 185 | nodes_by_input[input_] = input_nodes 186 | input_nodes.append(node_) 187 | for output_ in node_.outputs: 188 | nodes_by_output[output_] = node_ 189 | nodes_.append(node_) 190 | 191 | inputs = [] 192 | for i in graph.input: 193 | if i.name not in input_tensors: 194 | inputs.append(_input_from_onnx_input(i)) 195 | 196 | outputs = [] 197 | for o in graph.output: 198 | outputs.append(_input_from_onnx_input(o)) 199 | 200 | for node_ in nodes_: 201 | for input_ in node_.inputs: 202 | if input_ in nodes_by_output: 203 | node_.parents.append(nodes_by_output[input_]) 204 | for output_ in node_.outputs: 205 | if output_ in nodes_by_input: 206 | node_.children.extend(nodes_by_input[output_]) 207 | 208 | # Dictionary to hold the "value_info" field from ONNX graph 209 | shape_dict = {} # type: Dict[Text,Tuple[int,...]] 210 | 211 | def extract_value_info(shape_dict, # type: Dict[Text,Tuple[int,...]] 212 | value_info, # type: ValueInfoProto[...] 213 | ): 214 | # type: (...) -> None 215 | shape_dict[value_info.name] = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) 216 | 217 | for value_info in graph.value_info: 218 | extract_value_info(shape_dict, value_info) 219 | for value_info in graph.input: 220 | extract_value_info(shape_dict, value_info) 221 | for value_info in graph.output: 222 | extract_value_info(shape_dict, value_info) 223 | 224 | 225 | return Graph(nodes_, inputs, outputs, shape_dict) 226 | -------------------------------------------------------------------------------- /onnx2caffe/_operators.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | from caffe import params as P 6 | import math 7 | import numpy as np 8 | from ._graph import Node, Graph 9 | from MyCaffe import Function as myf 10 | 11 | def _compare(a, b, encoding="utf8"): #type: (Text, Text, Text) -> bool 12 | if isinstance(a, bytes): 13 | a = a.decode(encoding) 14 | if isinstance(b, bytes): 15 | b = b.decode(encoding) 16 | return a == b 17 | 18 | def make_input(input): 19 | name = input[0] 20 | output = input[0] 21 | output = [output] 22 | shape = input[2] 23 | shape = list(shape) 24 | input_layer = myf("Input", name, [], output, input_param=dict(shape=dict(dim=shape))) 25 | return input_layer 26 | 27 | def _convert_conv(node, graph, err): 28 | weight_name = node.inputs[1] 29 | input_name = str(node.inputs[0]) 30 | output_name = str(node.outputs[0]) 31 | node_name = node.name 32 | W = None 33 | if weight_name in node.input_tensors: 34 | W = node.input_tensors[weight_name] 35 | else: 36 | err.missing_initializer(node, 37 | "Weight tensor: {} not found in the graph initializer".format(weight_name,)) 38 | is_deconv = False 39 | if node.op_type.endswith("Transpose"): 40 | is_deconv = True 41 | bias_flag = False 42 | bias = None 43 | if len(node.inputs) > 2: 44 | bias = node.input_tensors[node.inputs[2]] 45 | bias_flag = True 46 | dilations = node.attrs.get("dilations", [1, 1]) 47 | # groups = 1 48 | groups = node.attrs.get("group", 1) 49 | kernel_shape = node.attrs["kernel_shape"] 50 | pads = node.attrs.get("pads", [0, 0, 0, 0]) 51 | strides = node.attrs["strides"] 52 | 53 | layer = myf("Convolution", node_name, [input_name], [output_name], 54 | kernel_h = kernel_shape[0],kernel_w = kernel_shape[1], 55 | stride_h=strides[0], stride_w = strides[1], group = groups, 56 | pad_h = pads[0], pad_w = pads[1], 57 | num_output=W.shape[0], dilation = dilations[0], bias_term = bias_flag) 58 | 59 | graph.channel_dims[output_name] = W.shape[0] 60 | return layer 61 | 62 | def _convert_relu(node,graph,err): 63 | input_name = str(node.inputs[0]) 64 | output_name = str(node.outputs[0]) 65 | name = str(node.name) 66 | 67 | if input_name==output_name: 68 | inplace = True 69 | else: 70 | inplace = False 71 | 72 | layer = myf("ReLU",name,[input_name],[output_name],in_place=inplace) 73 | # l_top_relu1 = L.ReLU(l_bottom, name=name, in_place=True) 74 | 75 | graph.channel_dims[output_name] = graph.channel_dims[input_name] 76 | 77 | return layer 78 | 79 | def _convert_sigmoid(node,graph,err): 80 | input_name = str(node.inputs[0]) 81 | output_name = str(node.outputs[0]) 82 | name = str(node.name) 83 | 84 | if input_name==output_name: 85 | inplace = True 86 | else: 87 | inplace = False 88 | 89 | layer = myf("Sigmoid",name,[input_name],[output_name],in_place=inplace) 90 | # l_top_relu1 = L.ReLU(l_bottom, name=name, in_place=True) 91 | 92 | graph.channel_dims[output_name] = graph.channel_dims[input_name] 93 | 94 | return layer 95 | 96 | def _convert_BatchNorm(node,graph,err): 97 | epsilon = node.attrs.get("epsilon", 1e-5) 98 | scale = node.input_tensors[node.inputs[1]] 99 | bias = node.input_tensors[node.inputs[2]] 100 | mean = node.input_tensors[node.inputs[3]] 101 | var = node.input_tensors[node.inputs[4]] 102 | node_name = node.name 103 | 104 | input_name = str(node.inputs[0]) 105 | output_name = str(node.outputs[0]) 106 | 107 | if input_name==output_name: 108 | inplace = True 109 | else: 110 | inplace = False 111 | 112 | bn_layer = myf("BatchNorm", node_name+"_bn",[input_name],[output_name],eps = epsilon, use_global_stats = True, in_place=inplace) 113 | scale_layer = myf("Scale", node_name, [output_name],[output_name],in_place=True,bias_term=True) 114 | 115 | graph.channel_dims[output_name] = graph.channel_dims[input_name] 116 | 117 | return bn_layer,scale_layer 118 | 119 | def _convert_Add(node,graph,err): 120 | input_name_list = [str(i) for i in node.inputs] 121 | output_name = str(node.outputs[0]) 122 | node_name = node.name 123 | 124 | max_dim = 0 125 | for name in input_name_list: 126 | if graph.channel_dims[name]>max_dim: 127 | max_dim = graph.channel_dims[name] 128 | 129 | if 'broadcast' in node.attrs: 130 | if node.attrs['broadcast'] == 1: 131 | input_node_number = len(input_name_list) 132 | if input_node_number !=2: 133 | return err.unsupported_op_configuration(node, "Broadcast Add must has 2 input, not {}".format(input_node_number)) 134 | axis = node.attrs['axis'] 135 | flat_layer = myf("Flatten",node_name+'_flat',[input_name_list[1]],[output_name+'_flat']) 136 | layer = myf("Bias", node_name, [input_name_list[0],output_name+'_flat'], [output_name], axis = axis) 137 | # layer = myf("Bias", node_name, input_name_list, [output_name], bias_term = False, axis = axis) 138 | graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] 139 | return flat_layer,layer 140 | 141 | layer = myf("Eltwise",node_name,input_name_list,[output_name],operation=P.Eltwise.SUM) 142 | graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] 143 | return layer 144 | 145 | def _convert_Mul(node,graph,err): 146 | input_name_list = [str(i) for i in node.inputs] 147 | output_name = str(node.outputs[0]) 148 | node_name = node.name 149 | 150 | # max_dim = 0 151 | # for name in input_name_list: 152 | # if graph.channel_dims[name]>max_dim: 153 | # max_dim = graph.channel_dims[name] 154 | 155 | if 'broadcast' in node.attrs: 156 | if node.attrs['broadcast'] == 1: 157 | input_node_number = len(input_name_list) 158 | if input_node_number !=2: 159 | return err.unsupported_op_configuration(node, "Broadcast Mul must has 2 input, not {}".format(input_node_number)) 160 | axis = node.attrs['axis'] 161 | flat_layer = myf("Flatten",node_name+'_flat',[input_name_list[1]],[output_name+'_flat']) 162 | layer = myf("Scale", node_name, [input_name_list[0],output_name+'_flat'], [output_name], bias_term = False, axis = axis) 163 | graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] 164 | return flat_layer,layer 165 | 166 | layer = myf("Eltwise",node_name,input_name_list,[output_name],operation=P.Eltwise.PROD) 167 | graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] 168 | return layer 169 | 170 | def _convert_Reshape(node,graph,err): 171 | node_name = node.name 172 | input_name = str(node.inputs[0]) 173 | output_name = str(node.outputs[0]) 174 | if len(node.inputs)==1: 175 | shape = tuple(node.attrs.get('shape', ())) 176 | else: 177 | shape = tuple(node.input_tensors[node.inputs[1]]) 178 | # if shape == (): 179 | 180 | 181 | if input_name==output_name: 182 | inplace = True 183 | else: 184 | inplace = False 185 | if len(shape) == 2: 186 | layer = myf("Flatten",node_name,[input_name],[output_name],in_place=inplace) 187 | graph.channel_dims[output_name] = shape[1] 188 | return layer 189 | elif len(shape) == 4: 190 | graph.channel_dims[output_name] = shape[1] 191 | layer = myf("Reshape", node_name, [input_name], [output_name], reshape_param = dict(shape=dict(dim=list(shape)))) 192 | return layer 193 | else: 194 | return err.unsupported_op_configuration(node, "Reshape dimention number shall be 2 or 4") 195 | 196 | def _convert_Flatten(node,graph,err): 197 | node_name = node.name 198 | input_name = str(node.inputs[0]) 199 | output_name = str(node.outputs[0]) 200 | # shape = tuple(node.attrs.get('shape', ())) 201 | if input_name==output_name: 202 | inplace = True 203 | else: 204 | inplace = False 205 | layer = myf("Flatten", node_name, [input_name], [output_name], in_place=inplace) 206 | # graph.channel_dims[output_name] = shape[1] 207 | return layer 208 | 209 | def _convert_pool(node,graph,err): 210 | node_name = node.name 211 | input_name = str(node.inputs[0]) 212 | output_name = str(node.outputs[0]) 213 | if node.op_type.endswith("MaxPool"): 214 | pool_type = P.Pooling.MAX 215 | elif node.op_type.endswith("AveragePool"): 216 | pool_type = P.Pooling.AVE 217 | else: 218 | return err.unsupported_op_configuration(node, "Unsupported pool type") 219 | 220 | kernel_shape = node.attrs["kernel_shape"] 221 | strides = node.attrs.get('strides', [1, 1]) 222 | pads = node.attrs.get('pads', [0, 0, 0, 0]) 223 | 224 | layer = myf("Pooling",node_name,[input_name],[output_name],pooling_param = dict(pool = pool_type, 225 | kernel_h = kernel_shape[0], 226 | kernel_w = kernel_shape[1], 227 | stride_h = strides[0], 228 | stride_w = strides[1], 229 | pad_h = pads[0], 230 | pad_w = pads[1])) 231 | graph.channel_dims[output_name] = graph.channel_dims[input_name] 232 | return layer 233 | 234 | def _convert_dropout(node,graph,err): 235 | node_name = node.name 236 | input_name = str(node.inputs[0]) 237 | output_name = str(node.outputs[0]) 238 | ratio = node.attrs.get('ratio', 0.5) 239 | layer = myf("Dropout", node_name, [input_name], [output_name], dropout_ratio =ratio) 240 | graph.channel_dims[output_name] = graph.channel_dims[input_name] 241 | return layer 242 | 243 | def _convert_gemm(node,graph,err): 244 | node_name = node.name 245 | input_name = str(node.inputs[0]) 246 | output_name = str(node.outputs[0]) 247 | weight_name = node.inputs[1] 248 | if weight_name in node.input_tensors: 249 | W = node.input_tensors[weight_name] 250 | else: 251 | err.missing_initializer(node, 252 | "Weight tensor: {} not found in the graph initializer".format(weight_name, )) 253 | return 254 | 255 | if node.attrs["broadcast"] != 1 or node.attrs["transB"] != 1: 256 | return err.unsupported_op_configuration(node,"Gemm is supported only for inner_product layer") 257 | 258 | b = None 259 | bias_flag = False 260 | if len(node.inputs) > 2: 261 | b = node.input_tensors[node.inputs[2]] 262 | 263 | if len(W.shape) != 2 or (b is not None and len(b.shape) != 1): 264 | return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") 265 | if b is not None: 266 | bias_flag = True 267 | if W.shape[0] != b.shape[0]: 268 | return err.unsupported_op_configuration(node, 269 | "Gemm is supported only for inner_product layer") 270 | 271 | layer = myf("InnerProduct",node_name,[input_name],[output_name],num_output = W.shape[0],bias_term = bias_flag) 272 | graph.channel_dims[output_name] = W.shape[0] 273 | 274 | return layer 275 | 276 | def _convert_upsample(node,graph,err): 277 | factor = int(node.attrs["height_scale"]) 278 | node_name = node.name 279 | input_name = str(node.inputs[0]) 280 | output_name = str(node.outputs[0]) 281 | # input_shape = graph.shape_dict[input_name] 282 | # channels = input_shape[1] 283 | channels = graph.channel_dims[input_name] 284 | pad = int(math.ceil((factor - 1) / 2.)) 285 | # layer = myf("Deconvolution", node_name, [input_name], [output_name], 286 | # kernel_size=2 * factor - factor % 2, 287 | # stride=factor, group=channels, 288 | # pad = pad, num_output=channels, bias_term = False) 289 | mode = node.attrs["mode"] 290 | #https://github.com/pytorch/pytorch/issues/6900 291 | if mode=="bilinear": 292 | layer = myf("Deconvolution", node_name, [input_name], [output_name], 293 | convolution_param=dict( 294 | num_output=channels, 295 | kernel_size=2 * factor - factor % 2, 296 | stride=factor, 297 | pad=pad, 298 | group=channels, 299 | bias_term=False, 300 | weight_filler=dict(type="bilinear_upsampling") 301 | )) 302 | else: 303 | layer = myf("Deconvolution", node_name, [input_name], [output_name], 304 | convolution_param=dict( 305 | num_output=channels, 306 | kernel_size=factor, 307 | stride=factor, 308 | group=channels, 309 | bias_term=False, 310 | )) 311 | 312 | graph.channel_dims[output_name] = graph.channel_dims[input_name] 313 | return layer 314 | 315 | def _convert_concat(node,graph,err): 316 | node_name = node.name 317 | input_name_list = [str(i) for i in node.inputs] 318 | output_name = str(node.outputs[0]) 319 | axis = node.attrs.get("axis", 1) 320 | 321 | layer = myf('Concat', node_name, input_name_list, [output_name], axis = axis) 322 | if axis == 1: 323 | dim = 0 324 | for name in input_name_list: 325 | dim+=graph.channel_dims[name] 326 | graph.channel_dims[output_name] = dim 327 | else: 328 | graph.channel_dims[output_name] = graph.channel_dims[input_name_list[0]] 329 | 330 | return layer 331 | 332 | def _convert_conv_transpose(node,graph,err): 333 | input_name = str(node.inputs[0]) 334 | output_name = str(node.outputs[0]) 335 | node_name = node.name 336 | weight_name = node.inputs[1] 337 | W = None 338 | if weight_name in node.input_tensors: 339 | W = node.input_tensors[weight_name] 340 | else: 341 | err.missing_initializer(node, 342 | "Weight tensor: {} not found in the graph initializer".format(weight_name,)) 343 | bias_flag = False 344 | bias = None 345 | if len(node.inputs) > 2: 346 | bias = node.input_tensors[node.inputs[2]] 347 | bias_flag = True 348 | dilations = node.attrs.get("dilations", [1, 1]) 349 | # groups = 1 350 | groups = node.attrs.get("group", 1) 351 | kernel_shape = node.attrs["kernel_shape"] 352 | pads = node.attrs.get("pads", [0, 0, 0, 0]) 353 | strides = node.attrs["strides"] 354 | 355 | layer = myf('Deconvolution', node_name, [input_name], [output_name], 356 | convolution_param=dict( 357 | num_output=W.shape[1], 358 | kernel_h=kernel_shape[0],kernel_w=kernel_shape[1], 359 | stride_h=strides[0],stride_w = strides[1], 360 | group=groups, 361 | pad_h=pads[0], pad_w=pads[1], 362 | bias_term=bias_flag, 363 | )) 364 | 365 | graph.channel_dims[output_name] = W.shape[1] 366 | return layer 367 | 368 | # l_top = L.Deconvolution( 369 | # l_bottom, 370 | # name=name, 371 | # convolution_param=dict( 372 | # num_output=W.shape[1], 373 | # kernel_h=kernel_h, 374 | # kernel_w=kernel_w, 375 | # stride_h=stride_h, 376 | # stride_w=stride_w, 377 | # pad_h=pad_h, 378 | # pad_w=pad_w, 379 | # group=groups, 380 | # bias_term=bias_term)) 381 | 382 | 383 | 384 | _ONNX_NODE_REGISTRY = { 385 | "Conv": _convert_conv, 386 | "Relu": _convert_relu, 387 | "BatchNormalization": _convert_BatchNorm, 388 | "Add": _convert_Add, 389 | "Mul": _convert_Mul, 390 | "Reshape": _convert_Reshape, 391 | "MaxPool": _convert_pool, 392 | "AveragePool": _convert_pool, 393 | "Dropout": _convert_dropout, 394 | "Gemm": _convert_gemm, 395 | "Upsample": _convert_upsample, 396 | "Concat": _convert_concat, 397 | "ConvTranspose": _convert_conv_transpose, 398 | "Sigmoid": _convert_sigmoid, 399 | "Flatten": _convert_Flatten, 400 | } 401 | -------------------------------------------------------------------------------- /onnx2caffe/_transformers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from typing import Sequence, Text, Dict, List 7 | import numpy as np 8 | 9 | from onnx import TensorProto 10 | 11 | from ._graph import Graph, Node 12 | 13 | 14 | class NodesFuser(object): 15 | ''' 16 | An abstract helper for merging nodes 17 | ''' 18 | def __init__(self, 19 | num_nodes, # type: int 20 | ): 21 | # type: (...) -> None 22 | assert num_nodes >= 2, "Algorithm only works if fusing multiple nodes" 23 | self.num_nodes = num_nodes 24 | 25 | def __call__(self, graph): # type: (Graph) -> Graph 26 | nodes = graph.nodes 27 | merged_nodes = {} 28 | for node in nodes: 29 | nodes_window = [] # type: List[Node] 30 | n = node 31 | for _ in range(self.num_nodes - 1): 32 | if len(n.parents) != 1: 33 | # We're only fusing nodes with single parents 34 | break 35 | p = n.get_only_parent() 36 | if len(p.children) != 1: 37 | # We can only fuse a node if its parent's 38 | # value isn't used by any other node. 39 | break 40 | nodes_window.insert(0, n) 41 | n = p 42 | if len(nodes_window) > 0: 43 | # add parent of chained nodes 44 | first = nodes_window[0] 45 | p = first.get_only_parent() 46 | if len(p.children) == 1: 47 | nodes_window.insert(0, p) 48 | if len(nodes_window) != self.num_nodes: 49 | continue 50 | if not self.is_eligible(graph, nodes_window): 51 | continue 52 | merged = self.merge(graph, nodes_window) 53 | first, last = nodes_window[0], nodes_window[-1] 54 | for parent in first.parents: 55 | parent.children.remove(first) 56 | if merged[0] not in parent.children: 57 | parent.add_child(merged[0]) 58 | for child in last.children: 59 | child.parents.remove(last) 60 | if merged[-1] not in child.parents: 61 | child.add_parent(merged[-1]) 62 | for n in nodes_window: 63 | merged_nodes[n.name] = merged 64 | 65 | transformed_nodes = [] 66 | added_merged = [] # type: List[Node] 67 | for node in nodes: 68 | if node.name in merged_nodes: 69 | merged = merged_nodes[node.name] 70 | if merged[0] not in added_merged: 71 | for n in merged: 72 | transformed_nodes.append(n) 73 | added_merged.append(merged[0]) 74 | else: 75 | transformed_nodes.append(node) 76 | return Graph(transformed_nodes, graph.inputs, graph.outputs, graph.shape_dict) 77 | 78 | def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool 79 | '''Returns true if this subset of nodes is eligible for fusion.''' 80 | raise NotImplementedError('Must be implemented by subclass.') 81 | 82 | def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] 83 | '''Merge nodes''' 84 | nodes[0].outputs = nodes[-1].outputs 85 | return [nodes[0]] 86 | 87 | 88 | class ConvAddFuser(NodesFuser): 89 | ''' 90 | Fuses Add layer into parent convolution layer. 91 | ''' 92 | def __init__(self): # type: () -> None 93 | super(ConvAddFuser, self).__init__(2) 94 | 95 | def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool 96 | parent, child = nodes[0], nodes[1] 97 | if parent.op_type != 'Conv': 98 | return False 99 | if child.op_type != 'Add': 100 | return False 101 | if 'broadcast' not in child.attrs: 102 | return False 103 | if 'axis' not in child.attrs: 104 | return False 105 | if parent.inputs[1] not in parent.input_tensors: 106 | return False 107 | if len(parent.inputs) > 2 and parent.inputs[2] not in parent.input_tensors: 108 | return False 109 | if child.inputs[1] not in child.input_tensors: 110 | return False 111 | 112 | broadcast = child.attrs['broadcast'] 113 | if broadcast != 1: 114 | return False 115 | 116 | axis = child.attrs['axis'] 117 | if axis != 1: 118 | return False 119 | 120 | return True 121 | 122 | def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] 123 | parent, child = nodes[0], nodes[1] 124 | output_channels = parent.input_tensors[parent.inputs[1]].shape[0] 125 | if len(parent.inputs) > 2: 126 | bias_input_name = parent.inputs[2] 127 | bias = parent.input_tensors[bias_input_name] 128 | else: 129 | bias_input_name = "{}_bias".format(parent.name,) 130 | parent.inputs.append(bias_input_name) 131 | bias = np.zeros( 132 | (output_channels,), dtype=np.float32 133 | ) 134 | parent.input_tensors[bias_input_name] = bias 135 | bias = bias + child.input_tensors[child.inputs[1]] 136 | parent.input_tensors[bias_input_name] = bias 137 | parent.outputs = child.outputs 138 | parent.children.remove(child) 139 | child.parents.remove(parent) 140 | return [parent] 141 | 142 | 143 | class BNBroadcastedMulFuser(NodesFuser): 144 | ''' 145 | Fuses Mul into BatchNorm 146 | ''' 147 | def __init__(self): # type: () -> None 148 | super(BNBroadcastedMulFuser, self).__init__(2) 149 | 150 | def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool 151 | parent, child = nodes[0], nodes[1] 152 | if parent.op_type != 'BatchNormalization': 153 | return False 154 | if child.op_type != 'Mul': 155 | return False 156 | if "broadcast" not in child.attrs: 157 | return False 158 | if child.attrs["broadcast"] != 1: 159 | return False 160 | if "axis" not in child.attrs: 161 | return False 162 | if child.attrs["axis"] != 1: 163 | return False 164 | if child.inputs[1] not in child.input_tensors: 165 | return False 166 | if parent.inputs[1] not in parent.input_tensors: 167 | return False 168 | if parent.inputs[2] not in parent.input_tensors: 169 | return False 170 | return True 171 | 172 | def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] 173 | parent, child = nodes[0], nodes[1] 174 | weight = parent.input_tensors[parent.inputs[1]] 175 | bias = parent.input_tensors[parent.inputs[2]] 176 | W = child.input_tensors[child.inputs[1]] 177 | parent.input_tensors[parent.inputs[1]] = np.multiply(weight, W) 178 | parent.input_tensors[parent.inputs[2]] = np.multiply(bias, W) 179 | parent.outputs = child.outputs 180 | parent.children.remove(child) 181 | child.parents.remove(parent) 182 | return [parent] 183 | 184 | 185 | class BNBroadcastedAddFuser(NodesFuser): 186 | ''' 187 | Fuses Add into BatchNorm 188 | ''' 189 | def __init__(self): # type: () -> None 190 | super(BNBroadcastedAddFuser, self).__init__(2) 191 | 192 | def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool 193 | parent, child = nodes[0], nodes[1] 194 | if parent.op_type != 'BatchNormalization': 195 | return False 196 | if child.op_type != 'Add': 197 | return False 198 | if "broadcast" not in child.attrs: 199 | return False 200 | if child.attrs["broadcast"] != 1: 201 | return False 202 | if "axis" not in child.attrs: 203 | return False 204 | if child.attrs["axis"] != 1: 205 | return False 206 | if len(child.inputs) != 2: 207 | return False 208 | if child.inputs[1] not in child.input_tensors: 209 | return False 210 | if parent.inputs[2] not in parent.input_tensors: 211 | return False 212 | return True 213 | 214 | def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] 215 | parent, child = nodes[0], nodes[1] 216 | bias = parent.input_tensors[parent.inputs[2]] 217 | b = child.input_tensors[child.inputs[1]] 218 | parent.input_tensors[parent.inputs[2]] = bias + b 219 | parent.outputs = child.outputs 220 | parent.children.remove(child) 221 | child.parents.remove(parent) 222 | return [parent] 223 | 224 | 225 | class DropoutRemover(NodesFuser): 226 | ''' 227 | Removes Dropout layer 228 | ''' 229 | def __init__(self): # type: () -> None 230 | super(DropoutRemover, self).__init__(2) 231 | 232 | def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool 233 | child = nodes[1] 234 | return child.op_type == "Dropout" 235 | 236 | def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] 237 | parent, child = nodes[0], nodes[1] 238 | parent.children.remove(child) 239 | child.parents.remove(parent) 240 | parent.outputs = child.outputs 241 | return [parent] 242 | 243 | 244 | class ReshapeInitTensorFuser(object): 245 | ''' 246 | Fuses Reshape operator if it is used only to reshape blob in 247 | graph initializer. We can reshape here instead of runtime. 248 | ''' 249 | 250 | def __call__(self, graph): # type: (Graph) -> Graph 251 | nodes = graph.nodes 252 | removed = [] 253 | for node in nodes: 254 | if node.op_type != 'Reshape': 255 | continue 256 | if not (len(node.input_tensors) == 2 or len(node.input_tensors) == 1): 257 | continue 258 | tensor_name = node.inputs[0] 259 | if tensor_name not in node.input_tensors: 260 | continue 261 | if len(node.inputs) > 1: 262 | shape_name = node.inputs[1] 263 | if shape_name not in node.input_tensors: 264 | continue 265 | is_non_constant_parent = False 266 | if len(node.parents) > 0: 267 | for parent in node.parents: 268 | if parent.op_type != 'Constant': 269 | is_non_constant_parent = True 270 | break 271 | if is_non_constant_parent: 272 | continue 273 | 274 | removed.append(node) 275 | output_name = node.outputs[0] 276 | 277 | tensor = node.input_tensors[tensor_name] 278 | if 'shape' in node.attrs: 279 | shape = tuple(node.attrs["shape"]) 280 | else: 281 | shape = node.input_tensors[shape_name] # type: ignore 282 | 283 | # ONNX spec supports setting dimension to '0', in which case 284 | # it should be taken from old dimension. 285 | # This isn't supported in numpy, so don't transform. 286 | # TODO Should we support this case? 287 | if any([s == 0 for s in shape]): 288 | continue 289 | 290 | reshaped_tensor = tensor.reshape(shape) 291 | 292 | for child in node.children: 293 | child.parents.remove(node) 294 | child.input_tensors[output_name] = reshaped_tensor 295 | 296 | transformed_nodes = [node for node in nodes if node not in removed] 297 | return Graph(transformed_nodes, graph.inputs, graph.outputs, graph.shape_dict) 298 | 299 | 300 | class OutputRenamer(object): 301 | ''' 302 | Rename outputs according to mapping 303 | ''' 304 | def __init__(self, 305 | mapping, # type: Dict[Text, Text] 306 | ): 307 | # type: (...) -> None 308 | self.mapping = mapping 309 | 310 | def __call__(self, graph): # type: (Graph) -> Graph 311 | mapping = self.mapping.copy() 312 | nodes = graph.nodes 313 | for node in nodes: 314 | for i in range(len(node.outputs)): 315 | output = node.outputs[i] 316 | if output not in mapping: 317 | continue 318 | node.outputs[i] = mapping[output] 319 | for child in node.children: 320 | for j in range(len(child.inputs)): 321 | input_ = child.inputs[j] 322 | if input_ != output: 323 | continue 324 | child.inputs[j] = mapping[output] 325 | del mapping[output] 326 | if len(mapping) == 0: 327 | break 328 | return graph 329 | 330 | 331 | class PixelShuffleFuser(NodesFuser): 332 | ''' 333 | Fuses 3 operators reshape->transpose->reshape which is equivalent to 334 | pytorch's pixel_shuffle layer 335 | ''' 336 | def __init__(self): # type: () -> None 337 | super(PixelShuffleFuser, self).__init__(3) 338 | self.num_added = 0 339 | 340 | def is_eligible(self, graph, nodes): # type: (Graph, Sequence[Node]) -> bool 341 | if nodes[0].op_type != 'Reshape': 342 | return False 343 | if nodes[1].op_type != 'Transpose': 344 | return False 345 | if nodes[2].op_type != 'Reshape': 346 | return False 347 | if nodes[0].inputs[1] not in nodes[0].input_tensors: 348 | return False 349 | if nodes[2].inputs[1] not in nodes[2].input_tensors: 350 | return False 351 | 352 | shape = nodes[0].input_tensors[nodes[0].inputs[1]] 353 | if len(shape) != 6: 354 | return False 355 | if shape[0] != 1 or shape[2] != shape[3]: 356 | return False 357 | 358 | input_channels = shape[1] 359 | scale_factor = shape[2] 360 | input_height = shape[4] 361 | input_width = shape[5] 362 | 363 | if nodes[1].attrs.get('perm', []) != [0, 1, 4, 2, 5, 3]: 364 | return False 365 | 366 | shape = nodes[2].input_tensors[nodes[2].inputs[1]] 367 | if len(shape) != 4: 368 | return False 369 | 370 | output_channels = shape[1] 371 | output_height = shape[2] 372 | output_width = shape[3] 373 | if input_channels != output_channels: 374 | return False 375 | if (input_height * scale_factor) != output_height: 376 | return False 377 | if (input_width * scale_factor) != output_width: 378 | return False 379 | 380 | return True 381 | 382 | def get_unique_edge_name(self, graph, name): # type: (Graph, Text) -> Text 383 | self.num_added += 1 384 | return graph.get_unique_edge_name(name + '_' + str(self.num_added)) 385 | 386 | def merge(self, graph, nodes): # type: (Graph, Sequence[Node]) -> Sequence[Node] 387 | ''' 388 | Pixel shuffle is implemented using 3 operators: 389 | - Reshape(1, channels, scale, scale, height, width) 390 | - Transpose(0, 1, 4, 2, 5, 3) 391 | - Reshape(1, channels, height * scale, width * scale) 392 | CoreML Reshape and Transpose layers don't support tensors with more 393 | than 4 dimensions. Thus we change above sequence of operators to the 394 | following equivalent sequence: 395 | - Reshape(channels, scale * scale, height, width) 396 | - Transpose(0, 2, 1, 3) 397 | - Reshape(channels * height, scale, scale, width) 398 | - Transpose(0, 1, 3, 2) 399 | - Reshape(1, channels, height * scale, width * scale) 400 | ''' 401 | reshape_1 = nodes[0] 402 | transpose_1 = nodes[1] 403 | transpose_1.children = [] 404 | 405 | shape = reshape_1.input_tensors[reshape_1.inputs[1]] 406 | 407 | channels = shape[1] 408 | scale = shape[2] 409 | height = shape[4] 410 | width = shape[5] 411 | 412 | reshape_1.input_tensors[reshape_1.inputs[1]] = np.asarray([channels, scale * scale, height, width]) 413 | transpose_1.attrs['perm'] = [0, 2, 1, 3] 414 | 415 | reshape_output_name = 'pixel_shuffle_reshape' 416 | transpose_output_name = 'pixel_shuffle_transpose' 417 | 418 | transpose_1.outputs = [ 419 | self.get_unique_edge_name(graph, transpose_output_name) 420 | ] 421 | 422 | shape_name_second_reshape = self.get_unique_edge_name(graph, reshape_output_name) 423 | output_name_second_reshape = self.get_unique_edge_name(graph, reshape_output_name) 424 | reshape_2 = Node( 425 | reshape_output_name, 426 | 'Reshape', 427 | {}, 428 | [transpose_1.outputs[0], shape_name_second_reshape], 429 | [output_name_second_reshape] 430 | ) 431 | reshape_2.input_tensors[shape_name_second_reshape] = np.asarray([channels * height, scale, scale, width]) 432 | transpose_1.add_child(reshape_2) 433 | 434 | transpose_2 = Node( 435 | transpose_output_name, 436 | 'Transpose', 437 | {'perm': [0, 1, 3, 2]}, 438 | reshape_2.outputs, 439 | [self.get_unique_edge_name(graph, transpose_output_name)] 440 | ) 441 | reshape_2.add_child(transpose_2) 442 | 443 | final_reshape = nodes[2] 444 | final_reshape.inputs = [transpose_2.outputs[0], nodes[2].inputs[1]] 445 | final_reshape.parents = [] 446 | transpose_2.add_child(final_reshape) 447 | return [reshape_1, transpose_1, reshape_2, transpose_2, final_reshape] 448 | 449 | 450 | class AddModelInputsOutputs(object): 451 | ''' 452 | Expose hidden states of recurrent layers as model inputs and outputs 453 | ''' 454 | def __call__(self, graph): # type: (Graph) -> Graph 455 | input_names = [str(input_[0]) for input_ in graph.inputs] 456 | output_names = [str(output_[0]) for output_ in graph.outputs] 457 | for node in graph.nodes: 458 | if str(node.op_type) == 'LSTM': 459 | input_h = node.inputs[5] if len(node.inputs) > 5 else node.inputs[0] + '_h_input' 460 | input_c = node.inputs[6] if len(node.inputs) > 6 else node.inputs[0] + '_c_input' 461 | output_h = node.outputs[1] if len(node.outputs) > 1 else node.outputs[0] + '_h_output' 462 | output_c = node.outputs[2] if len(node.outputs) > 2 else node.outputs[0] + '_c_output' 463 | h = node.attrs["hidden_size"] 464 | for input_ in [str(input_h), str(input_c)]: 465 | if input_ not in input_names: 466 | graph.inputs.append(tuple((input_, TensorProto.FLOAT, (h,)))) #type: ignore 467 | if input_ not in graph.blob_to_op_type: 468 | graph.blob_to_op_type[input_] = ['LSTM'] 469 | for output_ in [str(output_h), str(output_c)]: 470 | if output_ not in output_names: 471 | graph.outputs.append(tuple((output_, TensorProto.FLOAT, (h,)))) #type: ignore 472 | graph.blob_from_op_type[output_] = 'LSTM' 473 | return graph 474 | 475 | 476 | class ConstantsToInitializers(object): 477 | ''' 478 | Takes onnx Constant nodes and puts the tensor into graph initializers instead. 479 | ''' 480 | def __call__(self, graph): # type: (Graph) -> Graph 481 | output_names = [str(output_[0]) for output_ in graph.outputs] 482 | remaining_nodes = [] 483 | for node in graph.nodes: 484 | if node.op_type != 'Constant' or node.name in output_names: 485 | remaining_nodes.append(node) 486 | continue 487 | for child in node.children: 488 | child.input_tensors[node.outputs[0]] = node.attrs["value"] 489 | 490 | graph.nodes = remaining_nodes 491 | return graph 492 | 493 | 494 | class ImageScalerRemover(object): 495 | ''' 496 | Removes ImageScaler layer if connected to a model input and single parent child nodes 497 | ''' 498 | 499 | def __call__(self, graph): # type: (Graph) -> Graph 500 | input_names = [str(input_[0]) for input_ in graph.inputs] 501 | nodes_to_be_removed = [] 502 | for node in graph.nodes: 503 | if (node.op_type != 'ImageScaler') or (len(node.parents) != 0) or (node.inputs[0] not in input_names): 504 | continue 505 | is_eligible = True 506 | for child in node.children: 507 | if not (len(child.parents) == 1 and child.inputs[0] == node.outputs[0]): 508 | is_eligible = False 509 | break 510 | child.inputs[0] = node.inputs[0] 511 | child.parents = [] 512 | if not is_eligible: 513 | continue 514 | nodes_to_be_removed.append(node.name) 515 | 516 | transformed_nodes = [] 517 | for node in graph.nodes: 518 | if node.name not in nodes_to_be_removed: 519 | transformed_nodes.append(node) 520 | return Graph(transformed_nodes, graph.inputs, graph.outputs, graph.shape_dict) -------------------------------------------------------------------------------- /onnx2caffe/_weightloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | # from caffe import params as P 6 | import numpy as np 7 | from ._graph import Node, Graph 8 | 9 | def _convert_conv(net, node, graph, err): 10 | weight_name = node.inputs[1] 11 | input_name = str(node.inputs[0]) 12 | output_name = str(node.outputs[0]) 13 | node_name = node.name 14 | W = None 15 | if weight_name in node.input_tensors: 16 | W = node.input_tensors[weight_name] 17 | else: 18 | err.missing_initializer(node, 19 | "Weight tensor: {} not found in the graph initializer".format(weight_name,)) 20 | bias_flag = False 21 | bias = None 22 | if len(node.inputs) > 2: 23 | bias = node.input_tensors[node.inputs[2]] 24 | bias_flag = True 25 | # net.params[node_name][0].data = W 26 | # if bias_flag: 27 | # net.params[node_name][1].data = bias 28 | np.copyto(net.params[node_name][0].data,W,casting='same_kind') 29 | if bias_flag: 30 | np.copyto(net.params[node_name][1].data, bias, casting='same_kind') 31 | 32 | def _convert_relu(net, node, graph, err): 33 | pass 34 | 35 | def _convert_sigmoid(net, node, graph, err): 36 | pass 37 | 38 | def _convert_BatchNorm(net, node, graph, err): 39 | scale = node.input_tensors[node.inputs[1]] 40 | bias = node.input_tensors[node.inputs[2]] 41 | mean = node.input_tensors[node.inputs[3]] 42 | var = node.input_tensors[node.inputs[4]] 43 | node_name = node.name 44 | np.copyto(net.params[node_name + '_bn'][0].data, mean, casting='same_kind') 45 | np.copyto(net.params[node_name + '_bn'][1].data, var, casting='same_kind') 46 | net.params[node_name + '_bn'][2].data[...] = 1.0 47 | np.copyto(net.params[node_name][0].data, scale, casting='same_kind') 48 | np.copyto(net.params[node_name][1].data, bias, casting='same_kind') 49 | # net.params[node_name+'_bn'][1].data = var 50 | # net.params[node_name][0].data = scale 51 | # net.params[node_name][1].data = bias 52 | 53 | def _convert_Add(net, node, graph, err): 54 | pass 55 | 56 | def _convert_Mul(net, node, graph, err): 57 | pass 58 | 59 | def _convert_Reshape(net, node, graph, err): 60 | pass 61 | 62 | def _convert_Flatten(net, node, graph, err): 63 | pass 64 | 65 | def _convert_pool(net, node, graph, err): 66 | pass 67 | 68 | def _convert_dropout(net, node, graph, err): 69 | pass 70 | 71 | def _convert_gemm(net, node, graph, err): 72 | node_name = node.name 73 | weight_name = node.inputs[1] 74 | if weight_name in node.input_tensors: 75 | W = node.input_tensors[weight_name] 76 | else: 77 | err.missing_initializer(node, 78 | "Weight tensor: {} not found in the graph initializer".format(weight_name, )) 79 | if node.attrs["broadcast"] != 1 or node.attrs["transB"] != 1: 80 | return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") 81 | b = None 82 | if len(node.inputs) > 2: 83 | b = node.input_tensors[node.inputs[2]] 84 | if len(W.shape) != 2 or (b is not None and len(b.shape) != 1): 85 | return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") 86 | if b is not None: 87 | if W.shape[0] != b.shape[0]: 88 | return err.unsupported_op_configuration(node, "Gemm is supported only for inner_product layer") 89 | net.params[node_name][0].data[...] = W 90 | net.params[node_name][1].data[...] = b 91 | 92 | def _convert_upsample(net, node, graph, err): 93 | mode = node.attrs["mode"] 94 | node_name = node.name 95 | if mode == "nearest": 96 | caffe_params = net.params[node_name][0].data 97 | weights = np.ones(caffe_params.shape).astype("float32") 98 | np.copyto(net.params[node_name][0].data, weights, casting='same_kind') 99 | # net.params[node_name][0].data[] 100 | 101 | def _convert_concat(net, node, graph, err): 102 | pass 103 | 104 | def _convert_conv_transpose(net, node, graph, err): 105 | weight_name = node.inputs[1] 106 | input_name = str(node.inputs[0]) 107 | output_name = str(node.outputs[0]) 108 | node_name = node.name 109 | W = None 110 | if weight_name in node.input_tensors: 111 | W = node.input_tensors[weight_name] 112 | else: 113 | err.missing_initializer(node, 114 | "Weight tensor: {} not found in the graph initializer".format(weight_name,)) 115 | bias_flag = False 116 | bias = None 117 | if len(node.inputs) > 2: 118 | bias = node.input_tensors[node.inputs[2]] 119 | bias_flag = True 120 | # net.params[node_name][0].data = W 121 | # if bias_flag: 122 | # net.params[node_name][1].data = bias 123 | np.copyto(net.params[node_name][0].data,W,casting='same_kind') 124 | if bias_flag: 125 | np.copyto(net.params[node_name][1].data, bias, casting='same_kind') 126 | 127 | _ONNX_NODE_REGISTRY = { 128 | "Conv": _convert_conv, 129 | "Relu": _convert_relu, 130 | "BatchNormalization": _convert_BatchNorm, 131 | "Add": _convert_Add, 132 | "Mul": _convert_Mul, 133 | "Reshape": _convert_Reshape, 134 | "MaxPool": _convert_pool, 135 | "AveragePool": _convert_pool, 136 | "Dropout": _convert_dropout, 137 | "Gemm": _convert_gemm, 138 | "Upsample": _convert_upsample, 139 | "Concat": _convert_concat, 140 | "ConvTranspose": _convert_conv_transpose, 141 | "Sigmoid": _convert_sigmoid, 142 | "Flatten": _convert_Flatten, 143 | } 144 | 145 | 146 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import onnx 3 | import numpy as np 4 | import caffe 5 | caffe.set_mode_cpu() 6 | import importlib 7 | from convertCaffe import convertToCaffe, getGraph 8 | import os 9 | 10 | def getPytorchModel(name): 11 | py_model_path = 'model' 12 | module = importlib.import_module("model_generator."+name) 13 | var, model = module.get_model_and_input(py_model_path) 14 | return var, model 15 | 16 | module_name_list = [ 17 | "broadcast_mul", 18 | "broadcast_add", 19 | "googlenet", 20 | "resnet", 21 | "MobileNetV2", 22 | ] 23 | 24 | model_save_dir = 'model' 25 | if not os.path.isdir(model_save_dir): 26 | os.makedirs(model_save_dir) 27 | 28 | for module_name in module_name_list: 29 | print("export {} onnx model ...".format(module_name)) 30 | module = importlib.import_module("model_generator"+"."+module_name) 31 | module.export(model_save_dir) 32 | 33 | var, pt_model = getPytorchModel(module_name) 34 | var_numpy = var.data.numpy() 35 | pt_model.eval() 36 | pt_out = pt_model(var) 37 | pt_out = pt_out.data.numpy() 38 | onnx_path = os.path.join("model", module_name+'.onnx') 39 | prototxt_path = os.path.join("model", module_name+'.prototxt') 40 | caffemodel_path = os.path.join("model", module_name+'.caffemodel') 41 | 42 | graph = getGraph(onnx_path) 43 | print("converting {} to caffe ...".format(module_name)) 44 | caffe_model = convertToCaffe(graph, prototxt_path, caffemodel_path) 45 | 46 | input_name = str(graph.inputs[0][0]) 47 | output_name = str(graph.outputs[0][0]) 48 | 49 | caffe_model.blobs[input_name].data[...] = var_numpy 50 | net_output = caffe_model.forward() 51 | caffe_out = net_output[output_name] 52 | 53 | minus_result = caffe_out-pt_out 54 | mse = np.sum(minus_result*minus_result) 55 | 56 | print("{} mse between caffe and pytorch model output: {}".format(module_name,mse)) 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | --------------------------------------------------------------------------------