├── .gitignore ├── README.md ├── convert2onnx └── pytorch2onnx_resize.py ├── onnxapi └── creat_onnx_example.py ├── test_tool.py └── tools ├── __init__.py └── tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | *.pyc 3 | *.pth 4 | model/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # onnx_learn 2 | 3 | 记录学习ONNX的一些资料以及提供一些实践经验。 4 | 5 | # 依赖 6 | 7 | - numpy 必选 8 | - onnx 必选 9 | - onnxruntime 可选,如果脚本里面使用到了onnxruntime则需要安装 10 | - pytorch 可选,如果脚本里面使用到了pytorch则需要安装 11 | 12 | # 代码结构 13 | 14 | - convert2onnx 15 | - pytorch2onnx_resize.py 通过Pytorch导出ONNX模型,Reshape操作 16 | - onnxapi 17 | - creat_onnx_example.py 创建一个onnx模型例子 18 | - tools 维护一个工具类,方便修改ONNX模型来解决ONNX版本迭代以及框架之间对OP定义的不兼容问题 19 | - test_tool.py 测试ONNX工具类 20 | 21 | # 学习笔记 22 | 23 | - [ONNX初探](https://mp.weixin.qq.com/s/H1tDcmrg0vTcSw9PgpgIIQ) 24 | - [ONNX再探](https://mp.weixin.qq.com/s/_iNhfZNR5-swXLhHKjYRkQ) 25 | - [onnx simplifier 和 optimizer](https://mp.weixin.qq.com/s/q0Aa2LRpeCPCnIzRJbMmaQ) 26 | - [onnx2pytorch和onnx-simplifier新版介绍](https://mp.weixin.qq.com/s/NDv-quXeBrPeDcCbg97FHA) 27 | - [深度学习框架OneFlow是如何和ONNX交互的?](https://mp.weixin.qq.com/s/sxBDHl00jAKRXq-Y6Rii7A) 28 | - [Pytorch转ONNX-理论篇](https://mp.weixin.qq.com/s/RoqaMPwCbtHfLKgnJX95ng) 29 | - [Pytorch转ONNX-实战篇1(tracing机制)](https://mp.weixin.qq.com/s/L2lZAo35ZeybuiH3tgJsvw) 30 | - [Pytorch转ONNX-实战篇2(实战踩坑总结)](https://mp.weixin.qq.com/s/nG45SDO2_J48omSkn27EtQ) -------------------------------------------------------------------------------- /convert2onnx/pytorch2onnx_resize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class JustReshape(torch.nn.Module): 5 | def __init__(self): 6 | super(JustReshape, self).__init__() 7 | self.mean = torch.randn(2, 3, 4, 5) 8 | self.std = torch.randn(2, 3, 4, 5) 9 | 10 | def forward(self, x): 11 | # x = (x - self.mean) / self.std 12 | return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2])) 13 | 14 | 15 | net = JustReshape() 16 | model_name = '../model/just_reshape.onnx' 17 | dummy_input = torch.randn(1, 3, 4, 5) 18 | torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output']) -------------------------------------------------------------------------------- /onnxapi/creat_onnx_example.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper 3 | from onnx import AttributeProto, TensorProto, GraphProto 4 | 5 | 6 | # The protobuf definition can be found here: 7 | # https://github.com/onnx/onnx/blob/master/onnx/onnx.proto 8 | 9 | 10 | # Create one input (ValueInfoProto) 11 | X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2]) 12 | pads = helper.make_tensor_value_info('pads', TensorProto.FLOAT, [1, 4]) 13 | 14 | value = helper.make_tensor_value_info('value', AttributeProto.FLOAT, [1]) 15 | 16 | 17 | # Create one output (ValueInfoProto) 18 | Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 4]) 19 | 20 | # Create a node (NodeProto) - This is based on Pad-11 21 | node_def = helper.make_node( 22 | 'Pad', # node name 23 | ['X', 'pads', 'value'], # inputs 24 | ['Y'], # outputs 25 | mode='constant', # attributes 26 | ) 27 | 28 | # Create the graph (GraphProto) 29 | graph_def = helper.make_graph( 30 | [node_def], 31 | 'test-model', 32 | [X, pads, value], 33 | [Y], 34 | ) 35 | 36 | # Create the model (ModelProto) 37 | model_def = helper.make_model(graph_def, producer_name='onnx-example') 38 | 39 | print('The model is:\n{}'.format(model_def)) 40 | onnx.checker.check_model(model_def) 41 | print('The model is checked!') -------------------------------------------------------------------------------- /test_tool.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import onnx 3 | import numpy as np 4 | from onnx import helper 5 | from onnx import numpy_helper 6 | import argparse 7 | from tools import * 8 | 9 | # mxnet老版本遗留问题 10 | def old_mxnet_version_example(tool): 11 | # NOTE 1 12 | # in some old version mxnet model, the fix_gamma in BatchNormalization is set to True, 13 | # but when converting to onnx model which do NOT have the fix_gamma attribute, and the 14 | # gamma (named scale in onnx) parameter is not all ones, it may cause result inconsistent 15 | # NOTE 2 16 | # in some old version mxnet model, the average pooling layer has an attribute "count_include_pad" 17 | # but is was not set when converting to onnx model, it seems like the default value is 1 18 | bn_nodes = tool.get_nodes_by_optype("BatchNormalization") 19 | for bn_node in bn_nodes: 20 | gamma_name = bn_node.input[1] 21 | tool.set_weight_by_name(gamma_name, all_ones=True) 22 | avg_nodes = tool.get_nodes_by_optype("AveragePool") 23 | for avg_node in avg_nodes: 24 | tool.set_node_attribute(avg_node, "count_include_pad", 1) 25 | 26 | # 为tensorflow转出来的ONNX模型设置batch维度 27 | def tf_set_batch_size_example(tool, batch_size=8): 28 | # NOTE 29 | # when using tf2onnx convert the tensorflow pb model to onnx 30 | # the input batch_size dim is not set, we can append it 31 | tool.list_model_inputs(2) 32 | # tool.set_model_input_shape(name="pb_input:0", shape=(32,3,256,256)) 33 | tool.set_model_input_batch_size(batch_size=batch_size) 34 | 35 | # 获取我们想要的任意节点的推理结果 36 | def debug_internal_output(tool, node_name, output_name): 37 | # NOTE 38 | # sometimes we hope to get the internal result of some node for debug, 39 | # but onnx do NOT have the API to support this function. Don't worry, 40 | # we can append an Identity OP and an extra output following the target 41 | # node to get the result we want 42 | node = tool.get_node_by_name(node_name) 43 | tool.add_extra_output(node, output_name) 44 | 45 | 46 | def tensorrt_set_epsilon_example(tool, epsilon=1e-3): 47 | # NOTE 48 | # We found when converting an onnx model with InstanceNormalization OP to TensorRT engine, the inference result is inaccurate 49 | # you can find the details at https://devtalk.nvidia.com/default/topic/1071094/tensorrt/inference-result-inaccurate-with-conv-and-instancenormalization-under-certain-conditions/ 50 | # After days of debugging, and we finally find this issue is caused by the following line of code 51 | # https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1557 52 | # it is strange that TensorRT onnx parser only supports epsilon >= 1e-4, if you do NOT 53 | # want to re-compile the TensorRT OSS, you can change epsilon to 1e-3 manually... 54 | # I tried comment out that line, it worked but the error is bigger than setting epsilon to 1e-3 55 | in_nodes = tool.get_nodes_by_optype("InstanceNormalization") 56 | for in_node in in_nodes: 57 | tool.set_node_attribute(in_node, "epsilon", epsilon) 58 | 59 | # 在ONNX模型的指定节点前添加卷积OP 60 | def add_conv_layer(tool, target_node_name): 61 | # NOTE: 62 | # The name, attribute and weight of the OP can be found at: 63 | # https://github.com/onnx/onnx/blob/master/docs/Operators.md 64 | # You must convert all your weight and attribute to the standard 65 | # of the ONNX to avoid unexpected error 66 | target_node = tool.get_node_by_name(target_node_name) 67 | # NOTE: 68 | # the weight name better be complicated enough to avoid conflict, 69 | # And weight_dict must be in order (make sure your python version >= 3.6) 70 | weight_dict = { 71 | "W_from_a_new_conv_op": np.random.normal(0, 1, (64, 64, 3, 3)).astype(np.float32), 72 | "B_from_a_new_conv_op": np.random.normal(0, 1, (64,)).astype(np.float32) 73 | } 74 | attr_dict = { 75 | "kernel_shape": [3, 3], 76 | "pads": [0, 0, 0, 0] 77 | } 78 | tool.insert_op_before( 79 | node_name="new_conv_op", 80 | target_node=target_node, 81 | op_name="Conv", 82 | weight_dict=weight_dict, 83 | attr_dict=attr_dict 84 | ) 85 | 86 | # 查看ONNX节点的属性 87 | def show_node_attributes(node): 88 | print("="*10, "attributes of node: ", node.name, "="*10) 89 | for attr in node.attribute: 90 | print(attr.name) 91 | print("="*60) 92 | 93 | # 查看ONNX节点的输入 94 | def show_node_inputs(node): 95 | # Generally, the first input is the truely input 96 | # and the rest input is weight initializer 97 | print("="*10, "inputs of node: ", node.name, "="*10) 98 | for input_name in node.input: 99 | print(input_name) # type of input_name is str 100 | print("="*60) 101 | 102 | # 查看ONNX节点的输出 103 | def show_node_outputs(node): 104 | print("="*10, "outputs of node: ", node.name, "="*10) 105 | for output_name in node.output: 106 | print(output_name) # type of output_name is str 107 | print("="*60) 108 | 109 | # 打印权重 110 | def show_weight(weight): 111 | print("="*10, "details of weight: ", weight.name, "="*10) 112 | print("data type: ", weight.data_type) 113 | print("shape: ", weight.dims) 114 | data_numpy = numpy_helper.to_array(weight) 115 | # data_numpy = np.frombuffer(weight.raw_data, dtype=xxx) 116 | # print("detail data:", data_numpy) 117 | print("="*40) 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser(description="onnx test") 121 | parser.add_argument("--input", default="", type=str, required=True) 122 | parser.add_argument("--output", default="", type=str, required=True) 123 | args = parser.parse_args() 124 | 125 | tool = Tool(args.input) 126 | 127 | # old_mxnet_version_example(tool) 128 | # tf_set_batch_size_example(tool, 16) 129 | # debug_internal_output(tool, "your target node name", "debug_test") 130 | tensorrt_set_epsilon_example(tool, 1e-3) 131 | 132 | tool.export(args.output) -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .tool import * -------------------------------------------------------------------------------- /tools/tool.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import onnx 3 | import numpy as np 4 | from onnx import helper 5 | from onnx import numpy_helper 6 | 7 | class Tool(object): 8 | # 初始化onnx模型 9 | def __init__(self, onnx_model_path): 10 | self.model = onnx.load(onnx_model_path) 11 | self.model = onnx.shape_inference.infer_shapes(self.model) 12 | self.inputs = [] 13 | self.outputs = [] 14 | 15 | # 保存onnx模型 16 | def export(self, save_path): 17 | onnx.checker.check_model(self.model) 18 | self.model = onnx.shape_inference.infer_shapes(self.model) 19 | onnx.save(self.model, save_path) 20 | 21 | # 获取onnx模型的输入,返回一个列表 22 | def get_input_names(self): 23 | set_input = set() 24 | set_initializer = set() 25 | for ipt in self.model.graph.input: 26 | set_input.add(ipt.name) 27 | for x in model.graph.initializer: 28 | set_initializer.add(x.name) 29 | return list(set_input - set_initializer) 30 | 31 | # 为onnx模型增加batch维度 32 | def set_model_input_batch(self, index=0, name=None, batch_size=4): 33 | model_input = None 34 | if name is not None: 35 | for ipt in self.model.graph.input: 36 | if ipt.name == name: 37 | model_input = ipt 38 | else: 39 | model_input = self.model.graph.input[index] 40 | if model_input: 41 | tensor_dim = model_input.type.tensor_type.shape.dim 42 | tensor_dim[0].ClearField("dim_param") 43 | tensor_dim[0].dim_value = batch_size 44 | else: 45 | print('get model input failed, check index or name') 46 | 47 | # 为onnx模型的输入设置形状 48 | def set_model_input_shape(self, index=0, name=None, shape=None): 49 | model_input = None 50 | if name is not None: 51 | for ipt in self.model.graph.input: 52 | if ipt.name == name: 53 | model_input = ipt 54 | else: 55 | model_input = self.model.graph.input[index] 56 | if model_input: 57 | if shape is not None: 58 | tensor_shape_proto = model_input.type.tensor_type.shape 59 | tensor_shape_proto.ClearField("dim") 60 | tensor_shape_proto.dim.extend([]) 61 | for d in shape: 62 | dim = tensor_shape_proto.dim.add() 63 | dim.dim_value = d 64 | else: 65 | print('get input shape failed, check input') 66 | else: 67 | print('get model input failed, check index or name') 68 | 69 | # 通过名字获取onnx模型中的计算节点 70 | def get_node_by_name(self, name): 71 | for node in self.model.graph.node: 72 | if node.name == name: 73 | return node 74 | 75 | # 通过op的类型获取onnx模型的计算节点 76 | def get_nodes_by_optype(self, typename): 77 | nodes = [] 78 | for node in self.model.graph.node: 79 | if node.op_type == typename: 80 | nodes.append(node) 81 | return nodes 82 | 83 | # 通过名字获取onnx模型计算节点的权重 84 | def get_weight_by_name(self, name): 85 | for weight in self.model.graph.initializer: 86 | if weight.name == name: 87 | return weight 88 | 89 | # 设置权重,注意这个weight是TensorProto类型,`https://github.com/onnx/onnx/blob/b1e0bc9a31eaefc2a9946182fbad939843534984/onnx/onnx.proto#L461` 90 | def set_weight(self, weight, data_numpy=None, all_ones=False, all_zeros=False): 91 | if data_numpy is not None: 92 | raw_shape = tuple([i for i in weight.dims]) 93 | new_shape = np.shape(data_numpy) 94 | if weight.data_type == 8: 95 | print("Can NOT handle string data type right now...") 96 | exit() 97 | if new_shape != raw_shape: 98 | print("Warning: the new weight shape is not consistent with original shape!") 99 | weight.dims[:] = list(new_shape) 100 | for model_input in self.model.graph.input: 101 | if model_input.name == weight.name: 102 | # copy from onnx.helper... 103 | tensor_shape_proto = model_input.type.tensor_type.shape 104 | tensor_shape_proto.ClearField("dim") 105 | tensor_shape_proto.dim.extend([]) 106 | for d in new_shape: 107 | dim = tensor_shape_proto.dim.add() 108 | dim.dim_value = d 109 | 110 | weight.ClearField("float_data") 111 | weight.ClearField("int32_data") 112 | weight.ClearField("int64_data") 113 | weight.raw_data = data_numpy.tobytes() 114 | else: 115 | if all_ones: 116 | wr = numpy_helper.to_array(weight) 117 | wn = np.ones_like(wr) 118 | elif all_zeros: 119 | wr = numpy_helper.to_array(weight) 120 | wn = np.zeros_like(wr) 121 | else: 122 | print("You must give a data_numpy to set the weight, or set the all_ones/all_zeros flag.") 123 | exit() 124 | weight.ClearField("float_data") 125 | weight.ClearField("int32_data") 126 | weight.ClearField("int64_data") 127 | weight.raw_data = wn.tobytes() 128 | 129 | # 通过名字设置ONNX节点的权重 130 | def set_weight_by_name(self, name, data_numpy=None, all_ones=False, all_zeros=False): 131 | weight = self.get_weight_by_name(name) 132 | self.set_weight(weight, data_numpy, all_ones, all_zeros) 133 | 134 | # 移除ONNX模型中的目标节点 135 | def remove_node(self, target_node): 136 | ''' 137 | 删除只有一个输入和输出的节点 138 | ''' 139 | node_input = target_node.input[0] 140 | node_output = target_node.output[0] 141 | # 将后继节点的输入设置为目标节点的前置节点 142 | for node in self.model.graph.node: 143 | for i, n in enumerate(node.input): 144 | if n == node_output: 145 | node.input[i] = node_input 146 | 147 | target_names = set(target_node.input) & set([weight.name for weight in self.model.graph.initializer]) 148 | self.remove_weights(target_names) 149 | target_names.add(node_output) 150 | self.remove_inputs(target_names) 151 | self.remove_value_infos(target_names) 152 | self.model.graph.node.remove(target_node) 153 | 154 | # 移除ONNX模型中指定节点的权重 155 | def remove_weights(self, name_list): 156 | rm_list = [] 157 | for weight in self.model.graph.initializer: 158 | if weight.name in name_list: 159 | rm_list.append(weight) 160 | for weight in rm_list: 161 | self.model.graph.initializer.remove(weight) 162 | 163 | # 移除ONNX模型中指定的输入节点 164 | def remove_inputs(self, name_list): 165 | rm_list = [] 166 | for input_t in self.model.graph.input: 167 | if input_t.name in name_list: 168 | rm_list.append(input_t) 169 | for input_t in rm_list: 170 | self.model.graph.input.remove(input_t) 171 | 172 | # 移除ONNX模型中指定的输入输出节点 173 | def remove_value_infos(self, name_list): 174 | rm_list = [] 175 | for value_info in self.model.graph.value_info: 176 | if value_info.name in name_list: 177 | rm_list.append(value_info) 178 | for value_info in rm_list: 179 | self.model.graph.value_info.remove(value_info) 180 | 181 | # 给ONNX模型中的目标节点设置指定属性 182 | def set_node_attribute(self, target_node, attr_name, attr_value): 183 | flag = False 184 | for attr in target_node.attribute: 185 | if (attr.name == attr_name): 186 | if attr.type == 1: 187 | attr.f = attr_value 188 | elif attr.type == 2: 189 | attr.i = attr_value 190 | elif attr.type == 3: 191 | attr.s = attr_value 192 | elif attr.type == 4: 193 | attr.t = attr_value 194 | elif attr.type == 5: 195 | attr.g = attr_value 196 | # NOTE: For repeated composite types, we should use something like 197 | # del attr.xxx[:] 198 | # attr.xxx.extend([n1, n2, n3]) 199 | elif attr.type == 6: 200 | attr.floats[:] = attr_value 201 | elif attr.type == 7: 202 | attr.ints[:] = attr_value 203 | elif attr.type == 8: 204 | attr.strings[:] = attr_value 205 | else: 206 | print("unsupported attribute data type with attribute name") 207 | return False 208 | flag = True 209 | 210 | if not flag: 211 | # attribute not in original node 212 | print("Warning: you are appending a new attribute to the node!") 213 | target_node.attribute.append(helper.make_attribute(attr_name, attr_value)) 214 | flag = True 215 | return flag 216 | 217 | def chunk_at(self, target_node): 218 | r_nodes = [target_node] 219 | r_input_names = [input_n for input_n in target_node.input] 220 | r_count = len(r_nodes) + len(r_input_names) 221 | 222 | while True: 223 | for node in self.model.graph.node: 224 | # print("nn", node.output) 225 | if node in r_nodes: 226 | continue 227 | for o in node.output: 228 | if o in r_input_names: 229 | r_nodes.append(node) 230 | r_input_names.extend([input_n for input_n in node.input]) 231 | continue 232 | n_count = len(r_nodes) + len(r_input_names) 233 | if n_count == r_count: 234 | break 235 | r_count = n_count 236 | 237 | print("debug r count", r_count) 238 | 239 | d_nodes = [] 240 | d_inputs = [] 241 | d_weights = [] 242 | d_value_infos = [] 243 | for node in self.model.graph.node: 244 | if node not in r_nodes: 245 | d_nodes.append(node) 246 | for model_input in self.model.graph.input: 247 | if model_input.name not in r_input_names: 248 | d_inputs.append(model_input) 249 | for weight in self.model.graph.initializer: 250 | if weight.name not in r_input_names: 251 | d_weights.append(weight) 252 | for value_info in self.model.graph.value_info: 253 | if value_info.name not in r_input_names: 254 | d_values.append(value_info) 255 | for node in d_nodes: 256 | self.model.graph.node.remove(node) 257 | for model_input in d_inputs: 258 | self.model.graph.input.remove(model_input) 259 | for weight in d_weights: 260 | self.model.graph.initializer.remove(weight) 261 | for value_info in d_value_infos: 262 | self.model.graph.value_info.remove(value_info) 263 | 264 | target_node.output[0] = self.model.graph.output[0].name 265 | # remove other outputs if model has multi-output 266 | d_outputs = [] 267 | for i, output in enumerate(self.model.graph.output): 268 | if i != 0 : 269 | d_outputs.append(output) 270 | for output in d_outputs: 271 | self.model.graph.output.remove(output) 272 | 273 | # 在指定节点前插入flatten node 274 | def insert_flatten_before(self, target_node): 275 | # get target_node inputs 276 | node_input = target_node.input[0] 277 | # create new node 278 | node_name = "flatten_test" 279 | flatten_node = helper.make_node('Flatten', inputs=[node_input], outputs=[node_name], name=node_name) 280 | # set target_node inputs to new node outputs 281 | target_node.input[0] = node_name 282 | for target_node_index, _target_node in enumerate(self.model.graph.node): 283 | if _target_node == target_node: 284 | self.model.graph.node.insert(target_node_index, flatten_node) 285 | break 286 | 287 | # 在指定节点target_node前插入一个新的OP 288 | def insert_op_before(self, node_name, target_node, input_idx=0, *args, **kwargs): 289 | ''' 290 | op_name 291 | weight_dict 292 | attr_dict 293 | ...... 294 | NOTE: 295 | you must ensure the output shape match the input shape of target_node 296 | ''' 297 | # get target_node inputs 298 | node_input = target_node.input[input_idx] 299 | weight_input = [] 300 | weight_input_vi = [] 301 | weight_initializer = [] 302 | if "weight_dict" in kwargs: 303 | for weight_name, weight_numpy in kwargs["weight_dict"].items(): 304 | weight_input.append(weight_name) 305 | weight_input_vi.append( 306 | helper.make_tensor_value_info( 307 | name=weight_name, 308 | elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_numpy.dtype], 309 | shape=weight_numpy.shape 310 | ) 311 | ) 312 | weight_initializer.append( 313 | helper.make_tensor( 314 | name=weight_name, 315 | data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[weight_numpy.dtype], 316 | dims=weight_numpy.shape, 317 | vals=weight_numpy.tobytes(), 318 | raw=True 319 | ) 320 | ) 321 | # create new node 322 | new_op_node = helper.make_node( 323 | kwargs["op_name"], 324 | inputs=[node_input, *weight_input], 325 | outputs=[node_name], 326 | name=node_name, 327 | **kwargs["attr_dict"] 328 | ) 329 | # set target_node input to new node outputs 330 | target_node.input[input_idx] = node_name 331 | # TODO: change other nodes input into the new node? 332 | # iterator all the nodes in the graph and find 333 | # which node's input equals the original target_node input 334 | # ... 335 | # add new node and weight input into the graph 336 | for target_node_index, _target_node in enumerate(self.model.graph.node): 337 | if _target_node == target_node: 338 | self.model.graph.node.insert(target_node_index, new_op_node) 339 | break 340 | self.model.graph.input.extend(weight_input_vi) 341 | self.model.graph.initializer.extend(weight_initializer) 342 | 343 | # 将target_node添加到ONNX模型中作为输出节点 344 | def add_extra_output(self, target_node, output_name): 345 | target_output = target_node.output[0] 346 | extra_shape = [] 347 | for vi in self.model.graph.value_info: 348 | if vi.name == target_output: 349 | extra_elem_type = vi.type.tensor_type.elem_type 350 | for s in vi.type.tensor_type.shape.dim: 351 | extra_shape.append(s.dim_value) 352 | extra_output = helper.make_tensor_value_info( 353 | output_name, 354 | extra_elem_type, 355 | extra_shape 356 | ) 357 | identity_node = helper.make_node('Identity', inputs=[target_output], outputs=[output_name], name=output_name) 358 | self.model.graph.node.append(identity_node) 359 | self.model.graph.output.append(extra_output) 360 | --------------------------------------------------------------------------------