├── .gitignore ├── 00_searching_subgraphs.py ├── 01_removing_nodes_demo1.py ├── 02_removing_nodes_demo2.py ├── 03_replacing_a_subgraph.jpg ├── 03_replacing_a_subgraph.py ├── README.md ├── happy_onnx_modify.py ├── onnx_matcher.py ├── sample.png └── yolov5s.onnx /.gitignore: -------------------------------------------------------------------------------- 1 | relu_yolov5.onnx 2 | remove01.onnx 3 | replace03.onnx 4 | __pycache__ -------------------------------------------------------------------------------- /00_searching_subgraphs.py: -------------------------------------------------------------------------------- 1 | import onnx_matcher 2 | import onnx.helper as helper 3 | import onnx 4 | 5 | name = "yolov5s.onnx" 6 | model = onnx.load(name) 7 | 8 | # Define a subgraph pattern. 9 | subgraph_matcher_demo1 = onnx_matcher.Matcher( 10 | """ 11 | Sigmoid(?, x0) 12 | Slice(x0, a0) 13 | Mul(a0, b0) 14 | Pow(b0, ?) 15 | """ 16 | ) 17 | # Define a subgraph pattern. 18 | subgraph_matcher_demo2 = onnx_matcher.Matcher( 19 | """ 20 | Conv(?, ?) 21 | ?(?, ?) 22 | ?(?, ?) 23 | Conv(?, ?) 24 | """ 25 | ) 26 | 27 | # Print all matched subgraph to the current console. 28 | subgraph_matcher_demo1.print_match(model) 29 | # Print all matched subgraph to the current console. 30 | subgraph_matcher_demo2.print_match(model) 31 | -------------------------------------------------------------------------------- /01_removing_nodes_demo1.py: -------------------------------------------------------------------------------- 1 | import onnx_matcher 2 | import onnx.helper as helper 3 | import onnx 4 | 5 | name = "yolov5s.onnx" 6 | model = onnx.load(name) 7 | 8 | # Define a replace policy function. 9 | def removing_fuction(model, i, subgraph): 10 | parent = onnx_matcher.find_node_by_output(model, subgraph[0].input[0]) 11 | child = onnx_matcher.find_node_by_input(model, subgraph[-1].output[0]) 12 | parent.output[0] = child.input[0] 13 | return [], [] 14 | 15 | # Define a subgraph pattern to delete Conv, Reshape, and Transpose. 16 | subgraph_matcher = onnx_matcher.Matcher( 17 | """ 18 | Conv(?, b0) 19 | Reshape(b0, c0) 20 | Transpose(c0, ?) 21 | """ 22 | ) 23 | 24 | # Print all matched subgraph to the current console. 25 | subgraph_matcher.print_match(model) 26 | 27 | # Use a specific policy to build new subgraphs and replace matching subgraphs. 28 | num_replaced_graph = subgraph_matcher.replace(model, removing_fuction) 29 | print(f"Done for replace {num_replaced_graph} nodes.") 30 | onnx.save(model, "remove01.onnx") 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /02_removing_nodes_demo2.py: -------------------------------------------------------------------------------- 1 | import onnx_matcher 2 | import onnx.helper as helper 3 | import onnx 4 | 5 | name = "yolov5s.onnx" 6 | model = onnx.load(name) 7 | 8 | # Define a replace policy function. 9 | def removing_fuction(model, i, subgraph): 10 | subgraph[-1].input[0] = subgraph[0].output[0] 11 | return [subgraph[0],subgraph[-1]], [] 12 | 13 | # Define a subgraph pattern to delete Conv, Reshape, and Transpose. 14 | subgraph_matcher = onnx_matcher.Matcher( 15 | """ 16 | Mul(?,a0) 17 | Conv(a0, b0) 18 | Reshape(b0, c0) 19 | Transpose(c0, d0) 20 | Sigmoid(d0, ?) 21 | """ 22 | ) 23 | 24 | # Print all matched subgraph to the current console. 25 | subgraph_matcher.print_match(model) 26 | 27 | # Use a specific policy to build new subgraphs and replace matching subgraphs. 28 | num_replaced_graph = subgraph_matcher.replace(model, removing_fuction) 29 | print(f"Done for replace {num_replaced_graph} nodes.") 30 | onnx.save(model, "remove02.onnx") -------------------------------------------------------------------------------- /03_replacing_a_subgraph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sesmfs/onnx_matcher/e1564c91e0cbcd001331bd7f387356343c61352b/03_replacing_a_subgraph.jpg -------------------------------------------------------------------------------- /03_replacing_a_subgraph.py: -------------------------------------------------------------------------------- 1 | import onnx_matcher 2 | import onnx.helper as helper 3 | import onnx 4 | 5 | name = "yolov5s.onnx" 6 | model = onnx.load(name) 7 | 8 | # Define a replace policy function. 9 | def replacing_fuction(model, i, subgraph): 10 | parent = subgraph[0] 11 | sub = subgraph[1] 12 | add = subgraph[2] 13 | mul = subgraph[3] 14 | 15 | sigmoid = helper.make_node("Sigmoid", inputs=[sub.output[0]], outputs=[f"Custom_sigmoid_{i}"], name=f"Custom_sigmoid_{i}") 16 | mul.input[0] = sub.output[0] 17 | 18 | # remove old const node 19 | onnx_matcher.remove_costnode_by_tensor(model, mul.input[1]) 20 | mul.input[1] = sigmoid.output[0] 21 | 22 | return [parent, sub, sigmoid, mul], [] 23 | 24 | # Define a subgraph pattern, deleting Add, adding Sigmod. 25 | # change to Sub->Mul, Sub->Sigmod->Mul 26 | # Replacing result in picture 03_replacing_a_subgraph.jpg 27 | subgraph_matcher = onnx_matcher.Matcher( 28 | """ 29 | ?(?, i0) 30 | Sub(i0, a0) 31 | Add(a0, b0) 32 | Mul(b0, ?) 33 | """ 34 | ) 35 | 36 | # Print all matched subgraph to the current console. 37 | subgraph_matcher.print_match(model) 38 | 39 | # Use a specific policy to build new subgraphs and replace matching subgraphs. 40 | num_replaced_graph = subgraph_matcher.replace(model, replacing_fuction) 41 | 42 | print(f"Done for replace {num_replaced_graph} nodes.") 43 | onnx.save(model, "replace03.onnx") 44 | 45 | 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX Pattern Matcher 2 | Using **pattern matcher** in onnx model to match and replace subgraphs. 3 | ![](sample.png) 4 | 5 | 6 | # Get Start 7 | - Python Code: [happy_onnx_modify.py](happy_onnx_modify.py) 8 | ```python 9 | # Define a replace policy function. 10 | def conv_swish_to_conv_relu(i, subgraph): 11 | conv = subgraph[0] 12 | mul = subgraph[2] 13 | relu = helper.make_node("Relu", inputs=conv.output, outputs=mul.output, name=f"{conv.output[0]}_relu") 14 | return [conv, relu], [] 15 | 16 | # Define a subgraph pattern. 17 | subgraph_matcher = onnx_matcher.Matcher( 18 | """ 19 | Conv(?, c0) 20 | Sigmoid(c0, s0) 21 | Mul([s0, c0], ?) 22 | """ 23 | ) 24 | 25 | # Replace all conv+sigmoid+mul to conv+relu. 26 | subgraph_matcher.replace(model, conv_swish_to_conv_relu) 27 | ``` 28 | 29 | - Run demo: 30 | ```bash 31 | $> python happy_onnx_modify.py 32 | ``` 33 | 34 | # Subsraph Rules 35 | ```python 36 | layername1/layername2([input_argument1, input_argument2], [output_argument1, output_argument2]) 37 | layername(input_argument, output_argument) 38 | 39 | where: 40 | ? will match any layer or argument. 41 | 42 | For example1: 43 | """ 44 | Conv(?, c0) 45 | Sigmoid(c0, s0) 46 | Mul([s0, c0], ?) 47 | """ 48 | 49 | For example2: 50 | """ 51 | Conv/Avgpool(?, c0) 52 | ?(c0, s0) 53 | Mul([s0, c0], ?) 54 | """ 55 | ``` 56 | 57 | # Reference 58 | - No reference 59 | 60 | -------------------------------------------------------------------------------- /happy_onnx_modify.py: -------------------------------------------------------------------------------- 1 | import onnx_matcher 2 | import onnx.helper as helper 3 | import onnx 4 | 5 | name = "yolov5s.onnx" 6 | model = onnx.load(name) 7 | 8 | # Define a replace policy function. 9 | def conv_swish_to_conv_relu(model, i, subgraph): 10 | conv = subgraph[0] 11 | mul = subgraph[2] 12 | relu = helper.make_node("Relu", inputs=conv.output, outputs=mul.output, name=f"{conv.output[0]}_relu") 13 | return [conv, relu], [] 14 | 15 | # Define a subgraph pattern. 16 | subgraph_matcher = onnx_matcher.Matcher( 17 | """ 18 | Conv(?, c0) 19 | Sigmoid(c0, s0) 20 | Mul([s0, c0], ?) 21 | """ 22 | ) 23 | 24 | # Print all matched subgraph to the current console. 25 | subgraph_matcher.print_match(model) 26 | 27 | # Use a specific policy(to_conv_relu) to build new subgraphs and replace matching subgraphs. 28 | num_replaced_graph = subgraph_matcher.replace(model, conv_swish_to_conv_relu) 29 | 30 | print(f"Done for replace {num_replaced_graph} nodes.") 31 | onnx.save(model, "relu_yolov5.onnx") -------------------------------------------------------------------------------- /onnx_matcher.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # Copyright (c) 2023 Zyy 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | 21 | import re 22 | from copy import deepcopy 23 | import inspect 24 | 25 | def log(msg): 26 | lineno = inspect.stack()[1].lineno 27 | print(f"[ONNX_Matcher:{lineno}]: {msg}") 28 | 29 | def find_node_by_output(model, output): 30 | for node in model.graph.node: 31 | if node.op_type == "Constant": 32 | continue 33 | 34 | if output in node.output: 35 | return node 36 | 37 | def find_node_by_input(model, input): 38 | for node in model.graph.node: 39 | if node.op_type == "Constant": 40 | continue 41 | 42 | if input in node.input: 43 | return node 44 | 45 | def find_nodes_by_input(model, input): 46 | nodes = [] 47 | for node in model.graph.node: 48 | if node.op_type == "Constant": 49 | continue 50 | 51 | if input in node.input: 52 | nodes.append(node) 53 | return nodes 54 | 55 | def find_nodes_by_output(model, output): 56 | nodes = [] 57 | for node in model.graph.node: 58 | if node.op_type == "Constant": 59 | continue 60 | 61 | if output in node.output: 62 | nodes.append(node) 63 | return nodes 64 | 65 | def find_consts(model, name): 66 | nodes = [] 67 | for node in model.graph.node: 68 | if name in node.output and node.op_type == "Constant": 69 | nodes.append(node) 70 | return nodes 71 | 72 | def find_initializers(model, name): 73 | nodes = [] 74 | for node in model.graph.initializer: 75 | if node.name == name: 76 | nodes.append(node) 77 | return nodes 78 | 79 | def remove_node_and_init_by_indexs(model, inodes, inints): 80 | inodes = sorted(inodes, reverse=True) 81 | inints = sorted(inints, reverse=True) 82 | for i in inodes: 83 | del model.graph.node[i] 84 | 85 | for i in inints: 86 | del model.graph.initializer[i] 87 | 88 | def remove_costnode_by_tensor(model, tensor_name): 89 | for i, node in enumerate(model.graph.node): 90 | if node.op_type == "Constant": 91 | if tensor_name in node.output: 92 | log(f"Remove a constant node: {node.name}") 93 | del model.graph.node[i] 94 | return True 95 | return False 96 | 97 | def remove_node_and_info(model, node): 98 | nidxs = [] 99 | iidxs = [] 100 | lnodes = list(model.graph.node) 101 | linits = list(model.graph.initializer) 102 | for input in node.input: 103 | consts = find_consts(model, input) 104 | 105 | for n in consts: 106 | nidxs.append(lnodes.index(n)) 107 | 108 | inits = find_initializers(model, input) 109 | for n in inits: 110 | iidxs.append(linits.index(n)) 111 | 112 | remove_node_and_init_by_indexs(model, nidxs, iidxs) 113 | 114 | def cleanup(model): 115 | in_graph_tensors = set() 116 | already_pass = set() 117 | output_names = set([item.name for item in model.graph.output]) 118 | tensors = [[item.name] for item in model.graph.input] 119 | for node in model.graph.node: 120 | if len(node.input) == 0: 121 | tensors.extend(list(node.output)) 122 | 123 | already_pass_tensors = [] 124 | while len(tensors) > 0: 125 | names = tensors.pop() 126 | tensor = names[-1] 127 | if tensor in already_pass: 128 | already_pass_tensors.append(names) 129 | continue 130 | 131 | already_pass.add(tensor) 132 | if tensor in output_names: 133 | in_graph_tensors.update(names) 134 | continue 135 | 136 | nodes = find_nodes_by_input(model, tensor) 137 | for node in nodes: 138 | for output in node.output: 139 | tensors.append(names + list(node.input) + [output]) 140 | 141 | for names in already_pass_tensors: 142 | tensor = names[-1] 143 | if tensor in in_graph_tensors: 144 | in_graph_tensors.update(names) 145 | 146 | del_nodes = [] 147 | del_inits = [] 148 | for inode, node in enumerate(model.graph.node): 149 | in_graph = any([output in in_graph_tensors for output in node.output]) 150 | if not in_graph: 151 | log(f"Remove a floating node: {node.name}, the node output is: {node.output}") 152 | del_nodes.append(inode) 153 | 154 | for iinit, init in enumerate(model.graph.initializer): 155 | in_graph = init.name in in_graph_tensors 156 | if not in_graph: 157 | log(f"Remove a unused initializer: {init.name}") 158 | del_inits.append(iinit) 159 | 160 | remove_node_and_init_by_indexs(model, del_nodes, del_inits) 161 | 162 | class Lexer: 163 | def __init__(self, pattern): 164 | 165 | # Compile the extraction regular expression. 166 | extract_name_and_argument = re.compile("([\W\w]+)\(([\W\w]+)\)") 167 | # Slice(c2, ?) 168 | 169 | # Remove spaces and split patterns by the break line. 170 | lines = [item for item in pattern.replace(" ", "").split("\n") if item != ""] 171 | 172 | # Parsing patterns by lexical analyzer. 173 | self.pattern = pattern 174 | self.lines = lines 175 | self.patterns = [] 176 | for line in lines: 177 | names_and_arguments = extract_name_and_argument.findall(line) 178 | assert len(names_and_arguments) == 1, f"Unexpected line: {line}. The valid symbol is: name(input_argument, output_argument)" 179 | operator_names, argumants = names_and_arguments[0] 180 | inputs, outputs = self.parse_arguments(argumants) 181 | self.patterns.append([operator_names.split("/"), inputs, outputs]) 182 | 183 | def parse_variable(self): 184 | 185 | variable_name = "" 186 | while self.itoken < len(self.symbols): 187 | self.token = self.symbols[self.itoken] 188 | 189 | # If a valid token(alpha/number/_ or ?) for variable. 190 | if self.token.isalnum() or self.token == "?" or self.token == "_": 191 | variable_name += self.token 192 | else: 193 | break 194 | 195 | self.itoken += 1 196 | return variable_name 197 | 198 | def parse_list(self): 199 | self.itoken += 1 200 | lists = [self.parse_variable()] 201 | while self.itoken < len(self.symbols): 202 | self.token = self.symbols[self.itoken] 203 | if self.token == ",": 204 | self.itoken += 1 205 | name = self.parse_variable() 206 | lists.append(name) 207 | continue 208 | elif self.token == "]": 209 | self.itoken += 1 210 | break 211 | else: 212 | raise ValueError(f"Unexpected token: {self.token}") 213 | assert self.token == "]", f"Unexpected end token for list: ], pos: {self.itoken}" 214 | return lists 215 | 216 | def parse_arguments(self, symbols): 217 | self.itoken = 0 218 | self.symbols = symbols 219 | 220 | lists = [] 221 | while self.itoken < len(symbols): 222 | self.token = symbols[self.itoken] 223 | if self.token == "[": 224 | lists.append(self.parse_list()) 225 | else: 226 | lists.append([self.parse_variable()]) 227 | self.itoken += 1 228 | assert len(lists) == 2, f"Unexpected number of params: {len(lists)}" 229 | return lists 230 | 231 | class Matcher: 232 | def __init__(self, pattern): 233 | self.lexer = Lexer(pattern) 234 | 235 | def _match_io(self, input_params, input_names, variables): 236 | for item in input_params: 237 | if item != "?" and variables[item] not in input_names: 238 | return False 239 | return True 240 | 241 | def _try_to_match(self, model, anchor): 242 | matched_paths = [] 243 | params_stack = [[[anchor], 0, dict()]] 244 | while len(params_stack) > 0: 245 | path, icondition, variables = params_stack.pop() 246 | anchor = path[-1] 247 | allowed_op_types, inputs, outputs = self.lexer.patterns[icondition] 248 | if not (anchor.op_type in allowed_op_types or "?" in allowed_op_types): 249 | # if icondition > 1: 250 | # path_string = ", ".join([item.name for item in path]) 251 | # print(f"Can not match type[{path_string}], icondition={icondition}, anchor={anchor.name}[{anchor.op_type}]") 252 | continue 253 | 254 | if not self._match_io(inputs, anchor.input, variables): 255 | # if icondition > 1: 256 | # path_string = ", ".join([item.name for item in path]) 257 | # print(f"Can not match io[{path_string}], icondition={icondition}, anchor={anchor.name}[{anchor.op_type}]") 258 | continue 259 | 260 | if icondition == len(self.lexer.patterns) - 1: 261 | # last condition 262 | matched_paths.append(path) 263 | continue 264 | 265 | variables = deepcopy(variables) 266 | for i, item in enumerate(outputs): 267 | if item != "?": 268 | variables[item] = anchor.output[i] 269 | 270 | for output in anchor.output: 271 | for item in find_nodes_by_input(model, output): 272 | params_stack.append([path + [item], icondition+1, variables]) 273 | return matched_paths 274 | 275 | def match(self, model): 276 | all_matched_pairs = [] 277 | for node in model.graph.node: 278 | if node.op_type == "Constant": 279 | continue 280 | 281 | all_matched_pairs.extend(self._try_to_match(model, node)) 282 | return all_matched_pairs 283 | 284 | def print_match(self, model): 285 | print("=====================================================================") 286 | matched_subgraphs = self.match(model) 287 | log(f"Found {len(matched_subgraphs)} subgraphs:") 288 | for i, subgraph in enumerate(self.match(model)): 289 | subgraph_names = ", ".join([f"{item.name}({item.op_type})" for item in subgraph]) 290 | print(f"\tSubgraph{i}: {subgraph_names}") 291 | 292 | pattern_text = "\n\t".join(self.lexer.lines) 293 | log(f"Pattern is:\n\t{pattern_text}") 294 | print("=====================================================================") 295 | 296 | # delete some subgraph 297 | def delete(self, model): 298 | self.replace(model, None) 299 | 300 | # replace some subgraph to new 301 | def replace(self, model, new_graph_fn=None): 302 | matched_subgraphs = self.match(model) 303 | for i, subgraph in enumerate(matched_subgraphs): 304 | if new_graph_fn is not None: 305 | new_nodes, new_initializers = new_graph_fn(model, i, subgraph) 306 | else: 307 | new_nodes, new_initializers = [], [] 308 | 309 | newgraph_names = ", ".join([f"{item.name}({item.op_type})" for item in new_nodes]) 310 | subgraph_names = ", ".join([f"{item.name}({item.op_type})" for item in subgraph]) 311 | if len(new_nodes) > 0: 312 | log(f"Replace subgraph{i}: [{subgraph_names}] to: [{newgraph_names}]") 313 | else: 314 | log(f"Delete subgraph{i}: {subgraph_names}") 315 | 316 | lnodes = list(model.graph.node) 317 | idxs = sorted([lnodes.index(item) for item in subgraph], reverse=True) 318 | for i in idxs: 319 | del model.graph.node[i] 320 | 321 | for n in subgraph: 322 | # Remove the node and its corresponding information if it is not in new_nodes 323 | if n not in new_nodes: 324 | remove_node_and_info(model, n) 325 | 326 | if len(new_nodes) == 0: 327 | input_node = subgraph[0] 328 | output_node = subgraph[-1] 329 | assert len(input_node.input) == len(output_node.output) and new_graph_fn is None or new_graph_fn is not None, f"Invalid replace" 330 | 331 | i2o = {a:b for a, b in zip(input_node.input, output_node.output)} 332 | for input_name in input_node.input: 333 | parents = find_nodes_by_output(model, input_name) 334 | for p in parents: 335 | p.output[list(p.output).index(input_name)] = i2o[input_name] 336 | else: 337 | insert_point = idxs[-1] 338 | for node in new_nodes: 339 | model.graph.node.insert(insert_point, node) 340 | insert_point += 1 341 | 342 | for init in new_initializers: 343 | model.graph.initializer.append(init) 344 | return len(matched_subgraphs) -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sesmfs/onnx_matcher/e1564c91e0cbcd001331bd7f387356343c61352b/sample.png -------------------------------------------------------------------------------- /yolov5s.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sesmfs/onnx_matcher/e1564c91e0cbcd001331bd7f387356343c61352b/yolov5s.onnx --------------------------------------------------------------------------------