├── .gitignore ├── README.md ├── common.py ├── dataflow.py ├── hand1.jpg ├── hand2.jpg ├── hands_dataset.py ├── hands_metadata.py ├── images ├── 6264.jpg ├── detection-stages.png ├── hand_sample.png ├── hand_synth_sample1.jpg ├── hand_synth_sample2.jpg └── hand_synth_sample3.jpg ├── models └── numpy │ └── download ├── network_base.py ├── network_cmuhand.py ├── networks.py ├── render_data.py ├── requirements.txt ├── run_checkpoint.py ├── synthhands.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # models 105 | *ckpt* 106 | *.npy 107 | timeline*.json 108 | 109 | # project customs 110 | tmp 111 | models/graph/cmu*/*.pb 112 | models/train/*/checkpoint 113 | models/train/*/*.pb 114 | models/train/*/model*.data-* 115 | models/train/*/model*.index 116 | models/train/*/model*.meta 117 | models/train/*/events* 118 | models/pretrained/resnet_v2_101/eval.graph 119 | models/pretrained/resnet_v2_101/train.graph 120 | src/pafprocess/pafprocess.py 121 | src/pafprocess/pafprocess_wrap.cpp 122 | src/pafprocess/pafprocess_wrap.cxx 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hand_detector_train 2 | 3 | Replicating the Openpose hand detection algorithm, training a similar Convolutional Neural Network using Tensorflow. 4 | 5 | Original Openpose Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose 6 | 7 | Read the Medium story explaining this repository on: https://medium.com/@apofeniaco/training-a-hand-detector-like-the-openpose-one-in-tensorflow-45c5177d6679 8 | 9 | ![Detection example](./images/detection-stages.png) 10 | 11 | ## Install 12 | 13 | ### Dependencies 14 | 15 | You need dependencies below. 16 | 17 | - python3 18 | - tensorflow 1.4.1+ 19 | - opencv3, protobuf, python3-tk 20 | 21 | ### Install 22 | 23 | Clone the repo and install 3rd-party libraries. 24 | 25 | ```bash 26 | $ git clone https://github.com/ortegatron/hand_detector_train.git 27 | $ cd hand_detector_train 28 | $ pip3 install -r requirements.txt 29 | ``` 30 | 31 | Download openpose_vgg16.npy from http://www.mediafire.com/file/7e73ddj31rzw6qq/openpose_vgg16.npy and save into hand_detector_train/models/numpy. The VGG-16 first layers are taken from there to do the feature extraction. 32 | 33 | Download the [Hands from Synthetic Data Dataset](http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels_synth.zip), extract somewhere on your disk. We will refer to this folder as $HANDS_SNYTH_PATH. 34 | 35 | ### Training 36 | 37 | To start the training: 38 | 39 | ``` 40 | $ python3 train.py --datapath=$HANDS_SNYTH_PATH 41 | 42 | [2019-08-10 03:24:51,727] [train] [INFO] define model- 43 | [2019-08-10 03:24:52,163] [train] [INFO] model weights initialization 44 | [2019-08-10 03:24:54,351] [train] [INFO] Restore pretrained weights... ./models/numpy/openpose_vgg16.npy 45 | [2019-08-10 03:25:02,852] [train] [INFO] Restore pretrained weights...Done 46 | [2019-08-10 03:25:02,853] [train] [INFO] prepare file writer 47 | [2019-08-10 03:25:04,574] [train] [INFO] prepare coordinator 48 | [2019-08-10 03:25:04,577] [train] [INFO] Training Started. 49 | ``` 50 | This will use the default parameters for the training. Check the ArgumentParser of train.py to see the list of parameters you can change. 51 | 52 | While training, the script will give for each 500 steps the following output: 53 | 54 | ``` 55 | [2019-08-10 03:51:07,374] [train] [INFO] epoch=0.53 step=1000, 40.9523 examples/sec lr=0.000100, loss=155.03, loss_ll_heat=30.5468 56 | ``` 57 | This shows the current epoch and step (each epoch has 121745 / batchsize steps), the processing speed, the learing rate being used, the current loss value, and the loss for the last layer. For each 2000 steps a checkpoint of the model is going to be saved on the models/train/test folder. Training will run until it reaches the max_epoch, by default set to 600. This is a lot, but you can stop the training at any moment just by killing the script. 58 | 59 | Once you are happy with the loss value gotten while training, stop it and the last saved checkpoint will be available on models/train/test. Checkpoints are made of three files: 60 | - model_latest-XXXX.meta: The training graph structure is saved here. 61 | - model_latest-XXXX.index: Stores metadata about the tensors 62 | - model_latest-XXXX.data-00000-of-00001: Stores the values of all variables on the network. 63 | 64 | You can see more information about the training process runing tensorboard on the log folder. It will show you loss graphs, and the current results on the set of validation images. 65 | 66 | ``` 67 | $ tensorboard --logdir=models/train/test/ 68 | ``` 69 | 70 | While it may be tempting to convert those three files into a .pb file and use it, It won't be a good idea, since the saved graph is the training graph, containing the batch tensors, the heatmap tensor for training, etc. We need to generate a fresh graph, containing only the tensors we are going to use on inference time. 71 | 72 | #### Freezing the graph 73 | 74 | First run: 75 | 76 | ``` 77 | $ python3 run_chekpoint.py 78 | ``` 79 | 80 | This will generate the graph and save it on tmp/graph_definition.pb. 81 | 82 | Now we can freeze the graph with the saved checkpoint by doing: 83 | 84 | ``` 85 | $ python3 -m tensorflow.python.tools.freeze_graph \ 86 | --input_graph=./tmp/graph_definition.pb \ 87 | --input_checkpoint=./models/train/test/model_latest-XXXXX \ 88 | --output_graph=./models/frozengraph.pb --output_node_names="Openpose/out" 89 | ``` 90 | Where model_latest-XXXXX is the name for the saved checkpoint you want to use. This will freeze the graph with the same variable's value as the saved checkpoint, and save it on ./models/frozengraph.pb. 91 | 92 | #### Testing the detector 93 | 94 | To test the trained detector: 95 | 96 | ``` 97 | $ python3 test.py --graph-path=$GRAPH_PATH --image-path=$IMAGE_PATH 98 | ``` 99 | 100 | It will show the detected belief maps on top of the image. 101 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import tensorflow as tf 4 | import cv2 5 | 6 | 7 | regularizer_conv = 0.004 8 | regularizer_dsconv = 0.0004 9 | batchnorm_fused = True 10 | activation_fn = tf.nn.relu 11 | 12 | 13 | def read_imgfile(path, width=None, height=None): 14 | val_image = cv2.imread(path, cv2.IMREAD_COLOR) 15 | if width is not None and height is not None: 16 | val_image = cv2.resize(val_image, (width, height)) 17 | return val_image 18 | 19 | 20 | def get_sample_images(w, h): 21 | val_image = [ 22 | read_imgfile('images/hand_sample.png', w, h), 23 | read_imgfile('images/hand_synth_sample1.jpg', w, h), 24 | read_imgfile('images/hand_synth_sample2.jpg', w, h), 25 | read_imgfile('images/hand_synth_sample3.jpg', w, h), 26 | ] 27 | return val_image 28 | 29 | 30 | def to_str(s): 31 | if not isinstance(s, str): 32 | return s.decode('utf-8') 33 | return s 34 | -------------------------------------------------------------------------------- /dataflow.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import threading 4 | import tensorflow as tf 5 | 6 | from tensorpack.dataflow.base import RNGDataFlow, DataFlowTerminated 7 | 8 | try: 9 | from StringIO import StringIO 10 | except ImportError: 11 | from io import StringIO 12 | 13 | from contextlib import contextmanager 14 | 15 | 16 | logging.getLogger("requests").setLevel(logging.WARNING) 17 | logger = logging.getLogger('pose_dataset') 18 | logger.setLevel(logging.INFO) 19 | 20 | class DataFlowToQueue(threading.Thread): 21 | def __init__(self, ds, placeholders, queue_size=5): 22 | super().__init__() 23 | self.daemon = True 24 | 25 | self.ds = ds 26 | self.placeholders = placeholders 27 | self.queue = tf.FIFOQueue(queue_size, [ph.dtype for ph in placeholders], shapes=[ph.get_shape() for ph in placeholders]) 28 | self.op = self.queue.enqueue(placeholders) 29 | self.close_op = self.queue.close(cancel_pending_enqueues=True) 30 | 31 | self._coord = None 32 | self._sess = None 33 | 34 | self.last_dp = None 35 | 36 | @contextmanager 37 | def default_sess(self): 38 | if self._sess: 39 | with self._sess.as_default(): 40 | yield 41 | else: 42 | logger.warning("DataFlowToQueue {} wasn't under a default session!".format(self.name)) 43 | yield 44 | 45 | def size(self): 46 | return self.queue.size() 47 | 48 | def start(self): 49 | self._sess = tf.get_default_session() 50 | super().start() 51 | 52 | def set_coordinator(self, coord): 53 | self._coord = coord 54 | 55 | def run(self): 56 | with self.default_sess(): 57 | try: 58 | while not self._coord.should_stop(): 59 | try: 60 | self.ds.reset_state() 61 | while True: 62 | for dp in self.ds.get_data(): 63 | feed = dict(zip(self.placeholders, dp)) 64 | self.op.run(feed_dict=feed) 65 | self.last_dp = dp 66 | except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): 67 | logger.error('err type1, placeholders={}'.format(self.placeholders)) 68 | sys.exit(-1) 69 | except Exception as e: 70 | logger.error('err type2, err={}, placeholders={}'.format(str(e), self.placeholders)) 71 | if isinstance(e, RuntimeError) and 'closed Session' in str(e): 72 | pass 73 | else: 74 | logger.exception("Exception in {}:{}".format(self.name, str(e))) 75 | sys.exit(-1) 76 | except Exception as e: 77 | logger.exception("Exception in {}:{}".format(self.name, str(e))) 78 | finally: 79 | try: 80 | self.close_op.run() 81 | except Exception: 82 | pass 83 | logger.info("{} Exited.".format(self.name)) 84 | 85 | def dequeue(self): 86 | return self.queue.dequeue() 87 | -------------------------------------------------------------------------------- /hand1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/hand1.jpg -------------------------------------------------------------------------------- /hand2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/hand2.jpg -------------------------------------------------------------------------------- /hands_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import multiprocessing 4 | import struct 5 | import sys 6 | import threading 7 | 8 | try: 9 | from StringIO import StringIO 10 | except ImportError: 11 | from io import StringIO 12 | 13 | from contextlib import contextmanager 14 | 15 | import os 16 | import random 17 | import requests 18 | import cv2 19 | import numpy as np 20 | import time 21 | 22 | import tensorflow as tf 23 | 24 | from tensorpack.dataflow import MultiThreadMapData 25 | from tensorpack.dataflow.image import MapDataComponent 26 | from tensorpack.dataflow.common import BatchData, MapData 27 | from tensorpack.dataflow.parallel import PrefetchData 28 | from tensorpack.dataflow.base import RNGDataFlow, DataFlowTerminated 29 | 30 | from numba import jit 31 | 32 | from dataflow import DataFlowToQueue 33 | from hands_metadata import HandsMetadata 34 | from synthhands import SynthHands 35 | 36 | logging.getLogger("requests").setLevel(logging.WARNING) 37 | logger = logging.getLogger('pose_dataset') 38 | logger.setLevel(logging.INFO) 39 | ch = logging.StreamHandler() 40 | ch.setLevel(logging.DEBUG) 41 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 42 | ch.setFormatter(formatter) 43 | logger.addHandler(ch) 44 | 45 | def read_image_url(metas): 46 | for meta in metas: 47 | meta.img = cv2.imread(meta.img_url, cv2.IMREAD_COLOR) 48 | if meta.img is None: 49 | logger.warning('image not read, path=%s' % meta.img_url) 50 | raise Exception() 51 | return metas 52 | 53 | def get_dataflow(path, is_train): 54 | ds = SynthHands(path, is_train) # read data from lmdb 55 | if is_train: 56 | ds = MapData(ds, read_image_url) 57 | ds = MapData(ds, pose_to_img) 58 | ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1) 59 | else: 60 | ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000) 61 | ds = MapData(ds, pose_to_img) 62 | ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4) 63 | return ds 64 | 65 | def _get_dataflow_onlyread(path, is_train): 66 | ds = SynthHands(path, is_train) # read data from lmdb 67 | ds = MapData(ds, read_image_url) 68 | ds = MapData(ds, pose_to_img) 69 | # ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4) 70 | return ds 71 | 72 | def get_dataflow_batch(path, is_train, batchsize): 73 | ds = get_dataflow(path, is_train) 74 | ds = BatchData(ds, batchsize) 75 | return ds 76 | 77 | 78 | _network_w = 368 79 | _network_h = 368 80 | _scale = 8 81 | 82 | def pose_to_img(meta_l): 83 | global _network_w, _network_h, _scale 84 | return [ 85 | meta_l[0].img.astype(np.float16), 86 | meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)) 87 | ] 88 | 89 | 90 | if __name__ == '__main__': 91 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 92 | 93 | 94 | # df = get_dataflow('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/') 95 | df = _get_dataflow_onlyread('/home/marcelo/hands/hand_labels_synth', True) 96 | # df = get_dataflow('/root/coco/annotations', False, img_path='http://gpu-twg.kakaocdn.net/braincloud/COCO/') 97 | 98 | from tensorpack.dataflow.common import TestDataSpeed 99 | TestDataSpeed(df).start() 100 | 101 | with tf.Session() as sess: 102 | df.reset_state() 103 | t1 = time.time() 104 | for idx, dp in enumerate(df.get_data()): 105 | if idx == 0: 106 | for d in dp: 107 | logger.info('%d dp shape={}'.format(d.shape)) 108 | print(time.time() - t1) 109 | t1 = time.time() 110 | SynthHands.display_image(dp[0], dp[1].astype(np.float32)) 111 | pass 112 | 113 | logger.info('done') 114 | -------------------------------------------------------------------------------- /hands_metadata.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from numba import jit 8 | 9 | 10 | class HandsMetadata: 11 | __hand_parts = 22 12 | def __init__(self, idx, img_url ,img_meta, annotations, sigma): 13 | self.idx = idx 14 | self.img_url = img_url 15 | self.sigma = sigma 16 | self.height = int(img_meta['height']) 17 | self.width = int(img_meta['width']) 18 | joint_list = [] 19 | for ann in annotations: 20 | kp = np.array(ann['keypoints']) 21 | xs = kp[:,0] 22 | ys = kp[:,1] 23 | vs = kp[:,2] 24 | joint_list.append([(x, y) if v >= 1 else (-1000, -1000) for x, y, v in zip(xs, ys, vs)]) 25 | 26 | self.joint_list = [] 27 | transform = list(zip( 28 | [1, 6, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4], 29 | [1, 7, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4] 30 | )) 31 | for prev_joint in joint_list: 32 | new_joint = [] 33 | for idx1, idx2 in transform: 34 | j1 = prev_joint[idx1-1] 35 | j2 = prev_joint[idx2-1] 36 | 37 | if j1[0] <= 0 or j1[1] <= 0 or j2[0] <= 0 or j2[1] <= 0: 38 | new_joint.append((-1000, -1000)) 39 | else: 40 | new_joint.append(((j1[0] + j2[0]) / 2, (j1[1] + j2[1]) / 2)) 41 | 42 | new_joint.append((-1000, -1000)) 43 | self.joint_list.append(new_joint) 44 | 45 | @jit 46 | def get_heatmap(self, target_size): 47 | heatmap = np.zeros((HandsMetadata.__hand_parts, self.height, self.width), dtype=np.float32) 48 | 49 | for joints in self.joint_list: 50 | for idx, point in enumerate(joints): 51 | if point[0] < 0 or point[1] < 0: 52 | continue 53 | HandsMetadata.put_heatmap(heatmap, idx, point, self.sigma) 54 | 55 | 56 | heatmap = heatmap.transpose((1, 2, 0)) 57 | 58 | # background 59 | heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0) 60 | 61 | if target_size: 62 | heatmap = cv2.resize(heatmap, target_size, interpolation=cv2.INTER_AREA) 63 | 64 | return heatmap.astype(np.float16) 65 | 66 | @staticmethod 67 | @jit(nopython=True) 68 | def put_heatmap(heatmap, plane_idx, center, sigma): 69 | center_x, center_y = center 70 | _, height, width = heatmap.shape[:3] 71 | 72 | th = 4.6052 73 | delta = math.sqrt(th * 2) 74 | 75 | x0 = int(max(0, center_x - delta * sigma)) 76 | y0 = int(max(0, center_y - delta * sigma)) 77 | 78 | x1 = int(min(width, center_x + delta * sigma)) 79 | y1 = int(min(height, center_y + delta * sigma)) 80 | 81 | for y in range(y0, y1): 82 | for x in range(x0, x1): 83 | d = (x - center_x) ** 2 + (y - center_y) ** 2 84 | exp = d / 2.0 / sigma / sigma 85 | if exp > th: 86 | continue 87 | heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp)) 88 | heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0) 89 | -------------------------------------------------------------------------------- /images/6264.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/images/6264.jpg -------------------------------------------------------------------------------- /images/detection-stages.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/images/detection-stages.png -------------------------------------------------------------------------------- /images/hand_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/images/hand_sample.png -------------------------------------------------------------------------------- /images/hand_synth_sample1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/images/hand_synth_sample1.jpg -------------------------------------------------------------------------------- /images/hand_synth_sample2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/images/hand_synth_sample2.jpg -------------------------------------------------------------------------------- /images/hand_synth_sample3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ortegatron/hand_detector_train/06851e55ebd4ec36001382990d18a883030304c2/images/hand_synth_sample3.jpg -------------------------------------------------------------------------------- /models/numpy/download: -------------------------------------------------------------------------------- 1 | Download openpose_vgg16.npy from http://www.mediafire.com/file/7e73ddj31rzw6qq/openpose_vgg16.npy 2 | -------------------------------------------------------------------------------- /network_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | 5 | import abc 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | 10 | import common 11 | 12 | DEFAULT_PADDING = 'SAME' 13 | 14 | 15 | _init_xavier = tf.contrib.layers.xavier_initializer() 16 | _init_norm = tf.truncated_normal_initializer(stddev=0.01) 17 | _init_zero = slim.init_ops.zeros_initializer() 18 | _l2_regularizer_00004 = tf.contrib.layers.l2_regularizer(0.00004) 19 | _l2_regularizer_convb = tf.contrib.layers.l2_regularizer(common.regularizer_conv) 20 | 21 | 22 | def layer(op): 23 | ''' 24 | Decorator for composable network layers. 25 | ''' 26 | 27 | def layer_decorated(self, *args, **kwargs): 28 | # Automatically set a name if not provided. 29 | name = kwargs.setdefault('name', self.get_unique_name(op.__name__)) 30 | # Figure out the layer inputs. 31 | if len(self.terminals) == 0: 32 | raise RuntimeError('No input variables found for layer %s.' % name) 33 | elif len(self.terminals) == 1: 34 | layer_input = self.terminals[0] 35 | else: 36 | layer_input = list(self.terminals) 37 | # Perform the operation and get the output. 38 | layer_output = op(self, layer_input, *args, **kwargs) 39 | # Add to layer LUT. 40 | self.layers[name] = layer_output 41 | # This output is now the input for the next layer. 42 | self.feed(layer_output) 43 | # Return self for chained calls. 44 | return self 45 | 46 | return layer_decorated 47 | 48 | 49 | class BaseNetwork(object): 50 | def __init__(self, inputs, trainable=True): 51 | # The input nodes for this network 52 | self.inputs = inputs 53 | # The current list of terminal nodes 54 | self.terminals = [] 55 | # Mapping from layer names to layers 56 | self.layers = dict(inputs) 57 | # If true, the resulting variables are set as trainable 58 | self.trainable = trainable 59 | # Switch variable for dropout 60 | self.use_dropout = tf.placeholder_with_default(tf.constant(1.0), 61 | shape=[], 62 | name='use_dropout') 63 | self.setup() 64 | 65 | @abc.abstractmethod 66 | def setup(self): 67 | '''Construct the network. ''' 68 | raise NotImplementedError('Must be implemented by the subclass.') 69 | 70 | def load(self, data_path, session, ignore_missing=True): 71 | ''' 72 | Load network weights. 73 | data_path: The path to the numpy-serialized network weights 74 | session: The current TensorFlow session 75 | ignore_missing: If true, serialized weights for missing layers are ignored. 76 | ''' 77 | data_dict = np.load(data_path, encoding='bytes').item() 78 | print(data_dict.keys()) 79 | for op_name, param_dict in data_dict.items(): 80 | if isinstance(data_dict[op_name], np.ndarray): 81 | if 'RMSProp' in op_name: 82 | continue 83 | with tf.variable_scope('', reuse=True): 84 | var = tf.get_variable(op_name.replace(':0', '')) 85 | try: 86 | session.run(var.assign(data_dict[op_name])) 87 | except Exception as e: 88 | print(op_name) 89 | print(e) 90 | sys.exit(-1) 91 | else: 92 | op_name = common.to_str(op_name) 93 | # if op_name > 'conv4': 94 | # print(op_name, 'skipped') 95 | # continue 96 | # print(op_name, 'restored') 97 | with tf.variable_scope(op_name, reuse=True): 98 | for param_name, data in param_dict.items(): 99 | try: 100 | var = tf.get_variable(common.to_str(param_name)) 101 | session.run(var.assign(data)) 102 | except ValueError as e: 103 | print(e) 104 | if not ignore_missing: 105 | raise 106 | 107 | def feed(self, *args): 108 | '''Set the input(s) for the next operation by replacing the terminal nodes. 109 | The arguments can be either layer names or the actual layers. 110 | ''' 111 | assert len(args) != 0 112 | self.terminals = [] 113 | for fed_layer in args: 114 | try: 115 | is_str = isinstance(fed_layer, basestring) 116 | except NameError: 117 | is_str = isinstance(fed_layer, str) 118 | if is_str: 119 | try: 120 | fed_layer = self.layers[fed_layer] 121 | except KeyError: 122 | raise KeyError('Unknown layer name fed: %s' % fed_layer) 123 | self.terminals.append(fed_layer) 124 | return self 125 | 126 | def get_output(self, name=None): 127 | '''Returns the current network output.''' 128 | if not name: 129 | return self.terminals[-1] 130 | else: 131 | return self.layers[name] 132 | 133 | def get_tensor(self, name): 134 | return self.get_output(name) 135 | 136 | def get_unique_name(self, prefix): 137 | '''Returns an index-suffixed unique name for the given prefix. 138 | This is used for auto-generating layer names based on the type-prefix. 139 | ''' 140 | ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1 141 | return '%s_%d' % (prefix, ident) 142 | 143 | def make_var(self, name, shape, trainable=True): 144 | '''Creates a new TensorFlow variable.''' 145 | return tf.get_variable(name, shape, trainable=self.trainable & trainable, initializer=tf.contrib.layers.xavier_initializer()) 146 | 147 | def validate_padding(self, padding): 148 | '''Verifies that the padding is one of the supported ones.''' 149 | assert padding in ('SAME', 'VALID') 150 | 151 | @layer 152 | def normalize_vgg(self, input, name): 153 | # normalize input -0.5 ~ 0.5 154 | input = tf.multiply(input, 1./ 256.0, name=name + '_divide') 155 | input = tf.add(input, -0.5, name=name + '_subtract') 156 | return input 157 | 158 | @layer 159 | def normalize_mobilenet(self, input, name): 160 | input = tf.divide(input, 128.0, name=name + '_divide') 161 | input = tf.subtract(input, 1.0, name=name + '_subtract') 162 | return input 163 | 164 | @layer 165 | def normalize_nasnet(self, input, name): 166 | input = tf.divide(input, 255.0, name=name + '_divide') 167 | input = tf.subtract(input, 0.5, name=name + '_subtract') 168 | input = tf.multiply(input, 2.0, name=name + '_multiply') 169 | return input 170 | 171 | @layer 172 | def upsample(self, input, factor, name): 173 | if isinstance(factor, str): 174 | sh = tf.shape(self.get_tensor(factor))[1:3] 175 | else: 176 | sh = tf.shape(input)[1:3] * factor 177 | return tf.image.resize_bilinear(input, sh, align_corners=False, name=name) 178 | 179 | @layer 180 | def separable_conv(self, input, k_h, k_w, c_o, stride, name, relu=True, set_bias=True): 181 | with slim.arg_scope([slim.batch_norm], decay=0.999, fused=common.batchnorm_fused, is_training=self.trainable): 182 | output = slim.separable_convolution2d(input, 183 | num_outputs=None, 184 | stride=stride, 185 | trainable=self.trainable, 186 | depth_multiplier=1.0, 187 | kernel_size=[k_h, k_w], 188 | # activation_fn=common.activation_fn if relu else None, 189 | activation_fn=None, 190 | # normalizer_fn=slim.batch_norm, 191 | weights_initializer=_init_xavier, 192 | # weights_initializer=_init_norm, 193 | weights_regularizer=_l2_regularizer_00004, 194 | biases_initializer=None, 195 | padding=DEFAULT_PADDING, 196 | scope=name + '_depthwise') 197 | 198 | output = slim.convolution2d(output, 199 | c_o, 200 | stride=1, 201 | kernel_size=[1, 1], 202 | activation_fn=common.activation_fn if relu else None, 203 | weights_initializer=_init_xavier, 204 | # weights_initializer=_init_norm, 205 | biases_initializer=_init_zero if set_bias else None, 206 | normalizer_fn=slim.batch_norm, 207 | trainable=self.trainable, 208 | weights_regularizer=None, 209 | scope=name + '_pointwise') 210 | 211 | return output 212 | 213 | @layer 214 | def convb(self, input, k_h, k_w, c_o, stride, name, relu=True, set_bias=True, set_tanh=False): 215 | with slim.arg_scope([slim.batch_norm], decay=0.999, fused=common.batchnorm_fused, is_training=self.trainable): 216 | output = slim.convolution2d(input, c_o, kernel_size=[k_h, k_w], 217 | stride=stride, 218 | normalizer_fn=slim.batch_norm, 219 | weights_regularizer=_l2_regularizer_convb, 220 | weights_initializer=_init_xavier, 221 | # weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 222 | biases_initializer=_init_zero if set_bias else None, 223 | trainable=self.trainable, 224 | activation_fn=common.activation_fn if relu else None, 225 | scope=name) 226 | if set_tanh: 227 | output = tf.nn.sigmoid(output, name=name + '_extra_acv') 228 | return output 229 | 230 | @layer 231 | def conv(self, 232 | input, 233 | k_h, 234 | k_w, 235 | c_o, 236 | s_h, 237 | s_w, 238 | name, 239 | relu=True, 240 | padding=DEFAULT_PADDING, 241 | group=1, 242 | trainable=True, 243 | biased=True): 244 | # Verify that the padding is acceptable 245 | self.validate_padding(padding) 246 | # Get the number of channels in the input 247 | c_i = int(input.get_shape()[-1]) 248 | # Verify that the grouping parameter is valid 249 | assert c_i % group == 0 250 | assert c_o % group == 0 251 | # Convolution for a given input and kernel 252 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) 253 | with tf.variable_scope(name) as scope: 254 | kernel = self.make_var('weights', shape=[k_h, k_w, c_i / group, c_o], trainable=self.trainable & trainable) 255 | if group == 1: 256 | # This is the common-case. Convolve the input without any further complications. 257 | output = convolve(input, kernel) 258 | else: 259 | # Split the input into groups and then convolve each of them independently 260 | input_groups = tf.split(3, group, input) 261 | kernel_groups = tf.split(3, group, kernel) 262 | output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)] 263 | # Concatenate the groups 264 | output = tf.concat(3, output_groups) 265 | # Add the biases 266 | if biased: 267 | biases = self.make_var('biases', [c_o], trainable=self.trainable & trainable) 268 | output = tf.nn.bias_add(output, biases) 269 | 270 | if relu: 271 | # ReLU non-linearity 272 | output = tf.nn.relu(output, name=scope.name) 273 | return output 274 | 275 | @layer 276 | def relu(self, input, name): 277 | return tf.nn.relu(input, name=name) 278 | 279 | @layer 280 | def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 281 | self.validate_padding(padding) 282 | return tf.nn.max_pool(input, 283 | ksize=[1, k_h, k_w, 1], 284 | strides=[1, s_h, s_w, 1], 285 | padding=padding, 286 | name=name) 287 | 288 | @layer 289 | def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 290 | self.validate_padding(padding) 291 | return tf.nn.avg_pool(input, 292 | ksize=[1, k_h, k_w, 1], 293 | strides=[1, s_h, s_w, 1], 294 | padding=padding, 295 | name=name) 296 | 297 | @layer 298 | def lrn(self, input, radius, alpha, beta, name, bias=1.0): 299 | return tf.nn.local_response_normalization(input, 300 | depth_radius=radius, 301 | alpha=alpha, 302 | beta=beta, 303 | bias=bias, 304 | name=name) 305 | 306 | @layer 307 | def concat(self, inputs, axis, name): 308 | return tf.concat(axis=axis, values=inputs, name=name) 309 | 310 | @layer 311 | def add(self, inputs, name): 312 | return tf.add_n(inputs, name=name) 313 | 314 | @layer 315 | def fc(self, input, num_out, name, relu=True): 316 | with tf.variable_scope(name) as scope: 317 | input_shape = input.get_shape() 318 | if input_shape.ndims == 4: 319 | # The input is spatial. Vectorize it first. 320 | dim = 1 321 | for d in input_shape[1:].as_list(): 322 | dim *= d 323 | feed_in = tf.reshape(input, [-1, dim]) 324 | else: 325 | feed_in, dim = (input, input_shape[-1].value) 326 | weights = self.make_var('weights', shape=[dim, num_out]) 327 | biases = self.make_var('biases', [num_out]) 328 | op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b 329 | fc = op(feed_in, weights, biases, name=scope.name) 330 | return fc 331 | 332 | @layer 333 | def softmax(self, input, name): 334 | input_shape = map(lambda v: v.value, input.get_shape()) 335 | if len(input_shape) > 2: 336 | # For certain models (like NiN), the singleton spatial dimensions 337 | # need to be explicitly squeezed, since they're not broadcast-able 338 | # in TensorFlow's NHWC ordering (unlike Caffe's NCHW). 339 | if input_shape[1] == 1 and input_shape[2] == 1: 340 | input = tf.squeeze(input, squeeze_dims=[1, 2]) 341 | else: 342 | raise ValueError('Rank 2 tensor input expected for softmax!') 343 | return tf.nn.softmax(input, name=name) 344 | 345 | @layer 346 | def batch_normalization(self, input, name, scale_offset=True, relu=False): 347 | # NOTE: Currently, only inference is supported 348 | with tf.variable_scope(name) as scope: 349 | shape = [input.get_shape()[-1]] 350 | if scale_offset: 351 | scale = self.make_var('scale', shape=shape) 352 | offset = self.make_var('offset', shape=shape) 353 | else: 354 | scale, offset = (None, None) 355 | output = tf.nn.batch_normalization( 356 | input, 357 | mean=self.make_var('mean', shape=shape), 358 | variance=self.make_var('variance', shape=shape), 359 | offset=offset, 360 | scale=scale, 361 | # TODO: This is the default Caffe batch norm eps 362 | # Get the actual eps from parameters 363 | variance_epsilon=1e-5, 364 | name=name) 365 | if relu: 366 | output = tf.nn.relu(output) 367 | return output 368 | 369 | @layer 370 | def dropout(self, input, keep_prob, name): 371 | keep = 1 - self.use_dropout + (self.use_dropout * keep_prob) 372 | return tf.nn.dropout(input, keep, name=name) 373 | 374 | @layer 375 | def se_block(self, input_feature, name, ratio=8): 376 | """Contains the implementation of Squeeze-and-Excitation block. 377 | As described in https://arxiv.org/abs/1709.01507. 378 | ref : https://github.com/kobiso/SENet-tensorflow-slim/blob/master/nets/attention_module.py 379 | """ 380 | 381 | kernel_initializer = tf.contrib.layers.variance_scaling_initializer() 382 | bias_initializer = tf.constant_initializer(value=0.0) 383 | 384 | with tf.variable_scope(name): 385 | channel = input_feature.get_shape()[-1] 386 | # Global average pooling 387 | squeeze = tf.reduce_mean(input_feature, axis=[1, 2], keepdims=True) 388 | excitation = tf.layers.dense(inputs=squeeze, 389 | units=channel // ratio, 390 | activation=tf.nn.relu, 391 | kernel_initializer=kernel_initializer, 392 | bias_initializer=bias_initializer, 393 | name='bottleneck_fc') 394 | excitation = tf.layers.dense(inputs=excitation, 395 | units=channel, 396 | activation=tf.nn.sigmoid, 397 | kernel_initializer=kernel_initializer, 398 | bias_initializer=bias_initializer, 399 | name='recover_fc') 400 | scale = input_feature * excitation 401 | return scale 402 | -------------------------------------------------------------------------------- /network_cmuhand.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import network_base 4 | import tensorflow as tf 5 | 6 | class CmuHandNetwork(network_base.BaseNetwork): 7 | def setup(self): 8 | (self.feed('image') 9 | .normalize_vgg(name='preprocess') 10 | .conv(3, 3, 64, 1, 1, name='conv1_1') 11 | .conv(3, 3, 64, 1, 1, name='conv1_2') 12 | .max_pool(2, 2, 2, 2, name='pool1_stage1', padding='VALID') 13 | .conv(3, 3, 128, 1, 1, name='conv2_1') 14 | .conv(3, 3, 128, 1, 1, name='conv2_2') 15 | .max_pool(2, 2, 2, 2, name='pool2_stage1', padding='VALID') 16 | .conv(3, 3, 256, 1, 1, name='conv3_1') 17 | .conv(3, 3, 256, 1, 1, name='conv3_2') 18 | .conv(3, 3, 256, 1, 1, name='conv3_3') 19 | .conv(3, 3, 256, 1, 1, name='conv3_4') 20 | .max_pool(2, 2, 2, 2, name='pool3_stage1', padding='VALID') 21 | .conv(3, 3, 512, 1, 1, name='conv4_1') 22 | .conv(3, 3, 512, 1, 1, name='conv4_2') 23 | .conv(3, 3, 512, 1, 1, name='conv4_3') 24 | .conv(3, 3, 512, 1, 1, name='conv4_4') 25 | .conv(3, 3, 512, 1, 1, name='conv5_1') 26 | .conv(3, 3, 512, 1, 1, name='conv5_2') 27 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM') # ***** 28 | .conv(3, 3, 512, 1, 1, name='conv6_1_CPM') 29 | .conv(3, 3, 22, 1, 1, relu = False, name='conv6_2_CPM')) 30 | 31 | (self.feed('conv5_3_CPM', 32 | 'conv6_2_CPM',) 33 | .concat(3, name='concat_stage2') 34 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2') 35 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2') 36 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2') 37 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2') 38 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2') 39 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2') 40 | .conv(1, 1, 22, 1, 1, relu=False, name='Mconv7_stage2')) 41 | 42 | (self.feed('conv5_3_CPM', 43 | 'Mconv7_stage2',) 44 | .concat(3, name='concat_stage3') 45 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3') 46 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3') 47 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3') 48 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3') 49 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3') 50 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3') 51 | .conv(1, 1, 22, 1, 1, relu=False, name='Mconv7_stage3')) 52 | 53 | (self.feed('conv5_3_CPM', 54 | 'Mconv7_stage3',) 55 | .concat(3, name='concat_stage4') 56 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4') 57 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4') 58 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4') 59 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4') 60 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4') 61 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4') 62 | .conv(1, 1, 22, 1, 1, relu=False, name='Mconv7_stage4')) 63 | 64 | (self.feed('conv5_3_CPM', 65 | 'Mconv7_stage4',) 66 | .concat(3, name='concat_stage5') 67 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5') 68 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5') 69 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5') 70 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5') 71 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5') 72 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5') 73 | .conv(1, 1, 22, 1, 1, relu=False, name='Mconv7_stage5')) 74 | 75 | (self.feed('conv5_3_CPM', 76 | 'Mconv7_stage5',) 77 | .concat(3, name='concat_stage6') 78 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6') 79 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6') 80 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6') 81 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6') 82 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6') 83 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6')) 84 | 85 | (self.feed('Mconv6_stage6') 86 | .conv(1, 1, 22, 1, 1, relu=False, name='Mconv7_stage6')) 87 | 88 | with tf.variable_scope('Openpose'): 89 | (self.feed('Mconv7_stage6') 90 | .concat(3, name='out')) 91 | 92 | 93 | def loss_l2(self): 94 | l2s = [] 95 | for layer_name in self.layers.keys(): 96 | if 'Mconv7' in layer_name: 97 | l2s.append(self.layers[layer_name]) 98 | return l2s 99 | 100 | def loss_last(self): 101 | return self.get_output('Mconv7_stage6') 102 | 103 | def restorable_variables(self): 104 | return None 105 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, abspath 3 | import tensorflow as tf 4 | import network_base 5 | 6 | from network_cmuhand import CmuHandNetwork 7 | 8 | 9 | def _get_base_path(): 10 | if not os.environ.get('OPENPOSE_MODEL', ''): 11 | return './models' 12 | return os.environ.get('OPENPOSE_MODEL') 13 | 14 | def get_network(type, placeholder_input, sess_for_load=None, trainable=True): 15 | if type=="vgg": 16 | net = CmuHandNetwork({'image': placeholder_input}, trainable=trainable) 17 | pretrain_path = 'numpy/openpose_vgg16.npy' 18 | last_layer = 'Mconv7_stage6_L{aux}' 19 | else: 20 | raise Exception('Invalid Model Name.') 21 | 22 | pretrain_path_full = os.path.join(_get_base_path(), pretrain_path) 23 | if sess_for_load is not None: 24 | if not os.path.isfile(pretrain_path_full): 25 | raise Exception('Model file doesn\'t exist, path=%s' % pretrain_path_full) 26 | net.load(os.path.join(_get_base_path(), pretrain_path), sess_for_load) 27 | 28 | return net, pretrain_path_full, last_layer 29 | -------------------------------------------------------------------------------- /render_data.py: -------------------------------------------------------------------------------- 1 | HandKeypointsPairs = [[0,1],[1,2], [2,3], [3,4],[0,5],[5,6], [6,7],[7,8], [0,9], 2 | [9,10],[10,11], [11,12], [0,13],[13,14], [14,15], [15,16], [0,17], [17,18], [18,19], [19,20]] 3 | HandKeypointsColors = [[100, 100, 100], \ 4 | [100, 0, 0], \ 5 | [150, 0, 0], \ 6 | [200, 0, 0], \ 7 | [255, 0, 0], \ 8 | [100, 100, 0], \ 9 | [150, 150, 0], \ 10 | [200, 200, 0], \ 11 | [255, 255, 0], \ 12 | [0, 100, 50], \ 13 | [0, 150, 75], \ 14 | [0, 200, 100], \ 15 | [0, 255, 125], \ 16 | [0, 50, 100], \ 17 | [0, 75, 150], \ 18 | [0, 100, 200], \ 19 | [0, 125, 255], \ 20 | [100, 0, 100], \ 21 | [150, 0, 150], \ 22 | [200, 0, 200], \ 23 | [255, 0, 255]] 24 | HandKeypoints = 21 25 | 26 | def draw_humans(img, hand_list): 27 | img_copied = np.copy(img) 28 | image_h, image_w = img_copied.shape[:2] 29 | centers = {} 30 | for hand in hand_list: 31 | part_idxs = hand.keys() 32 | 33 | # draw point 34 | for i in range(HandKeypoints): 35 | if i not in part_idxs: 36 | continue 37 | part_coord = hand[i][0:2] 38 | center = (int(part_coord[1] * image_w + 0.5), int(part_coord[0] * image_h + 0.5)) 39 | centers[i] = center 40 | cv2.circle(img_copied, center, 3, HandKeypointsColors[i], thickness=3, lineType=8, shift=0) 41 | 42 | # draw line 43 | for pair_order, pair in enumerate(HandKeypointsPairs): 44 | if pair[0] not in part_idxs or pair[1] not in part_idxs: 45 | continue 46 | 47 | img_copied = cv2.line(img_copied, centers[pair[0]], centers[pair[1]], HandKeypointsColors[pair_order], 3) 48 | 49 | return img_copied 50 | 51 | def render_image(path, hand_list): 52 | img = cv2.imread(path) 53 | image = draw_humans(image, hand_list) 54 | image = cv2.resize(image, (368, 368), interpolation=cv2.INTER_AREA) 55 | cv2.imshow('result', image) 56 | cv2.waitKey(0) 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | dill 3 | fire 4 | matplotlib 5 | numba 6 | psutil 7 | pycocotools 8 | requests 9 | scikit-image 10 | scipy 11 | slidingwindow 12 | tqdm 13 | git+https://github.com/ppwwyyxx/tensorpack.git 14 | -------------------------------------------------------------------------------- /run_checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import tensorflow as tf 6 | from networks import get_network 7 | 8 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 9 | 10 | config = tf.ConfigProto() 11 | config.gpu_options.allocator_type = 'BFC' 12 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 13 | config.gpu_options.allow_growth = True 14 | 15 | 16 | if __name__ == '__main__': 17 | """ 18 | Use this script to just save graph and checkpoint. 19 | While training, checkpoints are saved. You can test them with this python code. 20 | """ 21 | parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor') 22 | args = parser.parse_args() 23 | 24 | w = h = None 25 | input_node = tf.placeholder(tf.float32, shape=(None, h, w, 3), name='image') 26 | net, pretrain_path, last_layer = get_network("vgg", input_node) 27 | 28 | with tf.Session(config=config) as sess: 29 | net.load(pretrain_path, sess, True) 30 | tf.train.write_graph(sess.graph_def, './tmp', 'graph_definition.pb', as_text=True) 31 | flops = tf.profiler.profile(None, cmd='graph', options=tf.profiler.ProfileOptionBuilder.float_operation()) 32 | print('FLOP = ', flops.total_float_ops / float(1e6)) 33 | -------------------------------------------------------------------------------- /synthhands.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import multiprocessing 4 | import struct 5 | import sys 6 | import json 7 | from PIL import Image 8 | try: 9 | from StringIO import StringIO 10 | except ImportError: 11 | from io import StringIO 12 | import os 13 | import cv2 14 | import numpy as np 15 | import time 16 | import glob 17 | 18 | from numba import jit 19 | 20 | from hands_metadata import HandsMetadata 21 | 22 | from tensorpack.dataflow.base import RNGDataFlow, DataFlowTerminated 23 | 24 | class SynthHands(RNGDataFlow): 25 | @staticmethod 26 | def display_image(inp, heatmap, as_numpy=False): 27 | 28 | import matplotlib.pyplot as plt 29 | 30 | fig = plt.figure() 31 | a = fig.add_subplot(2, 2, 1) 32 | a.set_title('Image') 33 | plt.imshow(SynthHands.get_bgimg(inp)) 34 | 35 | a = fig.add_subplot(2, 2, 2) 36 | a.set_title('Heatmap') 37 | plt.imshow(SynthHands.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5) 38 | tmp = np.amax(heatmap, axis=2) 39 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5) 40 | plt.colorbar() 41 | 42 | if not as_numpy: 43 | plt.show() 44 | else: 45 | fig.canvas.draw() 46 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 47 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 48 | fig.clear() 49 | plt.close() 50 | return data 51 | 52 | @staticmethod 53 | def get_bgimg(inp, target_size=None): 54 | inp = cv2.cvtColor(inp.astype(np.uint8), cv2.COLOR_BGR2RGB) 55 | if target_size: 56 | inp = cv2.resize(inp, target_size, interpolation=cv2.INTER_AREA) 57 | return inp 58 | 59 | def __init__(self, path, is_train=True): 60 | self.is_train = is_train 61 | self.path = path 62 | self.idxs = [] 63 | for i in [2,3]: 64 | synth_idx = "synth" + str(i) + "/" 65 | path = self.path + "/" + synth_idx 66 | json_files = [f for f in glob.glob(path + "*.json")] 67 | # Only saves /synthX/XXX 68 | self.idxs += [synth_idx + os.path.basename(j).split(".")[0] for j in json_files] 69 | 70 | def size(self): 71 | return len(self.idxs) 72 | 73 | def get_data(self): 74 | idxs = np.arange(self.size()) 75 | if self.is_train: 76 | self.rng.shuffle(idxs) 77 | else: 78 | pass 79 | for idx in self.idxs: 80 | json_path = self.path + "/" + idx + ".json" 81 | img_url = self.path + "/" + idx + ".jpg" 82 | 83 | img_meta = {} 84 | img_meta['width'] , img_meta['height'] = Image.open(img_url).size 85 | with open(json_path) as json_file: 86 | data = json.load(json_file) 87 | annotation = {} 88 | annotation['keypoints'] = data['hand_pts'] 89 | meta = HandsMetadata(idx, img_url, img_meta, [annotation], sigma=8.0) 90 | yield [meta] 91 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import sys 3 | import cv2 4 | import os 5 | from sys import platform 6 | import argparse 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | from networks import get_network 10 | from synthhands import SynthHands 11 | 12 | 13 | 14 | def read_imgfile(path, width, height): 15 | img = cv2.imread(path) 16 | if img.shape[0] != width or img.shape[1] != height: 17 | raise Exception('Image size must be 368x368!') 18 | return img.astype(np.float16) 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description='Test trained detector') 23 | parser.add_argument('--graph-path', type=str, default='./models/frozengraph.pb') 24 | parser.add_argument('--image-path', type=str, default='./images/hand_sample.png') 25 | args = parser.parse_args() 26 | 27 | img = read_imgfile(args.image_path,368,368) 28 | 29 | with tf.gfile.GFile(args.graph_path, 'rb') as f: 30 | graph_def = tf.GraphDef() 31 | graph_def.ParseFromString(f.read()) 32 | 33 | graph = tf.get_default_graph() 34 | tf.import_graph_def(graph_def, name='CmuHand') 35 | 36 | tf_config= tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) 37 | with tf.Session(config=tf_config) as sess: 38 | graph = tf.get_default_graph() 39 | inputs = graph.get_tensor_by_name('CmuHand/image:0') 40 | out = graph.get_tensor_by_name('CmuHand/Openpose/out:0') 41 | stage_5 = graph.get_tensor_by_name('CmuHand/Mconv7_stage5/BiasAdd:0') 42 | stage_4 = graph.get_tensor_by_name('CmuHand/Mconv7_stage4/BiasAdd:0') 43 | stage_3 = graph.get_tensor_by_name('CmuHand/Mconv7_stage3/BiasAdd:0') 44 | stage_2 = graph.get_tensor_by_name('CmuHand/Mconv7_stage2/BiasAdd:0') 45 | stage_1 = graph.get_tensor_by_name('CmuHand/conv6_2_CPM/BiasAdd:0') 46 | stages_outs = sess.run([stage_1, stage_2, stage_3, stage_4, stage_5, out], feed_dict={ 47 | inputs: [img] 48 | }) 49 | last_stage = stages_outs[-1][0] 50 | 51 | print("Belief maps for last stage, one for each keypoint plus and additional one for the background") 52 | fig, ax = plt.subplots(nrows=5, ncols=5) 53 | index = 0 54 | for row in ax: 55 | for col in row: 56 | col.imshow(last_stage[:,:,index]) 57 | index += 1 58 | if index >= 22: 59 | break 60 | plt.show() 61 | 62 | print("Draws last stage belief maps on top of image") 63 | test_result = SynthHands.display_image(img, last_stage, as_numpy=True) 64 | test_result = cv2.cvtColor(test_result, cv2.COLOR_RGB2BGR) 65 | cv2.imshow("Belief Maps",test_result) 66 | cv2.waitKey(0) 67 | 68 | 69 | print("Draws each stage belief maps") 70 | fig = plt.figure() 71 | for i, stage_out in enumerate(stages_outs): 72 | stage_out= stage_out[0] 73 | a = fig.add_subplot(3, 2, i+1) 74 | a.set_title('Stage #{}'.format(i+1)) 75 | plt.imshow(SynthHands.get_bgimg(img, target_size=(stage_out.shape[1], stage_out.shape[0])), alpha=0.5) 76 | tmp = np.amax(stage_out, axis=2) 77 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5) 78 | plt.colorbar() 79 | plt.show() 80 | 81 | cv2.destroyAllWindows() 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | from tqdm import tqdm 10 | 11 | from networks import get_network 12 | from synthhands import SynthHands 13 | from hands_dataset import get_dataflow_batch 14 | from dataflow import DataFlowToQueue 15 | from common import get_sample_images 16 | 17 | logger = logging.getLogger('train') 18 | logger.handlers.clear() 19 | logger.setLevel(logging.DEBUG) 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.DEBUG) 22 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 23 | ch.setFormatter(formatter) 24 | logger.addHandler(ch) 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser(description='Training codes for Openpose using Tensorflow') 29 | parser.add_argument('--datapath', type=str, default='../hand_labels_synth') 30 | parser.add_argument('--batchsize', type=int, default=64) 31 | parser.add_argument('--input-width', type=int, default=368) 32 | parser.add_argument('--input-height', type=int, default=368) 33 | parser.add_argument('--gpus', type=int, default=4) 34 | parser.add_argument('--checkpoint', type=str, default='') 35 | parser.add_argument('--lr', type=str, default='0.0001') 36 | parser.add_argument('--max-epoch', type=int, default=600) 37 | parser.add_argument('--tag', type=str, default='test') 38 | args = parser.parse_args() 39 | 40 | modelpath = logpath = './models/train/' 41 | 42 | scale = 8 43 | output_w, output_h = args.input_width // scale, args.input_height // scale 44 | logger.info('define model+') 45 | with tf.device(tf.DeviceSpec(device_type="CPU")): 46 | input_node = tf.placeholder(tf.float32, shape=(args.batchsize, args.input_height, args.input_width, 3), name='image') 47 | heatmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 22), name='heatmap') 48 | 49 | # prepare data 50 | df = get_dataflow_batch(args.datapath, True, args.batchsize) 51 | enqueuer = DataFlowToQueue(df, [input_node, heatmap_node], queue_size=100) 52 | q_inp, q_heat = enqueuer.dequeue() 53 | 54 | df_valid = get_dataflow_batch(args.datapath, False, args.batchsize) 55 | df_valid.reset_state() 56 | validation_cache = [] 57 | 58 | val_image = get_sample_images(args.input_width, args.input_height) 59 | logger.debug('tensorboard val image: %d' % len(val_image)) 60 | logger.debug(q_inp) 61 | logger.debug(q_heat) 62 | 63 | # define model for multi-gpu 64 | q_inp_split, q_heat_split = tf.split(q_inp, args.gpus), tf.split(q_heat, args.gpus) 65 | 66 | output_vectmap = [] 67 | output_heatmap = [] 68 | losses = [] 69 | last_losses_l2 = [] 70 | outputs = [] 71 | for gpu_id in range(args.gpus): 72 | with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)): 73 | with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)): 74 | net, pretrain_path, last_layer = get_network("vgg", q_inp_split[gpu_id]) 75 | if args.checkpoint: 76 | pretrain_path = args.checkpoint 77 | heat = net.loss_last() 78 | output_heatmap.append(heat) 79 | outputs.append(net.get_output()) 80 | 81 | l2s = net.loss_l2() 82 | for idx, l2 in enumerate(l2s): 83 | loss_l2 = tf.nn.l2_loss(tf.concat(l2, axis=0) - q_heat_split[gpu_id], name='loss_l2_stage%d_tower%d' % (idx, gpu_id)) 84 | losses.append(tf.reduce_mean([loss_l2])) 85 | 86 | last_losses_l2.append(loss_l2) 87 | 88 | outputs = tf.concat(outputs, axis=0) 89 | 90 | with tf.device(tf.DeviceSpec(device_type="GPU")): 91 | # define loss 92 | total_loss = tf.reduce_sum(losses) / args.batchsize 93 | total_loss_ll_heat = tf.reduce_sum(last_losses_l2) / args.batchsize 94 | 95 | # define optimizer 96 | step_per_epoch = 121745 // args.batchsize 97 | global_step = tf.Variable(0, trainable=False) 98 | if ',' not in args.lr: 99 | starter_learning_rate = float(args.lr) 100 | # learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 101 | # decay_steps=10000, decay_rate=0.33, staircase=True) 102 | learning_rate = tf.train.cosine_decay(starter_learning_rate, global_step, args.max_epoch * step_per_epoch, alpha=0.0) 103 | else: 104 | lrs = [float(x) for x in args.lr.split(',')] 105 | boundaries = [step_per_epoch * 5 * i for i, _ in range(len(lrs)) if i > 0] 106 | learning_rate = tf.train.piecewise_constant(global_step, boundaries, lrs) 107 | 108 | 109 | # optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=0.0005, momentum=0.9, epsilon=1e-10) 110 | optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-8) 111 | # optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.8, use_locking=True, use_nesterov=True) 112 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 113 | with tf.control_dependencies(update_ops): 114 | train_op = optimizer.minimize(total_loss, global_step, colocate_gradients_with_ops=True) 115 | logger.info('define model-') 116 | 117 | # define summary 118 | tf.summary.scalar("loss", total_loss) 119 | tf.summary.scalar("loss_lastlayer_heat", total_loss_ll_heat) 120 | tf.summary.scalar("queue_size", enqueuer.size()) 121 | tf.summary.scalar("lr", learning_rate) 122 | merged_summary_op = tf.summary.merge_all() 123 | 124 | valid_loss = tf.placeholder(tf.float32, shape=[]) 125 | valid_loss_ll_heat = tf.placeholder(tf.float32, shape=[]) 126 | sample_train = tf.placeholder(tf.float32, shape=(4, 640, 640, 3)) 127 | sample_valid = tf.placeholder(tf.float32, shape=(4, 640, 640, 3)) 128 | train_img = tf.summary.image('training sample', sample_train, 4) 129 | valid_img = tf.summary.image('validation sample', sample_valid, 12) 130 | valid_loss_t = tf.summary.scalar("loss_valid", valid_loss) 131 | merged_validate_op = tf.summary.merge([train_img, valid_img, valid_loss_t]) 132 | 133 | 134 | saver = tf.train.Saver(max_to_keep=1000) 135 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 136 | config.gpu_options.allow_growth = True 137 | with tf.Session(config=config) as sess: 138 | logger.info('model weights initialization') 139 | sess.run(tf.global_variables_initializer()) 140 | 141 | if args.checkpoint and os.path.isdir(args.checkpoint): 142 | logger.info('Restore from checkpoint...') 143 | # loader = tf.train.Saver(net.restorable_variables()) 144 | # loader.restore(sess, tf.train.latest_checkpoint(args.checkpoint)) 145 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint)) 146 | logger.info('Restore from checkpoint...Done') 147 | elif pretrain_path: 148 | logger.info('Restore pretrained weights... %s' % pretrain_path) 149 | if '.npy' in pretrain_path: 150 | net.load(pretrain_path, sess, True) 151 | else: 152 | try: 153 | loader = tf.train.Saver(net.restorable_variables(only_backbone=False)) 154 | loader.restore(sess, pretrain_path) 155 | except: 156 | logger.info('Restore only weights in backbone layers.') 157 | loader = tf.train.Saver(net.restorable_variables()) 158 | loader.restore(sess, pretrain_path) 159 | logger.info('Restore pretrained weights...Done') 160 | 161 | logger.info('prepare file writer') 162 | file_writer = tf.summary.FileWriter(os.path.join(logpath, args.tag), sess.graph) 163 | 164 | logger.info('prepare coordinator') 165 | coord = tf.train.Coordinator() 166 | enqueuer.set_coordinator(coord) 167 | enqueuer.start() 168 | 169 | logger.info('Training Started.') 170 | time_started = time.time() 171 | last_gs_num = last_gs_num2 = 0 172 | initial_gs_num = sess.run(global_step) 173 | 174 | last_log_epoch1 = last_log_epoch2 = -1 175 | while True: 176 | _, gs_num = sess.run([train_op, global_step]) 177 | curr_epoch = float(gs_num) / step_per_epoch 178 | 179 | if gs_num > step_per_epoch * args.max_epoch: 180 | break 181 | 182 | if gs_num - last_gs_num >= 500: 183 | train_loss, train_loss_ll_heat, lr_val, summary = sess.run([total_loss, total_loss_ll_heat, learning_rate, merged_summary_op]) 184 | 185 | # log of training loss / accuracy 186 | batch_per_sec = (gs_num - initial_gs_num) / (time.time() - time_started) 187 | logger.info('epoch=%.2f step=%d, %0.4f examples/sec lr=%f, loss=%g, loss_ll_heat=%g' % (gs_num / step_per_epoch, gs_num, batch_per_sec * args.batchsize, lr_val, train_loss, train_loss_ll_heat)) 188 | last_gs_num = gs_num 189 | 190 | if last_log_epoch1 < curr_epoch: 191 | file_writer.add_summary(summary, curr_epoch) 192 | last_log_epoch1 = curr_epoch 193 | 194 | if gs_num - last_gs_num2 >= 2000: 195 | # save weights 196 | saver.save(sess, os.path.join(modelpath, args.tag, 'model_latest'), global_step=global_step) 197 | 198 | average_loss = average_loss_ll_heat = 0 199 | total_cnt = 0 200 | 201 | if len(validation_cache) == 0: 202 | for images_test, heatmaps in tqdm(df_valid.get_data()): 203 | validation_cache.append((images_test, heatmaps)) 204 | df_valid.reset_state() 205 | del df_valid 206 | df_valid = None 207 | 208 | # log of test accuracy 209 | for images_test, heatmaps in validation_cache: 210 | lss, lss_ll_heat, vectmap_sample, heatmap_sample = sess.run( 211 | [total_loss, total_loss_ll_heat, output_vectmap, output_heatmap], 212 | feed_dict={q_inp: images_test, q_heat: heatmaps} 213 | ) 214 | average_loss += lss * len(images_test) 215 | average_loss_ll_heat += lss_ll_heat * len(images_test) 216 | total_cnt += len(images_test) 217 | 218 | logger.info('validation(%d) %s loss=%f, loss_ll_heat=%f' % (total_cnt, args.tag, average_loss / total_cnt, average_loss_ll_heat / total_cnt)) 219 | last_gs_num2 = gs_num 220 | 221 | sample_image = [enqueuer.last_dp[0][i] for i in range(4)] 222 | outputMat = sess.run( 223 | outputs, 224 | feed_dict={q_inp: np.array((sample_image + val_image) * max(1, (args.batchsize // 8)))} 225 | ) 226 | heatMat = outputMat[:, :, :, :19] 227 | 228 | sample_results = [] 229 | for i in range(len(sample_image)): 230 | test_result = SynthHands.display_image(sample_image[i], heatMat[i], as_numpy=True) 231 | test_result = cv2.resize(test_result, (640, 640)) 232 | test_result = test_result.reshape([640, 640, 3]).astype(float) 233 | sample_results.append(test_result) 234 | 235 | test_results = [] 236 | for i in range(len(val_image)): 237 | test_result = SynthHands.display_image(val_image[i], heatMat[len(sample_image) + i], as_numpy=True) 238 | test_result = cv2.resize(test_result, (640, 640)) 239 | test_result = test_result.reshape([640, 640, 3]).astype(float) 240 | test_results.append(test_result) 241 | 242 | # save summary 243 | summary = sess.run(merged_validate_op, feed_dict={ 244 | valid_loss: average_loss / total_cnt, 245 | valid_loss_ll_heat: average_loss_ll_heat / total_cnt, 246 | sample_valid: test_results, 247 | sample_train: sample_results 248 | }) 249 | if last_log_epoch2 < curr_epoch: 250 | file_writer.add_summary(summary, curr_epoch) 251 | last_log_epoch2 = curr_epoch 252 | 253 | saver.save(sess, os.path.join(modelpath, args.tag, 'model'), global_step=global_step) 254 | logger.info('optimization finished. %f' % (time.time() - time_started)) 255 | --------------------------------------------------------------------------------