├── Converting A PyTorch Model to Tensorflow pb using ONNX.html ├── Converting A PyTorch Model to Tensorflow pb using ONNX.md ├── README.md ├── convert_pytorch2onnx2tfpb.py ├── diymodel.py ├── mlmcmodel.py └── pants.jpg /Converting A PyTorch Model to Tensorflow pb using ONNX.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 278 | 279 |
280 |

Converting A PyTorch Model to Tensorflow pb using ONNX

281 | 282 |

pilgrim.bin@gmail.com

283 | 284 | 317 | 318 | 319 |

1. Pre-installation

320 | 321 |

Version Info

322 | 323 |
pytorch                   0.4.0           py27_cuda0.0_cudnn0.0_1    pytorch
324 | torchvision               0.2.1                    py27_1    pytorch
325 | tensorflow                1.8.0                     <pip>
326 | onnx                      1.2.2                     <pip>
327 | onnx-tf                   1.1.2                     <pip> 
328 | 
329 | 330 |

注意:

331 | 332 |
    333 |
  1. ONNX1.1.2版本太低会引发BatchNormalization错误,当前pip已经支持1.3.0版本;也可以考虑源码安装 pip install -U git+https://github.com/onnx/onnx.git@master
  2. 334 |
  3. 本实验验证ONNX1.2.2版本可正常运行
  4. 335 |
  5. onnx-tf采用源码安装;要求 Tensorflow>=1.5.0.;
  6. 336 |
337 | 338 |

2. 转换过程

339 | 340 |

2.1 Step 1.2.3.

341 | 342 |

pipeline: pytorch model --> onnx modle --> tensorflow graph pb.

343 | 344 |
# step 1, load pytorch model and export onnx during running.
345 |     modelname = 'resnet18'
346 |     weightfile = 'models/model_best_checkpoint_resnet18.pth.tar'
347 |     modelhandle = DIY_Model(modelname, weightfile, class_numbers)
348 |     model = modelhandle.model
349 |     #model.eval() # useless
350 |     dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw
351 |     onnx_filename = os.path.split(weightfile)[-1] + ".onnx"
352 |     torch.onnx.export(model, dummy_input,
353 |                       onnx_filename,
354 |                       verbose=True)
355 |     
356 |     # step 2, create onnx_model using tensorflow as backend. check if right and export graph.
357 |     onnx_model = onnx.load(onnx_filename)
358 |     tf_rep = prepare(onnx_model, strict=False)
359 |     # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
360 |     # Reference https://github.com/onnx/onnx-tensorflow/issues/167
361 |     #tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
362 |     image = Image.open('pants.jpg')
363 |     # debug, here using the same input to check onnx and tf.
364 |     output_pytorch, img_np = modelhandle.process(image)
365 |     print('output_pytorch = {}'.format(output_pytorch))
366 |     output_onnx_tf = tf_rep.run(img_np)
367 |     print('output_onnx_tf = {}'.format(output_onnx_tf))
368 |     # onnx --> tf.graph.pb
369 |     tf_pb_path = onnx_filename + '_graph.pb'
370 |     tf_rep.export_graph(tf_pb_path)
371 |     
372 |     # step 3, check if tf.pb is right.
373 |     with tf.Graph().as_default():
374 |         graph_def = tf.GraphDef()
375 |         with open(tf_pb_path, "rb") as f:
376 |             graph_def.ParseFromString(f.read())
377 |             tf.import_graph_def(graph_def, name="")
378 |         with tf.Session() as sess:
379 |             #init = tf.initialize_all_variables()
380 |             init = tf.global_variables_initializer()
381 |             #sess.run(init)
382 |             
383 |             # print all ops, check input/output tensor name.
384 |             # uncomment it if you donnot know io tensor names.
385 |             '''
386 |             print('-------------ops---------------------')
387 |             op = sess.graph.get_operations()
388 |             for m in op:
389 |                 print(m.values())
390 |             print('-------------ops done.---------------------')
391 |             '''
392 | 
393 |             input_x = sess.graph.get_tensor_by_name("0:0") # input
394 |             outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
395 |             outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
396 |             output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np})
397 |             #output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)})
398 |             print('output_tf_pb = {}'.format(output_tf_pb))
399 | 
400 | 401 |

2.2 Verification

402 | 403 |

确保输出结果一致

