├── Converting A PyTorch Model to Tensorflow pb using ONNX.html ├── Converting A PyTorch Model to Tensorflow pb using ONNX.md ├── README.md ├── convert_pytorch2onnx2tfpb.py ├── diymodel.py ├── mlmcmodel.py └── pants.jpg /Converting A PyTorch Model to Tensorflow pb using ONNX.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |pilgrim.bin@gmail.com
283 | 284 |Version Info
322 | 323 |pytorch 0.4.0 py27_cuda0.0_cudnn0.0_1 pytorch
324 | torchvision 0.2.1 py27_1 pytorch
325 | tensorflow 1.8.0 <pip>
326 | onnx 1.2.2 <pip>
327 | onnx-tf 1.1.2 <pip>
328 |
329 |
330 | 注意:
331 | 332 |pip install -U git+https://github.com/onnx/onnx.git@master
。pipeline: pytorch model --> onnx modle --> tensorflow graph pb.
343 | 344 |# step 1, load pytorch model and export onnx during running.
345 | modelname = 'resnet18'
346 | weightfile = 'models/model_best_checkpoint_resnet18.pth.tar'
347 | modelhandle = DIY_Model(modelname, weightfile, class_numbers)
348 | model = modelhandle.model
349 | #model.eval() # useless
350 | dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw
351 | onnx_filename = os.path.split(weightfile)[-1] + ".onnx"
352 | torch.onnx.export(model, dummy_input,
353 | onnx_filename,
354 | verbose=True)
355 |
356 | # step 2, create onnx_model using tensorflow as backend. check if right and export graph.
357 | onnx_model = onnx.load(onnx_filename)
358 | tf_rep = prepare(onnx_model, strict=False)
359 | # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
360 | # Reference https://github.com/onnx/onnx-tensorflow/issues/167
361 | #tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
362 | image = Image.open('pants.jpg')
363 | # debug, here using the same input to check onnx and tf.
364 | output_pytorch, img_np = modelhandle.process(image)
365 | print('output_pytorch = {}'.format(output_pytorch))
366 | output_onnx_tf = tf_rep.run(img_np)
367 | print('output_onnx_tf = {}'.format(output_onnx_tf))
368 | # onnx --> tf.graph.pb
369 | tf_pb_path = onnx_filename + '_graph.pb'
370 | tf_rep.export_graph(tf_pb_path)
371 |
372 | # step 3, check if tf.pb is right.
373 | with tf.Graph().as_default():
374 | graph_def = tf.GraphDef()
375 | with open(tf_pb_path, "rb") as f:
376 | graph_def.ParseFromString(f.read())
377 | tf.import_graph_def(graph_def, name="")
378 | with tf.Session() as sess:
379 | #init = tf.initialize_all_variables()
380 | init = tf.global_variables_initializer()
381 | #sess.run(init)
382 |
383 | # print all ops, check input/output tensor name.
384 | # uncomment it if you donnot know io tensor names.
385 | '''
386 | print('-------------ops---------------------')
387 | op = sess.graph.get_operations()
388 | for m in op:
389 | print(m.values())
390 | print('-------------ops done.---------------------')
391 | '''
392 |
393 | input_x = sess.graph.get_tensor_by_name("0:0") # input
394 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
395 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
396 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np})
397 | #output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)})
398 | print('output_tf_pb = {}'.format(output_tf_pb))
399 |
400 |
401 | 确保输出结果一致
404 | 405 |output_pytorch = [array([ 2.5359073 , -1.4261041 , -5.2394 , -0.62402934, 4.7426634 ], dtype=float32), array([ 7.6249304, 5.1203837, 1.8118637, 1.5143847, -4.9409146, 1.1695148, -6.2375665, -1.6033885, -1.4286405, -2.964429 ], dtype=float32)]
406 |
407 | output_onnx_tf = Outputs(_0=array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), _1=array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32))
408 |
409 | output_tf_pb = [array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)]
410 |
411 |
412 | 独立TF验证程序
413 | 414 |def get_img_np_nchw(filename):
415 | try:
416 | image = Image.open(filename).convert('RGB').resize((224, 224))
417 | miu = np.array([0.485, 0.456, 0.406])
418 | std = np.array([0.229, 0.224, 0.225])
419 | #miu = np.array([0.5, 0.5, 0.5])
420 | #std = np.array([0.22, 0.22, 0.22])
421 | # img_np.shape = (224, 224, 3)
422 | img_np = np.array(image, dtype=float) / 255.
423 | r = (img_np[:,:,0] - miu[0]) / std[0]
424 | g = (img_np[:,:,1] - miu[1]) / std[1]
425 | b = (img_np[:,:,2] - miu[2]) / std[2]
426 | img_np_t = np.array([r,g,b])
427 | img_np_nchw = np.expand_dims(img_np_t, axis=0)
428 | return img_np_nchw
429 | except:
430 | print("RuntimeError: get_img_np_nchw({}).".format(filename))
431 | # NoneType
432 |
433 |
434 | if __name__ == '__main__':
435 |
436 | tf_pb_path = 'model_best_checkpoint_resnet18.pth.tar.onnx_graph.pb'
437 |
438 | filename = 'pants.jpg'
439 | img_np_nchw = get_img_np_nchw(filename)
440 |
441 | # step 3, check if tf.pb is right.
442 | with tf.Graph().as_default():
443 | graph_def = tf.GraphDef()
444 | with open(tf_pb_path, "rb") as f:
445 | graph_def.ParseFromString(f.read())
446 | tf.import_graph_def(graph_def, name="")
447 | with tf.Session() as sess:
448 | init = tf.global_variables_initializer()
449 | #init = tf.initialize_all_variables()
450 | sess.run(init)
451 |
452 | # print all ops, check input/output tensor name.
453 | # uncomment it if you donnot know io tensor names.
454 | '''
455 | print('-------------ops---------------------')
456 | op = sess.graph.get_operations()
457 | for m in op:
458 | print(m.values())
459 | print('-------------ops done.---------------------')
460 | '''
461 |
462 | input_x = sess.graph.get_tensor_by_name("0:0") # input
463 | outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
464 | outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
465 | output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np_nchw})
466 | print('output_tf_pb = {}'.format(output_tf_pb))
467 |
468 |
469 | Open Neural Network Exchange
474 | https://github.com/onnx
475 | https://onnx.ai/
The ONNX exporter is a trace-based exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. Limitations
478 | 479 |https://github.com/onnx/tensorflow-onnx
480 | https://github.com/onnx/onnx-tensorflow
当前网络没有调通
485 | https://github.com/Microsoft/MMdnn/blob/master/mmdnn/conversion/pytorch/README.md
pilgrim.bin@gmail.com
4 | 5 | [TOC] 6 | 7 | # 1. Pre-installation 8 | 9 | **Version Info** 10 | 11 | ``` 12 | pytorch 0.4.0 py27_cuda0.0_cudnn0.0_1 pytorch 13 | torchvision 0.2.1 py27_1 pytorch 14 | tensorflow 1.8.0pilgrim.bin@gmail.com
4 | 5 | **有必要说在前面,避免后来者陷坑:** 6 | 7 | **ONNX本来是Facebook联合AWS对抗Tensorflow的,所以注定ONNX-TF这件事是奸情,这是ONNX和TF偷情的行为,两个平台都不会为他们背书;Pytorch和Tensorflow各自在独立演变,动态图和静态图优化两者不会停战。如果你在尝试转模型这件事情,觉得你有必要考虑:1.让服务部署平台支持Pytorch; 2.转训练平台到TF; 3.这件事是一锤子买卖,干完就不要再倒腾了。**; 8 | 9 | 本Demo所使用模型来自:https://github.com/cinastanbean/Pytorch-Multi-Task-Multi-class-Classification 10 | 11 | 12 | [TOC] 13 | 14 | # 1. Pre-installation 15 | 16 | **Version Info** 17 | 18 | ``` 19 | pytorch 0.4.0 py27_cuda0.0_cudnn0.0_1 pytorch 20 | torchvision 0.2.1 py27_1 pytorch 21 | tensorflow 1.8.0