├── README.md ├── find.py ├── json2prototxt.py ├── mxnet2caffe.py └── prototxt_basic.py /README.md: -------------------------------------------------------------------------------- 1 | # Advanced-Mxnet2Caffe 2 | 3 | ### Operator Support Lists 4 | 5 | - Convolution 6 | - ChannelwiseConvolution 7 | - BatchNorm 8 | - Activation 9 | - ElementWiseSum 10 | - _Plus 11 | - Concat 12 | - Crop 13 | - Pooling 14 | - Flatten 15 | - FullyConnected 16 | - SoftmaxOutput&SoftmaxFocalOutput 17 | - SoftmaxActivation 18 | - LeakyReLU 19 | - elemwise_add 20 | - UpSampling 21 | - Deconvolution 22 | - Clip 23 | - Reshape 24 | 25 | ### Tested models 26 | + Mxnet-SSH 27 | + MobileNet-V2 28 | + Resnet-50 29 | + RetinaFace with ( Resnet-50 ,Mobilenet 0.25 backbone) 30 | + All models from Insightface Model Zoo . 31 | 32 | ### Note&Bugs 33 | 34 | The convertor Is not fully automatically, The convertor not 35 | 36 | + if you wanna convert upsampling operator , the convertor will convert Upsampling operator to Deconvolution in Caffe , The Deconvolution channels need to be set (in prototxt_basic.py names_output). 37 | + If you use Flatten Layer ,You need to manually to connect them becasuse the converted compute graph will be divided into two parts. 38 | + If convert a detection model. You need to remove the anchor process and put it into post process. 39 | + Usually,If you find that conversion errors, please set the prefix name of you backbone network in mxnet2caffe.py. -------------------------------------------------------------------------------- /find.py: -------------------------------------------------------------------------------- 1 | from difflib import SequenceMatcher 2 | import json 3 | import collections 4 | 5 | 6 | def find_backbone(json_path): 7 | with open(json_path) as json_file: 8 | jdata = json.load(json_file) 9 | 10 | matches = [] 11 | for i_node in range(0, len(jdata['nodes']) - 1): 12 | node_i1 = jdata['nodes'][i_node] 13 | node_i2 = jdata['nodes'][i_node+1] 14 | name1 = (node_i1['name']) 15 | name2 = (node_i2['name']) 16 | 17 | match = SequenceMatcher(None, name1, name2).find_longest_match(0, name1.find('_'), 0, name2.find('_')) 18 | matches.append(name1[match.a: match.a + match.size]) 19 | 20 | counter = collections.Counter(matches) 21 | final_match = counter.most_common()[0][0] 22 | 23 | return final_match 24 | -------------------------------------------------------------------------------- /json2prototxt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import json 4 | from prototxt_basic import * 5 | 6 | parser = argparse.ArgumentParser(description='Convert MXNet jason to Caffe prototxt') 7 | parser.add_argument('--mx-json', type=str, default='R50v2/R50v2-symbol.json') 8 | parser.add_argument('--cf-prototxt', type=str, default='R50v2/R50v2.prototxt') 9 | parser.add_argument('--input_shape', type=str, default='1,3,640,640') 10 | args = parser.parse_args() 11 | 12 | with open(args.mx_json) as json_file: 13 | jdata = json.load(json_file) 14 | print(jdata) 15 | 16 | with open(args.cf_prototxt, "w") as prototxt_file: 17 | for i_node in range(0,len(jdata['nodes'])): 18 | node_i = jdata['nodes'][i_node] 19 | if str(node_i['op']) == 'null' and str(node_i['name']) != 'data': 20 | continue 21 | 22 | print('{}, \top:{}, name:{} -> {}'.format(i_node,node_i['op'].ljust(20), 23 | node_i['name'].ljust(30), 24 | node_i['name']).ljust(20)) 25 | info = node_i 26 | 27 | info['top'] = info['name'] 28 | info['bottom'] = [] 29 | info['params'] = [] 30 | for input_idx_i in node_i['inputs']: 31 | input_i = jdata['nodes'][input_idx_i[0]] 32 | if str(input_i['op']) != 'null' or (str(input_i['name']) == 'data'): 33 | info['bottom'].append(str(input_i['name'])) 34 | if str(input_i['op']) == 'null': 35 | info['params'].append(str(input_i['name'])) 36 | if not str(input_i['name']).startswith(str(node_i['name'])): 37 | print(' use shared weight -> %s'% str(input_i['name'])) 38 | info['share'] = True 39 | 40 | if str(node_i['op']) == 'data': 41 | for char in ['[', ']', '(', ')']: 42 | input_shape = args.input_shape.replace(char, '') 43 | input_shape = [int(item) for item in input_shape.split(',')] 44 | info["shape"] = input_shape 45 | 46 | write_node(prototxt_file, info) 47 | 48 | print("*** JSON to PROTOTXT FINISH ***") 49 | 50 | -------------------------------------------------------------------------------- /mxnet2caffe.py: -------------------------------------------------------------------------------- 1 | import sys, argparse 2 | import mxnet as mx 3 | import sys 4 | import os 5 | 6 | try: 7 | import caffe 8 | except ImportError: 9 | import os, sys 10 | curr_path = os.path.abspath(os.path.dirname(__file__)) 11 | sys.path.append(os.path.join(curr_path, "/Users/yujinke/me/caffe/python")) 12 | import caffe 13 | 14 | from find import * 15 | 16 | import time 17 | import os 18 | os.environ["CUDA_VISIBLE_DEVICES"] = '4' 19 | parser = argparse.ArgumentParser(description='Convert MXNet model to Caffe model') 20 | parser.add_argument('--mx-model', type=str, default='model_mxnet/face/facega2') 21 | parser.add_argument('--mx-epoch', type=int, default=0) 22 | parser.add_argument('--cf-prototxt', type=str, default='model_caffe/face/facega2.prototxt') 23 | parser.add_argument('--cf-model', type=str, default='model_caffe/face/facega2.caffemodel') 24 | args = parser.parse_args() 25 | 26 | # ------------------------------------------ 27 | # Load 28 | _, arg_params, aux_params = mx.model.load_checkpoint(args.mx_model, args.mx_epoch) 29 | #net = caffe.Net(args.cf_prototxt, caffe.TRAIN) 30 | net = caffe.Net(args.cf_prototxt, caffe.TEST) 31 | 32 | 33 | # ------------------------------------------ 34 | # Convert 35 | all_keys = arg_params.keys() + aux_params.keys() 36 | all_keys.sort() 37 | 38 | print('----------------------------------\n') 39 | print('ALL KEYS IN MXNET:') 40 | print(all_keys) 41 | print('%d KEYS' %len(all_keys)) 42 | print('----------------------------------\n') 43 | print('VALID KEYS:') 44 | 45 | # backbone = "hstage1" 46 | backbone = find_backbone(args.mx_model + '-symbol.json') 47 | 48 | for i_key,key_i in enumerate(all_keys): 49 | 50 | # try: 51 | 52 | if 'data' is key_i: 53 | pass 54 | elif '_weight' in key_i: 55 | if key_i.find(backbone)!=-1 or key_i.find("dense") != -1: 56 | key_caffe = key_i.replace('_weight', '_fwd') 57 | else: 58 | key_caffe = key_i.replace('_weight','') 59 | 60 | # if 61 | # key_caffe = key_i.replace('_weight', '_fwd') 62 | # else: 63 | # key_caffe = key_i.replace('_weight', '') 64 | print(key_i,key_caffe) 65 | print("{}: {}->{}".format(key_i, arg_params[key_i].shape, net.params[key_caffe][0].data.shape)) 66 | # if 'fc' in key_i: 67 | # print key_i 68 | # print arg_params[key_i].shape 69 | # print net.params[key_caffe][0].data.shape 70 | net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat 71 | elif '_bias' in key_i: 72 | if key_i.find(backbone)!=-1: 73 | key_caffe = key_i.replace('_bias', '_fwd') 74 | else: 75 | key_caffe = key_i.replace('_bias','') 76 | 77 | if key_i.find("dense") != -1: 78 | key_caffe = key_i.replace('_bias', '_fwd') 79 | else: 80 | key_caffe = key_i.replace('_bias', '') 81 | 82 | print("{}: {}->{}".format(key_i, arg_params[key_i].shape, net.params[key_caffe][0].data.shape)) 83 | net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat 84 | elif '_gamma' in key_i and 'relu' not in key_i: 85 | if key_i.find(backbone)!=-1: 86 | key_caffe = key_i.replace('_gamma', '_fwd_scale') 87 | else: 88 | key_caffe = key_i.replace('_gamma','_scale') 89 | 90 | print("{}: {}->{}".format(key_i, arg_params[key_i].shape, net.params[key_caffe][0].data.shape)) 91 | net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat 92 | # TODO: support prelu 93 | elif '_gamma' in key_i and 'relu' in key_i: # for prelu 94 | key_caffe = key_i.replace('_gamma','') 95 | print("key_i",key_i) 96 | print("{}: {}->{}".format(key_i, arg_params[key_i].shape, net.params[key_caffe][0].data.shape)) 97 | assert (len(net.params[key_caffe]) == 1) 98 | net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat 99 | elif '_beta' in key_i: 100 | 101 | if key_i.find(backbone)!=-1: 102 | key_caffe = key_i.replace('_beta', '_fwd_scale') 103 | else: 104 | key_caffe = key_i.replace('_beta','_scale') 105 | 106 | print("key in mxnet",key_i,key_i in arg_params.keys()) 107 | print("key in caffe",key_caffe,key_caffe in net.params.keys()) 108 | print("{}: {}->{}".format(key_i, arg_params[key_i].shape, net.params[key_caffe][0].data.shape)) 109 | net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat 110 | elif '_moving_mean' in key_i: 111 | key_caffe = key_i.replace('_moving_mean','') 112 | print("{}: {}->{}".format(key_i, aux_params[key_i].shape, net.params[key_caffe][0].data.shape)) 113 | net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat 114 | net.params[key_caffe][2].data[...] = 1 115 | elif '_moving_var' in key_i: 116 | key_caffe = key_i.replace('_moving_var','') 117 | print("{}: {}->{}".format(key_i, aux_params[key_i].shape, net.params[key_caffe][0].data.shape)) 118 | net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat 119 | net.params[key_caffe][2].data[...] = 1 120 | elif '_running_mean' in key_i: 121 | exit() 122 | key_caffe = key_i.replace('_running_mean', '_fwd') 123 | print("{}: {}->{}".format(key_i, aux_params[key_i].shape, net.params[key_caffe][0].data.shape)) 124 | net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat 125 | net.params[key_caffe][2].data[...] = 1 126 | elif '_running_var' in key_i: 127 | exit() 128 | key_caffe = key_i.replace('_running_var', '_fwd') 129 | print("{}: {}->{}".format(key_i, aux_params[key_i].shape, net.params[key_caffe][0].data.shape)) 130 | net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat 131 | net.params[key_caffe][2].data[...] = 1 132 | else: 133 | # pass 134 | sys.exit("Warning! Unknown mxnet:{}".format(key_i)) 135 | 136 | print("% 3d | %s -> %s, initialized." 137 | %(i_key, key_i.ljust(40), key_caffe.ljust(30))) 138 | 139 | # except KeyError: 140 | # pass 141 | # 142 | # import traceback 143 | # print(traceback.print_exc()) 144 | # print("\nError! key error mxnet:{}".format(key_i)) 145 | # break 146 | # 147 | # ------------------------------------------ 148 | # Finish 149 | net.save(args.cf_model) 150 | 151 | print("\n*** PARAMS to CAFFEMODEL Finished. ***\n") 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /prototxt_basic.py: -------------------------------------------------------------------------------- 1 | # prototxt_basic 2 | 3 | import math 4 | 5 | attrstr = "attrs" 6 | #attrstr = "param" 7 | 8 | names_output = {"rf_c2_upsampling":256 ,"rf_c3_upsampling":256} 9 | #names_output = {"ssh_m2_red_up":32,"ssh_c3_up":32 } 10 | 11 | def data(txt_file, info): 12 | txt_file.write('name: "mxnet-mdoel"\n') 13 | txt_file.write('layer {\n') 14 | txt_file.write(' name: "data"\n') 15 | txt_file.write(' type: "Input"\n') 16 | txt_file.write(' top: "data"\n') 17 | txt_file.write(' input_param {\n') 18 | txt_file.write(' shape: {{ dim: {} dim: {} dim: {} dim: {} }}\n'.format(info['shape'][0], 19 | info['shape'][1], 20 | info['shape'][2], 21 | info['shape'][3])) 22 | txt_file.write(' }\n') 23 | txt_file.write('}\n') 24 | txt_file.write('\n') 25 | 26 | def fuzzy_haskey(d, key): 27 | for eachkey in d: 28 | if key in eachkey: 29 | return True 30 | return False 31 | 32 | def Convolution(txt_file, info): 33 | print(info[attrstr]) 34 | if fuzzy_haskey(info['params'], 'bias'): 35 | bias_term = 'true' 36 | elif info[attrstr].has_key('no_bias') and info['attrs']['no_bias'] == 'True': 37 | 38 | bias_term = 'false' 39 | else: 40 | bias_term = 'true' 41 | txt_file.write('layer {\n') 42 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 43 | txt_file.write(' top: "%s"\n' % info['top']) 44 | txt_file.write(' name: "%s"\n' % info['top']) 45 | txt_file.write(' type: "Convolution"\n') 46 | txt_file.write(' convolution_param {\n') 47 | txt_file.write(' num_output: %s\n' % info[attrstr]['num_filter']) 48 | txt_file.write(' kernel_size: %s\n' % info[attrstr]['kernel'].split('(')[1].split(',')[0]) # TODO 49 | if info[attrstr].has_key('pad'): 50 | txt_file.write(' pad: %s\n' % info[attrstr]['pad'].split('(')[1].split(',')[0]) # TODO 51 | if info[attrstr].has_key('num_group'): 52 | txt_file.write(' group: %s\n' % info[attrstr]['num_group']) 53 | if info[attrstr].has_key('stride'): 54 | txt_file.write(' stride: %s\n' % info[attrstr]['stride'].split('(')[1].split(',')[0]) 55 | if info[attrstr].has_key('dilate'): 56 | txt_file.write(' dilation: %s\n' % info[attrstr]['dilate'].split('(')[1].split(',')[0]) 57 | txt_file.write(' bias_term: %s\n' % bias_term) 58 | txt_file.write(' }\n') 59 | if 'share' in info.keys() and info['share']: 60 | txt_file.write(' param {\n') 61 | txt_file.write(' name: "%s"\n' % info['params'][0]) 62 | txt_file.write(' }\n') 63 | txt_file.write('}\n') 64 | txt_file.write('\n') 65 | 66 | def ChannelwiseConvolution(txt_file, info): 67 | Convolution(txt_file, info) 68 | 69 | def BatchNorm(txt_file, info): 70 | #pprint.pprint(info) 71 | txt_file.write('layer {\n') 72 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 73 | txt_file.write(' top: "%s"\n' % info['top']) 74 | txt_file.write(' name: "%s"\n' % info['top']) 75 | txt_file.write(' type: "BatchNorm"\n') 76 | txt_file.write(' batch_norm_param {\n') 77 | txt_file.write(' use_global_stats: true\n') # TODO 78 | if info[attrstr].has_key('momentum'): 79 | txt_file.write(' moving_average_fraction: %s\n' % info[attrstr]['momentum']) 80 | else: 81 | txt_file.write(' moving_average_fraction: 0.9\n') 82 | if info[attrstr].has_key('eps'): 83 | txt_file.write(' eps: %s\n' % info[attrstr]['eps']) 84 | else: 85 | txt_file.write(' eps: 0.001\n') 86 | txt_file.write(' }\n') 87 | txt_file.write('}\n') 88 | # if info['fix_gamma'] is "False": # TODO 89 | txt_file.write('layer {\n') 90 | txt_file.write(' bottom: "%s"\n' % info['top']) 91 | txt_file.write(' top: "%s"\n' % info['top']) 92 | txt_file.write(' name: "%s_scale"\n' % info['top']) 93 | txt_file.write(' type: "Scale"\n') 94 | txt_file.write(' scale_param { bias_term: true }\n') 95 | txt_file.write('}\n') 96 | txt_file.write('\n') 97 | pass 98 | 99 | def Activation(txt_file, info): 100 | txt_file.write('layer {\n') 101 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 102 | txt_file.write(' top: "%s"\n' % info['top']) 103 | txt_file.write(' name: "%s"\n' % info['top']) 104 | if info[attrstr]['act_type']=='sigmoid': 105 | txt_file.write(' type: "Sigmoid"\n') 106 | else: 107 | txt_file.write(' type: "ReLU"\n') # TODO 108 | txt_file.write('}\n') 109 | txt_file.write('\n') 110 | pass 111 | 112 | def Activation_Relu6(txt_file, info): 113 | info[attrstr]['act_type'] = 'ReLU' 114 | Activation(txt_file, info) 115 | pass 116 | 117 | def Deconvolution(txt_file, info): 118 | if fuzzy_haskey(info['params'], 'bias'): 119 | bias_term = 'true' 120 | elif info[attrstr].has_key('no_bias') and info['attrs']['no_bias'] == 'True': 121 | bias_term = 'false' 122 | else: 123 | bias_term = 'true' 124 | txt_file.write('layer {\n') 125 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 126 | txt_file.write(' top: "%s"\n' % info['top']) 127 | txt_file.write(' name: "%s"\n' % info['top']) 128 | txt_file.write(' type: "Deconvolution"\n') 129 | txt_file.write(' convolution_param {\n') 130 | 131 | txt_file.write(' num_output: %s\n' % info[attrstr]['num_filter']) 132 | txt_file.write(' kernel_size: %s\n' % info[attrstr]['kernel'].split('(')[1].split(',')[0]) # TODO 133 | if info[attrstr].has_key('pad'): 134 | txt_file.write(' pad: %s\n' % info[attrstr]['pad'].split('(')[1].split(',')[0]) # TODO 135 | if info[attrstr].has_key('num_group'): 136 | txt_file.write(' group: %s\n' % info[attrstr]['num_group']) 137 | if info[attrstr].has_key('stride'): 138 | txt_file.write(' stride: %s\n' % info[attrstr]['stride'].split('(')[1].split(',')[0]) 139 | if info[attrstr].has_key('dilate'): 140 | txt_file.write(' dilation: %s\n' % info[attrstr]['dilate'].split('(')[1].split(',')[0]) 141 | txt_file.write(' bias_term: %s\n' % bias_term) 142 | 143 | txt_file.write('}\n') 144 | txt_file.write('}\n') 145 | txt_file.write('\n') 146 | pass 147 | 148 | def Upsampling(txt_file, info): 149 | scale = int(info[attrstr]['scale']) 150 | assert(scale > 0) 151 | txt_file.write('layer {\n') 152 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 153 | txt_file.write(' top: "%s"\n' % info['top']) 154 | txt_file.write(' name: "%s"\n' % info['top']) 155 | #txt_file.write(' name: "%s"\n' % info['top']) 156 | txt_file.write(' type: "Deconvolution"\n') 157 | print(info[attrstr]) 158 | print(info) 159 | txt_file.write(' convolution_param {\n') 160 | txt_file.write(' num_output: %s\n' % names_output[info["name"]]) 161 | #txt_file.write(' num_output: %s\n' % info[attrstr]['num_filter']) 162 | txt_file.write(' kernel_size: %d\n' % (2 * scale - scale % 2)) # TODO 163 | txt_file.write(' stride: %d\n' % scale) 164 | txt_file.write(' pad: %d\n' % math.ceil((scale - 1)/2.0)) # TODO 165 | #txt_file.write(' group: %s\n' % info[attrstr]['num_filter']) 166 | txt_file.write(' group: %s\n' % names_output[info["name"]]) 167 | 168 | txt_file.write(' bias_term: false\n') 169 | txt_file.write(' weight_filler: {\n') 170 | txt_file.write(' type: "bilinear"\n') 171 | txt_file.write(' }\n') 172 | 173 | txt_file.write(' }\n') 174 | txt_file.write('}\n') 175 | txt_file.write('\n') 176 | pass 177 | 178 | def Concat(txt_file, info): 179 | txt_file.write('layer {\n') 180 | txt_file.write(' name: "%s"\n' % info['top']) 181 | txt_file.write(' type: "Concat"\n') 182 | for bottom_i in info['bottom']: 183 | txt_file.write(' bottom: "%s"\n' % bottom_i) 184 | txt_file.write(' top: "%s"\n' % info['top']) 185 | txt_file.write('}\n') 186 | txt_file.write('\n') 187 | pass 188 | 189 | # 190 | 191 | def Crop(txt_file, info): 192 | txt_file.write('layer {\n') 193 | txt_file.write(' type: "Crop"\n') 194 | txt_file.write(' top: "%s"\n' % info['top']) 195 | txt_file.write(' name: "%s"\n' % info['top']) 196 | for btom in info['bottom']: 197 | txt_file.write(' bottom: "%s"\n' % btom) 198 | txt_file.write(' crop_param { \n axis: 2 \n offset: 0 \n } \n' ) 199 | txt_file.write('}\n') 200 | txt_file.write('\n') 201 | 202 | def ElementWiseSum(txt_file, info): 203 | txt_file.write('layer {\n') 204 | txt_file.write(' name: "%s"\n' % info['top']) 205 | txt_file.write(' type: "Eltwise"\n') 206 | for bottom_i in info['bottom']: 207 | txt_file.write(' bottom: "%s"\n' % bottom_i) 208 | txt_file.write(' top: "%s"\n' % info['top']) 209 | txt_file.write(' eltwise_param { operation: SUM }\n') 210 | txt_file.write('}\n') 211 | txt_file.write('\n') 212 | pass 213 | 214 | def Pooling(txt_file, info): 215 | pool_type = 'AVE' if info[attrstr]['pool_type'] == 'avg' else 'MAX' 216 | txt_file.write('layer {\n') 217 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 218 | txt_file.write(' top: "%s"\n' % info['top']) 219 | txt_file.write(' name: "%s"\n' % info['top']) 220 | txt_file.write(' type: "Pooling"\n') 221 | txt_file.write(' pooling_param {\n') 222 | txt_file.write(' pool: %s\n' % pool_type) # TODO 223 | if info[attrstr].has_key('global_pool') and info[attrstr]['global_pool'] == 'True': 224 | txt_file.write(' global_pooling: true\n') 225 | else: 226 | txt_file.write(' kernel_size: %s\n' % info[attrstr]['kernel'].split('(')[1].split(',')[0]) 227 | txt_file.write(' stride: %s\n' % info[attrstr]['stride'].split('(')[1].split(',')[0]) 228 | if info[attrstr].has_key('pad'): 229 | txt_file.write(' pad: %s\n' % info[attrstr]['pad'].split('(')[1].split(',')[0]) 230 | txt_file.write(' }\n') 231 | txt_file.write('}\n') 232 | txt_file.write('\n') 233 | pass 234 | 235 | 236 | def FullyConnected(txt_file, info): 237 | txt_file.write('layer {\n') 238 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 239 | txt_file.write(' top: "%s"\n' % info['top']) 240 | txt_file.write(' name: "%s"\n' % info['top']) 241 | txt_file.write(' type: "InnerProduct"\n') 242 | txt_file.write(' inner_product_param {\n') 243 | txt_file.write(' num_output: %s\n' % info[attrstr]['num_hidden']) 244 | txt_file.write(' }\n') 245 | txt_file.write('}\n') 246 | txt_file.write('\n') 247 | pass 248 | 249 | 250 | 251 | import json 252 | def Reshape(txt_file, info): 253 | print(info) 254 | txt_file.write('layer {\n') 255 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 256 | txt_file.write(' top: "%s"\n' % info['top']) 257 | txt_file.write(' name: "%s"\n' % info['top']) 258 | txt_file.write(' type: "Reshape"\n') 259 | g =eval(info["attrs"]["shape"]) 260 | print("reshape",g) 261 | # exit() 262 | txt_file.write(' reshape_param { \nshape\n {dim: '+str(g[0])+' \ndim: '+str(g[1])+' \n dim: '+str(g[2])+' \ndim: '+str(g[3])+' \n} \n}') 263 | 264 | 265 | txt_file.write('}\n') 266 | txt_file.write('\n') 267 | pass 268 | 269 | def Flatten(txt_file, info): 270 | pass 271 | 272 | def SoftmaxActivation(txt_file, info): 273 | # softmax 274 | txt_file.write('layer {\n') 275 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 276 | txt_file.write(' top: "%s"\n'%(info['top'])) 277 | txt_file.write(' name: "%s"\n'%(info['top'])) 278 | txt_file.write(' type: "Softmax"\n') 279 | txt_file.write(' softmax_param: {\n') 280 | txt_file.write(' axis: 1\n') 281 | txt_file.write(' }\n') 282 | txt_file.write('}\n') 283 | txt_file.write('\n') 284 | 285 | 286 | def SoftmaxOutput(txt_file, info): 287 | # softmax 288 | txt_file.write('layer {\n') 289 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 290 | txt_file.write(' top: "prob"\n') 291 | txt_file.write(' name: "prob"\n') 292 | txt_file.write(' type: "Softmax"\n') 293 | txt_file.write(' softmax_param: {\n') 294 | txt_file.write(' axis: 1\n') 295 | txt_file.write(' }\n') 296 | txt_file.write('}\n') 297 | txt_file.write('\n') 298 | 299 | # argmax 300 | txt_file.write('layer {\n') 301 | txt_file.write(' bottom: "prob"\n') 302 | #txt_file.write(' top: "%s"\n' % info['top']) 303 | txt_file.write(' top: "out_label"\n') 304 | #txt_file.write(' name: "%s"\n' % info['top']) 305 | txt_file.write(' name: "out_label"\n') 306 | txt_file.write(' type: "ArgMax"\n') 307 | txt_file.write(' argmax_param: {\n') 308 | txt_file.write(' axis: 1\n') 309 | txt_file.write(' top_k: 1\n') 310 | txt_file.write(' }\n') 311 | txt_file.write('}\n') 312 | txt_file.write('\n') 313 | 314 | def LeakyReLU(txt_file, info): 315 | if info[attrstr]['act_type'] == 'elu': 316 | txt_file.write('layer {\n') 317 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 318 | txt_file.write(' top: "%s"\n' % info['top']) 319 | txt_file.write(' name: "%s"\n' % info['top']) 320 | txt_file.write(' type: "ELU"\n') 321 | txt_file.write(' elu_param { alpha: 0.25 }\n') 322 | txt_file.write('}\n') 323 | txt_file.write('\n') 324 | elif info[attrstr]['act_type'] == 'prelu': 325 | txt_file.write('layer {\n') 326 | txt_file.write(' bottom: "%s"\n' % info['bottom'][0]) 327 | txt_file.write(' top: "%s"\n' % info['top']) 328 | txt_file.write(' name: "%s"\n' % info['top']) 329 | txt_file.write(' type: "PReLU"\n') 330 | txt_file.write('}\n') 331 | txt_file.write('\n') 332 | else: 333 | raise Exception("unsupported Activation") 334 | 335 | def Eltwise(txt_file, info, op): 336 | txt_file.write('layer {\n') 337 | txt_file.write(' type: "Eltwise"\n') 338 | txt_file.write(' top: "%s"\n' % info['top']) 339 | txt_file.write(' name: "%s"\n' % info['top']) 340 | for btom in info['bottom']: 341 | txt_file.write(' bottom: "%s"\n' % btom) 342 | txt_file.write(' eltwise_param { operation: %s }\n' % op) 343 | txt_file.write('}\n') 344 | txt_file.write('\n') 345 | 346 | # ---------------------------------------------------------------- 347 | def write_node(txt_file, info): 348 | # info["top"] = info["top"].replace("_fwd","") 349 | if 'label' in info['name']: 350 | return 351 | if info['op'] == 'null' and info['name'] == 'data': 352 | data(txt_file, info) 353 | elif info['op'] == 'Convolution': 354 | Convolution(txt_file, info) 355 | elif info['op'] == 'ChannelwiseConvolution': 356 | ChannelwiseConvolution(txt_file, info) 357 | elif info['op'] == 'BatchNorm': 358 | BatchNorm(txt_file, info) 359 | elif info['op'] == 'Activation': 360 | Activation(txt_file, info) 361 | elif info['op'] == 'ElementWiseSum': 362 | ElementWiseSum(txt_file, info) 363 | elif info['op'] == '_Plus': 364 | ElementWiseSum(txt_file, info) 365 | elif info['op'] == 'Concat': 366 | Concat(txt_file, info) 367 | elif info['op'] == 'Crop': 368 | Crop(txt_file,info) 369 | elif info['op'] == 'Pooling': 370 | Pooling(txt_file, info) 371 | elif info['op'] == 'Flatten': 372 | Flatten(txt_file, info) 373 | elif info['op'] == 'FullyConnected': 374 | FullyConnected(txt_file, info) 375 | elif info['op'] == 'SoftmaxOutput' or info['op'] == 'SoftmaxFocalOutput' : 376 | SoftmaxOutput(txt_file, info) 377 | elif info['op'] == 'LeakyReLU': 378 | LeakyReLU(txt_file, info) 379 | elif info['op'] == 'elemwise_add': 380 | ElementWiseSum(txt_file, info) 381 | elif info['op'] == 'UpSampling': 382 | Upsampling(txt_file, info) 383 | elif info['op'] == 'Deconvolution': 384 | Deconvolution(txt_file, info) 385 | elif info['op'] == 'clip': 386 | Activation_Relu6(txt_file, info) 387 | elif info['op'] == 'Reshape': 388 | Reshape(txt_file, info) 389 | elif info['op'] == 'SoftmaxActivation': 390 | SoftmaxActivation(txt_file, info) 391 | # pass 392 | # Activation_Relu6(txt_file, info) 393 | else: 394 | # pass 395 | print("unknown",info) 396 | # raise Exception("Warning! Skip Unknown mxnet op:{}".format(info['op'])) 397 | 398 | 399 | 400 | 401 | 402 | --------------------------------------------------------------------------------