404 | 405 |
output_pytorch = [array([ 2.5359073 , -1.4261041 , -5.2394    , -0.62402934,  4.7426634 ], dtype=float32), array([ 7.6249304,  5.1203837,  1.8118637,  1.5143847, -4.9409146, 1.1695148, -6.2375665, -1.6033885, -1.4286405, -2.964429 ], dtype=float32)]
406 |       
407 | output_onnx_tf = Outputs(_0=array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269,  4.7426634]], dtype=float32), _1=array([[ 7.6249285,  5.12038  ,  1.811865 ,  1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32))
408 |       
409 | output_tf_pb = [array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269,  4.7426634]], dtype=float32), array([[ 7.6249285,  5.12038  ,  1.811865 ,  1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)]
410 | 
411 | 412 |

独立TF验证程序

413 | 414 |
def get_img_np_nchw(filename):
415 |     try:
416 |         image = Image.open(filename).convert('RGB').resize((224, 224))
417 |         miu = np.array([0.485, 0.456, 0.406])
418 |         std = np.array([0.229, 0.224, 0.225])
419 |         #miu = np.array([0.5, 0.5, 0.5])
420 |         #std = np.array([0.22, 0.22, 0.22])
421 |         # img_np.shape = (224, 224, 3)
422 |         img_np = np.array(image, dtype=float) / 255.
423 |         r = (img_np[:,:,0] - miu[0]) / std[0]
424 |         g = (img_np[:,:,1] - miu[1]) / std[1]
425 |         b = (img_np[:,:,2] - miu[2]) / std[2]
426 |         img_np_t = np.array([r,g,b])
427 |         img_np_nchw = np.expand_dims(img_np_t, axis=0)
428 |         return img_np_nchw
429 |     except:
430 |         print("RuntimeError: get_img_np_nchw({}).".format(filename))
431 |         # NoneType
432 |     
433 | 
434 | if __name__ == '__main__':
435 |     
436 |     tf_pb_path = 'model_best_checkpoint_resnet18.pth.tar.onnx_graph.pb'
437 |     
438 |     filename = 'pants.jpg'
439 |     img_np_nchw = get_img_np_nchw(filename)
440 |     
441 |     # step 3, check if tf.pb is right.
442 |     with tf.Graph().as_default():
443 |         graph_def = tf.GraphDef()
444 |         with open(tf_pb_path, "rb") as f:
445 |             graph_def.ParseFromString(f.read())
446 |             tf.import_graph_def(graph_def, name="")
447 |         with tf.Session() as sess:
448 |             init = tf.global_variables_initializer()
449 |             #init = tf.initialize_all_variables()
450 |             sess.run(init)
451 |             
452 |             # print all ops, check input/output tensor name.
453 |             # uncomment it if you donnot know io tensor names.
454 |             '''
455 |             print('-------------ops---------------------')
456 |             op = sess.graph.get_operations()
457 |             for m in op:
458 |                 print(m.values())
459 |             print('-------------ops done.---------------------')
460 |             '''
461 | 
462 |             input_x = sess.graph.get_tensor_by_name("0:0") # input
463 |             outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
464 |             outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
465 |             output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np_nchw})
466 |             print('output_tf_pb = {}'.format(output_tf_pb))
467 | 
468 | 469 |

3. Related Info

470 | 471 |

3.1 ONNX

472 | 473 |

Open Neural Network Exchange
474 | https://github.com/onnx
475 | https://onnx.ai/

476 | 477 |

The ONNX exporter is a trace-based exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. Limitations

478 | 479 |

https://github.com/onnx/tensorflow-onnx
480 | https://github.com/onnx/onnx-tensorflow

481 | 482 |

3.2 Microsoft/MMdnn

483 | 484 |

当前网络没有调通
485 | https://github.com/Microsoft/MMdnn/blob/master/mmdnn/conversion/pytorch/README.md

486 | 487 |

Reference

488 | 489 |
    490 |
  1. Open Neural Network Exchange https://github.com/onnx
  2. 491 |
  3. Exporting model from PyTorch to ONNX
  4. 492 |
  5. Importing ONNX models to Tensorflow(ONNX)
  6. 493 |
  7. Tensorflow + tornado服务
  8. 494 |
  9. graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())
  10. 495 |
  11. A Tool Developer's Guide to TensorFlow Model Files
  12. 496 |
  13. TensorFlow学习笔记:Retrain Inception_v3
  14. 497 |
498 | 499 |
500 | 501 | 502 | -------------------------------------------------------------------------------- /Converting A PyTorch Model to Tensorflow pb using ONNX.md: -------------------------------------------------------------------------------- 1 | # Converting A PyTorch Model to Tensorflow pb using ONNX 2 | 3 |

pilgrim.bin@gmail.com

