├── .dockerignore ├── images ├── p1.jpg ├── p2.jpg ├── p3.jpg ├── golf.jpg ├── hand1.jpg ├── hand2.jpg ├── hand1_small.jpg └── valid_person1.jpg ├── etcs ├── inference_result2.png ├── openpose_macbook_cmu.gif ├── openpose_tx2_mobilenet3.gif ├── openpose_macbook_mobilenet3.gif ├── openpose_macbook13_mobilenet2.gif ├── feature.md └── training.md ├── models ├── numpy │ └── download.sh └── pretrained │ ├── mobilenet_v1_1.0_224_2017_06_14 │ └── download.sh │ ├── mobilenet_v1_0.50_224_2017_06_14 │ └── download.sh │ └── mobilenet_v1_0.75_224_2017_06_14 │ └── download.sh ├── requirements.txt ├── .gitattributes ├── convert ├── resources │ └── batch_normalization.png ├── .gitignore ├── inference_by_keras.py ├── README.md └── tensorToKeras.py ├── pose_datamaster.py ├── Dockerfile ├── .gitignore ├── pose_dataworker.py ├── networks.py ├── common_test.py ├── pose_stats.py ├── datum_pb2.py ├── realtime_webcam.py ├── inference.py ├── network_mobilenet.py ├── README.md ├── network_cmu.py ├── network_dsconv.py ├── common.py ├── pose_augment.py ├── LICENSE ├── network_base.py ├── train.py └── pose_dataset.py /.dockerignore: -------------------------------------------------------------------------------- 1 | ./models 2 | ./models/* 3 | models 4 | ./tests 5 | ./tests/* 6 | tests -------------------------------------------------------------------------------- /images/p1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/p1.jpg -------------------------------------------------------------------------------- /images/p2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/p2.jpg -------------------------------------------------------------------------------- /images/p3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/p3.jpg -------------------------------------------------------------------------------- /images/golf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/golf.jpg -------------------------------------------------------------------------------- /images/hand1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/hand1.jpg -------------------------------------------------------------------------------- /images/hand2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/hand2.jpg -------------------------------------------------------------------------------- /images/hand1_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/hand1_small.jpg -------------------------------------------------------------------------------- /images/valid_person1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/images/valid_person1.jpg -------------------------------------------------------------------------------- /etcs/inference_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/etcs/inference_result2.png -------------------------------------------------------------------------------- /models/numpy/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://www.dropbox.com/s/xh5s7sb7remu8tx/openpose_coco.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | lmdb 3 | matplotlib 4 | scipy 5 | git+https://github.com/ppwwyyxx/tensorpack.git -------------------------------------------------------------------------------- /etcs/openpose_macbook_cmu.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/etcs/openpose_macbook_cmu.gif -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | models/numpy/*.npy filter=lfs diff=lfs merge=lfs -text 2 | *.ckpt* filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /etcs/openpose_tx2_mobilenet3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/etcs/openpose_tx2_mobilenet3.gif -------------------------------------------------------------------------------- /etcs/openpose_macbook_mobilenet3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/etcs/openpose_macbook_mobilenet3.gif -------------------------------------------------------------------------------- /etcs/openpose_macbook13_mobilenet2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/etcs/openpose_macbook13_mobilenet2.gif -------------------------------------------------------------------------------- /convert/resources/batch_normalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocom-tpo/tf-openpose/HEAD/convert/resources/batch_normalization.png -------------------------------------------------------------------------------- /convert/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | !.vscode/settings.json 3 | !.vscode/tasks.json 4 | !.vscode/launch.json 5 | !.vscode/extensions.json 6 | 7 | output/* 8 | models/* -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_1.0_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 4 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_1.0_224.ckpt.index 5 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_1.0_224.ckpt.meta -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.50_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_0.50_224.ckpt.data-00000-of-00001 4 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_0.50_224.ckpt.index 5 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_0.50_224.ckpt.meta -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.75_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_0.75_224.ckpt.data-00000-of-00001 4 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_0.75_224.ckpt.index 5 | wget http://gpu-twg.kakaocdn.net/braincloud/models/mobilenet/mobilenet_v1_0.75_224.ckpt.meta -------------------------------------------------------------------------------- /etcs/feature.md: -------------------------------------------------------------------------------- 1 | ## Features 2 | 3 | - [x] CMU's original network architecture and weights. 4 | 5 | - [x] Transfer Original Weights to Tensorflow 6 | 7 | - [x] Training Code with multi-gpus 8 | 9 | - [ ] Evaluate with test dataset 10 | 11 | - [ ] Inference 12 | 13 | - [x] Post processing from network output. 14 | 15 | - [ ] Faster post-processing 16 | 17 | - [ ] Multi-Scale Inference 18 | 19 | - [x] Faster network variants using custom mobilenet architecture. 20 | 21 | - [x] Depthwise Separable Convolution Version 22 | 23 | - [x] Mobilenet Version 24 | 25 | - [ ] Demos 26 | 27 | - [x] Realtime Webcam Demo 28 | 29 | - [x] Image File Demo 30 | 31 | - [ ] Video File Demo 32 | 33 | - [ ] ROS Support. -------------------------------------------------------------------------------- /pose_datamaster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import logging 5 | from tensorpack.dataflow.remote import RemoteDataZMQ 6 | 7 | from pose_dataset import CocoPoseLMDB 8 | 9 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 10 | 11 | if __name__ == '__main__': 12 | """ 13 | Speed Test for Getting Input batches from other nodes 14 | """ 15 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 16 | parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027') 17 | parser.add_argument('--show', type=bool, default=False) 18 | args = parser.parse_args() 19 | 20 | df = RemoteDataZMQ(args.listen) 21 | 22 | logging.info('tcp queue start') 23 | df.reset_state() 24 | t = time.time() 25 | for i, dp in enumerate(df.get_data()): 26 | if i == 100: 27 | break 28 | logging.info('Input batch %d received.' % i) 29 | if i == 0: 30 | for d in dp: 31 | logging.info('%d dp shape={}'.format(d.shape)) 32 | 33 | if args.show: 34 | CocoPoseLMDB.display_image(dp[0][0], dp[1][0], dp[2][0]) 35 | 36 | logging.info('Speed Test Done for 100 Batches in %f seconds.' % (time.time() - t)) 37 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | ENV http_proxy=http://10.41.249.28:8080 https_proxy=http://10.41.249.28:8080 4 | 5 | RUN apt-get update -yq && apt-get install -yq build-essential cmake git pkg-config && \ 6 | apt-get install -yq libjpeg8-dev libtiff5-dev libjasper-dev libpng12-dev && \ 7 | apt-get install -yq libavcodec-dev libavformat-dev libswscale-dev libv4l-dev && \ 8 | apt-get install -yq libgtk2.0-dev && \ 9 | apt-get install -yq libatlas-base-dev gfortran && \ 10 | apt-get install -yq python3 python3-dev python3-pip python3-setuptools python3-tk git && \ 11 | pip3 install numpy && \ 12 | cd ~ && git clone https://github.com/Itseez/opencv.git && \ 13 | cd opencv && mkdir build && cd build && \ 14 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 15 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 16 | -D INSTALL_PYTHON_EXAMPLES=ON \ 17 | -D BUILD_opencv_python3=yes -D PYTHON_EXECUTABLE=/usr/bin/python3 .. && \ 18 | make -j8 && make install && rm -rf /root/opencv/ && \ 19 | mkdir -p /root/tf-openpose && \ 20 | rm -rf /tmp/*.tar.gz && \ 21 | apt-get clean && rm -rf /tmp/* /var/tmp* /var/lib/apt/lists/* && \ 22 | rm -f /etc/ssh/ssh_host_* && rm -rf /usr/share/man?? /usr/share/man/??_* 23 | 24 | COPY . /root/tf-openpose/ 25 | WORKDIR /root/tf-openpose/ 26 | 27 | RUN cd /root/tf-openpose/ && pip3 install -U setuptools && \ 28 | pip3 install tensorflow && pip3 install -r requirements.txt 29 | 30 | ENTRYPOINT ["python3", "pose_dataworker.py"] 31 | 32 | ENV http_proxy= https_proxy= 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # models 104 | *ckpt* 105 | *.npy 106 | timeline*.json -------------------------------------------------------------------------------- /pose_dataworker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from shutil import copyfile 4 | 5 | import logging 6 | from tensorpack.dataflow.remote import send_dataflow_zmq 7 | 8 | from pose_augment import set_network_input_wh 9 | from pose_dataset import get_dataflow_batch 10 | 11 | 12 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 13 | 14 | if __name__ == '__main__': 15 | """ 16 | OpenPose Data Preparation might be a bottleneck for training. 17 | You can run multiple workers to generate input batches in multi-nodes to make training process faster. 18 | """ 19 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 20 | parser.add_argument('--datapath', type=str, default='/data/public/rw/coco-pose-estimation-lmdb/') 21 | parser.add_argument('--batchsize', type=int, default=64) 22 | parser.add_argument('--train', type=bool, default=True) 23 | parser.add_argument('--copydb', type=bool, default=False) 24 | parser.add_argument('--master', type=str, default='tcp://csi-cluster-gpu20.dakao.io:1027') 25 | parser.add_argument('--input-width', type=int, default=368) 26 | parser.add_argument('--input-height', type=int, default=368) 27 | args = parser.parse_args() 28 | 29 | set_network_input_wh(args.input_width, args.input_height) 30 | 31 | if args.copydb: 32 | logging.info('db copy to local+') 33 | try: 34 | os.stat('/tmp/openposedb/') 35 | except: 36 | os.mkdir('/tmp/openposedb/') 37 | copyfile(args.datapath + 'data.mdb', '/tmp/openposedb/data.mdb') 38 | copyfile(args.datapath + 'lock.mdb', '/tmp/openposedb/lock.mdb') 39 | logging.info('db copy to local-') 40 | 41 | df = get_dataflow_batch('/tmp/openposedb/', args.train, args.batchsize) 42 | else: 43 | df = get_dataflow_batch(args.datapath, args.train, args.batchsize) 44 | 45 | send_dataflow_zmq(df, args.master, hwm=10) 46 | -------------------------------------------------------------------------------- /convert/inference_by_keras.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import matplotlib as mpl 3 | mpl.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from keras import backend as K 6 | 7 | from PIL import Image 8 | import numpy as np 9 | import sys 10 | import os 11 | import cv2 12 | sys.path.append('../') 13 | 14 | from common import estimate_pose, CocoPairsRender, read_imgfile, CocoColors, draw_humans 15 | from pose_dataset import CocoPoseLMDB 16 | from keras.models import Model 17 | from keras.models import load_model 18 | 19 | 20 | test_img_path = "../images/pose.jpg" 21 | input_height = 368 22 | input_width = 368 23 | 24 | im = read_imgfile(test_img_path, 368, 368) 25 | s = im.shape 26 | _im = im.reshape(1, s[0], s[1], s[2]) 27 | 28 | # if os.path.exists("output/predict.hd5"): 29 | # from keras.applications.mobilenet import DepthwiseConv2D 30 | # from keras.utils.generic_utils import CustomObjectScope 31 | # with CustomObjectScope({'DepthwiseConv2D': DepthwiseConv2D}): 32 | # net = load_model('output/predict.hd5') 33 | # else: 34 | import tensorflow as tf 35 | from tensorToKeras import get_model 36 | config = tf.ConfigProto() 37 | with tf.Session(config=config) as sess: 38 | net = get_model(sess, input_height, input_width) 39 | out = net.predict(_im) 40 | 41 | 42 | heatMat = out[:, :, :, :19] 43 | pafMat = out[:, :, :, 19:] 44 | 45 | heatMat, pafMat = heatMat[0], pafMat[0] 46 | 47 | #--------------- 48 | # Draw Image 49 | #--------------- 50 | 51 | humans = estimate_pose(heatMat, pafMat) 52 | 53 | # im = im[:, :, ::-1] 54 | process_img = CocoPoseLMDB.display_image(im, heatMat, pafMat, as_numpy=True) 55 | 56 | # display 57 | image = cv2.imread(test_img_path) 58 | image_h, image_w = image.shape[:2] 59 | image = draw_humans(image, humans) 60 | 61 | scale = 480.0 / image_h 62 | newh, neww = 480, int(scale * image_w + 0.5) 63 | 64 | image = cv2.resize(image, (neww, newh), interpolation=cv2.INTER_AREA) 65 | 66 | 67 | convas = np.zeros([480, 640 + neww, 3], dtype=np.uint8) 68 | convas[:, :640] = process_img 69 | convas[:, 640:] = image 70 | 71 | pilImg = Image.fromarray(np.uint8(convas)) 72 | pilImg.save("result.png") 73 | 74 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from network_cmu import CmuNetwork 6 | from network_mobilenet import MobilenetNetwork 7 | 8 | 9 | def _get_base_path(): 10 | if not os.environ.get('OPENPOSE_MODEL', ''): 11 | return './models' 12 | return os.environ.get('OPENPOSE_MODEL') 13 | 14 | 15 | def get_network(type, placeholder_input, sess_for_load=None, trainable=False): 16 | if type == 'mobilenet': 17 | net = MobilenetNetwork({'image': placeholder_input}, trainable=trainable, conv_width=0.75, conv_width2=0.50) 18 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 19 | last_layer = 'MConv_Stage6_L{aux}_5' 20 | elif type == 'mobilenet_accurate': 21 | net = MobilenetNetwork({'image': placeholder_input}, trainable=trainable, conv_width=1.00) 22 | pretrain_path = 'pretrained/mobilenet_v1_1.0_224_2017_06_14/mobilenet_v1_1.0_224.ckpt' 23 | last_layer = 'MConv_Stage6_L{aux}_5' 24 | elif type == 'mobilenet_fast': 25 | net = MobilenetNetwork({'image': placeholder_input}, trainable=trainable, conv_width=0.50) 26 | pretrain_path = 'pretrained/mobilenet_v1_0.50_224_2017_06_14/mobilenet_v1_0.50_224.ckpt' 27 | last_layer = 'MConv_Stage6_L{aux}_5' 28 | elif type == 'cmu': 29 | net = CmuNetwork({'image': placeholder_input}) 30 | pretrain_path = 'numpy/openpose_coco.npy' 31 | last_layer = 'Mconv7_stage6_L{aux}' 32 | else: 33 | raise Exception('Invalid Mode.') 34 | 35 | if sess_for_load is not None: 36 | if type == 'cmu': 37 | net.load('./models/numpy/openpose_coco.npy', sess_for_load) 38 | else: 39 | s = '%dx%d' % (placeholder_input.shape[2], placeholder_input.shape[1]) 40 | ckpts = { 41 | 'mobilenet': 'trained/mobilenet_%s/model-release' % s, 42 | 'mobilenet_fast': 'trained/mobilenet_fast/model-163000', 43 | 'mobilenet_accurate': 'trained/mobilenet_accurate/model-170000' 44 | } 45 | loader = tf.train.Saver() 46 | loader.restore(sess_for_load, os.path.join(_get_base_path(), ckpts[type])) 47 | 48 | return net, os.path.join(_get_base_path(), pretrain_path), last_layer 49 | -------------------------------------------------------------------------------- /common_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import logging 4 | import numpy as np 5 | import cv2 6 | import time 7 | 8 | import common 9 | from pose_dataset import CocoPoseLMDB 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 12 | 13 | 14 | class TestStringMethods(unittest.TestCase): 15 | def _show(self, path, inpmat, heatmat, pafmat, humans): 16 | image = cv2.imread(path) 17 | 18 | # CocoPoseLMDB.display_image(inpmat, heatmat, pafmat) 19 | 20 | image_h, image_w = image.shape[:2] 21 | heat_h, heat_w = heatmat.shape[:2] 22 | for _, human in humans.items(): 23 | for part in human: 24 | if part['partIdx'] not in common.CocoPairsRender: 25 | continue 26 | center1 = (int((part['c1'][0] + 0.5) * image_w / heat_w), int((part['c1'][1] + 0.5) * image_h / heat_h)) 27 | center2 = (int((part['c2'][0] + 0.5) * image_w / heat_w), int((part['c2'][1] + 0.5) * image_h / heat_h)) 28 | cv2.circle(image, center1, 2, (255, 0, 0), thickness=3, lineType=8, shift=0) 29 | cv2.circle(image, center2, 2, (255, 0, 0), thickness=3, lineType=8, shift=0) 30 | cv2.putText(image, str(part['partIdx'][1]), center2, cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 0, 0), 1) 31 | image = cv2.line(image, center1, center2, (255, 0, 0), 1) 32 | cv2.imshow('result', image) 33 | cv2.waitKey(0) 34 | 35 | def test_mobilenet(self): 36 | inpmat = np.load('./tests/person3.pickle') 37 | heatmat = np.load('./tests/mobilenet_person3_heatmat.pickle') 38 | pafmat = np.load('./tests/mobilenet_person3_pafmat.pickle') 39 | 40 | t = time.time() 41 | humans = common.estimate_pose(heatmat, pafmat) 42 | elapsed = time.time() - t 43 | logging.info('[test_mobilenet] elapsed=%f' % elapsed) 44 | 45 | self._show('./images/p3.jpg', inpmat, heatmat, pafmat, humans) 46 | 47 | def test_cmu(self): 48 | inpmat = np.load('./tests/person3.pickle') 49 | heatmat = np.load('./tests/cmu_person3_heatmat.pickle') 50 | pafmat = np.load('./tests/cmu_person3_pafmat.pickle') 51 | 52 | t = time.time() 53 | humans = common.estimate_pose(heatmat, pafmat) 54 | elapsed = time.time() - t 55 | logging.info('[test_mobilenet] elapsed=%f' % elapsed) 56 | 57 | # self._show('./images/p3.jpg', inpmat, heatmat, pafmat, humans) 58 | 59 | if __name__ == '__main__': 60 | unittest.main() 61 | -------------------------------------------------------------------------------- /pose_stats.py: -------------------------------------------------------------------------------- 1 | from tensorpack import imgaug 2 | from tensorpack.dataflow.common import MapDataComponent, MapData 3 | from tensorpack.dataflow.image import AugmentImageComponent 4 | 5 | from common import CocoPart 6 | from pose_augment import * 7 | from pose_dataset import CocoPoseLMDB 8 | 9 | 10 | def get_idx_hands_up(): 11 | from pose_augment import set_network_input_wh 12 | set_network_input_wh(368, 368) 13 | 14 | show_sample = True 15 | db = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train=True, decode_img=show_sample) 16 | db.reset_state() 17 | total_cnt = 0 18 | handup_cnt = 0 19 | for idx, metas in enumerate(db.get_data()): 20 | meta = metas[0] 21 | if len(meta.joint_list) <= 0: 22 | continue 23 | body = meta.joint_list[0] 24 | if body[CocoPart.Neck.value][1] <= 0: 25 | continue 26 | if body[CocoPart.LWrist.value][1] <= 0: 27 | continue 28 | if body[CocoPart.RWrist.value][1] <= 0: 29 | continue 30 | 31 | if body[CocoPart.Neck.value][1] > body[CocoPart.LWrist.value][1] or body[CocoPart.Neck.value][1] > body[CocoPart.RWrist.value][1]: 32 | print(meta.idx) 33 | handup_cnt += 1 34 | 35 | if show_sample: 36 | l1, l2, l3 = pose_to_img(metas) 37 | CocoPoseLMDB.display_image(l1, l2, l3) 38 | 39 | total_cnt += 1 40 | 41 | print('%d / %d' % (handup_cnt, total_cnt)) 42 | 43 | 44 | def sample_augmentations(): 45 | ds = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train=False, only_idx=0) 46 | ds = MapDataComponent(ds, pose_random_scale) 47 | ds = MapDataComponent(ds, pose_rotation) 48 | ds = MapDataComponent(ds, pose_flip) 49 | ds = MapDataComponent(ds, pose_resize_shortestedge_random) 50 | ds = MapDataComponent(ds, pose_crop_random) 51 | ds = MapData(ds, pose_to_img) 52 | augs = [ 53 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([ 54 | imgaug.GaussianBlur(3), 55 | imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 56 | imgaug.RandomOrderAug([ 57 | imgaug.BrightnessScale((0.8, 1.2), clip=False), 58 | imgaug.Contrast((0.8, 1.2), clip=False), 59 | # imgaug.Saturation(0.4, rgb=True), 60 | ]), 61 | ]), 0.7), 62 | ] 63 | ds = AugmentImageComponent(ds, augs) 64 | 65 | ds.reset_state() 66 | for l1, l2, l3 in ds.get_data(): 67 | CocoPoseLMDB.display_image(l1, l2, l3) 68 | 69 | 70 | if __name__ == '__main__': 71 | # codes for tests 72 | # get_idx_hands_up() 73 | 74 | # show augmentation samples 75 | sample_augmentations() 76 | -------------------------------------------------------------------------------- /convert/README.md: -------------------------------------------------------------------------------- 1 | # Keras Mobilenet-Model 2 | 3 | Convert from tf-openpose to Keras. 4 | It worked on Keras, but it did not work on coreml. 5 | The reason is described below. 6 | 7 | # How to use 8 | 9 | * Run 10 | ``` 11 | $ python tensorToKeras.py 12 | ``` 13 | 14 | ### DownLoad Model 15 | - save to models folder 16 | - [model-388003](https://www.dropbox.com/s/09xivpuboecge56/mobilenet_0.75_0.50_model-388003.zip?dl=0) 17 | 18 | ## Dependencies 19 | 20 | ``` 21 | numpy 22 | h5py 23 | scipy 24 | opencv-python=3.3.0.10 25 | coremltools=0.6.3 26 | tensorflow=1.4.0 27 | Keras=2.1.1 28 | Pillow=4.3.0 29 | ``` 30 | 31 | ## Contributer 32 | 33 | - [Infocom TPO](https://lab.infocom.co.jp/) 34 | - [@mganeko](https://github.com/mganeko) 35 | - [@mbotsu](https://github.com/mbotsu) 36 | - [@tnosho](https://github.com/tnosho) 37 | 38 | ## How to convert to CoreML 39 | CoreML doesn't support instance_normalization. 40 | And [tf-openpose-weight](https://www.dropbox.com/s/09xivpuboecge56/mobilenet_0.75_0.50_model-388003.zip?dl=0) that you can download from [original repository](https://github.com/ildoonet/tf-openpose) is trained with instance_normalization. 41 | So, if you would like to convert to CoreML, you have to retrain without instance_normalization.([See this commit](https://github.com/infocom-tpo/tf-openpose/commit/2c6484888f6035054b897ddc35cbcc257f1c1cdf)) 42 | - tf-openpose .. BatchNormalization: instance_normalization supported 43 | - coreml .. instance_normalization not supported 44 | 45 | [Instance normalization removed in 0.4.0](https://forums.developer.apple.com/thread/81520) 46 | 47 | **Retraining or coreml instance_normalization Waiting for support** 48 | 49 | ![batch_normalization](resources/batch_normalization.png) 50 | 51 | 52 | ``` 53 | from keras.models import Model, load_model 54 | import numpy as np 55 | from PIL import Image 56 | from keras.preprocessing.image import load_img, img_to_array 57 | from keras.applications.mobilenet import DepthwiseConv2D 58 | from keras.utils.generic_utils import CustomObjectScope 59 | from keras import backend as K 60 | import coremltools 61 | 62 | img_path = '[filename]' 63 | img = load_img(img_path, target_size=(368, 368)) 64 | kerasImg = img_to_array(img) 65 | kerasImg = np.expand_dims(kerasImg, axis=0) 66 | 67 | with CustomObjectScope({'DepthwiseConv2D': DepthwiseConv2D}): 68 | 69 | model = load_model('./output/predict.hd5') 70 | 71 | prediction = model.predict(kerasImg) 72 | prediction = prediction[0] 73 | 74 | coreml_model = coremltools.converters.keras.convert(model 75 | , input_names = 'image' 76 | , image_input_names='image' 77 | , output_names='net_output' 78 | , is_bgr=True 79 | , image_scale=2./255 80 | , red_bias=-1 81 | , green_bias=-1 82 | , blue_bias=-1 83 | ) 84 | 85 | out = coreml_model.predict({'image': img})['net_output'] 86 | 87 | coreml_model.author = 'Infocom TPO' 88 | coreml_model.license = 'MIT' 89 | coreml_model.save('mobilenet.mlmodel') 90 | ``` -------------------------------------------------------------------------------- /etcs/training.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | 3 | ### Dataset 4 | 5 | You should download the dataset in LMDB format provided by CMU. See : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation/blob/master/training/get_lmdb.sh 6 | 7 | ``` 8 | $ wget -nc --directory-prefix=lmdb_trainVal/ http://posefs1.perception.cs.cmu.edu/Users/ZheCao/lmdb_trainVal/data.mdb 9 | $ wget -nc --directory-prefix=lmdb_trainVal/ http://posefs1.perception.cs.cmu.edu/Users/ZheCao/lmdb_trainVal/lock.mdb 10 | ``` 11 | 12 | ### Augmentation 13 | 14 | CMU Perceptual Computing Lab has modified Caffe to provide data augmentation. See : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 15 | 16 | I implemented the augmentation codes as the way of the original version, See [pose_dataset.py](pose_dataset.py) and [pose_augment.py](pose_augment.py). This includes scaling, rotation, flip, cropping. 17 | 18 | This process can be a bottleneck for training, so if you have enough computing resources, please see [Run for Faster Training]() Section 19 | 20 | ### Run 21 | 22 | ``` 23 | $ python3 train.py --model=cmu --datapath={datapath} --batchsize=64 --lr=0.001 --modelpath={path-to-save} 24 | 25 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 26 | ``` 27 | 28 | ### Run for Faster Training 29 | 30 | If you have enough computing resources in multiple nodes, you can launch multiple workers on nodes to help data preparation. 31 | 32 | ``` 33 | worker-node1$ python3 pose_dataworker.py --master=tcp://host:port 34 | worker-node2$ python3 pose_dataworker.py --master=tcp://host:port 35 | worker-node3$ python3 pose_dataworker.py --master=tcp://host:port 36 | ... 37 | ``` 38 | 39 | After above preparation, you can launch training script with 'remote-data' arguments. 40 | 41 | ``` 42 | $ python3 train.py --remote-data=tcp://0.0.0.0:port 43 | 44 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 45 | ``` 46 | 47 | Also, You can quickly train with multiple gpus. This automatically splits batch into multiple gpus for forward/backward computations. 48 | 49 | ``` 50 | $ python3 train.py --remote-data=tcp://0.0.0.0:port --gpus=8 51 | 52 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 53 | ``` 54 | 55 | I trained models within a day with 8 gpus and multiple pre-processing nodes with 48 core cpus. 56 | 57 | ### Model Optimization for Inference 58 | 59 | After trained a model, I optimized models by folding batch normalization to convolutional layers and removing redundant operations. 60 | 61 | Firstly, the model should be frozen. 62 | 63 | ```bash 64 | $ python3 -m tensorflow.python.tools.freeze_graph \ 65 | --input_graph=... \ 66 | --output_graph=... \ 67 | --input_checkpoint=... \ 68 | --output_node_names="Openpose/concat_stage7" 69 | ``` 70 | 71 | And the optimization can be performed on the frozen model via graph transform provided by tensorflow. 72 | 73 | ```bash 74 | $ bazel build tensorflow/tools/graph_transforms:transform_graph 75 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 76 | --in_graph=... \ 77 | --out_graph=... \ 78 | --inputs='image:0' \ 79 | --outputs='Openpose/concat_stage7:0' \ 80 | --transforms=' 81 | strip_unused_nodes(type=float, shape="1,368,368,3") 82 | remove_nodes(op=Identity, op=CheckNumerics) 83 | fold_constants(ignoreError=False) 84 | fold_old_batch_norms 85 | fold_batch_norms' 86 | ``` -------------------------------------------------------------------------------- /datum_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: datum.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='datum.proto', 20 | package='', 21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _DATUM = _descriptor.Descriptor( 29 | name='Datum', 30 | full_name='Datum', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='channels', full_name='Datum.channels', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='Datum.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='width', full_name='Datum.width', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='data', full_name='Datum.data', index=3, 58 | number=4, type=12, cpp_type=9, label=1, 59 | has_default_value=False, default_value=_b(""), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='label', full_name='Datum.label', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='float_data', full_name='Datum.float_data', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='encoded', full_name='Datum.encoded', index=6, 79 | number=7, type=8, cpp_type=7, label=1, 80 | has_default_value=True, default_value=False, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | options=None, 91 | is_extendable=False, 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=16, 96 | serialized_end=145, 97 | ) 98 | 99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM 100 | 101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict( 102 | DESCRIPTOR = _DATUM, 103 | __module__ = 'datum_pb2' 104 | # @@protoc_insertion_point(class_scope:Datum) 105 | )) 106 | _sym_db.RegisterMessage(Datum) 107 | 108 | 109 | # @@protoc_insertion_point(module_scope) 110 | -------------------------------------------------------------------------------- /realtime_webcam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import time 5 | import logging 6 | 7 | import tensorflow as tf 8 | 9 | from common import CocoPairsRender, CocoColors, preprocess, estimate_pose, draw_humans 10 | from network_cmu import CmuNetwork 11 | from network_mobilenet import MobilenetNetwork 12 | from networks import get_network 13 | from pose_dataset import CocoPoseLMDB 14 | 15 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') 16 | 17 | 18 | fps_time = 0 19 | 20 | 21 | def cb_showimg(img, preprocessed, heatMat, pafMat, humans, show_process=False): 22 | global fps_time 23 | 24 | # display 25 | image = img 26 | image_h, image_w = image.shape[:2] 27 | image = draw_humans(image, humans) 28 | 29 | scale = 480.0 / image_h 30 | newh, neww = 480, int(scale * image_w + 0.5) 31 | 32 | image = cv2.resize(image, (neww, newh), interpolation=cv2.INTER_AREA) 33 | 34 | if show_process: 35 | process_img = CocoPoseLMDB.display_image(preprocessed, heatMat, pafMat, as_numpy=True) 36 | process_img = cv2.resize(process_img, (640, 480), interpolation=cv2.INTER_AREA) 37 | 38 | canvas = np.zeros([480, 640 + neww, 3], dtype=np.uint8) 39 | canvas[:, :640] = process_img 40 | canvas[:, 640:] = image 41 | else: 42 | canvas = image 43 | 44 | cv2.putText(canvas, "FPS: %f" % (1.0 / (time.time() - fps_time)), (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) 45 | cv2.imshow('openpose', canvas) 46 | 47 | fps_time = time.time() 48 | 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser(description='Tensorflow Openpose Realtime Webcam') 52 | parser.add_argument('--input-width', type=int, default=368) 53 | parser.add_argument('--input-height', type=int, default=368) 54 | parser.add_argument('--stage-level', type=int, default=6) 55 | parser.add_argument('--camera', type=int, default=0) 56 | parser.add_argument('--zoom', type=float, default=1.0) 57 | parser.add_argument('--model', type=str, default='mobilenet', help='cmu / mobilenet / mobilenet_accurate / mobilenet_fast') 58 | parser.add_argument('--show-process', type=bool, default=False, help='for debug purpose, if enabled, speed for inference is dropped.') 59 | args = parser.parse_args() 60 | 61 | input_node = tf.placeholder(tf.float32, shape=(1, args.input_height, args.input_width, 3), name='image') 62 | 63 | with tf.Session() as sess: 64 | net, _, last_layer = get_network(args.model, input_node, sess) 65 | 66 | cam = cv2.VideoCapture(args.camera) 67 | ret_val, img = cam.read() 68 | logging.info('cam image=%dx%d' % (img.shape[1], img.shape[0])) 69 | 70 | while True: 71 | logging.debug('cam read+') 72 | ret_val, img = cam.read() 73 | 74 | logging.debug('cam preprocess+') 75 | if args.zoom < 1.0: 76 | canvas = np.zeros_like(img) 77 | img_scaled = cv2.resize(img, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR) 78 | dx = (canvas.shape[1] - img_scaled.shape[1]) // 2 79 | dy = (canvas.shape[0] - img_scaled.shape[0]) // 2 80 | canvas[dy:dy + img_scaled.shape[0], dx:dx + img_scaled.shape[1]] = img_scaled 81 | img = canvas 82 | elif args.zoom > 1.0: 83 | img_scaled = cv2.resize(img, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR) 84 | dx = (img_scaled.shape[1] - img.shape[1]) // 2 85 | dy = (img_scaled.shape[0] - img.shape[0]) // 2 86 | img = img_scaled[dy:img.shape[0], dx:img.shape[1]] 87 | preprocessed = preprocess(img, args.input_width, args.input_height) 88 | 89 | logging.debug('cam process+') 90 | pafMat, heatMat = sess.run( 91 | [ 92 | net.get_output(name=last_layer.format(stage=args.stage_level, aux=1)), 93 | net.get_output(name=last_layer.format(stage=args.stage_level, aux=2)) 94 | ], feed_dict={'image:0': [preprocessed]} 95 | ) 96 | heatMat, pafMat = heatMat[0], pafMat[0] 97 | 98 | logging.debug('cam postprocess+') 99 | t = time.time() 100 | humans = estimate_pose(heatMat, pafMat) 101 | 102 | logging.debug('cam show+') 103 | cb_showimg(img, preprocessed, heatMat, pafMat, humans, show_process=args.show_process) 104 | 105 | if cv2.waitKey(1) == 27: 106 | break # esc to quit 107 | logging.debug('cam finished+') 108 | cv2.destroyAllWindows() 109 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tensorflow as tf 3 | import cv2 4 | import numpy as np 5 | import time 6 | import logging 7 | import argparse 8 | 9 | from tensorflow.python.client import timeline 10 | 11 | from common import estimate_pose, CocoPairsRender, read_imgfile, CocoColors, draw_humans 12 | from networks import get_network 13 | from pose_dataset import CocoPoseLMDB 14 | 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 16 | 17 | config = tf.ConfigProto() 18 | config.gpu_options.allocator_type = 'BFC' 19 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 20 | config.gpu_options.allow_growth = True 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser(description='Tensorflow Openpose Inference') 25 | parser.add_argument('--imgpath', type=str, default='./images/p2.jpg') 26 | parser.add_argument('--input-width', type=int, default=368) 27 | parser.add_argument('--input-height', type=int, default=368) 28 | parser.add_argument('--stage-level', type=int, default=6) 29 | parser.add_argument('--model', type=str, default='mobilenet', help='cmu / mobilenet / mobilenet_accurate / mobilenet_fast') 30 | args = parser.parse_args() 31 | 32 | input_node = tf.placeholder(tf.float32, shape=(1, args.input_height, args.input_width, 3), name='image') 33 | 34 | with tf.Session(config=config) as sess: 35 | net, _, last_layer = get_network(args.model, input_node, sess, trainable=False) 36 | 37 | logging.debug('read image+') 38 | image = read_imgfile(args.imgpath, args.input_width, args.input_height) 39 | vec = sess.run(net.get_output(name='concat_stage7'), feed_dict={'image:0': [image]}) 40 | 41 | a = time.time() 42 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 43 | run_metadata = tf.RunMetadata() 44 | pafMat, heatMat = sess.run( 45 | [ 46 | net.get_output(name=last_layer.format(stage=args.stage_level, aux=1)), 47 | net.get_output(name=last_layer.format(stage=args.stage_level, aux=2)) 48 | ], feed_dict={'image:0': [image]}, options=run_options, run_metadata=run_metadata 49 | ) 50 | logging.info('inference- elapsed_time={}'.format(time.time() - a)) 51 | 52 | tl = timeline.Timeline(run_metadata.step_stats) 53 | ctf = tl.generate_chrome_trace_format() 54 | with open('timeline.json', 'w') as f: 55 | f.write(ctf) 56 | heatMat, pafMat = heatMat[0], pafMat[0] 57 | 58 | logging.debug('inference+') 59 | 60 | avg = 0 61 | for _ in range(10): 62 | a = time.time() 63 | sess.run( 64 | [ 65 | net.get_output(name=last_layer.format(stage=args.stage_level, aux=1)), 66 | net.get_output(name=last_layer.format(stage=args.stage_level, aux=2)) 67 | ], feed_dict={'image:0': [image]} 68 | ) 69 | logging.info('inference- elapsed_time={}'.format(time.time() - a)) 70 | avg += time.time() - a 71 | logging.info('prediction avg= %f' % (avg / 10)) 72 | 73 | ''' 74 | logging.info('pickle data') 75 | with open('person3.pickle', 'wb') as pickle_file: 76 | pickle.dump(image, pickle_file, pickle.HIGHEST_PROTOCOL) 77 | with open('heatmat.pickle', 'wb') as pickle_file: 78 | pickle.dump(heatMat, pickle_file, pickle.HIGHEST_PROTOCOL) 79 | with open('pafmat.pickle', 'wb') as pickle_file: 80 | pickle.dump(pafMat, pickle_file, pickle.HIGHEST_PROTOCOL) 81 | ''' 82 | 83 | logging.info('pose+') 84 | a = time.time() 85 | humans = estimate_pose(heatMat, pafMat) 86 | logging.info('pose- elapsed_time={}'.format(time.time() - a)) 87 | 88 | logging.info('image={} heatMap={} pafMat={}'.format(image.shape, heatMat.shape, pafMat.shape)) 89 | process_img = CocoPoseLMDB.display_image(image, heatMat, pafMat, as_numpy=True) 90 | 91 | # display 92 | image = cv2.imread(args.imgpath) 93 | image_h, image_w = image.shape[:2] 94 | image = draw_humans(image, humans) 95 | 96 | scale = 480.0 / image_h 97 | newh, neww = 480, int(scale * image_w + 0.5) 98 | 99 | image = cv2.resize(image, (neww, newh), interpolation=cv2.INTER_AREA) 100 | 101 | convas = np.zeros([480, 640 + neww, 3], dtype=np.uint8) 102 | convas[:, :640] = process_img 103 | convas[:, 640:] = image 104 | 105 | cv2.imshow('result', convas) 106 | cv2.waitKey(0) 107 | 108 | tf.train.write_graph(sess.graph_def, '.', 'graph-tmp.pb', as_text=True) 109 | -------------------------------------------------------------------------------- /network_mobilenet.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | import tensorflow as tf 3 | 4 | 5 | class MobilenetNetwork(network_base.BaseNetwork): 6 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 7 | self.conv_width = conv_width 8 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 9 | network_base.BaseNetwork.__init__(self, inputs, trainable) 10 | 11 | def setup(self): 12 | min_depth = 8 13 | depth = lambda d: max(int(d * self.conv_width), min_depth) 14 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 15 | 16 | with tf.variable_scope(None, 'MobilenetV1'): 17 | (self.feed('image') 18 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 19 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 20 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 21 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 22 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 23 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 24 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 25 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 29 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 30 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 31 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 32 | ) 33 | 34 | (self.feed('Conv2d_3').max_pool(2, 2, 2, 2, name='Conv2d_3_pool')) 35 | 36 | (self.feed('Conv2d_3_pool', 'Conv2d_7', 'Conv2d_11') 37 | .concat(3, name='feat_concat')) 38 | 39 | feature_lv = 'feat_concat' 40 | with tf.variable_scope(None, 'Openpose'): 41 | prefix = 'MConv_Stage1' 42 | (self.feed(feature_lv) 43 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 44 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 45 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 46 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 47 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 48 | 49 | (self.feed(feature_lv) 50 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 51 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 52 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 53 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 54 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 55 | 56 | for stage_id in range(5): 57 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 58 | prefix = 'MConv_Stage%d' % (stage_id + 2) 59 | (self.feed(prefix_prev + '_L1_5', 60 | prefix_prev + '_L2_5', 61 | feature_lv) 62 | .concat(3, name=prefix + '_concat') 63 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 64 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 65 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 66 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 67 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 68 | 69 | (self.feed(prefix + '_concat') 70 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 71 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 72 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 73 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 74 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 75 | 76 | # final result 77 | (self.feed('MConv_Stage6_L2_5', 78 | 'MConv_Stage6_L1_5') 79 | .concat(3, name='concat_stage7')) 80 | 81 | def loss_l1_l2(self): 82 | l1s = [] 83 | l2s = [] 84 | for layer_name in sorted(self.layers.keys()): 85 | if '_L1_5' in layer_name: 86 | l1s.append(self.layers[layer_name]) 87 | if '_L2_5' in layer_name: 88 | l2s.append(self.layers[layer_name]) 89 | 90 | return l1s, l2s 91 | 92 | def loss_last(self): 93 | return self.get_output('MConv_Stage6_L1_5'), self.get_output('MConv_Stage6_L2_5') 94 | 95 | def restorable_variables(self): 96 | vs = {v.op.name: v for v in tf.global_variables() if 97 | 'MobilenetV1/Conv2d' in v.op.name and 98 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 'Ada' not in v.op.name 99 | } 100 | return vs 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-openpose 2 | 3 | 'Openpose' for human pose estimation have been implemented using Tensorflow. It also provides several variants that have made some changes to the network structure for **real-time processing on the CPU or low-power embedded devices.** 4 | 5 | 6 | **You can even run this on your macbook with descent FPS!** 7 | 8 | Original Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose 9 | 10 | | CMU's Original Model
on Macbook Pro 15" | Mobilenet Variant
on Macbook Pro 15" | Mobilenet Variant
on Jetson TK2 | 11 | |:---------|:--------------------|:----------------| 12 | | ![cmu-model](/etcs/openpose_macbook_cmu.gif) | ![mb-model-macbook](/etcs/openpose_macbook_mobilenet3.gif) | ![mb-model-tx2](/etcs/openpose_tx2_mobilenet3.gif) | 13 | | **~0.6 FPS** | **~4.2 FPS** @ 368x368 | **~10 FPS** @ 368x368 | 14 | | 2.8GHz Quad-core i7 | 2.8GHz Quad-core i7 | Jetson TX2 Embedded Board | 15 | 16 | Implemented features are listed here : [features](./etcs/feature.md) 17 | 18 | ## Install 19 | 20 | ### Dependencies 21 | 22 | You need dependencies below. 23 | 24 | - python3 25 | 26 | - tensorflow 1.3 27 | 28 | - opencv3 29 | 30 | - protobuf 31 | 32 | - python3-tk 33 | 34 | ### Install 35 | 36 | ```bash 37 | $ git clone https://www.github.com/ildoonet/tf-openpose 38 | $ cd tf-openpose 39 | $ pip3 install -r requirements.txt 40 | ``` 41 | 42 | ## Models 43 | 44 | - cmu 45 | - the model based VGG pretrained network which described in the original paper. 46 | - I converted Weights in Caffe format to use in tensorflow. 47 | - [weight download](https://www.dropbox.com/s/xh5s7sb7remu8tx/openpose_coco.npy?dl=0) 48 | 49 | - dsconv 50 | - Same architecture as the cmu version except for
the **depthwise separable convolution** of mobilenet. 51 | - I trained it using 'transfer learning', but it provides not-enough speed and accuracy. 52 | 53 | - mobilenet 54 | - Based on the mobilenet paper, 12 convolutional layers are used as feature-extraction layers. 55 | - To improve on small person, **minor modification** on the architecture have been made. 56 | - Three models were learned according to network size parameters. 57 | - mobilenet 58 | - 368x368 : [weight download](https://www.dropbox.com/s/09xivpuboecge56/mobilenet_0.75_0.50_model-388003.zip?dl=0) 59 | - mobilenet_fast 60 | - mobilenet_accurate 61 | - I published models which is not the best ones, but you can test them before you trained a model from the scratch. 62 | 63 | ### Inference Time 64 | 65 | #### Macbook Pro - 3.1GHz i5 Dual Core 66 | 67 | | Dataset | Model | Inference Time | 68 | |---------|--------------------|----------------:| 69 | | Coco | cmu | 10.0s @ 368x368 | 70 | | Coco | dsconv | 1.10s @ 368x368 | 71 | | Coco | mobilenet_accurate | 0.40s @ 368x368 | 72 | | Coco | mobilenet | 0.24s @ 368x368 | 73 | | Coco | mobilenet_fast | 0.16s @ 368x368 | 74 | 75 | #### Jetson TX2 76 | 77 | On embedded GPU Board from Nvidia, Test results are as below. 78 | 79 | | Dataset | Model | Inference Time | 80 | |---------|--------------------|----------------:| 81 | | Coco | cmu | OOM @ 368x368
5.5s @ 320x240| 82 | | Coco | mobilenet_accurate | 0.18s @ 368x368 | 83 | | Coco | mobilenet | 0.10s @ 368x368 | 84 | | Coco | mobilenet_fast | 0.07s @ 368x368 | 85 | 86 | CMU's original model can not be executed due to 'out of memory' on '368x368' size. 87 | 88 | ## Demo 89 | 90 | ### Test Inference 91 | 92 | You can test the inference feature with a single image. 93 | 94 | ``` 95 | $ python3 inference.py --model=mobilenet --imgpath=... 96 | ``` 97 | 98 | Then you will see the screen as below with pafmap, heatmap, result and etc. 99 | 100 | ![inferent_result](./etcs/inference_result2.png) 101 | 102 | ### Realtime Webcam 103 | 104 | ``` 105 | $ python3 realtime_webcam.py --camera=0 --model=mobilenet --zoom=1.0 106 | ``` 107 | 108 | Then you will see the realtime webcam screen with estimated poses as below. This [Realtime Result](./etcs/openpose_macbook13_mobilenet2.gif) was recored on macbook pro 13" with 3.1Ghz Dual-Core CPU. 109 | 110 | ## Training 111 | 112 | See : [etcs/training.md](./etcs/training.md) 113 | 114 | ## References 115 | 116 | ### OpenPose 117 | 118 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose 119 | 120 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation 121 | 122 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 123 | 124 | [4] Keras Openpose : https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation 125 | 126 | ### Mobilenet 127 | 128 | [1] Original Paper : https://arxiv.org/abs/1704.04861 129 | 130 | [2] Pretrained model : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md 131 | 132 | ### Libraries 133 | 134 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack 135 | 136 | ### Tensorflow Tips 137 | 138 | [1] Freeze graph : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py 139 | 140 | [2] Optimize graph : https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2 141 | -------------------------------------------------------------------------------- /network_cmu.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | 3 | 4 | class CmuNetwork(network_base.BaseNetwork): 5 | def setup(self): 6 | (self.feed('image') 7 | .conv(3, 3, 64, 1, 1, name='conv1_1') 8 | .conv(3, 3, 64, 1, 1, name='conv1_2') 9 | .max_pool(2, 2, 2, 2, name='pool1_stage1') 10 | .conv(3, 3, 128, 1, 1, name='conv2_1') 11 | .conv(3, 3, 128, 1, 1, name='conv2_2') 12 | .max_pool(2, 2, 2, 2, name='pool2_stage1') 13 | .conv(3, 3, 256, 1, 1, name='conv3_1') 14 | .conv(3, 3, 256, 1, 1, name='conv3_2') 15 | .conv(3, 3, 256, 1, 1, name='conv3_3') 16 | .conv(3, 3, 256, 1, 1, name='conv3_4') 17 | .max_pool(2, 2, 2, 2, name='pool3_stage1') 18 | .conv(3, 3, 512, 1, 1, name='conv4_1') 19 | .conv(3, 3, 512, 1, 1, name='conv4_2') 20 | .conv(3, 3, 256, 1, 1, name='conv4_3_CPM') 21 | .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # ***** 22 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1') 23 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1') 24 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1') 25 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 26 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 27 | 28 | (self.feed('conv4_4_CPM') 29 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2') 30 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2') 31 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2') 32 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 33 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 34 | 35 | (self.feed('conv5_5_CPM_L1', 36 | 'conv5_5_CPM_L2', 37 | 'conv4_4_CPM') 38 | .concat(3, name='concat_stage2') 39 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1') 40 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1') 41 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1') 42 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1') 43 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1') 44 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 45 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 46 | 47 | (self.feed('concat_stage2') 48 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2') 49 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2') 50 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2') 51 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2') 52 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2') 53 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 54 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 55 | 56 | (self.feed('Mconv7_stage2_L1', 57 | 'Mconv7_stage2_L2', 58 | 'conv4_4_CPM') 59 | .concat(3, name='concat_stage3') 60 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1') 61 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1') 62 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1') 63 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1') 64 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1') 65 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 66 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 67 | 68 | (self.feed('concat_stage3') 69 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2') 70 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2') 71 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2') 72 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2') 73 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2') 74 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 75 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 76 | 77 | (self.feed('Mconv7_stage3_L1', 78 | 'Mconv7_stage3_L2', 79 | 'conv4_4_CPM') 80 | .concat(3, name='concat_stage4') 81 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1') 82 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1') 83 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1') 84 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1') 85 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1') 86 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 87 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 88 | 89 | (self.feed('concat_stage4') 90 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2') 91 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2') 92 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2') 93 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2') 94 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2') 95 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 96 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 97 | 98 | (self.feed('Mconv7_stage4_L1', 99 | 'Mconv7_stage4_L2', 100 | 'conv4_4_CPM') 101 | .concat(3, name='concat_stage5') 102 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1') 103 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1') 104 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1') 105 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1') 106 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1') 107 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 108 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 109 | 110 | (self.feed('concat_stage5') 111 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2') 112 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2') 113 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2') 114 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2') 115 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2') 116 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 117 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 118 | 119 | (self.feed('Mconv7_stage5_L1', 120 | 'Mconv7_stage5_L2', 121 | 'conv4_4_CPM') 122 | .concat(3, name='concat_stage6') 123 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1') 124 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1') 125 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1') 126 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1') 127 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1') 128 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 129 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 130 | 131 | (self.feed('concat_stage6') 132 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2') 133 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2') 134 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2') 135 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2') 136 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2') 137 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 138 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 139 | 140 | (self.feed('Mconv7_stage6_L2', 141 | 'Mconv7_stage6_L1') 142 | .concat(3, name='concat_stage7')) 143 | 144 | def loss_l1_l2(self): 145 | l1s = [] 146 | l2s = [] 147 | for layer_name in self.layers.keys(): 148 | if 'Mconv7' in layer_name and '_L1' in layer_name: 149 | l1s.append(self.layers[layer_name]) 150 | if 'Mconv7' in layer_name and '_L2' in layer_name: 151 | l2s.append(self.layers[layer_name]) 152 | 153 | return l1s, l2s 154 | 155 | def loss_last(self): 156 | return self.get_output('Mconv7_stage6_L1'), self.get_output('Mconv7_stage6_L2') 157 | 158 | def restorable_variables(self): 159 | return None -------------------------------------------------------------------------------- /network_dsconv.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | import tensorflow as tf 3 | 4 | 5 | class DSConvNetwork(network_base.BaseNetwork): 6 | def __init__(self, inputs, trainable=True, conv_width=1.0): 7 | self.conv_width = conv_width 8 | network_base.BaseNetwork.__init__(self, inputs, trainable) 9 | 10 | def setup(self): 11 | (self.feed('image') 12 | .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False) 13 | # .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=True) # TODO 14 | .separable_conv(3, 3, round(self.conv_width * 64), 2, name='conv1_2') 15 | # .max_pool(2, 2, 2, 2, name='pool1_stage1') 16 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv2_1') 17 | .separable_conv(3, 3, round(self.conv_width * 128), 2, name='conv2_2') 18 | # .max_pool(2, 2, 2, 2, name='pool2_stage1') 19 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_1') 20 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_2') 21 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_3') 22 | .separable_conv(3, 3, round(self.conv_width * 256), 2, name='conv3_4') 23 | # .max_pool(2, 2, 2, 2, name='pool3_stage1') 24 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_1') 25 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_2') 26 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv4_3_CPM') 27 | .separable_conv(3, 3, 128, 1, name='conv4_4_CPM') 28 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L1') 29 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L1') 30 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L1') 31 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 32 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 33 | 34 | (self.feed('conv4_4_CPM') 35 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L2') 36 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L2') 37 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L2') 38 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 39 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 40 | 41 | (self.feed('conv5_5_CPM_L1', 42 | 'conv5_5_CPM_L2', 43 | 'conv4_4_CPM') 44 | .concat(3, name='concat_stage2') 45 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L1') 46 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L1') 47 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L1') 48 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L1') 49 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L1') 50 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 51 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 52 | 53 | (self.feed('concat_stage2') 54 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L2') 55 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L2') 56 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L2') 57 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L2') 58 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L2') 59 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 60 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 61 | 62 | (self.feed('Mconv7_stage2_L1', 63 | 'Mconv7_stage2_L2', 64 | 'conv4_4_CPM') 65 | .concat(3, name='concat_stage3') 66 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L1') 67 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L1') 68 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L1') 69 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L1') 70 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L1') 71 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 72 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 73 | 74 | (self.feed('concat_stage3') 75 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L2') 76 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L2') 77 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L2') 78 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L2') 79 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L2') 80 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 81 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 82 | 83 | (self.feed('Mconv7_stage3_L1', 84 | 'Mconv7_stage3_L2', 85 | 'conv4_4_CPM') 86 | .concat(3, name='concat_stage4') 87 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L1') 88 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L1') 89 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L1') 90 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L1') 91 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L1') 92 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 93 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 94 | 95 | (self.feed('concat_stage4') 96 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L2') 97 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L2') 98 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L2') 99 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L2') 100 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L2') 101 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 102 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 103 | 104 | (self.feed('Mconv7_stage4_L1', 105 | 'Mconv7_stage4_L2', 106 | 'conv4_4_CPM') 107 | .concat(3, name='concat_stage5') 108 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L1') 109 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L1') 110 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L1') 111 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L1') 112 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L1') 113 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 114 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 115 | 116 | (self.feed('concat_stage5') 117 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L2') 118 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L2') 119 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L2') 120 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L2') 121 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L2') 122 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 123 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 124 | 125 | (self.feed('Mconv7_stage5_L1', 126 | 'Mconv7_stage5_L2', 127 | 'conv4_4_CPM') 128 | .concat(3, name='concat_stage6') 129 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L1') 130 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L1') 131 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L1') 132 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L1') 133 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L1') 134 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 135 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 136 | 137 | (self.feed('concat_stage6') 138 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L2') 139 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L2') 140 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L2') 141 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L2') 142 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L2') 143 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 144 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 145 | 146 | (self.feed('Mconv7_stage6_L2', 147 | 'Mconv7_stage6_L1') 148 | .concat(3, name='concat_stage7')) 149 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import Enum 3 | import math 4 | import logging 5 | 6 | import numpy as np 7 | import itertools 8 | import cv2 9 | from scipy.ndimage.filters import maximum_filter 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 12 | 13 | 14 | regularizer_conv = 0.04 15 | regularizer_dsconv = 0.004 16 | batchnorm_fused = True 17 | 18 | 19 | class CocoPart(Enum): 20 | Nose = 0 21 | Neck = 1 22 | RShoulder = 2 23 | RElbow = 3 24 | RWrist = 4 25 | LShoulder = 5 26 | LElbow = 6 27 | LWrist = 7 28 | RHip = 8 29 | RKnee = 9 30 | RAnkle = 10 31 | LHip = 11 32 | LKnee = 12 33 | LAnkle = 13 34 | REye = 14 35 | LEye = 15 36 | REar = 16 37 | LEar = 17 38 | Background = 18 39 | 40 | CocoPairs = [ 41 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11), 42 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17) 43 | ] # = 19 44 | CocoPairsRender = CocoPairs[:-2] 45 | CocoPairsNetwork = [ 46 | (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5), 47 | (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27) 48 | ] # = 19 49 | 50 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 51 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 52 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 53 | 54 | NMS_Threshold = 0.1 55 | InterMinAbove_Threshold = 6 56 | Inter_Threashold = 0.1 57 | Min_Subset_Cnt = 4 58 | Min_Subset_Score = 0.8 59 | Max_Human = 96 60 | 61 | 62 | def connections_to_human(connections, heatMat): 63 | point_dict = defaultdict(lambda: None) 64 | for conn in connections: 65 | point_dict[conn['partIdx'][0]] = (conn['partIdx'][0], (conn['c1'][0] / heatMat.shape[2], conn['c1'][1] / heatMat.shape[1]), heatMat[conn['partIdx'][0], conn['c1'][1], conn['c1'][0]]) 66 | point_dict[conn['partIdx'][1]] = (conn['partIdx'][1], (conn['c2'][0] / heatMat.shape[2], conn['c2'][1] / heatMat.shape[1]), heatMat[conn['partIdx'][1], conn['c2'][1], conn['c2'][0]]) 67 | return point_dict 68 | 69 | 70 | def non_max_suppression(np_input, window_size=3, threshold=NMS_Threshold): 71 | under_threshold_indices = np_input < threshold 72 | np_input[under_threshold_indices] = 0 73 | return np_input*(np_input == maximum_filter(np_input, footprint=np.ones((window_size, window_size)))) 74 | 75 | 76 | def estimate_pose(heatMat, pafMat): 77 | if heatMat.shape[2] == 19: 78 | heatMat = np.rollaxis(heatMat, 2, 0) 79 | if pafMat.shape[2] == 38: 80 | pafMat = np.rollaxis(pafMat, 2, 0) 81 | 82 | # reliability issue. 83 | logging.debug('preprocess') 84 | heatMat = heatMat - heatMat.min(axis=1).min(axis=1).reshape(19, 1, 1) 85 | heatMat = heatMat - heatMat.min(axis=2).reshape(19, heatMat.shape[1], 1) 86 | 87 | _NMS_Threshold = max(np.average(heatMat) * 4.0, NMS_Threshold) 88 | _NMS_Threshold = min(_NMS_Threshold, 0.3) 89 | 90 | logging.debug('nms, th=%f' % _NMS_Threshold) 91 | # heatMat = gaussian_filter(heatMat, sigma=0.5) 92 | coords = [] 93 | for plain in heatMat[:-1]: 94 | nms = non_max_suppression(plain, 5, _NMS_Threshold) 95 | coords.append(np.where(nms >= _NMS_Threshold)) 96 | 97 | logging.debug('estimate_pose1 : estimate pairs') 98 | connection_all = [] 99 | for (idx1, idx2), (paf_x_idx, paf_y_idx) in zip(CocoPairs, CocoPairsNetwork): 100 | connection = estimate_pose_pair(coords, idx1, idx2, pafMat[paf_x_idx], pafMat[paf_y_idx]) 101 | connection_all.extend(connection) 102 | 103 | logging.debug('estimate_pose2, connection=%d' % len(connection_all)) 104 | connection_by_human = dict() 105 | for idx, c in enumerate(connection_all): 106 | connection_by_human['human_%d' % idx] = [c] 107 | 108 | no_merge_cache = defaultdict(list) 109 | while True: 110 | is_merged = False 111 | for k1, k2 in itertools.combinations(connection_by_human.keys(), 2): 112 | if k1 == k2: 113 | continue 114 | if k2 in no_merge_cache[k1]: 115 | continue 116 | for c1, c2 in itertools.product(connection_by_human[k1], connection_by_human[k2]): 117 | if len(set(c1['uPartIdx']) & set(c2['uPartIdx'])) > 0: 118 | is_merged = True 119 | connection_by_human[k1].extend(connection_by_human[k2]) 120 | connection_by_human.pop(k2) 121 | break 122 | if is_merged: 123 | no_merge_cache.pop(k1, None) 124 | break 125 | else: 126 | no_merge_cache[k1].append(k2) 127 | 128 | if not is_merged: 129 | break 130 | 131 | logging.debug('estimate_pose3') 132 | 133 | # reject by subset count 134 | connection_by_human = {k: v for (k, v) in connection_by_human.items() if len(v) >= Min_Subset_Cnt} 135 | 136 | # reject by subset max score 137 | connection_by_human = {k: v for (k, v) in connection_by_human.items() if max([ii['score'] for ii in v]) >= Min_Subset_Score} 138 | 139 | logging.debug('estimate_pose4') 140 | return [connections_to_human(conn, heatMat) for conn in connection_by_human.values()] 141 | 142 | 143 | def estimate_pose_pair(coords, partIdx1, partIdx2, pafMatX, pafMatY): 144 | connection_temp = [] 145 | peak_coord1, peak_coord2 = coords[partIdx1], coords[partIdx2] 146 | 147 | cnt = 0 148 | for idx1, (y1, x1) in enumerate(zip(peak_coord1[0], peak_coord1[1])): 149 | for idx2, (y2, x2) in enumerate(zip(peak_coord2[0], peak_coord2[1])): 150 | score, count = get_score(x1, y1, x2, y2, pafMatX, pafMatY) 151 | cnt += 1 152 | if (partIdx1, partIdx2) in [(2, 3), (3, 4), (5, 6), (6, 7)]: 153 | if count < InterMinAbove_Threshold // 2 or score <= 0.0: 154 | continue 155 | elif count < InterMinAbove_Threshold or score <= 0.0: 156 | continue 157 | connection_temp.append({ 158 | 'score': score, 159 | 'c1': (x1, y1), 160 | 'c2': (x2, y2), 161 | 'idx': (idx1, idx2), 162 | 'partIdx': (partIdx1, partIdx2), 163 | 'uPartIdx': ('{}-{}-{}'.format(x1, y1, partIdx1), '{}-{}-{}'.format(x2, y2, partIdx2)) 164 | }) 165 | 166 | connection = [] 167 | used_idx1, used_idx2 = [], [] 168 | for candidate in sorted(connection_temp, key=lambda x: x['score'], reverse=True): 169 | # check not connected 170 | if candidate['idx'][0] in used_idx1 or candidate['idx'][1] in used_idx2: 171 | continue 172 | connection.append(candidate) 173 | used_idx1.append(candidate['idx'][0]) 174 | used_idx2.append(candidate['idx'][1]) 175 | 176 | return connection 177 | 178 | 179 | def get_score(x1, y1, x2, y2, pafMatX, pafMatY): 180 | __num_inter = 10 181 | __num_inter_f = float(__num_inter) 182 | dx, dy = x2 - x1, y2 - y1 183 | normVec = math.sqrt(dx ** 2 + dy ** 2) 184 | 185 | if normVec < 1e-4: 186 | return 0.0, 0 187 | 188 | vx, vy = dx / normVec, dy / normVec 189 | 190 | xs = np.arange(x1, x2, dx / __num_inter_f) if x1 != x2 else np.full((__num_inter, ), x1) 191 | ys = np.arange(y1, y2, dy / __num_inter_f) if y1 != y2 else np.full((__num_inter, ), y1) 192 | xs = (xs + 0.5).astype(np.int8) 193 | ys = (ys + 0.5).astype(np.int8) 194 | 195 | # without vectorization 196 | pafXs = np.zeros(__num_inter) 197 | pafYs = np.zeros(__num_inter) 198 | for idx, (mx, my) in enumerate(zip(xs, ys)): 199 | pafXs[idx] = pafMatX[my][mx] 200 | pafYs[idx] = pafMatY[my][mx] 201 | 202 | # vectorization slow? 203 | # pafXs = pafMatX[ys, xs] 204 | # pafYs = pafMatY[ys, xs] 205 | 206 | local_scores = pafXs * vx + pafYs * vy 207 | thidxs = local_scores > Inter_Threashold 208 | 209 | return sum(local_scores * thidxs), sum(thidxs) 210 | 211 | 212 | def read_imgfile(path, width, height): 213 | val_image = cv2.imread(path) 214 | return preprocess(val_image, width, height) 215 | 216 | 217 | def preprocess(img, width, height): 218 | val_image = cv2.resize(img, (width, height)) 219 | val_image = val_image.astype(float) 220 | val_image = val_image * (2.0 / 255.0) - 1.0 221 | return val_image 222 | 223 | 224 | def draw_humans(img, human_list): 225 | img_copied = np.copy(img) 226 | image_h, image_w = img_copied.shape[:2] 227 | centers = {} 228 | for human in human_list: 229 | part_idxs = human.keys() 230 | 231 | # draw point 232 | for i in range(CocoPart.Background.value): 233 | if i not in part_idxs: 234 | continue 235 | part_coord = human[i][1] 236 | center = (int(part_coord[0] * image_w + 0.5), int(part_coord[1] * image_h + 0.5)) 237 | centers[i] = center 238 | cv2.circle(img_copied, center, 3, CocoColors[i], thickness=3, lineType=8, shift=0) 239 | 240 | # draw line 241 | for pair_order, pair in enumerate(CocoPairsRender): 242 | if pair[0] not in part_idxs or pair[1] not in part_idxs: 243 | continue 244 | 245 | img_copied = cv2.line(img_copied, centers[pair[0]], centers[pair[1]], CocoColors[pair_order], 3) 246 | 247 | return img_copied 248 | -------------------------------------------------------------------------------- /pose_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from tensorpack.dataflow.imgaug.geometry import RotationAndCropValid 8 | 9 | from common import CocoPart 10 | 11 | 12 | _network_w = 368 13 | _network_h = 368 14 | 15 | 16 | def set_network_input_wh(w, h): 17 | global _network_w, _network_h 18 | _network_w, _network_h = w, h 19 | 20 | 21 | def pose_random_scale(meta): 22 | scalew = random.uniform(0.8, 1.4) 23 | scaleh = random.uniform(0.8, 1.4) 24 | neww = int(meta.width * scalew) 25 | newh = int(meta.height * scaleh) 26 | dst = cv2.resize(meta.img, (neww, newh), interpolation=cv2.INTER_AREA) 27 | 28 | # adjust meta data 29 | adjust_joint_list = [] 30 | for joint in meta.joint_list: 31 | adjust_joint = [] 32 | for point in joint: 33 | if point[0] < -100 or point[1] < -100: 34 | adjust_joint.append((-1000, -1000)) 35 | continue 36 | # if point[0] <= 0 or point[1] <= 0 or int(point[0] * scalew + 0.5) > neww or int( 37 | # point[1] * scaleh + 0.5) > newh: 38 | # adjust_joint.append((-1, -1)) 39 | # continue 40 | adjust_joint.append((int(point[0] * scalew + 0.5), int(point[1] * scaleh + 0.5))) 41 | adjust_joint_list.append(adjust_joint) 42 | 43 | meta.joint_list = adjust_joint_list 44 | meta.width, meta.height = neww, newh 45 | meta.img = dst 46 | return meta 47 | 48 | 49 | def pose_resize_shortestedge_fixed(meta): 50 | ratio_w = _network_w / meta.width 51 | ratio_h = _network_h / meta.height 52 | ratio = max(ratio_w, ratio_h) 53 | return pose_resize_shortestedge(meta, int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5))) 54 | 55 | 56 | def pose_resize_shortestedge_random(meta): 57 | target_size = int(min(_network_w, _network_h) * random.uniform(0.7, 1.5)) 58 | return pose_resize_shortestedge(meta, target_size) 59 | 60 | 61 | def pose_resize_shortestedge(meta, target_size): 62 | global _network_w, _network_h 63 | img = meta.img 64 | 65 | # adjust image 66 | scale = target_size * 1.0 / min(meta.height, meta.width) 67 | if meta.height < meta.width: 68 | newh, neww = target_size, int(scale * meta.width + 0.5) 69 | else: 70 | newh, neww = int(scale * meta.height + 0.5), target_size 71 | 72 | dst = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA) 73 | 74 | pw = ph = 0 75 | if neww < _network_w or newh < _network_h: 76 | pw = max(0, (_network_w - neww) // 2) 77 | ph = max(0, (_network_h - newh) // 2) 78 | mw = (_network_w - neww) % 2 79 | mh = (_network_h - newh) % 2 80 | cr = random.randint(0, 4) 81 | if cr == 0: 82 | color = 0 83 | elif cr == 1: 84 | color = 255 85 | else: 86 | color = 255 // 2 87 | dst = cv2.copyMakeBorder(dst, ph, ph+mh, pw, pw+mw, cv2.BORDER_CONSTANT, value=(color, color, color)) 88 | 89 | # adjust meta data 90 | adjust_joint_list = [] 91 | for joint in meta.joint_list: 92 | adjust_joint = [] 93 | for point in joint: 94 | if point[0] < -100 or point[1] < -100: 95 | adjust_joint.append((-1000, -1000)) 96 | continue 97 | # if point[0] <= 0 or point[1] <= 0 or int(point[0]*scale+0.5) > neww or int(point[1]*scale+0.5) > newh: 98 | # adjust_joint.append((-1, -1)) 99 | # continue 100 | adjust_joint.append((int(point[0]*scale+0.5) + pw, int(point[1]*scale+0.5) + ph)) 101 | adjust_joint_list.append(adjust_joint) 102 | 103 | meta.joint_list = adjust_joint_list 104 | meta.width, meta.height = neww + pw * 2, newh + ph * 2 105 | meta.img = dst 106 | return meta 107 | 108 | 109 | def pose_crop_center(meta): 110 | global _network_w, _network_h 111 | target_size = (_network_w, _network_h) 112 | x = (meta.width - target_size[0]) // 2 if meta.width > target_size[0] else 0 113 | y = (meta.height - target_size[1]) // 2 if meta.height > target_size[1] else 0 114 | 115 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 116 | 117 | 118 | def pose_crop_random(meta): 119 | global _network_w, _network_h 120 | target_size = (_network_w, _network_h) 121 | 122 | for _ in range(50): 123 | x = random.randrange(0, meta.width - target_size[0]) if meta.width > target_size[0] else 0 124 | y = random.randrange(0, meta.height - target_size[1]) if meta.height > target_size[1] else 0 125 | 126 | # check whether any face is inside the box to generate a reasonably-balanced datasets 127 | for joint in meta.joint_list: 128 | if x <= joint[CocoPart.Nose.value][0] < x + target_size[0] and y <= joint[CocoPart.Nose.value][1] < y + target_size[1]: 129 | break 130 | 131 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 132 | 133 | 134 | def pose_crop(meta, x, y, w, h): 135 | # adjust image 136 | target_size = (w, h) 137 | 138 | img = meta.img 139 | resized = img[y:y+target_size[1], x:x+target_size[0], :] 140 | 141 | # adjust meta data 142 | adjust_joint_list = [] 143 | for joint in meta.joint_list: 144 | adjust_joint = [] 145 | for point in joint: 146 | if point[0] < -100 or point[1] < -100: 147 | adjust_joint.append((-1000, -1000)) 148 | continue 149 | # if point[0] <= 0 or point[1] <= 0: 150 | # adjust_joint.append((-1000, -1000)) 151 | # continue 152 | new_x, new_y = point[0] - x, point[1] - y 153 | # if new_x <= 0 or new_y <= 0 or new_x > target_size[0] or new_y > target_size[1]: 154 | # adjust_joint.append((-1, -1)) 155 | # continue 156 | adjust_joint.append((new_x, new_y)) 157 | adjust_joint_list.append(adjust_joint) 158 | 159 | meta.joint_list = adjust_joint_list 160 | meta.width, meta.height = target_size 161 | meta.img = resized 162 | return meta 163 | 164 | 165 | def pose_flip(meta): 166 | r = random.uniform(0, 1.0) 167 | if r > 0.5: 168 | return meta 169 | 170 | img = meta.img 171 | img = cv2.flip(img, 1) 172 | 173 | # flip meta 174 | flip_list = [CocoPart.Nose, CocoPart.Neck, CocoPart.LShoulder, CocoPart.LElbow, CocoPart.LWrist, CocoPart.RShoulder, CocoPart.RElbow, CocoPart.RWrist, 175 | CocoPart.LHip, CocoPart.LKnee, CocoPart.LAnkle, CocoPart.RHip, CocoPart.RKnee, CocoPart.RAnkle, 176 | CocoPart.LEye, CocoPart.REye, CocoPart.LEar, CocoPart.REar, CocoPart.Background] 177 | adjust_joint_list = [] 178 | for joint in meta.joint_list: 179 | adjust_joint = [] 180 | for cocopart in flip_list: 181 | point = joint[cocopart.value] 182 | if point[0] < -100 or point[1] < -100: 183 | adjust_joint.append((-1000, -1000)) 184 | continue 185 | # if point[0] <= 0 or point[1] <= 0: 186 | # adjust_joint.append((-1, -1)) 187 | # continue 188 | adjust_joint.append((meta.width - point[0], point[1])) 189 | adjust_joint_list.append(adjust_joint) 190 | 191 | meta.joint_list = adjust_joint_list 192 | 193 | meta.img = img 194 | return meta 195 | 196 | 197 | def pose_rotation(meta): 198 | deg = random.uniform(-40.0, 40.0) 199 | img = meta.img 200 | 201 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5) 202 | rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1) 203 | ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], flags=cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT) 204 | if img.ndim == 3 and ret.ndim == 2: 205 | ret = ret[:, :, np.newaxis] 206 | neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) 207 | neww = min(neww, ret.shape[1]) 208 | newh = min(newh, ret.shape[0]) 209 | newx = int(center[0] - neww * 0.5) 210 | newy = int(center[1] - newh * 0.5) 211 | # print(ret.shape, deg, newx, newy, neww, newh) 212 | img = ret[newy:newy + newh, newx:newx + neww] 213 | 214 | # adjust meta data 215 | adjust_joint_list = [] 216 | for joint in meta.joint_list: 217 | adjust_joint = [] 218 | for point in joint: 219 | if point[0] < -100 or point[1] < -100: 220 | adjust_joint.append((-1000, -1000)) 221 | continue 222 | # if point[0] <= 0 or point[1] <= 0: 223 | # adjust_joint.append((-1, -1)) 224 | # continue 225 | x, y = _rotate_coord((meta.width, meta.height), (newx, newy), point, deg) 226 | adjust_joint.append((x, y)) 227 | adjust_joint_list.append(adjust_joint) 228 | 229 | meta.joint_list = adjust_joint_list 230 | meta.width, meta.height = neww, newh 231 | meta.img = img 232 | 233 | return meta 234 | 235 | 236 | def _rotate_coord(shape, newxy, point, angle): 237 | angle = -1 * angle / 180.0 * math.pi 238 | 239 | ox, oy = shape 240 | px, py = point 241 | 242 | ox /= 2 243 | oy /= 2 244 | 245 | qx = math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) 246 | qy = math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) 247 | 248 | new_x, new_y = newxy 249 | 250 | qx += ox - new_x 251 | qy += oy - new_y 252 | 253 | return int(qx + 0.5), int(qy + 0.5) 254 | 255 | 256 | def pose_to_img(meta_l): 257 | global _network_w, _network_h 258 | return [(2.0 / 255.0) * meta_l[0].img - 1.0, 259 | meta_l[0].get_heatmap(target_size=(_network_w // 8, _network_h // 8)), 260 | meta_l[0].get_vectormap(target_size=(_network_w // 8, _network_h // 8))] 261 | -------------------------------------------------------------------------------- /convert/tensorToKeras.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import sys,os 3 | sys.path.append('../') 4 | import tensorflow as tf 5 | import numpy as np 6 | import argparse 7 | import h5py 8 | from network_mobilenet import MobilenetNetwork 9 | from keras.preprocessing.image import load_img, img_to_array 10 | from PIL import Image 11 | from keras.applications.mobilenet import DepthwiseConv2D 12 | from keras import backend as K 13 | from keras.models import Model 14 | from keras.layers import Input, Conv2D, MaxPooling2D, concatenate, BatchNormalization, Activation 15 | from keras.regularizers import l2 16 | 17 | config = tf.ConfigProto() 18 | 19 | parser = argparse.ArgumentParser(description='Tensorflow Openpose Inference') 20 | # parser.add_argument('--imgpath', type=str, default='./images/person1.jpg') 21 | parser.add_argument('--input-width', type=int, default=368) 22 | parser.add_argument('--input-height', type=int, default=368) 23 | args = parser.parse_args() 24 | 25 | input_node = tf.placeholder(tf.float32, shape=(1, args.input_height, args.input_width, 3), name='image') 26 | 27 | global_layers = [] 28 | 29 | def get_variables(model_path, height , width): 30 | input_node = tf.placeholder(tf.float32, shape=(1, height, width, 3), name='image') 31 | 32 | net = MobilenetNetwork({'image': input_node}, trainable=False, conv_width=0.75, conv_width2=0.50) 33 | saver = tf.train.Saver(max_to_keep=100) 34 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 35 | with tf.Session(config=config) as sess: 36 | 37 | saver.restore(sess, model_path) 38 | variables = tf.global_variables() 39 | variables = [(v.name, v.eval(session=sess).copy(order='C')) for v in variables] 40 | return variables 41 | 42 | # Load Trained Weights 43 | tf_model_path = './models/model-388003' # includes model-388003.index, model-388003.meta, model-388003.data-00000-of-00001 44 | # tf_model_path = './models/model_final-365221' # includes model-388003.index, model-388003.meta, model-388003.data-00000-of-00001 45 | variables = get_variables(tf_model_path, args.input_height, args.input_width) 46 | 47 | def getTupleLayer(prefix,name): 48 | 49 | if name == "Conv2d_0": 50 | conv2d = (name, prefix + "/" + name + "/weights:0") 51 | waits = [] 52 | waits.append(prefix + "/" + name + "/BatchNorm/beta:0") 53 | waits.append(prefix + "/" + name + "/BatchNorm/moving_mean:0") 54 | waits.append(prefix + "/" + name + "/BatchNorm/moving_variance:0") 55 | wait = (name + "_bn", waits) 56 | layers = [conv2d,wait] 57 | else: 58 | sepConv2d = (name + "_depthwise", prefix + "/" + name + "_depthwise/depthwise_weights:0") 59 | conv2d = (name + "_pointwise", prefix + "/" + name + "_pointwise/weights:0") 60 | 61 | waits = [] 62 | waits.append(prefix + "/" + name + "_pointwise/BatchNorm/beta:0") 63 | waits.append(prefix + "/" + name + "_pointwise/BatchNorm/moving_mean:0") 64 | waits.append(prefix + "/" + name + "_pointwise/BatchNorm/moving_variance:0") 65 | wait = (name + "_bn" , waits) 66 | 67 | layers = [sepConv2d,conv2d,wait] 68 | return layers 69 | 70 | def setLayer(model,layers): 71 | global variables 72 | vnames = [name for name, v in variables] 73 | 74 | for ln in layers: 75 | layer = model.get_layer(name=ln[0]) 76 | layer_weights = layer.get_weights() 77 | print("ln: ", ln[1]) 78 | wn = [] 79 | if isinstance(ln[1],list): 80 | # batch_norm 81 | # gamma 82 | wn.append(layer_weights[0]) 83 | 84 | # 1.beta, 2.moving_mean, 3.variance 85 | for i, name in enumerate(ln[1]): 86 | ix = vnames.index(name) 87 | v = variables[ix][1] 88 | wn.append(v) 89 | else: 90 | ix = vnames.index(ln[1]) 91 | v = variables[ix][1] 92 | wn.append(v) 93 | 94 | if len(layer_weights) > 1: 95 | for n in range(1,len(layer_weights)): 96 | # pointwise 97 | wn.append(layer_weights[n]) 98 | 99 | layer.set_weights(wn) 100 | 101 | return model 102 | 103 | def separable_conv(x, c_o, kernel,stride, name, relu=True): 104 | global global_layers 105 | global_layers.append(name) 106 | 107 | x = DepthwiseConv2D(kernel 108 | , strides=stride 109 | , padding='same' 110 | , use_bias=False 111 | , depthwise_regularizer=l2(0.00004) 112 | , name=name+'_depthwise' 113 | )(x) 114 | 115 | x = Conv2D(c_o,(1,1) 116 | ,strides=1 117 | ,use_bias=False 118 | ,padding='same' 119 | ,kernel_regularizer=l2(0.004) 120 | ,name=name+"_pointwise" 121 | )(x) 122 | 123 | x = BatchNormalization(scale=True, name=name+'_bn')(x,training=False) 124 | if relu: 125 | x = Activation('relu', name=name+'_relu')(x) 126 | 127 | return x 128 | 129 | def get_model(sess, height, width): 130 | 131 | init = tf.global_variables_initializer() 132 | sess.run(init) 133 | 134 | net = MobilenetNetwork({'image': input_node} 135 | , trainable=False, conv_width=0.75, conv_width2=0.50) 136 | 137 | K.set_session(sess) 138 | conv_width=0.75 139 | conv_width2=0.50 140 | min_depth = 8 141 | 142 | depth = lambda d: max(int(d * conv_width), min_depth) 143 | depth2 = lambda d: max(int(d * conv_width2), min_depth) 144 | 145 | image = Input(shape=(width, height, 3),name="image") 146 | 147 | x = Conv2D(depth(32),(3,3) 148 | , strides=2 149 | , use_bias=False 150 | , name="Conv2d_0" 151 | , trainable = False 152 | , padding='same' 153 | , kernel_regularizer=l2(0.04) 154 | )(image) 155 | 156 | x = BatchNormalization(scale=True, name='Conv2d_0_bn')(x,training=False) 157 | x = Activation('relu', name='Conv2d_0_relu')(x) 158 | 159 | x = separable_conv(x,depth(64),(3,3),1,name='Conv2d_1') 160 | x = separable_conv(x,depth(128),(3,3),2,name='Conv2d_2') 161 | o3 = separable_conv(x,depth(128),(3,3),1,name='Conv2d_3') 162 | x = separable_conv(o3,depth(256),(3,3),2,name='Conv2d_4') 163 | x = separable_conv(x,depth(256),(3,3),1,name='Conv2d_5') 164 | x = separable_conv(x,depth(512),(3,3),1,name='Conv2d_6') 165 | o7 = separable_conv(x,depth(512),(3,3),1,name='Conv2d_7') 166 | x = separable_conv(o7,depth(512),(3,3),1,name='Conv2d_8') 167 | x = separable_conv(x,depth(512),(3,3),1,name='Conv2d_9') 168 | x = separable_conv(x,depth(512),(3,3),1,name='Conv2d_10') 169 | o11 = separable_conv(x,depth(512),(3,3),1,name='Conv2d_11') 170 | 171 | o3_pool = MaxPooling2D((2, 2),(2, 2),padding='same')(o3) 172 | feat_concat = concatenate([o3_pool,o7,o11], axis=3) 173 | 174 | prefix = 'MConv_Stage1' 175 | 176 | r1 = separable_conv(feat_concat,depth2(128),(3,3),1,name=prefix + '_L1_1') 177 | r1 = separable_conv(r1,depth2(128),(3,3),1,name=prefix + '_L1_2') 178 | r1 = separable_conv(r1,depth2(128),(3,3),1,name=prefix + '_L1_3') 179 | r1 = separable_conv(r1,depth2(512),(1,1),1,name=prefix + '_L1_4') 180 | r1 = separable_conv(r1,38,(1,1),1,relu=False,name=prefix + '_L1_5') 181 | 182 | # concat = Input(shape=(46, 46, 864)) 183 | r2 = separable_conv(feat_concat,depth2(128),(3,3),1,name=prefix + '_L2_1') 184 | r2 = separable_conv(r2,depth2(128),(3,3),1,name=prefix + '_L2_2') 185 | r2 = separable_conv(r2,depth2(128),(3,3),1,name=prefix + '_L2_3') 186 | r2 = separable_conv(r2,depth2(512),(1,1),1,name=prefix + '_L2_4') 187 | r2 = separable_conv(r2,19,(1,1),1,relu=False,name=prefix + '_L2_5') 188 | 189 | for stage_id in range(5): 190 | prefix = 'MConv_Stage%d' % (stage_id + 2) 191 | cc = concatenate([r1,r2,feat_concat], axis=3) 192 | 193 | r1 = separable_conv(cc,depth2(128),(3,3),1,name=prefix + '_L1_1') 194 | r1 = separable_conv(r1,depth2(128),(3,3),1,name=prefix + '_L1_2') 195 | r1 = separable_conv(r1,depth2(128),(3,3),1,name=prefix + '_L1_3') 196 | r1 = separable_conv(r1,depth2(128),(1,1),1,name=prefix + '_L1_4') 197 | r1 = separable_conv(r1,38,(1,1),1,relu=False,name=prefix + '_L1_5') 198 | 199 | r2 = separable_conv(cc,depth2(128),(3,3),1,name=prefix + '_L2_1') 200 | r2 = separable_conv(r2,depth2(128),(3,3),1,name=prefix + '_L2_2') 201 | r2 = separable_conv(r2,depth2(128),(3,3),1,name=prefix + '_L2_3') 202 | r2 = separable_conv(r2,depth2(128),(1,1),1,name=prefix + '_L2_4') 203 | r2 = separable_conv(r2,19,(1,1),1,relu=False,name=prefix + '_L2_5') 204 | 205 | out = concatenate([r2, r1],axis=3) 206 | print(out) 207 | 208 | model = Model(image, out) 209 | 210 | layers = getTupleLayer("MobilenetV1","Conv2d_0") 211 | model = setLayer(model,layers) 212 | 213 | for (i, layer) in enumerate(global_layers): 214 | # idx = i + 2 215 | n = layer.split("_") 216 | n.pop() 217 | 218 | prefix = "" 219 | if n[0] == "Conv2d": 220 | prefix = "MobilenetV1" 221 | if n[0] == "MConv": 222 | prefix = "Openpose" 223 | 224 | if prefix != "": 225 | 226 | layers = getTupleLayer(prefix,layer) 227 | model = setLayer(model,layers) 228 | 229 | if not os.path.exists("output"): 230 | os.mkdir("output") 231 | model.save('output/predict.hd5') 232 | 233 | # plot_model(model, to_file='model_shape.png', show_shapes=True) 234 | 235 | # img = load_img(args.imgpath, target_size=(args.input_width, args.input_height)) 236 | # img = np.expand_dims(img, axis=0) 237 | # print(img.shape) 238 | # prediction = model.predict(img) 239 | # prediction = prediction[0] 240 | # print("#output") 241 | # print(prediction.shape) 242 | # print(prediction[0:1, 0:1, :]) 243 | # print(np.mean(prediction)) 244 | 245 | # np.save('output/prediction.npy', prediction, allow_pickle=False) 246 | 247 | return model 248 | 249 | def run(): 250 | with tf.Session(config=config) as sess: 251 | net = get_model(sess, args.input_height, args.input_width) 252 | 253 | if __name__ == "__main__": 254 | run() 255 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /network_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | 6 | import common 7 | 8 | DEFAULT_PADDING = 'SAME' 9 | 10 | 11 | def layer(op): 12 | ''' 13 | Decorator for composable network layers. 14 | ''' 15 | 16 | def layer_decorated(self, *args, **kwargs): 17 | # Automatically set a name if not provided. 18 | name = kwargs.setdefault('name', self.get_unique_name(op.__name__)) 19 | # Figure out the layer inputs. 20 | if len(self.terminals) == 0: 21 | raise RuntimeError('No input variables found for layer %s.' % name) 22 | elif len(self.terminals) == 1: 23 | layer_input = self.terminals[0] 24 | else: 25 | layer_input = list(self.terminals) 26 | # Perform the operation and get the output. 27 | layer_output = op(self, layer_input, *args, **kwargs) 28 | # Add to layer LUT. 29 | self.layers[name] = layer_output 30 | # This output is now the input for the next layer. 31 | self.feed(layer_output) 32 | # Return self for chained calls. 33 | return self 34 | 35 | return layer_decorated 36 | 37 | 38 | class BaseNetwork(object): 39 | 40 | def __init__(self, inputs, trainable=True): 41 | # The input nodes for this network 42 | self.inputs = inputs 43 | # The current list of terminal nodes 44 | self.terminals = [] 45 | # Mapping from layer names to layers 46 | self.layers = dict(inputs) 47 | # If true, the resulting variables are set as trainable 48 | self.trainable = trainable 49 | # Switch variable for dropout 50 | self.use_dropout = tf.placeholder_with_default(tf.constant(1.0), 51 | shape=[], 52 | name='use_dropout') 53 | self.setup() 54 | 55 | def setup(self): 56 | '''Construct the network. ''' 57 | raise NotImplementedError('Must be implemented by the subclass.') 58 | 59 | def load(self, data_path, session, ignore_missing=False): 60 | ''' 61 | Load network weights. 62 | data_path: The path to the numpy-serialized network weights 63 | session: The current TensorFlow session 64 | ignore_missing: If true, serialized weights for missing layers are ignored. 65 | ''' 66 | data_dict = np.load(data_path, encoding='bytes').item() 67 | for op_name in data_dict: 68 | if isinstance(data_dict[op_name], np.ndarray): 69 | if 'RMSProp' in op_name: 70 | continue 71 | with tf.variable_scope('', reuse=True): 72 | var = tf.get_variable(op_name.replace(':0', '')) 73 | try: 74 | session.run(var.assign(data_dict[op_name])) 75 | except Exception as e: 76 | print(op_name) 77 | print(e) 78 | sys.exit(-1) 79 | else: 80 | with tf.variable_scope(op_name, reuse=True): 81 | for param_name, data in data_dict[op_name].items(): 82 | try: 83 | var = tf.get_variable(param_name.decode("utf-8")) 84 | session.run(var.assign(data)) 85 | except ValueError as e: 86 | print(e) 87 | if not ignore_missing: 88 | raise 89 | 90 | def feed(self, *args): 91 | '''Set the input(s) for the next operation by replacing the terminal nodes. 92 | The arguments can be either layer names or the actual layers. 93 | ''' 94 | assert len(args) != 0 95 | self.terminals = [] 96 | for fed_layer in args: 97 | try: 98 | is_str = isinstance(fed_layer, basestring) 99 | except NameError: 100 | is_str = isinstance(fed_layer, str) 101 | if is_str: 102 | try: 103 | fed_layer = self.layers[fed_layer] 104 | except KeyError: 105 | raise KeyError('Unknown layer name fed: %s' % fed_layer) 106 | self.terminals.append(fed_layer) 107 | return self 108 | 109 | def get_output(self, name=None): 110 | '''Returns the current network output.''' 111 | if not name: 112 | return self.terminals[-1] 113 | else: 114 | return self.layers[name] 115 | 116 | def get_tensor(self, name): 117 | return self.get_output(name) 118 | 119 | def get_unique_name(self, prefix): 120 | '''Returns an index-suffixed unique name for the given prefix. 121 | This is used for auto-generating layer names based on the type-prefix. 122 | ''' 123 | ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1 124 | return '%s_%d' % (prefix, ident) 125 | 126 | def make_var(self, name, shape, trainable=True): 127 | '''Creates a new TensorFlow variable.''' 128 | return tf.get_variable(name, shape, trainable=self.trainable & trainable, initializer=tf.contrib.layers.xavier_initializer()) 129 | 130 | def validate_padding(self, padding): 131 | '''Verifies that the padding is one of the supported ones.''' 132 | assert padding in ('SAME', 'VALID') 133 | 134 | @layer 135 | def separable_conv(self, input, k_h, k_w, c_o, stride, name, relu=True): 136 | with slim.arg_scope([slim.batch_norm], fused=common.batchnorm_fused, is_training=self.trainable): 137 | output = slim.separable_convolution2d(input, 138 | num_outputs=None, 139 | stride=stride, 140 | trainable=self.trainable, 141 | depth_multiplier=1.0, 142 | kernel_size=[k_h, k_w], 143 | activation_fn=None, 144 | weights_initializer=tf.contrib.layers.xavier_initializer(), 145 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.09), 146 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.00004), 147 | biases_initializer=None, 148 | padding=DEFAULT_PADDING, 149 | scope=name + '_depthwise') 150 | 151 | output = slim.convolution2d(output, 152 | c_o, 153 | stride=1, 154 | kernel_size=[1, 1], 155 | activation_fn=tf.nn.relu if relu else None, 156 | weights_initializer=tf.contrib.layers.xavier_initializer(), 157 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.09), 158 | biases_initializer=slim.init_ops.zeros_initializer(), 159 | normalizer_fn=slim.batch_norm, 160 | trainable=self.trainable, 161 | weights_regularizer=tf.contrib.layers.l2_regularizer(common.regularizer_dsconv), 162 | # weights_regularizer=None, 163 | scope=name + '_pointwise') 164 | 165 | return output 166 | 167 | @layer 168 | def convb(self, input, k_h, k_w, c_o, stride, name): 169 | with slim.arg_scope([slim.batch_norm], fused=common.batchnorm_fused, is_training=self.trainable): 170 | output = slim.convolution2d(input, c_o, kernel_size=[k_h, k_w], 171 | stride=stride, 172 | normalizer_fn=slim.batch_norm, 173 | weights_regularizer=tf.contrib.layers.l2_regularizer(common.regularizer_conv), 174 | scope=name) 175 | return output 176 | 177 | @layer 178 | def conv(self, 179 | input, 180 | k_h, 181 | k_w, 182 | c_o, 183 | s_h, 184 | s_w, 185 | name, 186 | relu=True, 187 | padding=DEFAULT_PADDING, 188 | group=1, 189 | trainable=True, 190 | biased=True): 191 | # Verify that the padding is acceptable 192 | self.validate_padding(padding) 193 | # Get the number of channels in the input 194 | c_i = int(input.get_shape()[-1]) 195 | # Verify that the grouping parameter is valid 196 | assert c_i % group == 0 197 | assert c_o % group == 0 198 | # Convolution for a given input and kernel 199 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) 200 | with tf.variable_scope(name) as scope: 201 | kernel = self.make_var('weights', shape=[k_h, k_w, c_i / group, c_o], trainable=self.trainable & trainable) 202 | if group == 1: 203 | # This is the common-case. Convolve the input without any further complications. 204 | output = convolve(input, kernel) 205 | else: 206 | # Split the input into groups and then convolve each of them independently 207 | input_groups = tf.split(3, group, input) 208 | kernel_groups = tf.split(3, group, kernel) 209 | output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)] 210 | # Concatenate the groups 211 | output = tf.concat(3, output_groups) 212 | # Add the biases 213 | if biased: 214 | biases = self.make_var('biases', [c_o], trainable=self.trainable & trainable) 215 | output = tf.nn.bias_add(output, biases) 216 | 217 | if relu: 218 | # ReLU non-linearity 219 | output = tf.nn.relu(output, name=scope.name) 220 | return output 221 | 222 | @layer 223 | def relu(self, input, name): 224 | return tf.nn.relu(input, name=name) 225 | 226 | @layer 227 | def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 228 | self.validate_padding(padding) 229 | return tf.nn.max_pool(input, 230 | ksize=[1, k_h, k_w, 1], 231 | strides=[1, s_h, s_w, 1], 232 | padding=padding, 233 | name=name) 234 | 235 | @layer 236 | def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 237 | self.validate_padding(padding) 238 | return tf.nn.avg_pool(input, 239 | ksize=[1, k_h, k_w, 1], 240 | strides=[1, s_h, s_w, 1], 241 | padding=padding, 242 | name=name) 243 | 244 | @layer 245 | def lrn(self, input, radius, alpha, beta, name, bias=1.0): 246 | return tf.nn.local_response_normalization(input, 247 | depth_radius=radius, 248 | alpha=alpha, 249 | beta=beta, 250 | bias=bias, 251 | name=name) 252 | 253 | @layer 254 | def concat(self, inputs, axis, name): 255 | return tf.concat(axis=axis, values=inputs, name=name) 256 | 257 | @layer 258 | def add(self, inputs, name): 259 | return tf.add_n(inputs, name=name) 260 | 261 | @layer 262 | def fc(self, input, num_out, name, relu=True): 263 | with tf.variable_scope(name) as scope: 264 | input_shape = input.get_shape() 265 | if input_shape.ndims == 4: 266 | # The input is spatial. Vectorize it first. 267 | dim = 1 268 | for d in input_shape[1:].as_list(): 269 | dim *= d 270 | feed_in = tf.reshape(input, [-1, dim]) 271 | else: 272 | feed_in, dim = (input, input_shape[-1].value) 273 | weights = self.make_var('weights', shape=[dim, num_out]) 274 | biases = self.make_var('biases', [num_out]) 275 | op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b 276 | fc = op(feed_in, weights, biases, name=scope.name) 277 | return fc 278 | 279 | @layer 280 | def softmax(self, input, name): 281 | input_shape = map(lambda v: v.value, input.get_shape()) 282 | if len(input_shape) > 2: 283 | # For certain models (like NiN), the singleton spatial dimensions 284 | # need to be explicitly squeezed, since they're not broadcast-able 285 | # in TensorFlow's NHWC ordering (unlike Caffe's NCHW). 286 | if input_shape[1] == 1 and input_shape[2] == 1: 287 | input = tf.squeeze(input, squeeze_dims=[1, 2]) 288 | else: 289 | raise ValueError('Rank 2 tensor input expected for softmax!') 290 | return tf.nn.softmax(input, name=name) 291 | 292 | @layer 293 | def batch_normalization(self, input, name, scale_offset=True, relu=False): 294 | # NOTE: Currently, only inference is supported 295 | with tf.variable_scope(name) as scope: 296 | shape = [input.get_shape()[-1]] 297 | if scale_offset: 298 | scale = self.make_var('scale', shape=shape) 299 | offset = self.make_var('offset', shape=shape) 300 | else: 301 | scale, offset = (None, None) 302 | output = tf.nn.batch_normalization( 303 | input, 304 | mean=self.make_var('mean', shape=shape), 305 | variance=self.make_var('variance', shape=shape), 306 | offset=offset, 307 | scale=scale, 308 | # TODO: This is the default Caffe batch norm eps 309 | # Get the actual eps from parameters 310 | variance_epsilon=1e-5, 311 | name=name) 312 | if relu: 313 | output = tf.nn.relu(output) 314 | return output 315 | 316 | @layer 317 | def dropout(self, input, keep_prob, name): 318 | keep = 1 - self.use_dropout + (self.use_dropout * keep_prob) 319 | return tf.nn.dropout(input, keep, name=name) 320 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | import datetime 6 | 7 | import cv2 8 | import numpy as np 9 | import tensorflow as tf 10 | from tensorflow.python.client import timeline 11 | 12 | from common import read_imgfile 13 | from network_cmu import CmuNetwork 14 | from network_mobilenet import MobilenetNetwork 15 | from networks import get_network 16 | from pose_augment import set_network_input_wh 17 | from pose_dataset import get_dataflow_batch, DataFlowToQueue, CocoPoseLMDB 18 | from tensorpack.dataflow.remote import send_dataflow_zmq, RemoteDataZMQ 19 | 20 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser(description='Training codes for Openpose using Tensorflow') 25 | parser.add_argument('--model', default='mobilenet', help='model name') 26 | parser.add_argument('--datapath', type=str, default='/data/public/rw/coco-pose-estimation-lmdb/') 27 | parser.add_argument('--batchsize', type=int, default=10) 28 | parser.add_argument('--gpus', type=int, default=1) 29 | parser.add_argument('--max-epoch', type=int, default=60) 30 | parser.add_argument('--lr', type=str, default='0.0001') 31 | parser.add_argument('--modelpath', type=str, default='/data/private/tf-openpose-mobilenet_1.0/') 32 | parser.add_argument('--logpath', type=str, default='/data/private/tf-openpose-log/') 33 | parser.add_argument('--checkpoint', type=str, default='') 34 | parser.add_argument('--tag', type=str, default='') 35 | parser.add_argument('--remote-data', type=str, default='', help='eg. tcp://0.0.0.0:1027') 36 | 37 | parser.add_argument('--input-width', type=int, default=368) 38 | parser.add_argument('--input-height', type=int, default=368) 39 | args = parser.parse_args() 40 | 41 | if args.gpus <= 0: 42 | raise Exception('gpus <= 0') 43 | 44 | # define input placeholder 45 | set_network_input_wh(args.input_width, args.input_height) 46 | output_w = args.input_width // 8 47 | output_h = args.input_height // 8 48 | 49 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=0)): 50 | input_node = tf.placeholder(tf.float32, shape=(args.batchsize, args.input_height, args.input_width, 3), name='image') 51 | vectmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 38), name='vectmap') 52 | heatmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 19), name='heatmap') 53 | 54 | # prepare data 55 | if not args.remote_data: 56 | df = get_dataflow_batch(args.datapath, True, args.batchsize) 57 | else: 58 | df = RemoteDataZMQ(args.remote_data, hwm=5) 59 | enqueuer = DataFlowToQueue(df, [input_node, heatmap_node, vectmap_node], queue_size=100) 60 | q_inp, q_heat, q_vect = enqueuer.dequeue() 61 | 62 | df_valid = get_dataflow_batch(args.datapath, False, args.batchsize) 63 | df_valid.reset_state() 64 | validation_cache = [] 65 | for images_test, heatmaps, vectmaps in df_valid.get_data(): 66 | validation_cache.append((images_test, heatmaps, vectmaps)) 67 | 68 | val_image = read_imgfile('./images/p1.jpg', args.input_width, args.input_height) 69 | val_image2 = read_imgfile('./images/p2.jpg', args.input_width, args.input_height) 70 | val_image3 = read_imgfile('./images/p3.jpg', args.input_width, args.input_height) 71 | 72 | # define model for multi-gpu 73 | q_inp_split = tf.split(q_inp, args.gpus) 74 | output_vectmap = [] 75 | output_heatmap = [] 76 | vectmap_losses = [] 77 | heatmap_losses = [] 78 | 79 | for gpu_id in range(args.gpus): 80 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): 81 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): 82 | net, pretrain_path, last_layer = get_network(args.model, q_inp_split[gpu_id]) 83 | vect, heat = net.loss_last() 84 | output_vectmap.append(vect) 85 | output_heatmap.append(heat) 86 | 87 | l1s, l2s = net.loss_l1_l2() 88 | 89 | for idx, (l1, l2) in enumerate(zip(l1s, l2s)): 90 | if gpu_id == 0: 91 | vectmap_losses.append([]) 92 | heatmap_losses.append([]) 93 | vectmap_losses[idx].append(l1) 94 | heatmap_losses[idx].append(l2) 95 | 96 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): 97 | # define loss 98 | losses = [] 99 | for l1_idx, l1 in enumerate(vectmap_losses): 100 | l1_concat = tf.concat(l1, axis=0) 101 | loss = tf.nn.l2_loss(l1_concat - q_vect, name='loss_l1_stage%d' % l1_idx) 102 | losses.append(loss) 103 | for l2_idx, l2 in enumerate(heatmap_losses): 104 | l2_concat = tf.concat(l2, axis=0) 105 | loss = tf.nn.l2_loss(l2_concat - q_heat, name='loss_l2_stage%d' % l2_idx) 106 | losses.append(loss) 107 | 108 | output_vectmap = tf.concat(output_vectmap, axis=0) 109 | output_heatmap = tf.concat(output_heatmap, axis=0) 110 | total_loss = tf.reduce_mean(losses) 111 | total_loss_ll_paf = tf.reduce_mean(tf.nn.l2_loss(output_vectmap - q_vect)) 112 | total_loss_ll_heat = tf.reduce_mean(tf.nn.l2_loss(output_heatmap - q_heat)) 113 | total_ll_loss = tf.reduce_mean([total_loss_ll_paf, total_loss_ll_heat]) 114 | 115 | # define optimizer 116 | step_per_epoch = 121745 // args.batchsize 117 | global_step = tf.Variable(0, trainable=False) 118 | if ',' not in args.lr: 119 | starter_learning_rate = float(args.lr) 120 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 121 | decay_steps=50000, decay_rate=0.8, staircase=True) 122 | else: 123 | lrs = [float(x) for x in args.lr.split(',')] 124 | boundaries = [step_per_epoch * 5 * i for i, _ in range(len(lrs)) if i > 0] 125 | learning_rate = tf.train.piecewise_constant(global_step, boundaries, lrs) 126 | 127 | optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=0.0005, momentum=0.9, epsilon=1e-10) 128 | # optimizer = tf.train.AdadeltaOptimizer(learning_rate) 129 | # train_op = optimizer.minimize(total_loss, global_step, colocate_gradients_with_ops=True) 130 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 131 | with tf.control_dependencies(update_ops): 132 | train_op = optimizer.minimize(total_loss, global_step, colocate_gradients_with_ops=True) 133 | 134 | # define summary 135 | tf.summary.scalar("loss", total_loss) 136 | tf.summary.scalar("loss_lastlayer", total_ll_loss) 137 | tf.summary.scalar("loss_lastlayer_paf", tf.nn.l2_loss(output_vectmap - q_vect)) 138 | tf.summary.scalar("loss_lastlayer_heat", tf.nn.l2_loss(output_heatmap - q_heat)) 139 | tf.summary.scalar("queue_size", enqueuer.size()) 140 | merged_summary_op = tf.summary.merge_all() 141 | 142 | valid_loss = tf.placeholder(tf.float32, shape=[]) 143 | valid_loss_ll = tf.placeholder(tf.float32, shape=[]) 144 | sample_train = tf.placeholder(tf.float32, shape=(1, 640, 640, 3)) 145 | sample_valid = tf.placeholder(tf.float32, shape=(1, 640, 640, 3)) 146 | sample_valid2 = tf.placeholder(tf.float32, shape=(1, 640, 640, 3)) 147 | sample_valid3 = tf.placeholder(tf.float32, shape=(1, 640, 640, 3)) 148 | train_img = tf.summary.image('training sample', sample_train, 1) 149 | valid_img = tf.summary.image('validation sample', sample_valid, 1) 150 | valid_img2 = tf.summary.image('validation sample2', sample_valid2, 1) 151 | valid_img3 = tf.summary.image('validation sample3', sample_valid3, 1) 152 | valid_loss_t = tf.summary.scalar("loss_valid", valid_loss) 153 | valid_loss_ll_t = tf.summary.scalar("loss_valid_lastlayer", valid_loss_ll) 154 | merged_validate_op = tf.summary.merge([train_img, valid_img, valid_img2, valid_img3, valid_loss_t, valid_loss_ll_t]) 155 | 156 | saver = tf.train.Saver(max_to_keep=100) 157 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 158 | with tf.Session(config=config) as sess: 159 | sess.run(tf.global_variables_initializer()) 160 | if args.checkpoint: 161 | logging.info('Restore from checkpoint...') 162 | # loader = tf.train.Saver(net.restorable_variables()) 163 | # loader.restore(sess, tf.train.latest_checkpoint(args.checkpoint)) 164 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint)) 165 | logging.info('Restore from checkpoint...Done') 166 | elif pretrain_path: 167 | logging.info('Restore pretrained weights...') 168 | if '.ckpt' in pretrain_path: 169 | loader = tf.train.Saver(net.restorable_variables()) 170 | loader.restore(sess, pretrain_path) 171 | elif '.npy' in pretrain_path: 172 | net.load(pretrain_path, sess, False) 173 | logging.info('Restore pretrained weights...Done') 174 | 175 | logging.info('prepare file writer') 176 | training_name = '{}_batch:{}_lr:{}_gpus:{}_{}x{}_{}'.format( 177 | args.model, 178 | args.batchsize, 179 | args.lr, 180 | args.gpus, 181 | args.input_width, args.input_height, 182 | args.tag 183 | ) 184 | file_writer = tf.summary.FileWriter(args.logpath + training_name, sess.graph) 185 | 186 | logging.info('prepare coordinator') 187 | coord = tf.train.Coordinator() 188 | enqueuer.set_coordinator(coord) 189 | enqueuer.start() 190 | 191 | logging.info('examine timeline') 192 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 193 | run_metadata = tf.RunMetadata() 194 | sess.run([train_op, global_step]) 195 | _, gs_num = sess.run([train_op, global_step], options=run_options, run_metadata=run_metadata) 196 | tl = timeline.Timeline(run_metadata.step_stats) 197 | ctf = tl.generate_chrome_trace_format() 198 | with open('timeline.json', 'w') as f: 199 | f.write(ctf) 200 | 201 | tf.train.write_graph(sess.graph_def, args.modelpath, 'graph.pb'.format(gs_num)) 202 | 203 | logging.info('Training Started.') 204 | time_started = time.time() 205 | last_gs_num = last_gs_num2 = 0 206 | initial_gs_num = sess.run(global_step) 207 | 208 | while True: 209 | _, gs_num = sess.run([train_op, global_step]) 210 | 211 | if gs_num > step_per_epoch * args.max_epoch: 212 | break 213 | 214 | if gs_num - last_gs_num >= 100: 215 | train_loss, train_loss_ll, train_loss_ll_paf, train_loss_ll_heat, lr_val, summary, queue_size = sess.run([total_loss, total_ll_loss, total_loss_ll_paf, total_loss_ll_heat, learning_rate, merged_summary_op, enqueuer.size()]) 216 | 217 | # log of training loss / accuracy 218 | batch_per_sec = (gs_num - initial_gs_num) / (time.time() - time_started) 219 | logging.info('epoch=%.2f step=%d, %0.4f examples/sec lr=%f, loss=%g, loss_ll=%g, loss_ll_paf=%g, loss_ll_heat=%g, q=%d' % (gs_num / step_per_epoch, gs_num, batch_per_sec * args.batchsize, lr_val, train_loss, train_loss_ll, train_loss_ll_paf, train_loss_ll_heat, queue_size)) 220 | last_gs_num = gs_num 221 | 222 | file_writer.add_summary(summary, gs_num) 223 | 224 | if gs_num - last_gs_num2 >= 1000: 225 | average_loss = average_loss_ll = 0 226 | total_cnt = 0 227 | 228 | # log of test accuracy 229 | for images_test, heatmaps, vectmaps in validation_cache: 230 | lss, lss_ll, vectmap_sample, heatmap_sample = sess.run( 231 | [total_loss, total_ll_loss, output_vectmap, output_heatmap], 232 | feed_dict={q_inp: images_test, q_vect: vectmaps, q_heat: heatmaps} 233 | ) 234 | average_loss += lss * len(images_test) 235 | average_loss_ll += lss_ll * len(images_test) 236 | total_cnt += len(images_test) 237 | 238 | logging.info('validation(%d) loss=%f, loss_ll=%f' % (total_cnt, average_loss / total_cnt, average_loss_ll / total_cnt)) 239 | last_gs_num2 = gs_num 240 | 241 | sample_image = enqueuer.last_dp[0][0] 242 | pafMat, heatMat = sess.run( 243 | [ 244 | net.get_output(name=last_layer.format(aux=1)), 245 | net.get_output(name=last_layer.format(aux=2)) 246 | ], feed_dict={q_inp: np.array([sample_image, val_image, val_image2, val_image3]*(args.batchsize // 4))} 247 | ) 248 | sample_result = CocoPoseLMDB.display_image(sample_image, heatMat[0], pafMat[0], as_numpy=True) 249 | sample_result = cv2.resize(sample_result, (640, 640)) 250 | sample_result = sample_result.reshape([1, 640, 640, 3]).astype(float) 251 | 252 | test_result = CocoPoseLMDB.display_image(val_image, heatMat[1], pafMat[1], as_numpy=True) 253 | test_result = cv2.resize(test_result, (640, 640)) 254 | test_result = test_result.reshape([1, 640, 640, 3]).astype(float) 255 | 256 | test_result2 = CocoPoseLMDB.display_image(val_image2, heatMat[2], pafMat[2], as_numpy=True) 257 | test_result2 = cv2.resize(test_result2, (640, 640)) 258 | test_result2 = test_result2.reshape([1, 640, 640, 3]).astype(float) 259 | 260 | test_result3 = CocoPoseLMDB.display_image(val_image3, heatMat[3], pafMat[3], as_numpy=True) 261 | test_result3 = cv2.resize(test_result3, (640, 640)) 262 | test_result3 = test_result3.reshape([1, 640, 640, 3]).astype(float) 263 | 264 | # save summary 265 | summary = sess.run(merged_validate_op, feed_dict={ 266 | valid_loss: average_loss / total_cnt, 267 | valid_loss_ll: average_loss_ll / total_cnt, 268 | sample_valid: test_result, 269 | sample_valid2: test_result2, 270 | sample_valid3: test_result3, 271 | sample_train: sample_result 272 | }) 273 | file_writer.add_summary(summary, gs_num) 274 | 275 | # save weights 276 | saver.save(sess, os.path.join(args.modelpath, 'model'), global_step=global_step) 277 | 278 | saver.save(sess, os.path.join(args.modelpath, 'model_final'), global_step=global_step) 279 | logging.info('optimization finished. %f' % (time.time() - time_started)) 280 | -------------------------------------------------------------------------------- /pose_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import struct 4 | import threading 5 | import logging 6 | import multiprocessing 7 | 8 | from contextlib import contextmanager 9 | 10 | import lmdb 11 | import cv2 12 | import numpy as np 13 | import time 14 | 15 | import tensorflow as tf 16 | 17 | from tensorpack import imgaug 18 | from tensorpack.dataflow.image import MapDataComponent, AugmentImageComponent 19 | from tensorpack.dataflow.common import BatchData, MapData, TestDataSpeed 20 | from tensorpack.dataflow.prefetch import PrefetchData 21 | from tensorpack.dataflow.base import RNGDataFlow, DataFlowTerminated 22 | 23 | from datum_pb2 import Datum 24 | from pose_augment import pose_flip, pose_rotation, pose_to_img, pose_crop_random, \ 25 | pose_resize_shortestedge_random, pose_resize_shortestedge_fixed, pose_crop_center, pose_random_scale 26 | 27 | import matplotlib as mpl 28 | 29 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 30 | 31 | 32 | class CocoMetadata: 33 | # __coco_parts = 57 34 | __coco_parts = 19 35 | __coco_vecs = list(zip( 36 | [2, 9, 10, 2, 12, 13, 2, 3, 4, 3, 2, 6, 7, 6, 2, 1, 1, 15, 16], 37 | [9, 10, 11, 12, 13, 14, 3, 4, 5, 17, 6, 7, 8, 18, 1, 15, 16, 17, 18] 38 | )) 39 | 40 | @staticmethod 41 | def parse_float(four_np): 42 | assert len(four_np) == 4 43 | return struct.unpack('= 0 or -1000] 79 | joint_y = [val for val in joint_y if val >= 0 or -1000] 80 | joint_list.append(list(zip(joint_x, joint_y))) 81 | 82 | self.joint_list = [] 83 | transform = list(zip( 84 | [1, 6, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4], 85 | [1, 7, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4] 86 | )) 87 | for prev_joint in joint_list: 88 | new_joint = [] 89 | for idx1, idx2 in transform: 90 | j1 = prev_joint[idx1-1] 91 | j2 = prev_joint[idx2-1] 92 | 93 | if j1[0] <= 0 or j1[1] <= 0 or j2[0] <= 0 or j2[1] <= 0: 94 | new_joint.append((-1000, -1000)) 95 | else: 96 | new_joint.append(((j1[0] + j2[0]) / 2, (j1[1] + j2[1]) / 2)) 97 | 98 | new_joint.append((-1000, -1000)) 99 | self.joint_list.append(new_joint) 100 | 101 | logging.debug('joint size=%d' % len(self.joint_list)) 102 | 103 | def get_heatmap(self, target_size): 104 | heatmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width)) 105 | 106 | for joints in self.joint_list: 107 | for idx, point in enumerate(joints): 108 | if point[0] < 0 or point[1] < 0: 109 | continue 110 | CocoMetadata.put_heatmap(heatmap, idx, point, self.sigma) 111 | 112 | heatmap = heatmap.transpose((1, 2, 0)) 113 | 114 | # background 115 | heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0) 116 | 117 | if target_size: 118 | heatmap = cv2.resize(heatmap, target_size, interpolation=cv2.INTER_AREA) 119 | 120 | return heatmap 121 | 122 | @staticmethod 123 | def put_heatmap(heatmap, plane_idx, center, sigma): 124 | center_x, center_y = center 125 | _, height, width = heatmap.shape[:3] 126 | 127 | th = 4.6052 128 | delta = math.sqrt(th * 2) 129 | 130 | x0 = int(max(0, center_x - delta * sigma)) 131 | y0 = int(max(0, center_y - delta * sigma)) 132 | 133 | x1 = int(min(width, center_x + delta * sigma)) 134 | y1 = int(min(height, center_y + delta * sigma)) 135 | 136 | for y in range(y0, y1): 137 | for x in range(x0, x1): 138 | d = (x - center_x) ** 2 + (y - center_y) ** 2 139 | exp = d / 2.0 / sigma / sigma 140 | if exp > th: 141 | continue 142 | heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp)) 143 | heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0) 144 | 145 | def get_vectormap(self, target_size): 146 | vectormap = np.zeros((CocoMetadata.__coco_parts*2, self.height, self.width)) 147 | countmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width)) 148 | for joints in self.joint_list: 149 | for plane_idx, (j_idx1, j_idx2) in enumerate(CocoMetadata.__coco_vecs): 150 | j_idx1 -= 1 151 | j_idx2 -= 1 152 | 153 | center_from = joints[j_idx1] 154 | center_to = joints[j_idx2] 155 | 156 | if center_from[0] < -100 or center_from[1] < -100 or center_to[0] < -100 or center_to[1] < -100: 157 | continue 158 | 159 | CocoMetadata.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to) 160 | 161 | vectormap = vectormap.transpose((1, 2, 0)) 162 | nonzeros = np.nonzero(countmap) 163 | for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]): 164 | if countmap[p][y][x] <= 0: 165 | continue 166 | vectormap[y][x][p*2+0] /= countmap[p][y][x] 167 | vectormap[y][x][p*2+1] /= countmap[p][y][x] 168 | 169 | if target_size: 170 | vectormap = cv2.resize(vectormap, target_size, interpolation=cv2.INTER_AREA) 171 | 172 | return vectormap 173 | 174 | @staticmethod 175 | def put_vectormap(vectormap, countmap, plane_idx, center_from, center_to, threshold=8): 176 | _, height, width = vectormap.shape[:3] 177 | 178 | vec_x = center_to[0] - center_from[0] 179 | vec_y = center_to[1] - center_from[1] 180 | 181 | min_x = max(0, int(min(center_from[0], center_to[0]) - threshold)) 182 | min_y = max(0, int(min(center_from[1], center_to[1]) - threshold)) 183 | 184 | max_x = min(width, int(max(center_from[0], center_to[0]) + threshold)) 185 | max_y = min(height, int(max(center_from[1], center_to[1]) + threshold)) 186 | 187 | norm = math.sqrt(vec_x ** 2 + vec_y ** 2) 188 | if norm == 0: 189 | return 190 | 191 | vec_x /= norm 192 | vec_y /= norm 193 | 194 | for y in range(min_y, max_y): 195 | for x in range(min_x, max_x): 196 | bec_x = x - center_from[0] 197 | bec_y = y - center_from[1] 198 | dist = abs(bec_x * vec_y - bec_y * vec_x) 199 | 200 | if dist > threshold: 201 | continue 202 | 203 | countmap[plane_idx][y][x] += 1 204 | 205 | vectormap[plane_idx*2+0][y][x] = vec_x 206 | vectormap[plane_idx*2+1][y][x] = vec_y 207 | 208 | 209 | class CocoPoseLMDB(RNGDataFlow): 210 | __valid_i = 2745 211 | __max_key = 121745 212 | 213 | @staticmethod 214 | def display_image(inp, heatmap, vectmap, as_numpy=False): 215 | if as_numpy: 216 | mpl.use('Agg') 217 | import matplotlib.pyplot as plt 218 | 219 | fig = plt.figure() 220 | a = fig.add_subplot(2, 2, 1) 221 | a.set_title('Image') 222 | plt.imshow(CocoPoseLMDB.get_bgimg(inp)) 223 | 224 | a = fig.add_subplot(2, 2, 2) 225 | a.set_title('Heatmap') 226 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5) 227 | tmp = np.amax(heatmap, axis=2) 228 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5) 229 | plt.colorbar() 230 | 231 | tmp2 = vectmap.transpose((2, 0, 1)) 232 | tmp2_odd = np.amax(np.absolute(tmp2[::2, :, :]), axis=0) 233 | tmp2_even = np.amax(np.absolute(tmp2[1::2, :, :]), axis=0) 234 | 235 | a = fig.add_subplot(2, 2, 3) 236 | a.set_title('Vectormap-x') 237 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 238 | plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5) 239 | plt.colorbar() 240 | 241 | a = fig.add_subplot(2, 2, 4) 242 | a.set_title('Vectormap-y') 243 | plt.imshow(CocoPoseLMDB.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 244 | plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5) 245 | plt.colorbar() 246 | 247 | if not as_numpy: 248 | plt.show() 249 | else: 250 | fig.canvas.draw() 251 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 252 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 253 | fig.clear() 254 | plt.close() 255 | return data 256 | 257 | @staticmethod 258 | def get_bgimg(inp, target_size=None): 259 | if target_size: 260 | inp = cv2.resize(inp, target_size, interpolation=cv2.INTER_AREA) 261 | inp = cv2.cvtColor(((inp + 1.0) * (255.0 / 2.0)).astype(np.uint8), cv2.COLOR_BGR2RGB) 262 | return inp 263 | 264 | def __init__(self, path, is_train=True, decode_img=True, only_idx=-1): 265 | self.is_train = is_train 266 | self.decode_img = decode_img 267 | self.only_idx = only_idx 268 | self.env = lmdb.open(path, map_size=int(1e12), readonly=True) 269 | self.txn = self.env.begin(buffers=True) 270 | pass 271 | 272 | def size(self): 273 | if self.is_train: 274 | return CocoPoseLMDB.__max_key - CocoPoseLMDB.__valid_i 275 | else: 276 | return CocoPoseLMDB.__valid_i 277 | 278 | def get_data(self): 279 | idxs = np.arange(self.size()) 280 | if self.is_train: 281 | idxs += CocoPoseLMDB.__valid_i 282 | self.rng.shuffle(idxs) 283 | else: 284 | pass 285 | 286 | for idx in idxs: 287 | datum = Datum() 288 | if self.only_idx < 0: 289 | s = self.txn.get(('%07d' % idx).encode('utf-8')) 290 | else: 291 | s = self.txn.get(('%07d' % self.only_idx).encode('utf-8')) 292 | datum.ParseFromString(s) 293 | if isinstance(datum.data, bytes): 294 | data = np.fromstring(datum.data, dtype=np.uint8).reshape(datum.channels, datum.height, datum.width) 295 | else: 296 | data = np.fromstring(datum.data.tobytes(), dtype=np.uint8).reshape(datum.channels, datum.height, 297 | datum.width) 298 | if self.decode_img: 299 | img = data[:3].transpose((1, 2, 0)) 300 | else: 301 | img = None 302 | 303 | meta = CocoMetadata(idx, img, data[3], sigma=8.0) 304 | 305 | yield [meta] 306 | 307 | 308 | def get_dataflow(path, is_train): 309 | ds = CocoPoseLMDB(path, is_train) # read data from lmdb 310 | if is_train: 311 | ds = MapDataComponent(ds, pose_random_scale) 312 | ds = MapDataComponent(ds, pose_rotation) 313 | ds = MapDataComponent(ds, pose_flip) 314 | ds = MapDataComponent(ds, pose_resize_shortestedge_random) 315 | ds = MapDataComponent(ds, pose_crop_random) 316 | ds = MapData(ds, pose_to_img) 317 | augs = [ 318 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([ 319 | imgaug.BrightnessScale((0.6, 1.4), clip=False), 320 | imgaug.Contrast((0.7, 1.4), clip=False), 321 | imgaug.GaussianBlur(max_size=3) 322 | ]), 0.7), 323 | ] 324 | ds = AugmentImageComponent(ds, augs) 325 | else: 326 | ds = MapDataComponent(ds, pose_resize_shortestedge_fixed) 327 | ds = MapDataComponent(ds, pose_crop_center) 328 | ds = MapData(ds, pose_to_img) 329 | 330 | ds = PrefetchData(ds, 1000, multiprocessing.cpu_count()) 331 | 332 | return ds 333 | 334 | 335 | def get_dataflow_batch(path, is_train, batchsize): 336 | ds = get_dataflow(path, is_train) 337 | ds = BatchData(ds, batchsize) 338 | ds = PrefetchData(ds, 10, 2) 339 | 340 | return ds 341 | 342 | 343 | class DataFlowToQueue(threading.Thread): 344 | def __init__(self, ds, placeholders, queue_size=5): 345 | super().__init__() 346 | self.daemon = True 347 | 348 | self.ds = ds 349 | self.placeholders = placeholders 350 | self.queue = tf.FIFOQueue(queue_size, [ph.dtype for ph in placeholders], shapes=[ph.get_shape() for ph in placeholders]) 351 | self.op = self.queue.enqueue(placeholders) 352 | self.close_op = self.queue.close(cancel_pending_enqueues=True) 353 | 354 | self._coord = None 355 | self._sess = None 356 | 357 | self.last_dp = None 358 | 359 | @contextmanager 360 | def default_sess(self): 361 | if self._sess: 362 | with self._sess.as_default(): 363 | yield 364 | else: 365 | logging.warning("DataFlowToQueue {} wasn't under a default session!".format(self.name)) 366 | yield 367 | 368 | def size(self): 369 | return self.queue.size() 370 | 371 | def start(self): 372 | self._sess = tf.get_default_session() 373 | super().start() 374 | 375 | def set_coordinator(self, coord): 376 | self._coord = coord 377 | 378 | def run(self): 379 | with self.default_sess(): 380 | try: 381 | while not self._coord.should_stop(): 382 | try: 383 | self.ds.reset_state() 384 | while True: 385 | for dp in self.ds.get_data(): 386 | feed = dict(zip(self.placeholders, dp)) 387 | self.op.run(feed_dict=feed) 388 | self.last_dp = dp 389 | except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): 390 | logging.error('err type1, placeholders={}'.format(self.placeholders)) 391 | sys.exit(-1) 392 | except Exception as e: 393 | logging.error('err type2, err={}, placeholders={}'.format(str(e), self.placeholders)) 394 | if isinstance(e, RuntimeError) and 'closed Session' in str(e): 395 | pass 396 | else: 397 | logging.exception("Exception in {}:{}".format(self.name, str(e))) 398 | sys.exit(-1) 399 | except Exception as e: 400 | logging.exception("Exception in {}:{}".format(self.name, str(e))) 401 | finally: 402 | try: 403 | self.close_op.run() 404 | except Exception: 405 | pass 406 | logging.info("{} Exited.".format(self.name)) 407 | 408 | def dequeue(self): 409 | return self.queue.dequeue() 410 | 411 | 412 | if __name__ == '__main__': 413 | import os 414 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 415 | 416 | from pose_augment import set_network_input_wh 417 | set_network_input_wh(368, 368) 418 | 419 | # df = get_dataflow('/data/public/rw/coco-pose-estimation-lmdb/', False) 420 | df = get_dataflow('/data/public/rw/coco-pose-estimation-lmdb/', True) 421 | 422 | # input_node = tf.placeholder(tf.float32, shape=(None, 368, 368, 3), name='image') 423 | with tf.Session() as sess: 424 | # net = CmuNetwork({'image': input_node}, trainable=False) 425 | # net.load('./models/numpy/openpose_coco.npy', sess) 426 | 427 | df.reset_state() 428 | t1 = time.time() 429 | for idx, dp in enumerate(df.get_data()): 430 | if idx == 0: 431 | for d in dp: 432 | logging.info('%d dp shape={}'.format(d.shape)) 433 | if idx % 100 == 0: 434 | print(time.time() - t1) 435 | t1 = time.time() 436 | CocoPoseLMDB.display_image(dp[0], dp[1], dp[2]) 437 | print(dp[1].shape, dp[2].shape) 438 | 439 | # pafMat, heatMat = sess.run(net.loss_last(), feed_dict={'image:0': [dp[0] / 128.0]}) 440 | # print(heatMat.shape, pafMat.shape) 441 | # CocoPoseLMDB.display_image(dp[0], heatMat[0], pafMat[0]) 442 | pass 443 | 444 | logging.info('done') 445 | --------------------------------------------------------------------------------