├── README.md ├── .gitignore ├── tool.py ├── validate_onnx.py ├── example.py └── surgery.py /README.md: -------------------------------------------------------------------------------- 1 | # onnx-surgery 2 | 3 | Chinese blog about this repo: http://bindog.github.io/blog/2020/03/13/deep-learning-model-convert-and-depoly/ 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | temp/ 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.swp 6 | *.ipynb 7 | *.ipynb_checkpoints 8 | *.log 9 | *.so 10 | *.sh 11 | *.pt 12 | *.onnx 13 | -------------------------------------------------------------------------------- /tool.py: -------------------------------------------------------------------------------- 1 | from onnx import numpy_helper 2 | 3 | 4 | def show_node_attributes(node): 5 | print("="*10, "attributes of node: ", node.name, "="*10) 6 | for attr in node.attribute: 7 | print(attr.name) 8 | print("="*60) 9 | 10 | 11 | def show_node_inputs(node): 12 | # Generally, the first input is the truely input 13 | # and the rest input is weight initializer 14 | print("="*10, "inputs of node: ", node.name, "="*10) 15 | for input_name in node.input: 16 | print(input_name) # type of input_name is str 17 | print("="*60) 18 | 19 | 20 | def show_node_outputs(node): 21 | # Generally, the first input is the truely input 22 | # and the rest input is weight initializer 23 | print("="*10, "outputs of node: ", node.name, "="*10) 24 | for output_name in node.output: 25 | print(output_name) # type of output_name is str 26 | print("="*60) 27 | 28 | 29 | def show_weight(weight): 30 | print("="*10, "details of weight: ", weight.name, "="*10) 31 | print("data type: ", weight.data_type) 32 | print("shape: ", weight.dims) 33 | data_numpy = numpy_helper.to_array(weight) 34 | # data_numpy = np.frombuffer(weight.raw_data, dtype=xxx) 35 | # print("detail data:", data_numpy) 36 | print("="*40) 37 | 38 | 39 | # TODO elementwise op (Div Sub ...) constant 40 | 41 | -------------------------------------------------------------------------------- /validate_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import onnxruntime 4 | 5 | import cv2 6 | 7 | 8 | def validate(model_path, image_path, batch_size=1, input_shape=(224, 224)): 9 | # input_shape (w, h) 10 | image = cv2.imread(image_path) 11 | image = cv2.resize(image, input_shape).astype(np.float32) 12 | image = image / 255.0 13 | img_data = np.expand_dims(image, 0) 14 | img_data = np.transpose(img_data, [0, 3, 1, 2]) 15 | x = np.repeat(img_data, batch_size, axis=0).astype(np.float32) 16 | try: 17 | session_option = onnxruntime.SessionOptions() 18 | session_option.log_severity_level = 4 19 | model = onnxruntime.InferenceSession(model_path, sess_options=session_option) 20 | ort_inputs_name = model.get_inputs()[0].name 21 | ort_ouputs_names = [out.name for out in model.get_outputs()] 22 | ort_outs = model.run(ort_ouputs_names, {ort_inputs_name: x.astype('float32')}) 23 | if len(ort_outs) > 1: 24 | outputs = tuple([np.array(out).astype("float32") for out in ort_outs]) 25 | for output in outputs: 26 | print("one of output shape: ", output.shape) 27 | return outputs 28 | else: 29 | outputs = np.array(ort_outs[0]).astype("float32") 30 | print("output shape: ", outputs.shape) 31 | return outputs 32 | except Exception as e: 33 | print("validate error, check error message below:") 34 | print(str(e)) 35 | return None 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="onnx validate") 40 | parser.add_argument("--model", default="", type=str, required=True) 41 | parser.add_argument("--image", default="", type=str, required=True) 42 | args = parser.parse_args() 43 | 44 | if validate(args.model, args.image) is not None: 45 | print("this onnx model seems ok") 46 | else: 47 | print("something wrong, please check your onnx model according to the error message...") 48 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | from onnx import numpy_helper 5 | 6 | from surgery import Surgery 7 | 8 | 9 | def old_mxnet_version_example(onnxsu): 10 | # NOTE 1 11 | # in some old version mxnet model, the fix_gamma in BatchNormalization is set to True, 12 | # but when converting to onnx model which do NOT have the fix_gamma attribute, and the 13 | # gamma (named scale in onnx) parameter is not all ones, it may cause result inconsistent 14 | # NOTE 2 15 | # in some old version mxnet model, the average pooling layer has an attribute "count_include_pad" 16 | # but is was not set when converting to onnx model, it seems like the default value is 1 17 | bn_nodes = onnxsu.get_nodes_by_optype("BatchNormalization") 18 | for bn_node in bn_nodes: 19 | gamma_name = bn_node.input[1] 20 | onnxsu.set_weight_by_name(gamma_name, all_ones=True) 21 | avg_nodes = onnxsu.get_nodes_by_optype("AveragePool") 22 | for avg_node in avg_nodes: 23 | onnxsu.set_node_attribute(avg_node, "count_include_pad", 1) 24 | 25 | 26 | def tf_set_batch_size_example(onnxsu, batch_size=8): 27 | # NOTE 28 | # when using tf2onnx convert the tensorflow pb model to onnx 29 | # the input batch_size dim is not set, we can append it 30 | onnxsu.list_model_inputs(2) 31 | # onnxsu.set_model_input_shape(name="pb_input:0", shape=(32,3,256,256)) 32 | onnxsu.set_model_input_batch_size(batch_size=batch_size) 33 | 34 | 35 | def debug_internal_output(onnxsu, node_name, output_name): 36 | # NOTE 37 | # sometimes we hope to get the internal result of some node for debug, 38 | # but onnx do NOT have the API to support this function. Don't worry, 39 | # we can append an Identity OP and an extra output following the target 40 | # node to get the result we want 41 | node = onnxsu.get_node_by_name(node_name) 42 | onnxsu.add_extra_output(node, output_name) 43 | 44 | 45 | def tensorrt_set_epsilon_example(onnxsu, epsilon=1e-3): 46 | # NOTE 47 | # We found when converting an onnx model with InstanceNormalization OP to TensorRT engine, the inference result is inaccurate 48 | # you can find the details at https://devtalk.nvidia.com/default/topic/1071094/tensorrt/inference-result-inaccurate-with-conv-and-instancenormalization-under-certain-conditions/ 49 | # After days of debugging, and we finally find this issue is caused by the following line of code 50 | # https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1557 51 | # it is strange that TensorRT onnx parser only supports epsilon >= 1e-4, if you do NOT 52 | # want to re-compile the TensorRT OSS, you can change epsilon to 1e-3 manually... 53 | # I tried comment out that line, it worked but the error is bigger than setting epsilon to 1e-3 54 | in_nodes = onnxsu.get_nodes_by_optype("InstanceNormalization") 55 | for in_node in in_nodes: 56 | onnxsu.set_node_attribute(in_node, "epsilon", epsilon) 57 | 58 | 59 | def add_conv_layer(onnxsu, target_node_name): 60 | # NOTE: 61 | # The name, attribute and weight of the OP can be found at: 62 | # https://github.com/onnx/onnx/blob/master/docs/Operators.md 63 | # You must convert all your weight and attribute to the standard 64 | # of the ONNX to avoid unexpected error 65 | target_node = onnxsu.get_node_by_name(target_node_name) 66 | # NOTE: 67 | # the weight name better be complicated enough to avoid conflict, 68 | # And weight_dict must be in order (make sure your python version >= 3.6) 69 | weight_dict = { 70 | "W_from_a_new_conv_op": np.random.normal(0, 1, (64, 64, 3, 3)).astype(np.float32), 71 | "B_from_a_new_conv_op": np.random.normal(0, 1, (64,)).astype(np.float32) 72 | } 73 | attr_dict = { 74 | "kernel_shape": [3, 3], 75 | "pads": [0, 0, 0, 0] 76 | } 77 | onnxsu.insert_op_before( 78 | node_name="new_conv_op", 79 | target_node=target_node, 80 | op_name="Conv", 81 | weight_dict=weight_dict, 82 | attr_dict=attr_dict 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser(description="onnx test") 88 | parser.add_argument("--input", default="", type=str, required=True) 89 | parser.add_argument("--output", default="", type=str, required=True) 90 | args = parser.parse_args() 91 | 92 | onnxsu = Surgery(args.input) 93 | 94 | # old_mxnet_version_example(onnxsu) 95 | # tf_set_batch_size_example(onnxsu, 16) 96 | # debug_internal_output(onnxsu, "your target node name", "debug_test") 97 | # tensorrt_set_epsilon_example(onnxsu, 1e-3) 98 | add_conv_layer(onnxsu, "resnetv24_batchnorm1_fwd") 99 | 100 | onnxsu.export(args.output) 101 | -------------------------------------------------------------------------------- /surgery.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import onnx 3 | import numpy as np 4 | from onnx import helper 5 | from onnx import numpy_helper 6 | 7 | 8 | class Surgery(object): 9 | def __init__(self, onnx_model_path): 10 | self.model = onnx.load(onnx_model_path) 11 | self.model = onnx.shape_inference.infer_shapes(self.model) 12 | 13 | def export(self, file_name, infer_shapes=False): 14 | if infer_shapes: 15 | self.model = onnx.shape_inference.infer_shapes(self.model) 16 | onnx.checker.check_model(self.model) 17 | self.model = onnx.shape_inference.infer_shapes(self.model) 18 | onnx.save(self.model, file_name) 19 | 20 | def list_model_inputs(self, nums): 21 | count = 0 22 | for mi in self.model.graph.input: 23 | print(mi) 24 | ''' 25 | # NOTE: 26 | # the shape or dim in tensor is something like this below 27 | # it is just a list of {}, both dim_param and dim_value are optional 28 | 29 | shape { 30 | dim { 31 | dim_param: "batch_size" 32 | dim_value: 32 33 | } 34 | dim { 35 | dim_param: "channel" 36 | dim_value: 3 37 | } 38 | dim { 39 | dim_param: "height" 40 | dim_value: 224 41 | } 42 | dim { 43 | dim_param: "weight" 44 | dim_value: 224 45 | } 46 | } 47 | 48 | # we can access them like this 49 | # tensor_dim = model_input.type.tensor_type.shape.dim 50 | # print(tensor_dim[0].dim_param) 51 | # print(tensor_dim[0].dim_param) 52 | # print(tensor_dim[x].dim_param) 53 | # print(tensor_dim[x].dim_value) 54 | ''' 55 | count += 1 56 | if count == nums: 57 | break 58 | 59 | def set_model_input_batch_size(self, index=0, name=None, batch_size=8): 60 | model_input = None 61 | if name is not None: 62 | # get model input by its name 63 | for mi in self.model.graph.input: 64 | if mi.name == name: 65 | model_input = mi 66 | else: 67 | model_input = self.model.graph.input[index] 68 | 69 | if model_input: 70 | model_input = self.model.graph.input[index] 71 | tensor_dim = model_input.type.tensor_type.shape.dim 72 | tensor_dim[0].ClearField("dim_param") 73 | tensor_dim[0].dim_value = batch_size 74 | else: 75 | print("get model input error, check your index or name") 76 | 77 | def set_model_input_shape(self, index=0, name=None, shape=None): 78 | model_input = None 79 | if name is not None: 80 | # get model input by its name 81 | for mi in self.model.graph.input: 82 | if mi.name == name: 83 | model_input = mi 84 | else: 85 | model_input = self.model.graph.input[index] 86 | 87 | if model_input: 88 | if shape is not None: 89 | model_input = self.model.graph.input[index] 90 | tensor_shape_proto = model_input.type.tensor_type.shape 91 | tensor_shape_proto.ClearField("dim") 92 | tensor_shape_proto.dim.extend([]) 93 | for d in shape: 94 | dim = tensor_shape_proto.dim.add() 95 | dim.dim_value = d 96 | else: 97 | print("input shape must be set") 98 | else: 99 | print("get model input error, check your index or name") 100 | 101 | def get_node_by_name(self, name): 102 | for node in self.model.graph.node: 103 | if node.name == name: 104 | return node 105 | 106 | def get_nodes_by_optype(self, typename): 107 | nodes = [] 108 | for node in self.model.graph.node: 109 | if node.op_type == typename: 110 | nodes.append(node) 111 | return nodes 112 | 113 | def get_weight_by_name(self, name): 114 | for weight in self.model.graph.initializer: 115 | if weight.name == name: 116 | return weight 117 | 118 | def set_weight_by_name(self, name, data_numpy=None, all_ones=False, all_zeros=False): 119 | weight = self.get_weight_by_name(name) 120 | self.set_weight(weight, data_numpy, all_ones, all_zeros) 121 | 122 | def remove_node_by_name(self, name): 123 | target_node = self.get_node_by_name(name) 124 | self.remove_node(target_node) 125 | 126 | def remove_node(self, target_node): 127 | ''' 128 | remove the node with only one input and only one output 129 | ''' 130 | node_input = target_node.input[0] 131 | node_output = target_node.output[0] 132 | # set input of successor node to predecessor node of target node 133 | for node in self.model.graph.node: 134 | for i, n in enumerate(node.input): 135 | if n == node_output: 136 | node.input[i] = node_input 137 | 138 | target_names = set(target_node.input) & set([weight.name for weight in self.model.graph.initializer]) 139 | self.remove_weights(target_names) 140 | target_names.add(node_output) 141 | self.remove_inputs(target_names) 142 | self.remove_value_infos(target_names) 143 | self.model.graph.node.remove(target_node) 144 | 145 | def remove_weights(self, name_list): 146 | rm_list = [] 147 | for weight in self.model.graph.initializer: 148 | if weight.name in name_list: 149 | rm_list.append(weight) 150 | for weight in rm_list: 151 | self.model.graph.initializer.remove(weight) 152 | 153 | def remove_inputs(self, name_list): 154 | rm_list = [] 155 | for input_t in self.model.graph.input: 156 | if input_t.name in name_list: 157 | rm_list.append(input_t) 158 | for input_t in rm_list: 159 | self.model.graph.input.remove(input_t) 160 | 161 | def remove_value_infos(self, name_list): 162 | rm_list = [] 163 | for value_info in self.model.graph.value_info: 164 | if value_info.name in name_list: 165 | rm_list.append(value_info) 166 | for value_info in rm_list: 167 | self.model.graph.value_info.remove(value_info) 168 | 169 | def set_weight(self, weight, data_numpy=None, all_ones=False, all_zeros=False): 170 | # NOTE: weight can be stroed in human readable fields(float_data, int32_data, string_data, ...) 171 | # as well as raw_data, if we set weight by raw_data, we must clear the fields above to make it effective 172 | # NOTE: data_type between numpy and TensorProto 173 | if data_numpy is not None: 174 | raw_shape = tuple([i for i in weight.dims]) 175 | new_shape = np.shape(data_numpy) 176 | if weight.data_type == 8: 177 | # string data type is special, it requires to store data in string_data field 178 | # NOT the raw_data field 179 | print("Can NOT handle string data type right now...") 180 | exit() 181 | # weight.string_data = bytes(data_numpy, encoding = "utf8") 182 | # weight.ClearField("raw_data") 183 | if new_shape != raw_shape: 184 | print("Warning: the new weight shape is not consistent with original shape!") 185 | weight.dims[:] = list(new_shape) 186 | for model_input in self.model.graph.input: 187 | if model_input.name == weight.name: 188 | # copy from onnx.helper... 189 | tensor_shape_proto = model_input.type.tensor_type.shape 190 | tensor_shape_proto.ClearField("dim") 191 | tensor_shape_proto.dim.extend([]) 192 | for d in new_shape: 193 | dim = tensor_shape_proto.dim.add() 194 | dim.dim_value = d 195 | 196 | weight.ClearField("float_data") 197 | weight.ClearField("int32_data") 198 | weight.ClearField("int64_data") 199 | weight.raw_data = data_numpy.tobytes() 200 | else: 201 | if all_ones: 202 | wr = numpy_helper.to_array(weight) 203 | wn = np.ones_like(wr) 204 | elif all_zeros: 205 | wr = numpy_helper.to_array(weight) 206 | wn = np.zeros_like(wr) 207 | else: 208 | print("You must give a data_numpy to set the weight, or set the all_ones/all_zeros flag.") 209 | exit() 210 | weight.ClearField("float_data") 211 | weight.ClearField("int32_data") 212 | weight.ClearField("int64_data") 213 | weight.raw_data = wn.tobytes() 214 | 215 | def set_node_attribute(self, target_node, attr_name, attr_value): 216 | flag = False 217 | for attr in target_node.attribute: 218 | if (attr.name == attr_name): 219 | if attr.type == 1: 220 | attr.f = attr_value 221 | elif attr.type == 2: 222 | attr.i = attr_value 223 | elif attr.type == 3: 224 | attr.s = attr_value 225 | elif attr.type == 4: 226 | attr.t = attr_value 227 | elif attr.type == 5: 228 | attr.g = attr_value 229 | # NOTE: For repeated composite types, we should use something like 230 | # del attr.xxx[:] 231 | # attr.xxx.extend([n1, n2, n3]) 232 | elif attr.type == 6: 233 | attr.floats[:] = attr_value 234 | elif attr.type == 7: 235 | attr.ints[:] = attr_value 236 | elif attr.type == 8: 237 | attr.strings[:] = attr_value 238 | else: 239 | print("unsupported attribute data type with attribute name") 240 | return False 241 | flag = True 242 | 243 | if not flag: 244 | # attribute not in original node 245 | print("Warning: you are appending a new attribute to the node!") 246 | target_node.attribute.append(helper.make_attribute(attr_name, attr_value)) 247 | flag = True 248 | return flag 249 | 250 | def chunk_at(self, target_node): 251 | r_nodes = [target_node] 252 | r_input_names = [input_n for input_n in target_node.input] 253 | r_count = len(r_nodes) + len(r_input_names) 254 | 255 | while True: 256 | for node in self.model.graph.node: 257 | # print("nn", node.output) 258 | if node in r_nodes: 259 | continue 260 | for o in node.output: 261 | if o in r_input_names: 262 | r_nodes.append(node) 263 | r_input_names.extend([input_n for input_n in node.input]) 264 | continue 265 | n_count = len(r_nodes) + len(r_input_names) 266 | if n_count == r_count: 267 | break 268 | r_count = n_count 269 | 270 | print("debug r count", r_count) 271 | 272 | d_nodes = [] 273 | d_inputs = [] 274 | d_weights = [] 275 | d_value_infos = [] 276 | for node in self.model.graph.node: 277 | if node not in r_nodes: 278 | d_nodes.append(node) 279 | for model_input in self.model.graph.input: 280 | if model_input.name not in r_input_names: 281 | d_inputs.append(model_input) 282 | for weight in self.model.graph.initializer: 283 | if weight.name not in r_input_names: 284 | d_weights.append(weight) 285 | for value_info in self.model.graph.value_info: 286 | if value_info.name not in r_input_names: 287 | d_values.append(value_info) 288 | for node in d_nodes: 289 | self.model.graph.node.remove(node) 290 | for model_input in d_inputs: 291 | self.model.graph.input.remove(model_input) 292 | for weight in d_weights: 293 | self.model.graph.initializer.remove(weight) 294 | for value_info in d_value_infos: 295 | self.model.graph.value_info.remove(value_info) 296 | 297 | target_node.output[0] = self.model.graph.output[0].name 298 | # remove other outputs if model has multi-output 299 | d_outputs = [] 300 | for i, output in enumerate(self.model.graph.output): 301 | if i != 0 : 302 | d_outputs.append(output) 303 | for output in d_outputs: 304 | self.model.graph.output.remove(output) 305 | 306 | def insert_flatten_before(self, target_node): 307 | # get target_node inputs 308 | node_input = target_node.input[0] 309 | # create new node 310 | node_name = "flatten_test" 311 | flatten_node = helper.make_node('Flatten', inputs=[node_input], outputs=[node_name], name=node_name) 312 | # set target_node inputs to new node outputs 313 | target_node.input[0] = node_name 314 | for target_node_index, _target_node in enumerate(self.model.graph.node): 315 | if _target_node == target_node: 316 | self.model.graph.node.insert(target_node_index, flatten_node) 317 | break 318 | 319 | def insert_op_before(self, node_name, target_node, input_idx=0, *args, **kwargs): 320 | ''' 321 | op_name 322 | weight_dict 323 | attr_dict 324 | ...... 325 | 326 | NOTE: 327 | you must ensure the output shape match the input shape of target_node 328 | ''' 329 | # get target_node inputs 330 | node_input = target_node.input[input_idx] 331 | weight_input = [] 332 | weight_input_vi = [] 333 | weight_initializer = [] 334 | if "weight_dict" in kwargs: 335 | for weight_name, weight_numpy in kwargs["weight_dict"].items(): 336 | weight_input.append(weight_name) 337 | weight_input_vi.append( 338 | helper.make_tensor_value_info( 339 | name=weight_name, 340 | elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_numpy.dtype], 341 | shape=weight_numpy.shape 342 | ) 343 | ) 344 | weight_initializer.append( 345 | helper.make_tensor( 346 | name=weight_name, 347 | data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_numpy.dtype], 348 | dims=weight_numpy.shape, 349 | vals=weight_numpy.tobytes(), 350 | raw=True 351 | ) 352 | ) 353 | # create new node 354 | new_op_node = helper.make_node( 355 | kwargs["op_name"], 356 | inputs=[node_input, *weight_input], 357 | outputs=[node_name], 358 | name=node_name, 359 | **kwargs["attr_dict"] 360 | ) 361 | # set target_node input to new node outputs 362 | target_node.input[input_idx] = node_name 363 | # TODO: change other nodes input into the new node? 364 | # iterator all the nodes in the graph and find 365 | # which node's input equals the original target_node input 366 | # ... 367 | # add new node and weight input into the graph 368 | for target_node_index, _target_node in enumerate(self.model.graph.node): 369 | if _target_node == target_node: 370 | self.model.graph.node.insert(target_node_index, new_op_node) 371 | break 372 | self.model.graph.input.extend(weight_input_vi) 373 | self.model.graph.initializer.extend(weight_initializer) 374 | 375 | def add_extra_output(self, target_node, output_name): 376 | target_output = target_node.output[0] 377 | extra_shape = [] 378 | for vi in self.model.graph.value_info: 379 | if vi.name == target_output: 380 | extra_elem_type = vi.type.tensor_type.elem_type 381 | for s in vi.type.tensor_type.shape.dim: 382 | extra_shape.append(s.dim_value) 383 | extra_output = helper.make_tensor_value_info( 384 | output_name, 385 | extra_elem_type, 386 | extra_shape 387 | ) 388 | ''' 389 | # NOTE 390 | # if we know the value type and shape, we can alse use this 391 | def make_tensor_value_info( 392 | name, # type: Text 393 | elem_type, # type: int 394 | shape, # type: Optional[Sequence[Union[Text, int]]] 395 | doc_string="", # type: Text 396 | shape_denotation=None, # type: Optional[List[Text]] 397 | ): 398 | ''' 399 | identity_node = helper.make_node('Identity', inputs=[target_output], outputs=[output_name], name=output_name) 400 | self.model.graph.node.append(identity_node) 401 | self.model.graph.output.append(extra_output) 402 | --------------------------------------------------------------------------------