4 | 5 | [TOC] 6 | 7 | # 1. Pre-installation 8 | 9 | **Version Info** 10 | 11 | ``` 12 | pytorch 0.4.0 py27_cuda0.0_cudnn0.0_1 pytorch 13 | torchvision 0.2.1 py27_1 pytorch 14 | tensorflow 1.8.0 15 | onnx 1.2.2 16 | onnx-tf 1.1.2 17 | ``` 18 | 19 | 注意: 20 | 21 | 1. ONNX1.1.2版本太低会引发BatchNormalization错误,当前pip已经支持1.3.0版本;也可以考虑源码安装 `pip install -U git+https://github.com/onnx/onnx.git@master`。 22 | 2. 本实验验证ONNX1.2.2版本可正常运行 23 | 3. onnx-tf采用源码安装;要求 Tensorflow>=1.5.0.; 24 | 25 | 26 | # 2. 转换过程 27 | 28 | ## 2.1 Step 1.2.3. 29 | 30 | **pipeline: pytorch model --> onnx modle --> tensorflow graph pb.** 31 | 32 | ``` 33 | # step 1, load pytorch model and export onnx during running. 34 | modelname = 'resnet18' 35 | weightfile = 'models/model_best_checkpoint_resnet18.pth.tar' 36 | modelhandle = DIY_Model(modelname, weightfile, class_numbers) 37 | model = modelhandle.model 38 | #model.eval() # useless 39 | dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw 40 | onnx_filename = os.path.split(weightfile)[-1] + ".onnx" 41 | torch.onnx.export(model, dummy_input, 42 | onnx_filename, 43 | verbose=True) 44 | 45 | # step 2, create onnx_model using tensorflow as backend. check if right and export graph. 46 | onnx_model = onnx.load(onnx_filename) 47 | tf_rep = prepare(onnx_model, strict=False) 48 | # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False) 49 | # Reference https://github.com/onnx/onnx-tensorflow/issues/167 50 | #tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0' 51 | image = Image.open('pants.jpg') 52 | # debug, here using the same input to check onnx and tf. 53 | output_pytorch, img_np = modelhandle.process(image) 54 | print('output_pytorch = {}'.format(output_pytorch)) 55 | output_onnx_tf = tf_rep.run(img_np) 56 | print('output_onnx_tf = {}'.format(output_onnx_tf)) 57 | # onnx --> tf.graph.pb 58 | tf_pb_path = onnx_filename + '_graph.pb' 59 | tf_rep.export_graph(tf_pb_path) 60 | 61 | # step 3, check if tf.pb is right. 62 | with tf.Graph().as_default(): 63 | graph_def = tf.GraphDef() 64 | with open(tf_pb_path, "rb") as f: 65 | graph_def.ParseFromString(f.read()) 66 | tf.import_graph_def(graph_def, name="") 67 | with tf.Session() as sess: 68 | #init = tf.initialize_all_variables() 69 | init = tf.global_variables_initializer() 70 | #sess.run(init) 71 | 72 | # print all ops, check input/output tensor name. 73 | # uncomment it if you donnot know io tensor names. 74 | ''' 75 | print('-------------ops---------------------') 76 | op = sess.graph.get_operations() 77 | for m in op: 78 | print(m.values()) 79 | print('-------------ops done.---------------------') 80 | ''' 81 | 82 | input_x = sess.graph.get_tensor_by_name("0:0") # input 83 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5 84 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10 85 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np}) 86 | #output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)}) 87 | print('output_tf_pb = {}'.format(output_tf_pb)) 88 | ``` 89 | 90 | 91 | ## 2.2 Verification 92 | 93 | **确保输出结果一致** 94 | 95 | ``` 96 | output_pytorch = [array([ 2.5359073 , -1.4261041 , -5.2394 , -0.62402934, 4.7426634 ], dtype=float32), array([ 7.6249304, 5.1203837, 1.8118637, 1.5143847, -4.9409146, 1.1695148, -6.2375665, -1.6033885, -1.4286405, -2.964429 ], dtype=float32)] 97 | 98 | output_onnx_tf = Outputs(_0=array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), _1=array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)) 99 | 100 | output_tf_pb = [array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)] 101 | ``` 102 | 103 | **独立TF验证程序** 104 | 105 | 106 | ``` 107 | def get_img_np_nchw(filename): 108 | try: 109 | image = Image.open(filename).convert('RGB').resize((224, 224)) 110 | miu = np.array([0.485, 0.456, 0.406]) 111 | std = np.array([0.229, 0.224, 0.225]) 112 | #miu = np.array([0.5, 0.5, 0.5]) 113 | #std = np.array([0.22, 0.22, 0.22]) 114 | # img_np.shape = (224, 224, 3) 115 | img_np = np.array(image, dtype=float) / 255. 116 | r = (img_np[:,:,0] - miu[0]) / std[0] 117 | g = (img_np[:,:,1] - miu[1]) / std[1] 118 | b = (img_np[:,:,2] - miu[2]) / std[2] 119 | img_np_t = np.array([r,g,b]) 120 | img_np_nchw = np.expand_dims(img_np_t, axis=0) 121 | return img_np_nchw 122 | except: 123 | print("RuntimeError: get_img_np_nchw({}).".format(filename)) 124 | # NoneType 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | tf_pb_path = 'model_best_checkpoint_resnet18.pth.tar.onnx_graph.pb' 130 | 131 | filename = 'pants.jpg' 132 | img_np_nchw = get_img_np_nchw(filename) 133 | 134 | # step 3, check if tf.pb is right. 135 | with tf.Graph().as_default(): 136 | graph_def = tf.GraphDef() 137 | with open(tf_pb_path, "rb") as f: 138 | graph_def.ParseFromString(f.read()) 139 | tf.import_graph_def(graph_def, name="") 140 | with tf.Session() as sess: 141 | init = tf.global_variables_initializer() 142 | #init = tf.initialize_all_variables() 143 | sess.run(init) 144 | 145 | # print all ops, check input/output tensor name. 146 | # uncomment it if you donnot know io tensor names. 147 | ''' 148 | print('-------------ops---------------------') 149 | op = sess.graph.get_operations() 150 | for m in op: 151 | print(m.values()) 152 | print('-------------ops done.---------------------') 153 | ''' 154 | 155 | input_x = sess.graph.get_tensor_by_name("0:0") # input 156 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5 157 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10 158 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np_nchw}) 159 | print('output_tf_pb = {}'.format(output_tf_pb)) 160 | ``` 161 | 162 | # 3. Related Info 163 | 164 | ## 3.1 ONNX 165 | 166 | Open Neural Network Exchange 167 | https://github.com/onnx 168 | https://onnx.ai/ 169 | 170 | The ONNX exporter is a ==**trace-based**== exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. [Limitations](https://pytorch.org/docs/stable/onnx.html#example-end-to-end-alexnet-from-pytorch-to-caffe2) 171 | 172 | https://github.com/onnx/tensorflow-onnx 173 | https://github.com/onnx/onnx-tensorflow 174 | 175 | ## 3.2 Microsoft/MMdnn 176 | 177 | 当前网络没有调通 178 | https://github.com/Microsoft/MMdnn/blob/master/mmdnn/conversion/pytorch/README.md 179 | 180 | # Reference 181 | 182 | 1. Open Neural Network Exchange https://github.com/onnx 183 | 2. [Exporting model from PyTorch to ONNX](https://github.com/onnx/tutorials/blob/master/tutorials/PytorchOnnxExport.ipynb) 184 | 3. [Importing ONNX models to Tensorflow(ONNX)](https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowImport.ipynb) 185 | 4. [Tensorflow + tornado服务](https://zhuanlan.zhihu.com/p/26136080) 186 | 5. [graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())](https://github.com/llSourcell/tensorflow_image_classifier/blob/master/src/label_image.py) 187 | 6. [A Tool Developer's Guide to TensorFlow Model Files](https://www.tensorflow.org/extend/tool_developers/) 188 | 7. [TensorFlow学习笔记:Retrain Inception_v3](https://www.jianshu.com/p/613c3b08faea) 189 | 190 | 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Converting A PyTorch Model to Tensorflow pb using ONNX 2 | 3 |

pilgrim.bin@gmail.com

4 | 5 | **有必要说在前面,避免后来者陷坑:** 6 | 7 | **ONNX本来是Facebook联合AWS对抗Tensorflow的,所以注定ONNX-TF这件事是奸情,这是ONNX和TF偷情的行为,两个平台都不会为他们背书;Pytorch和Tensorflow各自在独立演变,动态图和静态图优化两者不会停战。如果你在尝试转模型这件事情,觉得你有必要考虑:1.让服务部署平台支持Pytorch; 2.转训练平台到TF; 3.这件事是一锤子买卖,干完就不要再倒腾了。**; 8 | 9 | 本Demo所使用模型来自:https://github.com/cinastanbean/Pytorch-Multi-Task-Multi-class-Classification 10 | 11 | 12 | [TOC] 13 | 14 | # 1. Pre-installation 15 | 16 | **Version Info** 17 | 18 | ``` 19 | pytorch 0.4.0 py27_cuda0.0_cudnn0.0_1 pytorch 20 | torchvision 0.2.1 py27_1 pytorch 21 | tensorflow 1.8.0 22 | onnx 1.2.2 23 | onnx-tf 1.1.2 24 | ``` 25 | 26 | 注意: 27 | 28 | 1. ONNX1.1.2版本太低会引发BatchNormalization错误,当前pip已经支持1.3.0版本;也可以考虑源码安装 `pip install -U git+https://github.com/onnx/onnx.git@master`。 29 | 2. 本实验验证ONNX1.2.2版本可正常运行 30 | 3. onnx-tf采用源码安装;要求 Tensorflow>=1.5.0.; 31 | 32 | 33 | # 2. 转换过程 34 | 35 | ## 2.1 Step 1.2.3. 36 | 37 | **pipeline: pytorch model --> onnx modle --> tensorflow graph pb.** 38 | 39 | ``` 40 | # step 1, load pytorch model and export onnx during running. 41 | modelname = 'resnet18' 42 | weightfile = 'models/model_best_checkpoint_resnet18.pth.tar' 43 | modelhandle = DIY_Model(modelname, weightfile, class_numbers) 44 | model = modelhandle.model 45 | #model.eval() # useless 46 | dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw 47 | onnx_filename = os.path.split(weightfile)[-1] + ".onnx" 48 | torch.onnx.export(model, dummy_input, 49 | onnx_filename, 50 | verbose=True) 51 | 52 | # step 2, create onnx_model using tensorflow as backend. check if right and export graph. 53 | onnx_model = onnx.load(onnx_filename) 54 | tf_rep = prepare(onnx_model, strict=False) 55 | # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False) 56 | # Reference https://github.com/onnx/onnx-tensorflow/issues/167 57 | #tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0' 58 | image = Image.open('pants.jpg') 59 | # debug, here using the same input to check onnx and tf. 60 | output_pytorch, img_np = modelhandle.process(image) 61 | print('output_pytorch = {}'.format(output_pytorch)) 62 | output_onnx_tf = tf_rep.run(img_np) 63 | print('output_onnx_tf = {}'.format(output_onnx_tf)) 64 | # onnx --> tf.graph.pb 65 | tf_pb_path = onnx_filename + '_graph.pb' 66 | tf_rep.export_graph(tf_pb_path) 67 | 68 | # step 3, check if tf.pb is right. 69 | with tf.Graph().as_default(): 70 | graph_def = tf.GraphDef() 71 | with open(tf_pb_path, "rb") as f: 72 | graph_def.ParseFromString(f.read()) 73 | tf.import_graph_def(graph_def, name="") 74 | with tf.Session() as sess: 75 | #init = tf.initialize_all_variables() 76 | init = tf.global_variables_initializer() 77 | #sess.run(init) 78 | 79 | # print all ops, check input/output tensor name. 80 | # uncomment it if you donnot know io tensor names. 81 | ''' 82 | print('-------------ops---------------------') 83 | op = sess.graph.get_operations() 84 | for m in op: 85 | print(m.values()) 86 | print('-------------ops done.---------------------') 87 | ''' 88 | 89 | input_x = sess.graph.get_tensor_by_name("0:0") # input 90 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5 91 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10 92 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np}) 93 | #output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)}) 94 | print('output_tf_pb = {}'.format(output_tf_pb)) 95 | ``` 96 | 97 | 98 | ## 2.2 Verification 99 | 100 | **确保输出结果一致** 101 | 102 | ``` 103 | output_pytorch = [array([ 2.5359073 , -1.4261041 , -5.2394 , -0.62402934, 4.7426634 ], dtype=float32), array([ 7.6249304, 5.1203837, 1.8118637, 1.5143847, -4.9409146, 1.1695148, -6.2375665, -1.6033885, -1.4286405, -2.964429 ], dtype=float32)] 104 | 105 | output_onnx_tf = Outputs(_0=array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), _1=array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)) 106 | 107 | output_tf_pb = [array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)] 108 | ``` 109 | 110 | **独立TF验证程序** 111 | 112 | 113 | ``` 114 | def get_img_np_nchw(filename): 115 | try: 116 | image = Image.open(filename).convert('RGB').resize((224, 224)) 117 | miu = np.array([0.485, 0.456, 0.406]) 118 | std = np.array([0.229, 0.224, 0.225]) 119 | #miu = np.array([0.5, 0.5, 0.5]) 120 | #std = np.array([0.22, 0.22, 0.22]) 121 | # img_np.shape = (224, 224, 3) 122 | img_np = np.array(image, dtype=float) / 255. 123 | r = (img_np[:,:,0] - miu[0]) / std[0] 124 | g = (img_np[:,:,1] - miu[1]) / std[1] 125 | b = (img_np[:,:,2] - miu[2]) / std[2] 126 | img_np_t = np.array([r,g,b]) 127 | img_np_nchw = np.expand_dims(img_np_t, axis=0) 128 | return img_np_nchw 129 | except: 130 | print("RuntimeError: get_img_np_nchw({}).".format(filename)) 131 | # NoneType 132 | 133 | 134 | if __name__ == '__main__': 135 | 136 | tf_pb_path = 'model_best_checkpoint_resnet18.pth.tar.onnx_graph.pb' 137 | 138 | filename = 'pants.jpg' 139 | img_np_nchw = get_img_np_nchw(filename) 140 | 141 | # step 3, check if tf.pb is right. 142 | with tf.Graph().as_default(): 143 | graph_def = tf.GraphDef() 144 | with open(tf_pb_path, "rb") as f: 145 | graph_def.ParseFromString(f.read()) 146 | tf.import_graph_def(graph_def, name="") 147 | with tf.Session() as sess: 148 | init = tf.global_variables_initializer() 149 | #init = tf.initialize_all_variables() 150 | sess.run(init) 151 | 152 | # print all ops, check input/output tensor name. 153 | # uncomment it if you donnot know io tensor names. 154 | ''' 155 | print('-------------ops---------------------') 156 | op = sess.graph.get_operations() 157 | for m in op: 158 | print(m.values()) 159 | print('-------------ops done.---------------------') 160 | ''' 161 | 162 | input_x = sess.graph.get_tensor_by_name("0:0") # input 163 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5 164 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10 165 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np_nchw}) 166 | print('output_tf_pb = {}'.format(output_tf_pb)) 167 | ``` 168 | 169 | # 3. Related Info 170 | 171 | ## 3.1 ONNX 172 | 173 | Open Neural Network Exchange 174 | https://github.com/onnx 175 | https://onnx.ai/ 176 | 177 | The ONNX exporter is a ==**trace-based**== exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. [Limitations](https://pytorch.org/docs/stable/onnx.html#example-end-to-end-alexnet-from-pytorch-to-caffe2) 178 | 179 | https://github.com/onnx/tensorflow-onnx 180 | https://github.com/onnx/onnx-tensorflow 181 | 182 | ## 3.2 Microsoft/MMdnn 183 | 184 | 当前网络没有调通 185 | https://github.com/Microsoft/MMdnn/blob/master/mmdnn/conversion/pytorch/README.md 186 | 187 | # Reference 188 | 189 | 1. Open Neural Network Exchange https://github.com/onnx 190 | 2. [Exporting model from PyTorch to ONNX](https://github.com/onnx/tutorials/blob/master/tutorials/PytorchOnnxExport.ipynb) 191 | 3. [Importing ONNX models to Tensorflow(ONNX)](https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowImport.ipynb) 192 | 4. [Tensorflow + tornado服务](https://zhuanlan.zhihu.com/p/26136080) 193 | 5. [graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())](https://github.com/llSourcell/tensorflow_image_classifier/blob/master/src/label_image.py) 194 | 6. [A Tool Developer's Guide to TensorFlow Model Files](https://www.tensorflow.org/extend/tool_developers/) 195 | 7. [TensorFlow学习笔记:Retrain Inception_v3](https://www.jianshu.com/p/613c3b08faea) 196 | 197 | 198 | -------------------------------------------------------------------------------- /convert_pytorch2onnx2tfpb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Aug 22 15:18:10 2018 5 | 6 | @author: pilgrim.bin@gmail.com 7 | """ 8 | 9 | import os 10 | import random 11 | import shutil 12 | 13 | import numpy as np 14 | from PIL import Image 15 | 16 | # model 17 | from diymodel import DIY_Model 18 | 19 | # onnx - step 1 20 | from torch.autograd import Variable 21 | import torch.onnx 22 | 23 | # onnx - step 2 24 | import onnx 25 | from onnx_tf.backend import prepare 26 | 27 | # 28 | import tensorflow as tf 29 | 30 | mlmc_tree = { 31 | 'length': {'c5_changku': 4, 'c2_5fenku': 1, 'c1_duanku': 0, 'c3_7fenku': 2, 'c4_9fenku': 3}, 32 | 'style': {'F5_Denglong': 4, 'F7_Kuotui': 6, 'LT_Lianti': 9, 'F3_Zhitong': 2, 'LT_Beidai': 8, 'F4_Kuansong': 3, 'F2_Xiaojiao': 1, 'F8_Laba': 7, 'F6_Halun': 5, 'F1_JinshenQianbi': 0}} 33 | #INFO: = mlmcdataloader.label_to_idx = {'length': 0, 'style': 1} 34 | 35 | class_numbers = [] 36 | for key in sorted(mlmc_tree.keys()): 37 | class_numbers.append(len(mlmc_tree[key])) 38 | 39 | print('------- = {}'.format(class_numbers)) 40 | 41 | def get_label_idx(label): 42 | idx = 0 43 | for key in mlmc_tree.keys(): 44 | if label in mlmc_tree[key].keys(): 45 | return idx 46 | idx += 1 47 | return None 48 | 49 | 50 | # usage: is_allowed_extension(filename, IMG_EXTENSIONS) 51 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 52 | extensions = IMG_EXTENSIONS 53 | def is_allowed_extension(filename, extensions): 54 | filename_lower = filename.lower() 55 | return any([filename_lower.endswith(ext) for ext in extensions]) 56 | 57 | 58 | 59 | def get_filelist(path): 60 | filelist = [] 61 | for root,dirs,filenames in os.walk(path): 62 | for fn in filenames: 63 | this_path = os.path.join(root,fn) 64 | filelist.append(this_path) 65 | return filelist 66 | 67 | # usage: mkdir_if_not_exist([root, dir]) 68 | def mkdir_if_not_exist(path): 69 | if not os.path.exists(os.path.join(*path)): 70 | os.makedirs(os.path.join(*path)) 71 | 72 | 73 | def get_dict_key(dict, value): 74 | for k in dict.keys(): 75 | if dict[k] == value: 76 | return k 77 | return None 78 | 79 | def load_image_into_numpy_array(image): 80 | (im_width, im_height) = image.size 81 | return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8) 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | # pipeline: pytorch model --> onnx modle --> tensorflow graph pb. 87 | 88 | # step 1, load pytorch model and export onnx during running. 89 | modelname = 'resnet18' 90 | weightfile = 'models/model_best_checkpoint_resnet18.pth.tar' 91 | modelhandle = DIY_Model(modelname, weightfile, class_numbers) 92 | model = modelhandle.model 93 | #model.eval() # useless 94 | dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw 95 | onnx_filename = os.path.split(weightfile)[-1] + ".onnx" 96 | torch.onnx.export(model, dummy_input, 97 | onnx_filename, 98 | verbose=True) 99 | 100 | # step 2, create onnx_model using tensorflow as backend. check if right and export graph. 101 | onnx_model = onnx.load(onnx_filename) 102 | tf_rep = prepare(onnx_model, strict=False) 103 | # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False) 104 | # Reference https://github.com/onnx/onnx-tensorflow/issues/167 105 | #tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0' 106 | image = Image.open('pants.jpg') 107 | # debug, here using the same input to check onnx and tf. 108 | output_pytorch, img_np = modelhandle.process(image) 109 | print('output_pytorch = {}'.format(output_pytorch)) 110 | output_onnx_tf = tf_rep.run(img_np) 111 | print('output_onnx_tf = {}'.format(output_onnx_tf)) 112 | # onnx --> tf.graph.pb 113 | tf_pb_path = onnx_filename + '_graph.pb' 114 | tf_rep.export_graph(tf_pb_path) 115 | 116 | # step 3, check if tf.pb is right. 117 | with tf.Graph().as_default(): 118 | graph_def = tf.GraphDef() 119 | with open(tf_pb_path, "rb") as f: 120 | graph_def.ParseFromString(f.read()) 121 | tf.import_graph_def(graph_def, name="") 122 | with tf.Session() as sess: 123 | #init = tf.initialize_all_variables() 124 | init = tf.global_variables_initializer() 125 | #sess.run(init) 126 | 127 | # print all ops, check input/output tensor name. 128 | # uncomment it if you donnot know io tensor names. 129 | ''' 130 | print('-------------ops---------------------') 131 | op = sess.graph.get_operations() 132 | for m in op: 133 | print(m.values()) 134 | print('-------------ops done.---------------------') 135 | ''' 136 | 137 | input_x = sess.graph.get_tensor_by_name("0:0") # input 138 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5 139 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10 140 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np}) 141 | #output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)}) 142 | print('output_tf_pb = {}'.format(output_tf_pb)) 143 | -------------------------------------------------------------------------------- /diymodel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Aug 13 12:23:29 2018 5 | 6 | @author: pilgrim.bin@gmail.com 7 | """ 8 | import os 9 | import sys 10 | 11 | import torch 12 | import torchvision.models as models 13 | import torchvision.transforms as transforms 14 | 15 | import mlmcmodel 16 | 17 | modelnames = sorted(name for name in models.__dict__ 18 | if name.islower() and not name.startswith("__") 19 | and callable(models.__dict__[name])) 20 | 21 | 22 | def base_model(modelname): 23 | for mn in modelnames: 24 | if mn in modelname: 25 | return mn 26 | return None 27 | 28 | 29 | ''' 30 | if saving the model using nn.DataParallel, which stores the model in module, 31 | we should convert the keys "module.***" -> "***" when trying to 32 | load it without DataParallel 33 | ''' 34 | from collections import OrderedDict 35 | def cvt_state_dict(state_dict): 36 | if not state_dict.keys()[0].startswith('module.'): 37 | return state_dict 38 | # create new OrderedDict that does not contain 'module'. 39 | new_state_dict = OrderedDict() 40 | for k, v in state_dict.items(): 41 | name = k[7:] # remove `module.` 42 | new_state_dict[name] = v 43 | return new_state_dict 44 | 45 | 46 | 47 | class DIY_Model(): 48 | def __init__(self, modelname, weightfile, class_numbers, gpus=None): 49 | # input check 50 | bm = base_model(modelname) 51 | if bm is None: 52 | raise(RuntimeError("Error: invalid modelname = {}".format(modelname))) 53 | if not os.path.exists(weightfile): 54 | raise(RuntimeError("Error: weightfile is not existed = {}".format(weightfile))) 55 | 56 | # create model @ both using in inception/renet18 57 | self.bm = bm 58 | self.model = models.__dict__[bm]() 59 | fc_features = self.model.fc.in_features 60 | self.model.fc = mlmcmodel.BuildMultiLabelModel(fc_features, class_numbers) 61 | 62 | ''' 63 | if 'inception' in bm: 64 | # auxiliary fc 65 | aux_logits_fc_features = self.model.AuxLogits.fc.in_features 66 | self.model.AuxLogits.fc = nn.Linear(\ 67 | aux_logits_fc_features, out_features=class_number, bias=True) 68 | ''' 69 | if torch.cuda.is_available(): 70 | self.model.cuda() 71 | else: 72 | self.model.cpu() 73 | 74 | # load model weight 75 | if torch.cuda.is_available(): 76 | checkpoint = torch.load(weightfile) 77 | else: 78 | checkpoint = torch.load(weightfile, map_location='cpu') 79 | 80 | self.model.load_state_dict( 81 | cvt_state_dict(checkpoint['state_dict'])) 82 | print("=> loaded checkpoint '{}'.".format(weightfile)) 83 | 84 | # switch to evaluate mode 85 | self.model.eval() 86 | 87 | # preprocess transform 88 | 89 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 90 | #normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.22, 0.22, 0.22]) 91 | if 'inception' in bm: 92 | self.transform = transforms.Compose([ 93 | transforms.Resize(320), 94 | transforms.CenterCrop(299), 95 | transforms.ToTensor(), 96 | normalize, 97 | ]) 98 | else: # resnet18 99 | self.transform = transforms.Compose([ 100 | transforms.Resize(224), # raw = 256 101 | transforms.CenterCrop(224), 102 | #transforms.Resize(224), 103 | transforms.ToTensor(), 104 | normalize, 105 | ]) 106 | 107 | 108 | def process(self, img): 109 | input = self.transform(img) 110 | 111 | if 'inception' in self.bm: 112 | input = input.reshape([1,3,299,299]) 113 | else: 114 | input = input.reshape([1,3,224,224]) 115 | 116 | if torch.cuda.is_available(): 117 | output = self.model(input.cuda()) 118 | else: 119 | output = self.model(input) 120 | 121 | result_list = [] 122 | for i in range(len(output)): 123 | result_list.append(output[i][0].cpu().detach().numpy()) 124 | 125 | # using for deployment 126 | #return result_list 127 | 128 | # using for onnx 129 | return (result_list, input.cpu().detach().numpy()) 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /mlmcmodel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jul 30 20:15:49 2018 5 | 6 | @author: pilgrim.bin@gmail.com 7 | """ 8 | import torch.nn as nn 9 | from torch.nn import init 10 | from torch.nn import functional as F 11 | 12 | class MultiLabelModel(nn.Module): 13 | def __init__(self, basemodel_output, num_classes, basemodel=None): 14 | super(MultiLabelModel, self).__init__() 15 | self.basemodel = basemodel 16 | self.num_classes = num_classes 17 | 18 | # config 19 | self.cfg_normalize = False # unchecked other method, diff with embedding. 20 | self.cfg_has_embedding = True 21 | self.cfg_num_features = basemodel_output # is there a better number? 22 | self.cfg_dropout_ratio = 0. # 0. is better than 0.8 at attributes:pants problem 23 | 24 | # diy head 25 | for index, num_class in enumerate(num_classes): 26 | if self.cfg_has_embedding: 27 | setattr(self, "EmbeddingFeature_FCLayer_" + str(index), nn.Linear(basemodel_output, self.cfg_num_features)) 28 | setattr(self, "EmbeddingFeature_FCLayer_BN_" + str(index), nn.BatchNorm1d(self.cfg_num_features)) 29 | feat = getattr(self, "EmbeddingFeature_FCLayer_" + str(index)) 30 | feat_bn = getattr(self, "EmbeddingFeature_FCLayer_BN_" + str(index)) 31 | init.kaiming_normal(feat.weight, mode='fan_out') 32 | init.constant(feat.bias, 0) 33 | init.constant(feat_bn.weight, 1) 34 | init.constant(feat_bn.bias, 0) 35 | if self.cfg_dropout_ratio > 0: 36 | setattr(self, "Dropout_" + str(index), nn.Dropout(self.cfg_dropout_ratio)) 37 | setattr(self, "FullyConnectedLayer_" + str(index), nn.Linear(self.cfg_num_features, num_class)) 38 | classifier = getattr(self, "FullyConnectedLayer_" + str(index)) 39 | init.normal(classifier.weight, std=0.001) 40 | init.constant(classifier.bias, 0) 41 | 42 | def forward(self, x): 43 | if self.basemodel is not None: 44 | x = self.basemodel.forward(x) 45 | outs = list() 46 | for index, num_class in enumerate(self.num_classes): 47 | if self.cfg_has_embedding: 48 | feat = getattr(self, "EmbeddingFeature_FCLayer_" + str(index)) 49 | feat_bn = getattr(self, "EmbeddingFeature_FCLayer_BN_" + str(index)) 50 | x = feat(x) 51 | x = feat_bn(x) 52 | if self.cfg_normalize: 53 | x = F.normalize(x) # getattr bug 54 | elif self.cfg_has_embedding: 55 | x = F.relu(x) 56 | if self.cfg_dropout_ratio > 0: 57 | dropout = getattr(self, "Dropout_" + str(index)) 58 | x = dropout(x) 59 | classifier = getattr(self, "FullyConnectedLayer_" + str(index)) 60 | out = classifier(x) 61 | outs.append(out) 62 | return outs 63 | 64 | 65 | def LoadPretrainedModel(model, pretrained_state_dict): 66 | model_dict = model.state_dict() 67 | union_dict = {k : v for k,v in pretrained_state_dict.iteritems() if k in model_dict} 68 | model_dict.update(union_dict) 69 | return model_dict 70 | 71 | def BuildMultiLabelModel(basemodel_output, num_classes, basemodel=None): 72 | return MultiLabelModel(basemodel_output, num_classes, basemodel=basemodel) 73 | 74 | '''----------------------------------------------------------------------------------------------------''' 75 | 76 | # original version of https://github.com/pangwong/pytorch-multi-label-classifier.git 77 | ''' 78 | import torch.nn as nn 79 | 80 | class MultiLabelModel(nn.Module): 81 | def __init__(self, basemodel, basemodel_output, num_classes): 82 | super(MultiLabelModel, self).__init__() 83 | self.basemodel = basemodel 84 | self.num_classes = num_classes 85 | for index, num_class in enumerate(num_classes): 86 | setattr(self, "FullyConnectedLayer_" + str(index), nn.Linear(basemodel_output, num_class)) 87 | 88 | def forward(self, x): 89 | x = self.basemodel.forward(x) 90 | outs = list() 91 | dir(self) 92 | for index, num_class in enumerate(self.num_classes): 93 | fun = eval("self.FullyConnectedLayer_" + str(index)) 94 | out = fun(x) 95 | outs.append(out) 96 | return outs 97 | 98 | def LoadPretrainedModel(model, pretrained_state_dict): 99 | model_dict = model.state_dict() 100 | union_dict = {k : v for k,v in pretrained_state_dict.iteritems() if k in model_dict} 101 | model_dict.update(union_dict) 102 | return model_dict 103 | 104 | def BuildMultiLabelModel(basemodel, basemodel_output, num_classes): 105 | return MultiLabelModel(basemodel, basemodel_output, num_classes) 106 | 107 | ''' 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /pants.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cinastanbean/pytorch-onnx-tensorflow-pb/15f344b40bb59cffc386447df6962fd1c00ff35e/pants.jpg --------------------------------------------------------------------------------