├── .gitignore ├── CMakeLists.txt ├── Dockerfile ├── README.md ├── images ├── sitting.jpeg └── standing.jpeg ├── models ├── __init__.py ├── graph │ ├── cmu │ │ ├── __init__.py │ │ └── download.sh │ ├── mobilenet_thin │ │ ├── __init__.py │ │ ├── graph.pb │ │ ├── graph_freeze.pb │ │ └── graph_opt.pb │ └── retrained │ │ ├── retrained_v1.0 │ │ ├── retrained_graph.pb │ │ └── retrained_labels.txt │ │ └── retrained_v2.0 │ │ ├── retrained_graph.pb │ │ └── retrained_labels.txt ├── numpy │ └── download.sh └── pretrained │ ├── mobilenet_v1_0.50_224_2017_06_14 │ └── download.sh │ ├── mobilenet_v1_0.75_224_2017_06_14 │ └── download.sh │ ├── mobilenet_v1_1.0_224_2017_06_14 │ └── download.sh │ └── resnet_v2_101 │ └── download.sh ├── requirements.txt ├── requirements_alt.txt ├── run_webcam.py ├── scripts ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── label_image.cpython-36.pyc ├── broadcaster_ros.py ├── count_ops.py ├── evaluate.py ├── graph_pb2tb.py ├── label_image.py ├── quantize_graph.py ├── retrain.py ├── show_image.py └── visualization.py └── tf_pose ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── common.cpython-36.pyc ├── estimator.cpython-36.pyc ├── network_base.cpython-36.pyc ├── network_cmu.cpython-36.pyc ├── network_mobilenet.cpython-36.pyc ├── network_mobilenet_thin.cpython-36.pyc ├── network_personlab.cpython-36.pyc ├── networks.cpython-36.pyc └── runner.cpython-36.pyc ├── common.py ├── datum_pb2.py ├── estimator.py ├── eval.py ├── network_base.py ├── network_cmu.py ├── network_dsconv.py ├── network_mobilenet.py ├── network_mobilenet_thin.py ├── network_personlab.py ├── networks.py ├── pafprocess ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── pafprocess.cpython-36.pyc ├── _pafprocess.cp36-win_amd64.pyd ├── build │ └── temp.win-amd64-3.6 │ │ └── Release │ │ ├── _pafprocess.cp36-win_amd64.exp │ │ ├── _pafprocess.cp36-win_amd64.lib │ │ ├── pafprocess.obj │ │ └── pafprocess_wrap.obj ├── numpy.i ├── pafprocess.cpp ├── pafprocess.h ├── pafprocess.i ├── pafprocess.py ├── pafprocess_wrap.cpp ├── pafprocess_wrap.cxx └── setup.py ├── pose_augment.py ├── pose_datamaster.py ├── pose_dataset.py ├── pose_dataworker.py ├── pycocotools ├── __init__.py ├── _mask.pyx ├── coco.py ├── cocoeval.py └── mask.py ├── pystopwatch.py ├── runner.py ├── slidingwindow ├── ArrayUtils.py ├── Batching.py ├── Merging.py ├── RectangleUtils.py ├── SlidingWindow.py ├── WindowDistance.py └── __init__.py ├── slim ├── BUILD ├── README.md ├── WORKSPACE ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── datasets │ ├── __init__.py │ ├── build_imagenet_data.py │ ├── cifar10.py │ ├── dataset_factory.py │ ├── dataset_utils.py │ ├── download_and_convert_cifar10.py │ ├── download_and_convert_flowers.py │ ├── download_and_convert_imagenet.sh │ ├── download_and_convert_mnist.py │ ├── download_imagenet.sh │ ├── flowers.py │ ├── imagenet.py │ ├── imagenet_2012_validation_synset_labels.txt │ ├── imagenet_lsvrc_2015_synsets.txt │ ├── imagenet_metadata.txt │ ├── mnist.py │ ├── preprocess_imagenet_validation_data.py │ └── process_bounding_boxes.py ├── deployment │ ├── __init__.py │ ├── model_deploy.py │ └── model_deploy_test.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── export_inference_graph.py ├── export_inference_graph_test.py ├── nets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── resnet_utils.cpython-36.pyc │ │ └── resnet_v2.cpython-36.pyc │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── cyclegan.py │ ├── cyclegan_test.py │ ├── dcgan.py │ ├── dcgan_test.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── mobilenet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── conv_blocks.py │ │ ├── madds_top1_accuracy.png │ │ ├── mnet_v1_vs_v2_pixel1_latency.png │ │ ├── mobilenet.py │ │ ├── mobilenet_example.ipynb │ │ ├── mobilenet_v2.py │ │ └── mobilenet_v2_test.py │ ├── mobilenet_v1.md │ ├── mobilenet_v1.png │ ├── mobilenet_v1.py │ ├── mobilenet_v1_eval.py │ ├── mobilenet_v1_test.py │ ├── mobilenet_v1_train.py │ ├── nasnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nasnet.py │ │ ├── nasnet_test.py │ │ ├── nasnet_utils.py │ │ ├── nasnet_utils_test.py │ │ ├── pnasnet.py │ │ └── pnasnet_test.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── pix2pix.py │ ├── pix2pix_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── vgg.py │ └── vgg_test.py ├── preprocessing │ ├── __init__.py │ ├── cifarnet_preprocessing.py │ ├── inception_preprocessing.py │ ├── lenet_preprocessing.py │ ├── preprocessing_factory.py │ └── vgg_preprocessing.py ├── scripts │ ├── export_mobilenet.sh │ ├── finetune_inception_resnet_v2_on_flowers.sh │ ├── finetune_inception_v1_on_flowers.sh │ ├── finetune_inception_v3_on_flowers.sh │ ├── finetune_resnet_v1_50_on_flowers.sh │ ├── train_cifarnet_on_cifar10.sh │ └── train_lenet_on_mnist.sh ├── setup.py ├── slim_walkthrough.ipynb └── train_image_classifier.py ├── tensblur ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── smoother.cpython-36.pyc └── smoother.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Example user template template 3 | ### Example user template 4 | 5 | # IntelliJ project files 6 | .idea 7 | *.iml 8 | out 9 | gen 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(tfpose_ros) 3 | 4 | ## Add support for C++11, supported in ROS Kinetic and newer 5 | add_definitions(-std=c++11) 6 | 7 | find_package(catkin REQUIRED COMPONENTS 8 | roscpp 9 | rospy 10 | std_msgs 11 | message_generation 12 | ) 13 | 14 | # setup.py is called during catkin build 15 | catkin_python_setup() 16 | 17 | # Generate messages in the 'msg' folder 18 | add_message_files( 19 | FILES 20 | BodyPartElm.msg 21 | Person.msg 22 | Persons.msg 23 | ) 24 | 25 | generate_messages( 26 | DEPENDENCIES std_msgs 27 | ) 28 | 29 | catkin_package( 30 | CATKIN_DEPENDS rospy message_generation message_runtime 31 | ) 32 | 33 | install() 34 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6-stretch 2 | MAINTAINER Tanmay Thakur 3 | 4 | # Install Build Utilities 5 | RUN apt-get update && \ 6 | apt-get install -y gcc make apt-transport-https ca-certificates build-essential 7 | 8 | # Check Python Environment 9 | RUN python --version 10 | RUN pip --version 11 | 12 | # set the working directory for containers 13 | WORKDIR /usr/src/pose 14 | 15 | # Installing Dependencies 16 | COPY requirements.txt . 17 | RUN pip install --no-cache-dir -r requirements.txt 18 | 19 | # Copy all the files from the project’s root to the working directory 20 | COPY . . 21 | 22 | # Running Python Application 23 | CMD ["python", "run-webcam.py","--model=mobilenet_thin","--resize=432x368","--camera=0"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pose-estimation-detection 2 | 3 | Pose estimation & detection has been minimally implemented using the OpenPose implementation https://github.com/ildoonet/tf-pose-estimation with Tensorflow. For the binary classification of poses, namely the classes : sitting or standing, the model used, MobileNet (a CNN originally trained on the ImageNet Large Visual Recognition Challenge dataset), was retrained (final layer) on a dataset of ~1500 images of poses. 4 | 5 | The model is able to estimate the human poses as well as classify the current pose to a fairly good degree of accuracy. 6 | 7 | ### Demo 8 | 9 | **An alternative for improving the model along with deep learning is to include heuristics, in the form of calculation of the skeletal points’ relative distances from each other.** 10 | 11 | **FPS & estimation/detection varies with the CPU/GPU power.** 12 | 13 | ### Training Examples 14 | 15 | - For sitting pose 16 | ![alt text](/images/sitting.jpeg) 17 | 18 | - For standing pose 19 | ![alt text](/images/standing.jpeg) 20 | 21 | ### Dependencies 22 | 23 | The following are required : 24 | 25 | - python3 26 | - tensorflow 1.9.0 (works well even with CPU version) 27 | - opencv3 28 | - slim 29 | - slidingwindow 30 | - https://github.com/adamrehn/slidingwindow 31 | 32 | ### Cloning & installing dependencies 33 | 34 | ```bash 35 | $ git clone https://github.com/SyBorg91/pose-estimation-detection 36 | $ cd pose-estimation-detection 37 | $ pip3 install -r requirements.txt 38 | ``` 39 | 40 | ### Pose Estimation with realtime webcam feed 41 | 42 | ``` 43 | $ python run_webcam.py --model=mobilenet_thin --resize=432x368 --camera=0 44 | ``` 45 | 46 | Run the above command to start pose estimation with the onboard webcam. 47 | 48 | ## References 49 | 50 | ### OpenPose 51 | 52 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose 53 | 54 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation 55 | 56 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 57 | 58 | [4] Keras Openpose : https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation 59 | 60 | [5] Keras Openpose2 : https://github.com/kevinlin311tw/keras-openpose-reproduce 61 | 62 | ### Lifting from the deep 63 | 64 | [1] Arxiv Paper : https://arxiv.org/abs/1701.00295 65 | 66 | [2] https://github.com/DenisTome/Lifting-from-the-Deep-release 67 | 68 | ### Mobilenet 69 | 70 | [1] Original Paper : https://arxiv.org/abs/1704.04861 71 | 72 | [2] Pretrained model (Pose estimation) : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md 73 | 74 | [3] Retrained model (Pose detection) : https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/ 75 | 76 | ### Libraries 77 | 78 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack 79 | 80 | ### Tensorflow Tips 81 | 82 | [1] Freeze graph : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py 83 | 84 | [2] Optimize graph : https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /images/sitting.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/images/sitting.jpeg -------------------------------------------------------------------------------- /images/standing.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/images/standing.jpeg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/__init__.py -------------------------------------------------------------------------------- /models/graph/cmu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/graph/cmu/__init__.py -------------------------------------------------------------------------------- /models/graph/cmu/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "[download] model graph : cmu" 4 | DIR="$(cd "$(dirname "$0")" && pwd)" 5 | 6 | extract_download_url() { 7 | 8 | url=$( wget -q -O - $1 | grep -o 'http*://download[^"]*' | tail -n 1 ) 9 | echo "$url" 10 | 11 | } 12 | 13 | # if you need, uncomment this. 14 | # wget -c --tries=2 $( extract_download_url http://www.mediafire.com/file/1pyjsjl0p93x27c/graph_freeze.pb ) -O $DIR/graph_freeze.pb 15 | #wget -c --tries=2 $( extract_download_url http://www.mediafire.com/file/i72ll9k5i7x6qfh/graph.pb ) -O $DIR/graph.pb 16 | wget -c --tries=2 $( extract_download_url http://www.mediafire.com/file/qlzzr20mpocnpa3/graph_opt.pb ) -O $DIR/graph_opt.pb 17 | echo "[download] end" 18 | -------------------------------------------------------------------------------- /models/graph/mobilenet_thin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/graph/mobilenet_thin/__init__.py -------------------------------------------------------------------------------- /models/graph/mobilenet_thin/graph_freeze.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/graph/mobilenet_thin/graph_freeze.pb -------------------------------------------------------------------------------- /models/graph/mobilenet_thin/graph_opt.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/graph/mobilenet_thin/graph_opt.pb -------------------------------------------------------------------------------- /models/graph/retrained/retrained_v1.0/retrained_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/graph/retrained/retrained_v1.0/retrained_graph.pb -------------------------------------------------------------------------------- /models/graph/retrained/retrained_v1.0/retrained_labels.txt: -------------------------------------------------------------------------------- 1 | sitting 2 | standing 3 | -------------------------------------------------------------------------------- /models/graph/retrained/retrained_v2.0/retrained_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/models/graph/retrained/retrained_v2.0/retrained_graph.pb -------------------------------------------------------------------------------- /models/graph/retrained/retrained_v2.0/retrained_labels.txt: -------------------------------------------------------------------------------- 1 | sitting 2 | standing 3 | -------------------------------------------------------------------------------- /models/numpy/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "download videos" 4 | 5 | extract_download_url() { 6 | 7 | url=$( wget -q -O - $1 | grep -o 'http*://download[^"]*' | tail -n 1 ) 8 | echo "$url" 9 | 10 | } 11 | 12 | wget --continue $( extract_download_url http://www.mediafire.com/file/ropayv77vklvf56/openpose_coco.npy ) -O openpose_coco.npy 13 | wget --continue $( extract_download_url http://www.mediafire.com/file/7e73ddj31rzw6qq/openpose_vgg16.npy ) -O openpose_vgg16.npy 14 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.50_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extract_download_url() { 4 | 5 | url=$( wget -q -O - $1 | grep -o 'http*://download[^"]*' | tail -n 1 ) 6 | echo "$url" 7 | 8 | } 9 | 10 | wget --continue $( extract_download_url http://www.mediafire.com/file/meu73iq8rxlsd3g/mobilenet_v1_0.50_224.ckpt.data-00000-of-00001 ) -O mobilenet_v1_0.50_224.ckpt.data-00000-of-00001 11 | wget --continue $( extract_download_url http://www.mediafire.com/file/7u6iupfkcaxk5hx/mobilenet_v1_0.50_224.ckpt.index ) -O mobilenet_v1_0.50_224.ckpt.index 12 | wget --continue $( extract_download_url http://www.mediafire.com/file/zp8y4d0ytzharzz/mobilenet_v1_0.50_224.ckpt.meta ) -O mobilenet_v1_0.50_224.ckpt.meta 13 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.75_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extract_download_url() { 4 | 5 | url=$( wget -q -O - $1 | grep -o 'http*://download[^"]*' | tail -n 1 ) 6 | echo "$url" 7 | 8 | } 9 | 10 | wget --continue $( extract_download_url http://www.mediafire.com/file/kibz0x9e7h11ueb/mobilenet_v1_0.75_224.ckpt.data-00000-of-00001 ) -O mobilenet_v1_0.75_224.ckpt.data-00000-of-00001 11 | wget --continue $( extract_download_url http://www.mediafire.com/file/t8909eaikvc6ea2/mobilenet_v1_0.75_224.ckpt.index ) -O mobilenet_v1_0.75_224.ckpt.index 12 | wget --continue $( extract_download_url http://www.mediafire.com/file/6jjnbn1aged614x/mobilenet_v1_0.75_224.ckpt.meta ) -O mobilenet_v1_0.75_224.ckpt.meta 13 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_1.0_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extract_download_url() { 4 | 5 | url=$( wget -q -O - $1 | grep -o 'http*://download[^"]*' | tail -n 1 ) 6 | echo "$url" 7 | 8 | } 9 | 10 | wget --continue $( extract_download_url http://www.mediafire.com/file/oh6njnz9lgoqwdj/mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 ) -O mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 11 | wget --continue $( extract_download_url http://www.mediafire.com/file/61qln0tbac4ny9o/mobilenet_v1_1.0_224.ckpt.meta ) -O mobilenet_v1_1.0_224.ckpt.meta 12 | wget --continue $( extract_download_url http://www.mediafire.com/file/2111rh6tb5fl1lr/mobilenet_v1_1.0_224.ckpt.index ) -O mobilenet_v1_1.0_224.ckpt.index 13 | -------------------------------------------------------------------------------- /models/pretrained/resnet_v2_101/download.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | wget http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz 4 | tar -xvf resnet_v2_101_2017_04_14.tar.gz 5 | rm resnet_v2_101_2017_04_14.tar.gz 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | dill 3 | fire 4 | matplotlib 5 | psutil 6 | requests 7 | scikit-image 8 | scipy 9 | slidingwindow 10 | tqdm 11 | git+https://github.com/ppwwyyxx/tensorpack.git 12 | -------------------------------------------------------------------------------- /requirements_alt.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.36.1 2 | requests==2.22.0 3 | tensorpack==0.9.8 4 | setuptools==41.2.0 5 | six==1.12.0 6 | opencv_python==4.0.0.21 7 | tensorflow_gpu==1.12.0 8 | matplotlib==3.1.1 9 | psutil==5.6.6 10 | scipy==1.3.1 11 | ipython==7.8.0 12 | Pillow==6.2.1 13 | protobuf==3.10.0 14 | rospkg==1.1.10 15 | tensorflow==2.0.0 16 | -------------------------------------------------------------------------------- /run_webcam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from tf_pose.estimator import TfPoseEstimator 9 | from tf_pose.networks import get_graph_path, model_wh 10 | import scripts.label_image as label_img 11 | 12 | logger = logging.getLogger('TfPoseEstimator-WebCam') 13 | logger.setLevel(logging.DEBUG) 14 | ch = logging.StreamHandler() 15 | ch.setLevel(logging.DEBUG) 16 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | 20 | fps_time = 0 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='tf-pose-estimation realtime webcam') 24 | parser.add_argument('--camera', type=int, default=0) 25 | 26 | parser.add_argument('--resize', type=str, default='0x0', 27 | help='if provided, resize images before they are processed. default=0x0, Recommends : 432x368 or 656x368 or 1312x736 ') 28 | parser.add_argument('--resize-out-ratio', type=float, default=4.0, 29 | help='if provided, resize heatmaps before they are post-processed. default=1.0') 30 | 31 | parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet_thin') 32 | parser.add_argument('--show-process', type=bool, default=False, 33 | help='for debug purpose, if enabled, speed for inference is dropped.') 34 | args = parser.parse_args() 35 | 36 | logger.debug('initialization %s : %s' % (args.model, get_graph_path(args.model))) 37 | w, h = model_wh(args.resize) 38 | if w > 0 and h > 0: 39 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 40 | else: 41 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(432, 368)) 42 | logger.debug('cam read+') 43 | cam = cv2.VideoCapture(args.camera) 44 | ret_val, image = cam.read() 45 | logger.info('cam image=%dx%d' % (image.shape[1], image.shape[0])) 46 | 47 | # count = 0 48 | while True: 49 | 50 | logger.debug('+image processing+') 51 | ret_val, image = cam.read() 52 | 53 | logger.debug('+postprocessing+') 54 | humans = e.inference(image, resize_to_default=(w > 0 and h > 0), upsample_size=args.resize_out_ratio) 55 | img = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 56 | 57 | logger.debug('+classification+') 58 | # Getting only the skeletal structure (with white background) of the actual image 59 | image = np.zeros(image.shape,dtype=np.uint8) 60 | image.fill(255) 61 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 62 | 63 | # Classification 64 | pose_class = label_img.classify(image) 65 | 66 | logger.debug('+displaying+') 67 | cv2.putText(img, 68 | "Current predicted pose is : %s" %(pose_class), 69 | (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 70 | (0, 255, 0), 2) 71 | 72 | cv2.imshow('tf-pose-estimation result', img) 73 | 74 | fps_time = time.time() 75 | if cv2.waitKey(1) == 27: 76 | break 77 | logger.debug('+finished+') 78 | 79 | # For gathering training data 80 | # title = 'img'+str(count)+'.jpeg' 81 | # path = 82 | # cv2.imwrite(os.path.join(path , title), image) 83 | # count += 1 84 | 85 | cv2.destroyAllWindows() 86 | 87 | # ============================================================================= 88 | # For running the script simply run the following in the cmd prompt/terminal : 89 | # python run_webcam.py --model=mobilenet_thin --resize=432x368 --camera=0 90 | # ============================================================================= 91 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/scripts/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/label_image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/scripts/__pycache__/label_image.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/broadcaster_ros.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | import os 4 | import sys 5 | import ast 6 | 7 | from threading import Lock 8 | import rospy 9 | import rospkg 10 | from cv_bridge import CvBridge, CvBridgeError 11 | from std_msgs.msg import String 12 | from sensor_msgs.msg import Image 13 | from tfpose_ros.msg import Persons, Person, BodyPartElm 14 | 15 | from tf_pose.estimator import TfPoseEstimator 16 | from tf_pose.networks import model_wh, get_graph_path 17 | 18 | 19 | def humans_to_msg(humans): 20 | persons = Persons() 21 | 22 | for human in humans: 23 | person = Person() 24 | 25 | for k in human.body_parts: 26 | body_part = human.body_parts[k] 27 | 28 | body_part_msg = BodyPartElm() 29 | body_part_msg.part_id = body_part.part_idx 30 | body_part_msg.x = body_part.x 31 | body_part_msg.y = body_part.y 32 | body_part_msg.confidence = body_part.score 33 | person.body_part.append(body_part_msg) 34 | persons.persons.append(person) 35 | 36 | return persons 37 | 38 | 39 | def callback_image(data): 40 | # et = time.time() 41 | try: 42 | cv_image = cv_bridge.imgmsg_to_cv2(data, "bgr8") 43 | except CvBridgeError as e: 44 | rospy.logerr('[tf-pose-estimation] Converting Image Error. ' + str(e)) 45 | return 46 | 47 | acquired = tf_lock.acquire(False) 48 | if not acquired: 49 | return 50 | 51 | try: 52 | humans = pose_estimator.inference(cv_image, resize_to_default=True, upsample_size=resize_out_ratio) 53 | finally: 54 | tf_lock.release() 55 | 56 | msg = humans_to_msg(humans) 57 | msg.image_w = data.width 58 | msg.image_h = data.height 59 | msg.header = data.header 60 | 61 | pub_pose.publish(msg) 62 | 63 | 64 | if __name__ == '__main__': 65 | rospy.loginfo('initialization+') 66 | rospy.init_node('TfPoseEstimatorROS', anonymous=True, log_level=rospy.INFO) 67 | 68 | # parameters 69 | image_topic = rospy.get_param('~camera', '') 70 | model = rospy.get_param('~model', 'cmu') 71 | 72 | resolution = rospy.get_param('~resolution', '432x368') 73 | resize_out_ratio = float(rospy.get_param('~resize_out_ratio', '4.0')) 74 | tf_lock = Lock() 75 | 76 | if not image_topic: 77 | rospy.logerr('Parameter \'camera\' is not provided.') 78 | sys.exit(-1) 79 | 80 | try: 81 | w, h = model_wh(resolution) 82 | graph_path = get_graph_path(model) 83 | 84 | rospack = rospkg.RosPack() 85 | graph_path = os.path.join(rospack.get_path('tfpose_ros'), graph_path) 86 | except Exception as e: 87 | rospy.logerr('invalid model: %s, e=%s' % (model, e)) 88 | sys.exit(-1) 89 | 90 | pose_estimator = TfPoseEstimator(graph_path, target_size=(w, h)) 91 | cv_bridge = CvBridge() 92 | 93 | rospy.Subscriber(image_topic, Image, callback_image, queue_size=1, buff_size=2**24) 94 | pub_pose = rospy.Publisher('~pose', Persons, queue_size=1) 95 | 96 | rospy.loginfo('start+') 97 | rospy.spin() 98 | rospy.loginfo('finished') 99 | -------------------------------------------------------------------------------- /scripts/count_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import sys 23 | import tensorflow as tf 24 | 25 | def load_graph(file_name): 26 | with open(file_name,'rb') as f: 27 | content = f.read() 28 | graph_def = tf.GraphDef() 29 | graph_def.ParseFromString(content) 30 | with tf.Graph().as_default() as graph: 31 | tf.import_graph_def(graph_def, name='') 32 | return graph 33 | 34 | def count_ops(file_name, op_name = None): 35 | graph = load_graph(file_name) 36 | 37 | if op_name is None: 38 | return len(graph.get_operations()) 39 | else: 40 | return sum(1 for op in graph.get_operations() 41 | if op.name == op_name) 42 | 43 | if __name__ == "__main__": 44 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 45 | print(count_ops(*sys.argv[1:])) 46 | 47 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import sys 23 | import argparse 24 | 25 | import numpy as np 26 | import PIL.Image as Image 27 | import tensorflow as tf 28 | 29 | import scripts.retrain as retrain 30 | from scripts.count_ops import load_graph 31 | 32 | def evaluate_graph(graph_file_name): 33 | with load_graph(graph_file_name).as_default() as graph: 34 | ground_truth_input = tf.placeholder( 35 | tf.float32, [None, 5], name='GroundTruthInput') 36 | 37 | image_buffer_input = graph.get_tensor_by_name('input:0') 38 | final_tensor = graph.get_tensor_by_name('final_result:0') 39 | accuracy, _ = retrain.add_evaluation_step(final_tensor, ground_truth_input) 40 | 41 | logits = graph.get_tensor_by_name("final_training_ops/Wx_plus_b/add:0") 42 | xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 43 | labels = ground_truth_input, 44 | logits = logits)) 45 | 46 | image_dir = 'tf_files/flower_photos' 47 | testing_percentage = 10 48 | validation_percentage = 10 49 | validation_batch_size = 100 50 | category='testing' 51 | 52 | image_lists = retrain.create_image_lists( 53 | image_dir, testing_percentage, 54 | validation_percentage) 55 | class_count = len(image_lists.keys()) 56 | 57 | ground_truths = [] 58 | filenames = [] 59 | 60 | for label_index, label_name in enumerate(image_lists.keys()): 61 | for image_index, image_name in enumerate(image_lists[label_name][category]): 62 | image_name = retrain.get_image_path( 63 | image_lists, label_name, image_index, image_dir, category) 64 | ground_truth = np.zeros([1, class_count], dtype=np.float32) 65 | ground_truth[0, label_index] = 1.0 66 | ground_truths.append(ground_truth) 67 | filenames.append(image_name) 68 | 69 | accuracies = [] 70 | xents = [] 71 | with tf.Session(graph=graph) as sess: 72 | for filename, ground_truth in zip(filenames, ground_truths): 73 | image = Image.open(filename).resize((224,224),Image.ANTIALIAS) 74 | image = np.array(image, dtype=np.float32)[None,...] 75 | image = (image-128)/128.0 76 | 77 | feed_dict={ 78 | image_buffer_input: image, 79 | ground_truth_input: ground_truth} 80 | 81 | eval_accuracy, eval_xent = sess.run([accuracy, xent], feed_dict) 82 | 83 | accuracies.append(eval_accuracy) 84 | xents.append(eval_xent) 85 | 86 | 87 | return np.mean(accuracies), np.mean(xents) 88 | 89 | if __name__ == "__main__": 90 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 91 | accuracy,xent = evaluate_graph(*sys.argv[1:]) 92 | print('Accuracy: %g' % accuracy) 93 | print('Cross Entropy: %g' % xent) 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /scripts/graph_pb2tb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import sys 19 | 20 | import tensorflow as tf 21 | 22 | def load_graph(graph_pb_path): 23 | with open(graph_pb_path,'rb') as f: 24 | content = f.read() 25 | graph_def = tf.GraphDef() 26 | graph_def.ParseFromString(content) 27 | with tf.Graph().as_default() as graph: 28 | tf.import_graph_def(graph_def, name='') 29 | return graph 30 | 31 | 32 | def graph_to_tensorboard(graph, out_dir): 33 | with tf.Session(): 34 | train_writer = tf.summary.FileWriter(out_dir) 35 | train_writer.add_graph(graph) 36 | 37 | 38 | def main(out_dir, graph_pb_path): 39 | graph = load_graph(graph_pb_path) 40 | graph_to_tensorboard(graph, out_dir) 41 | 42 | if __name__ == "__main__": 43 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 44 | main(*sys.argv[1:]) 45 | -------------------------------------------------------------------------------- /scripts/label_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import time 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | def load_graph(model_file): 27 | graph = tf.Graph() 28 | graph_def = tf.GraphDef() 29 | 30 | with open(model_file, "rb") as f: 31 | graph_def.ParseFromString(f.read()) 32 | with graph.as_default(): 33 | tf.import_graph_def(graph_def) 34 | 35 | return graph 36 | 37 | def read_tensor_from_image_file(image_file, input_height=299, input_width=299, 38 | input_mean=0, input_std=255): 39 | 40 | float_caster = tf.cast(image_file, tf.float32) 41 | dims_expander = tf.expand_dims(float_caster, 0); 42 | resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) 43 | normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) 44 | sess = tf.Session() 45 | result = sess.run(normalized) 46 | 47 | return result 48 | 49 | def load_labels(label_file): 50 | label = [] 51 | proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() 52 | for l in proto_as_ascii_lines: 53 | label.append(l.rstrip()) 54 | return label 55 | 56 | def classify(image_file): 57 | 58 | # ============================================================================= 59 | # Note : Provide your own absolute file path for the following 60 | # You can choose the retrained graph of either as v1.0 or v2.0 61 | # Both models are retrained inception models (on my procured dataset) 62 | # v1.0 was trained for 500 epocs on a preliminary dataset of poses. 63 | # v2.0 was trained for 4000 epocs on a dataset containing the previous dataset 64 | # and more. 65 | # ============================================================================= 66 | # Change the path to your convenience 67 | file_path = os.path.abspath(os.path.dirname(__file__)) 68 | path = os.path.join(file_path, '../models/graph/retrained/retrained_v1.0/') 69 | model_file = path+'retrained_graph.pb' 70 | label_file = path+'retrained_labels.txt' 71 | input_height = 224 72 | input_width = 224 73 | input_mean = 128 74 | input_std = 128 75 | input_layer = "input" 76 | output_layer = "final_result" 77 | 78 | graph = load_graph(model_file) 79 | t = read_tensor_from_image_file(image_file, 80 | input_height=input_height, 81 | input_width=input_width, 82 | input_mean=input_mean, 83 | input_std=input_std) 84 | 85 | input_name = "import/" + input_layer 86 | output_name = "import/" + output_layer 87 | input_operation = graph.get_operation_by_name(input_name); 88 | output_operation = graph.get_operation_by_name(output_name); 89 | 90 | with tf.Session(graph=graph) as sess: 91 | start = time.time() 92 | results = sess.run(output_operation.outputs[0], 93 | {input_operation.outputs[0]: t}) 94 | end=time.time() 95 | results = np.squeeze(results) 96 | 97 | labels = load_labels(label_file) 98 | 99 | print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start)) 100 | template = "{} (score={:0.5f})" 101 | label = '' 102 | if results[0] > results[1]: 103 | label = labels[0] 104 | result = results[0] 105 | else: 106 | label = labels[1] 107 | result = results[1] 108 | 109 | return template.format(label, result) 110 | -------------------------------------------------------------------------------- /scripts/show_image.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import os 20 | 21 | from IPython.display import Image, HTML, display 22 | 23 | 24 | root = "tf_files/flower_photos/" 25 | with open(root+"/LICENSE.txt") as f: 26 | attributions = f.readlines()[4:] 27 | attributions = [line.split(' CC-BY') for line in attributions] 28 | attributions = dict(attributions) 29 | 30 | def show_image(image_path): 31 | display(Image(image_path)) 32 | 33 | image_rel = image_path.replace(root,'') 34 | caption = "Image " + ' - '.join(attributions[image_rel].split(' - ')[:-1]) 35 | display(HTML("
%s
" % caption)) 36 | -------------------------------------------------------------------------------- /scripts/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | import cv2 4 | import rospy 5 | from sensor_msgs.msg import Image 6 | from cv_bridge import CvBridge, CvBridgeError 7 | 8 | from tfpose_ros.msg import Persons, Person, BodyPartElm 9 | from tf_pose.estimator import Human, BodyPart, TfPoseEstimator 10 | 11 | 12 | class VideoFrames: 13 | """ 14 | Reference : ros-video-recorder 15 | https://github.com/ildoonet/ros-video-recorder/blob/master/scripts/recorder.py 16 | """ 17 | def __init__(self, image_topic): 18 | self.image_sub = rospy.Subscriber(image_topic, Image, self.callback_image, queue_size=1) 19 | self.bridge = CvBridge() 20 | self.frames = [] 21 | 22 | def callback_image(self, data): 23 | try: 24 | cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") 25 | except CvBridgeError as e: 26 | rospy.logerr('Converting Image Error. ' + str(e)) 27 | return 28 | 29 | self.frames.append((data.header.stamp, cv_image)) 30 | 31 | def get_latest(self, at_time, remove_older=True): 32 | fs = [x for x in self.frames if x[0] <= at_time] 33 | if len(fs) == 0: 34 | return None 35 | 36 | f = fs[-1] 37 | if remove_older: 38 | self.frames = self.frames[len(fs) - 1:] 39 | 40 | return f[1] 41 | 42 | 43 | def cb_pose(data): 44 | # get image with pose time 45 | t = data.header.stamp 46 | image = vf.get_latest(t, remove_older=True) 47 | if image is None: 48 | rospy.logwarn('No received images.') 49 | return 50 | 51 | h, w = image.shape[:2] 52 | if resize_ratio > 0: 53 | image = cv2.resize(image, (int(resize_ratio*w), int(resize_ratio*h)), interpolation=cv2.INTER_LINEAR) 54 | 55 | # ros topic to Person instance 56 | humans = [] 57 | for p_idx, person in enumerate(data.persons): 58 | human = Human([]) 59 | for body_part in person.body_part: 60 | part = BodyPart('', body_part.part_id, body_part.x, body_part.y, body_part.confidence) 61 | human.body_parts[body_part.part_id] = part 62 | 63 | humans.append(human) 64 | 65 | # draw 66 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 67 | pub_img.publish(cv_bridge.cv2_to_imgmsg(image, 'bgr8')) 68 | 69 | 70 | if __name__ == '__main__': 71 | rospy.loginfo('initialization+') 72 | rospy.init_node('TfPoseEstimatorROS-Visualization', anonymous=True) 73 | 74 | # topics params 75 | image_topic = rospy.get_param('~camera', '') 76 | pose_topic = rospy.get_param('~pose', '/pose_estimator/pose') 77 | 78 | resize_ratio = float(rospy.get_param('~resize_ratio', '-1')) 79 | 80 | # publishers 81 | pub_img = rospy.Publisher('~output', Image, queue_size=1) 82 | 83 | # initialization 84 | cv_bridge = CvBridge() 85 | vf = VideoFrames(image_topic) 86 | rospy.wait_for_message(image_topic, Image, timeout=30) 87 | 88 | # subscribers 89 | rospy.Subscriber(pose_topic, Persons, cb_pose, queue_size=1) 90 | 91 | # run 92 | rospy.spin() 93 | -------------------------------------------------------------------------------- /tf_pose/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tf_pose.runner import infer, Estimator, get_estimator 6 | -------------------------------------------------------------------------------- /tf_pose/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/estimator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/estimator.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/network_base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/network_base.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/network_cmu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/network_cmu.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/network_mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/network_mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/network_mobilenet_thin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/network_mobilenet_thin.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/network_personlab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/network_personlab.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/__pycache__/runner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/__pycache__/runner.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/common.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import tensorflow as tf 4 | import cv2 5 | 6 | 7 | regularizer_conv = 0.004 8 | regularizer_dsconv = 0.0004 9 | batchnorm_fused = True 10 | activation_fn = tf.nn.relu 11 | 12 | 13 | class CocoPart(Enum): 14 | Nose = 0 15 | Neck = 1 16 | RShoulder = 2 17 | RElbow = 3 18 | RWrist = 4 19 | LShoulder = 5 20 | LElbow = 6 21 | LWrist = 7 22 | RHip = 8 23 | RKnee = 9 24 | RAnkle = 10 25 | LHip = 11 26 | LKnee = 12 27 | LAnkle = 13 28 | REye = 14 29 | LEye = 15 30 | REar = 16 31 | LEar = 17 32 | Background = 18 33 | 34 | 35 | class MPIIPart(Enum): 36 | RAnkle = 0 37 | RKnee = 1 38 | RHip = 2 39 | LHip = 3 40 | LKnee = 4 41 | LAnkle = 5 42 | RWrist = 6 43 | RElbow = 7 44 | RShoulder = 8 45 | LShoulder = 9 46 | LElbow = 10 47 | LWrist = 11 48 | Neck = 12 49 | Head = 13 50 | 51 | @staticmethod 52 | def from_coco(human): 53 | # t = { 54 | # MPIIPart.RAnkle: CocoPart.RAnkle, 55 | # MPIIPart.RKnee: CocoPart.RKnee, 56 | # MPIIPart.RHip: CocoPart.RHip, 57 | # MPIIPart.LHip: CocoPart.LHip, 58 | # MPIIPart.LKnee: CocoPart.LKnee, 59 | # MPIIPart.LAnkle: CocoPart.LAnkle, 60 | # MPIIPart.RWrist: CocoPart.RWrist, 61 | # MPIIPart.RElbow: CocoPart.RElbow, 62 | # MPIIPart.RShoulder: CocoPart.RShoulder, 63 | # MPIIPart.LShoulder: CocoPart.LShoulder, 64 | # MPIIPart.LElbow: CocoPart.LElbow, 65 | # MPIIPart.LWrist: CocoPart.LWrist, 66 | # MPIIPart.Neck: CocoPart.Neck, 67 | # MPIIPart.Nose: CocoPart.Nose, 68 | # } 69 | 70 | t = [ 71 | (MPIIPart.Head, CocoPart.Nose), 72 | (MPIIPart.Neck, CocoPart.Neck), 73 | (MPIIPart.RShoulder, CocoPart.RShoulder), 74 | (MPIIPart.RElbow, CocoPart.RElbow), 75 | (MPIIPart.RWrist, CocoPart.RWrist), 76 | (MPIIPart.LShoulder, CocoPart.LShoulder), 77 | (MPIIPart.LElbow, CocoPart.LElbow), 78 | (MPIIPart.LWrist, CocoPart.LWrist), 79 | (MPIIPart.RHip, CocoPart.RHip), 80 | (MPIIPart.RKnee, CocoPart.RKnee), 81 | (MPIIPart.RAnkle, CocoPart.RAnkle), 82 | (MPIIPart.LHip, CocoPart.LHip), 83 | (MPIIPart.LKnee, CocoPart.LKnee), 84 | (MPIIPart.LAnkle, CocoPart.LAnkle), 85 | ] 86 | 87 | pose_2d_mpii = [] 88 | visibilty = [] 89 | for mpi, coco in t: 90 | if coco.value not in human.body_parts.keys(): 91 | pose_2d_mpii.append((0, 0)) 92 | visibilty.append(False) 93 | continue 94 | pose_2d_mpii.append((human.body_parts[coco.value].x, human.body_parts[coco.value].y)) 95 | visibilty.append(True) 96 | return pose_2d_mpii, visibilty 97 | 98 | CocoPairs = [ 99 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11), 100 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17) 101 | ] # = 19 102 | CocoPairsRender = CocoPairs[:-2] 103 | # CocoPairsNetwork = [ 104 | # (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5), 105 | # (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27) 106 | # ] # = 19 107 | 108 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 109 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 110 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 111 | 112 | 113 | def read_imgfile(path, width=None, height=None): 114 | val_image = cv2.imread(path, cv2.IMREAD_COLOR) 115 | if width is not None and height is not None: 116 | val_image = cv2.resize(val_image, (width, height)) 117 | return val_image 118 | 119 | 120 | def get_sample_images(w, h): 121 | val_image = [ 122 | read_imgfile('./images/p1.jpg', w, h), 123 | read_imgfile('./images/p2.jpg', w, h), 124 | read_imgfile('./images/p3.jpg', w, h), 125 | read_imgfile('./images/golf.jpg', w, h), 126 | read_imgfile('./images/hand1.jpg', w, h), 127 | read_imgfile('./images/hand2.jpg', w, h), 128 | read_imgfile('./images/apink1_crop.jpg', w, h), 129 | read_imgfile('./images/ski.jpg', w, h), 130 | read_imgfile('./images/apink2.jpg', w, h), 131 | read_imgfile('./images/apink3.jpg', w, h), 132 | read_imgfile('./images/handsup1.jpg', w, h), 133 | read_imgfile('./images/p3_dance.png', w, h), 134 | ] 135 | return val_image 136 | -------------------------------------------------------------------------------- /tf_pose/datum_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: datum.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='datum.proto', 20 | package='', 21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _DATUM = _descriptor.Descriptor( 29 | name='Datum', 30 | full_name='Datum', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='channels', full_name='Datum.channels', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='Datum.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='width', full_name='Datum.width', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='data', full_name='Datum.data', index=3, 58 | number=4, type=12, cpp_type=9, label=1, 59 | has_default_value=False, default_value=_b(""), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='label', full_name='Datum.label', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='float_data', full_name='Datum.float_data', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='encoded', full_name='Datum.encoded', index=6, 79 | number=7, type=8, cpp_type=7, label=1, 80 | has_default_value=True, default_value=False, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | options=None, 91 | is_extendable=False, 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=16, 96 | serialized_end=145, 97 | ) 98 | 99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM 100 | 101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict( 102 | DESCRIPTOR = _DATUM, 103 | __module__ = 'datum_pb2' 104 | # @@protoc_insertion_point(class_scope:Datum) 105 | )) 106 | _sym_db.RegisterMessage(Datum) 107 | 108 | 109 | # @@protoc_insertion_point(module_scope) 110 | -------------------------------------------------------------------------------- /tf_pose/network_mobilenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | 5 | from tf_pose import network_base 6 | 7 | 8 | class MobilenetNetwork(network_base.BaseNetwork): 9 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 10 | self.conv_width = conv_width 11 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 12 | self.num_refine = 4 13 | network_base.BaseNetwork.__init__(self, inputs, trainable) 14 | 15 | def setup(self): 16 | min_depth = 8 17 | depth = lambda d: max(int(d * self.conv_width), min_depth) 18 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 19 | 20 | with tf.variable_scope(None, 'MobilenetV1'): 21 | (self.feed('image') 22 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 23 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 24 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 25 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 26 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 27 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 29 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 30 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 31 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 32 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 33 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 34 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 35 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 36 | ) 37 | 38 | (self.feed('Conv2d_1').max_pool(2, 2, 2, 2, name='Conv2d_1_pool')) 39 | (self.feed('Conv2d_7').upsample(2, name='Conv2d_7_upsample')) 40 | 41 | (self.feed('Conv2d_1_pool', 'Conv2d_3', 'Conv2d_7_upsample') 42 | .concat(3, name='feat_concat')) 43 | 44 | feature_lv = 'feat_concat' 45 | with tf.variable_scope(None, 'Openpose'): 46 | prefix = 'MConv_Stage1' 47 | (self.feed(feature_lv) 48 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 49 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 50 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 51 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 52 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 53 | 54 | (self.feed(feature_lv) 55 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 56 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 57 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 58 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 59 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 60 | 61 | for stage_id in range(self.num_refine): 62 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 63 | prefix = 'MConv_Stage%d' % (stage_id + 2) 64 | (self.feed(prefix_prev + '_L1_5', 65 | prefix_prev + '_L2_5', 66 | feature_lv) 67 | .concat(3, name=prefix + '_concat') 68 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_1') 69 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_2') 70 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_3') 71 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 72 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 73 | 74 | (self.feed(prefix + '_concat') 75 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_1') 76 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_2') 77 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_3') 78 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 79 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 80 | 81 | # final result 82 | (self.feed('MConv_Stage%d_L2_5' % self.get_refine_num(), 83 | 'MConv_Stage%d_L1_5' % self.get_refine_num()) 84 | .concat(3, name='concat_stage7')) 85 | 86 | def loss_l1_l2(self): 87 | l1s = [] 88 | l2s = [] 89 | for layer_name in sorted(self.layers.keys()): 90 | if '_L1_5' in layer_name: 91 | l1s.append(self.layers[layer_name]) 92 | if '_L2_5' in layer_name: 93 | l2s.append(self.layers[layer_name]) 94 | 95 | return l1s, l2s 96 | 97 | def loss_last(self): 98 | return self.get_output('MConv_Stage%d_L1_5' % self.get_refine_num()), \ 99 | self.get_output('MConv_Stage%d_L2_5' % self.get_refine_num()) 100 | 101 | def restorable_variables(self): 102 | vs = {v.op.name: v for v in tf.global_variables() if 103 | 'MobilenetV1/Conv2d' in v.op.name and 104 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 'Ada' not in v.op.name 105 | } 106 | return vs 107 | 108 | def get_refine_num(self): 109 | return self.num_refine + 1 110 | -------------------------------------------------------------------------------- /tf_pose/network_mobilenet_thin.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | 5 | from tf_pose import network_base 6 | 7 | 8 | class MobilenetNetworkThin(network_base.BaseNetwork): 9 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 10 | self.conv_width = conv_width 11 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 12 | network_base.BaseNetwork.__init__(self, inputs, trainable) 13 | 14 | def setup(self): 15 | min_depth = 8 16 | depth = lambda d: max(int(d * self.conv_width), min_depth) 17 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 18 | 19 | with tf.variable_scope(None, 'MobilenetV1'): 20 | (self.feed('image') 21 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 22 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 23 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 24 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 25 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 26 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 29 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 30 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 31 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 32 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 33 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 34 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 35 | ) 36 | 37 | (self.feed('Conv2d_3').max_pool(2, 2, 2, 2, name='Conv2d_3_pool')) 38 | 39 | (self.feed('Conv2d_3_pool', 'Conv2d_7', 'Conv2d_11') 40 | .concat(3, name='feat_concat')) 41 | 42 | feature_lv = 'feat_concat' 43 | with tf.variable_scope(None, 'Openpose'): 44 | prefix = 'MConv_Stage1' 45 | (self.feed(feature_lv) 46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 47 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 48 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 49 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 50 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 51 | 52 | (self.feed(feature_lv) 53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 54 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 55 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 56 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 57 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 58 | 59 | for stage_id in range(5): 60 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 61 | prefix = 'MConv_Stage%d' % (stage_id + 2) 62 | (self.feed(prefix_prev + '_L1_5', 63 | prefix_prev + '_L2_5', 64 | feature_lv) 65 | .concat(3, name=prefix + '_concat') 66 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 67 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 68 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 69 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 70 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 71 | 72 | (self.feed(prefix + '_concat') 73 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 74 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 75 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 76 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 77 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 78 | 79 | # final result 80 | (self.feed('MConv_Stage6_L2_5', 81 | 'MConv_Stage6_L1_5') 82 | .concat(3, name='concat_stage7')) 83 | 84 | def loss_l1_l2(self): 85 | l1s = [] 86 | l2s = [] 87 | for layer_name in sorted(self.layers.keys()): 88 | if '_L1_5' in layer_name: 89 | l1s.append(self.layers[layer_name]) 90 | if '_L2_5' in layer_name: 91 | l2s.append(self.layers[layer_name]) 92 | 93 | return l1s, l2s 94 | 95 | def loss_last(self): 96 | return self.get_output('MConv_Stage6_L1_5'), self.get_output('MConv_Stage6_L2_5') 97 | 98 | def restorable_variables(self): 99 | vs = {v.op.name: v for v in tf.global_variables() if 100 | 'MobilenetV1/Conv2d' in v.op.name and 101 | # 'global_step' not in v.op.name and 102 | # 'beta1_power' not in v.op.name and 'beta2_power' not in v.op.name and 103 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 104 | 'Ada' not in v.op.name and 'Adam' not in v.op.name 105 | } 106 | return vs 107 | -------------------------------------------------------------------------------- /tf_pose/network_personlab.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from tf_pose import network_base 4 | from tf_pose.slim.nets.resnet_v2 import resnet_v2_101 5 | 6 | 7 | class PersonLabNetwork(network_base.BaseNetwork): 8 | """ 9 | Reference : PersonLab: Person Pose Estimation and Instance Segmentation with a Bottom-Up, Part-Based, Geometric Embedding Model 10 | 11 | pretrained architecture * weights from : 12 | https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models 13 | """ 14 | 15 | def __init__(self, inputs, trainable=True, backbone='resnet152'): 16 | """ 17 | :param inputs: 18 | :param backbone: resnet101, resnet152, mobilenet-v2-1.0 19 | """ 20 | self.backbone = backbone 21 | super().__init__(inputs, trainable) 22 | 23 | def setup(self): 24 | if self.backbone == 'resnet101': 25 | net, end_points = resnet_v2_101(self.inputs, is_training=self.trainable, global_pool=False, 26 | output_stride=16) 27 | pass 28 | pass 29 | 30 | -------------------------------------------------------------------------------- /tf_pose/networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, abspath 3 | 4 | import tensorflow as tf 5 | from tf_pose.network_mobilenet import MobilenetNetwork 6 | from tf_pose.network_mobilenet_thin import MobilenetNetworkThin 7 | 8 | from tf_pose.network_cmu import CmuNetwork 9 | from tf_pose.network_personlab import PersonLabNetwork 10 | 11 | 12 | def _get_base_path(): 13 | if not os.environ.get('OPENPOSE_MODEL', ''): 14 | return './models' 15 | return os.environ.get('OPENPOSE_MODEL') 16 | 17 | 18 | def get_network(type, placeholder_input, sess_for_load=None, trainable=True): 19 | if type == 'mobilenet': 20 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable) 21 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 22 | last_layer = 'MConv_Stage6_L{aux}_5' 23 | elif type == 'mobilenet_fast': 24 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.5, conv_width2=0.5, trainable=trainable) 25 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 26 | last_layer = 'MConv_Stage6_L{aux}_5' 27 | elif type == 'mobilenet_accurate': 28 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=1.00, conv_width2=1.00, trainable=trainable) 29 | pretrain_path = 'pretrained/mobilenet_v1_1.0_224_2017_06_14/mobilenet_v1_1.0_224.ckpt' 30 | last_layer = 'MConv_Stage6_L{aux}_5' 31 | 32 | elif type == 'mobilenet_thin': 33 | net = MobilenetNetworkThin({'image': placeholder_input}, conv_width=0.75, conv_width2=0.50, trainable=trainable) 34 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 35 | last_layer = 'MConv_Stage6_L{aux}_5' 36 | 37 | elif type == 'cmu': 38 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable) 39 | pretrain_path = 'numpy/openpose_coco.npy' 40 | last_layer = 'Mconv7_stage6_L{aux}' 41 | elif type == 'vgg': 42 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable) 43 | pretrain_path = 'numpy/openpose_vgg16.npy' 44 | last_layer = 'Mconv7_stage6_L{aux}' 45 | 46 | elif type == 'personlab_resnet101': 47 | net = PersonLabNetwork({'image': placeholder_input}, trainable=trainable) 48 | pretrain_path = 'pretrained/resnet_v2_101/resnet_v2_101.ckpt' 49 | last_layer = 'Mconv7_stage6_L{aux}' 50 | else: 51 | raise Exception('Invalid Mode.') 52 | 53 | pretrain_path_full = os.path.join(_get_base_path(), pretrain_path) 54 | if sess_for_load is not None: 55 | if type == 'cmu' or type == 'vgg': 56 | if not os.path.isfile(pretrain_path_full): 57 | raise Exception('Model file doesn\'t exist, path=%s' % pretrain_path_full) 58 | net.load(os.path.join(_get_base_path(), pretrain_path), sess_for_load) 59 | else: 60 | s = '%dx%d' % (placeholder_input.shape[2], placeholder_input.shape[1]) 61 | ckpts = { 62 | 'mobilenet': 'trained/mobilenet_%s/model-246038' % s, 63 | 'mobilenet_thin': 'trained/mobilenet_thin_%s/model-449003' % s, 64 | 'mobilenet_fast': 'trained/mobilenet_fast_%s/model-189000' % s, 65 | 'mobilenet_accurate': 'trained/mobilenet_accurate/model-170000' 66 | } 67 | ckpt_path = os.path.join(_get_base_path(), ckpts[type]) 68 | loader = tf.train.Saver() 69 | try: 70 | loader.restore(sess_for_load, ckpt_path) 71 | except Exception as e: 72 | raise Exception('Fail to load model files. \npath=%s\nerr=%s' % (ckpt_path, str(e))) 73 | 74 | return net, pretrain_path_full, last_layer 75 | 76 | 77 | def get_graph_path(model_name): 78 | dyn_graph_path = { 79 | 'cmu': 'graph/cmu/graph_opt.pb', 80 | 'mobilenet_thin': 'graph/mobilenet_thin/graph_opt.pb' 81 | } 82 | 83 | base_data_dir = dirname(dirname(abspath(__file__))) 84 | if os.path.exists(os.path.join(base_data_dir, 'models')): 85 | base_data_dir = os.path.join(base_data_dir, 'models') 86 | else: 87 | base_data_dir = os.path.join(base_data_dir, 'tf_pose_data') 88 | 89 | graph_path = os.path.join(base_data_dir, dyn_graph_path[model_name]) 90 | if os.path.isfile(graph_path): 91 | return graph_path 92 | 93 | raise Exception('Graph file doesn\'t exist, path=%s' % graph_path) 94 | 95 | 96 | def model_wh(resolution_str): 97 | width, height = map(int, resolution_str.split('x')) 98 | if width % 16 != 0 or height % 16 != 0: 99 | raise Exception('Width and height should be multiples of 16. w=%d, h=%d' % (width, height)) 100 | return int(width), int(height) 101 | -------------------------------------------------------------------------------- /tf_pose/pafprocess/README.md: -------------------------------------------------------------------------------- 1 | # post-processing for Part-Affinity Fields Map implemented in C++ & Swig 2 | 3 | Need to install swig. 4 | 5 | ```bash 6 | $ sudo apt install swig 7 | ``` 8 | 9 | You need to build pafprocess module which is written in c++. It will be used for post processing. 10 | 11 | ```bash 12 | $ swig -python -c++ pafprocess.i && python3 setup.py build_ext --inplace 13 | ``` 14 | 15 | -------------------------------------------------------------------------------- /tf_pose/pafprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/__init__.py -------------------------------------------------------------------------------- /tf_pose/pafprocess/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/pafprocess/__pycache__/pafprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/__pycache__/pafprocess.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/pafprocess/_pafprocess.cp36-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/_pafprocess.cp36-win_amd64.pyd -------------------------------------------------------------------------------- /tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/_pafprocess.cp36-win_amd64.exp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/_pafprocess.cp36-win_amd64.exp -------------------------------------------------------------------------------- /tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/_pafprocess.cp36-win_amd64.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/_pafprocess.cp36-win_amd64.lib -------------------------------------------------------------------------------- /tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/pafprocess.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/pafprocess.obj -------------------------------------------------------------------------------- /tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/pafprocess_wrap.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/pafprocess/build/temp.win-amd64-3.6/Release/pafprocess_wrap.obj -------------------------------------------------------------------------------- /tf_pose/pafprocess/pafprocess.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifndef PAFPROCESS 4 | #define PAFPROCESS 5 | 6 | const float THRESH_HEAT = 0.05; 7 | const float THRESH_VECTOR_SCORE = 0.05; 8 | const int THRESH_VECTOR_CNT1 = 8; 9 | const int THRESH_PART_CNT = 4; 10 | const float THRESH_HUMAN_SCORE = 0.4; 11 | const int NUM_PART = 18; 12 | 13 | const int STEP_PAF = 10; 14 | 15 | const int COCOPAIRS_SIZE = 19; 16 | const int COCOPAIRS_NET[COCOPAIRS_SIZE][2] = { 17 | {12, 13}, {20, 21}, {14, 15}, {16, 17}, {22, 23}, {24, 25}, {0, 1}, {2, 3}, {4, 5}, 18 | {6, 7}, {8, 9}, {10, 11}, {28, 29}, {30, 31}, {34, 35}, {32, 33}, {36, 37}, {18, 19}, {26, 27} 19 | }; 20 | 21 | const int COCOPAIRS[COCOPAIRS_SIZE][2] = { 22 | {1, 2}, {1, 5}, {2, 3}, {3, 4}, {5, 6}, {6, 7}, {1, 8}, {8, 9}, {9, 10}, {1, 11}, 23 | {11, 12}, {12, 13}, {1, 0}, {0, 14}, {14, 16}, {0, 15}, {15, 17}, {2, 16}, {5, 17} 24 | }; 25 | 26 | struct Peak { 27 | int x; 28 | int y; 29 | float score; 30 | int id; 31 | }; 32 | 33 | struct VectorXY { 34 | float x; 35 | float y; 36 | }; 37 | 38 | struct ConnectionCandidate { 39 | int idx1; 40 | int idx2; 41 | float score; 42 | float etc; 43 | }; 44 | 45 | struct Connection { 46 | int cid1; 47 | int cid2; 48 | float score; 49 | int peak_id1; 50 | int peak_id2; 51 | }; 52 | 53 | int process_paf(int p1, int p2, int p3, float *peaks, int h1, int h2, int h3, float *heatmap, int f1, int f2, int f3, float *pafmap); 54 | int get_num_humans(); 55 | int get_part_cid(int human_id, int part_id); 56 | float get_score(int human_id); 57 | int get_part_x(int cid); 58 | int get_part_y(int cid); 59 | float get_part_score(int cid); 60 | 61 | #endif 62 | -------------------------------------------------------------------------------- /tf_pose/pafprocess/pafprocess.i: -------------------------------------------------------------------------------- 1 | %module pafprocess 2 | %{ 3 | #define SWIG_FILE_WITH_INIT 4 | #include "pafprocess.h" 5 | %} 6 | 7 | %include "numpy.i" 8 | %init %{ 9 | import_array(); 10 | %} 11 | 12 | //%apply (int DIM1, int DIM2, int* IN_ARRAY2) {(int p1, int p2, int *peak_idxs)} 13 | //%apply (int DIM1, int DIM2, int DIM3, float* IN_ARRAY3) {(int h1, int h2, int h3, float *heatmap), (int f1, int f2, int f3, float *pafmap)}; 14 | %apply (int DIM1, int DIM2, int DIM3, float* IN_ARRAY3) {(int p1, int p2, int p3, float *peaks), (int h1, int h2, int h3, float *heatmap), (int f1, int f2, int f3, float *pafmap)}; 15 | %include "pafprocess.h" 16 | -------------------------------------------------------------------------------- /tf_pose/pafprocess/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | import numpy 3 | import os 4 | 5 | # os.environ['CC'] = 'g++'; 6 | setup(name='pafprocess_ext', version='1.0', 7 | ext_modules=[ 8 | Extension('_pafprocess', ['pafprocess.cpp', 'pafprocess.i'], 9 | swig_opts=['-c++'], 10 | depends=["pafprocess.h"], 11 | include_dirs=[numpy.get_include(), '.']) 12 | ], 13 | py_modules=[ 14 | "pafprocess" 15 | ] 16 | ) 17 | -------------------------------------------------------------------------------- /tf_pose/pose_datamaster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | from tensorpack.dataflow.remote import RemoteDataZMQ 6 | 7 | from tf_pose.pose_dataset import CocoPose 8 | 9 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 10 | 11 | if __name__ == '__main__': 12 | """ 13 | Speed Test for Getting Input batches from other nodes 14 | """ 15 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 16 | parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027') 17 | parser.add_argument('--show', type=bool, default=False) 18 | args = parser.parse_args() 19 | 20 | df = RemoteDataZMQ(args.listen) 21 | 22 | logging.info('tcp queue start') 23 | df.reset_state() 24 | t = time.time() 25 | for i, dp in enumerate(df.get_data()): 26 | if i == 100: 27 | break 28 | logging.info('Input batch %d received.' % i) 29 | if i == 0: 30 | for d in dp: 31 | logging.info('%d dp shape={}'.format(d.shape)) 32 | 33 | if args.show: 34 | CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0]) 35 | 36 | logging.info('Speed Test Done for 100 Batches in %f seconds.' % (time.time() - t)) 37 | -------------------------------------------------------------------------------- /tf_pose/pose_dataworker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tensorpack.dataflow.remote import send_dataflow_zmq 4 | 5 | from tf_pose.pose_dataset import get_dataflow_batch 6 | from tf_pose.pose_augment import set_network_input_wh, set_network_scale 7 | 8 | if __name__ == '__main__': 9 | """ 10 | OpenPose Data Preparation might be a bottleneck for training. 11 | You can run multiple workers to generate input batches in multi-nodes to make training process faster. 12 | """ 13 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 14 | parser.add_argument('--datapath', type=str, default='/coco/annotations/') 15 | parser.add_argument('--imgpath', type=str, default='/coco/') 16 | parser.add_argument('--batchsize', type=int, default=64) 17 | parser.add_argument('--train', type=bool, default=True) 18 | parser.add_argument('--master', type=str, default='tcp://csi-cluster-gpu20.dakao.io:1027') 19 | parser.add_argument('--input-width', type=int, default=368) 20 | parser.add_argument('--input-height', type=int, default=368) 21 | parser.add_argument('--scale-factor', type=int, default=2) 22 | args = parser.parse_args() 23 | 24 | set_network_input_wh(args.input_width, args.input_height) 25 | set_network_scale(args.scale_factor) 26 | 27 | df = get_dataflow_batch(args.datapath, args.train, args.batchsize, args.imgpath) 28 | 29 | send_dataflow_zmq(df, args.master, hwm=10) 30 | -------------------------------------------------------------------------------- /tf_pose/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /tf_pose/pycocotools/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | try: 4 | import pycocotools._mask as _mask 5 | except Exception as e: 6 | print('mask.py err=', e) 7 | 8 | class Dummy: 9 | def __init__(self): 10 | self.iou = None 11 | self.merge = None 12 | self.frPyObjects = None 13 | _mask = Dummy() 14 | 15 | # Interface for manipulating masks stored in RLE format. 16 | # 17 | # RLE is a simple yet efficient format for storing binary masks. RLE 18 | # first divides a vector (or vectorized image) into a series of piecewise 19 | # constant regions and then for each piece simply stores the length of 20 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 21 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 22 | # (note that the odd counts are always the numbers of zeros). Instead of 23 | # storing the counts directly, additional compression is achieved with a 24 | # variable bitrate representation based on a common scheme called LEB128. 25 | # 26 | # Compression is greatest given large piecewise constant regions. 27 | # Specifically, the size of the RLE is proportional to the number of 28 | # *boundaries* in M (or for an image the number of boundaries in the y 29 | # direction). Assuming fairly simple shapes, the RLE representation is 30 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 31 | # is substantially lower, especially for large simple objects (large n). 32 | # 33 | # Many common operations on masks can be computed directly using the RLE 34 | # (without need for decoding). This includes computations such as area, 35 | # union, intersection, etc. All of these operations are linear in the 36 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 37 | # of the object. Computing these operations on the original mask is O(n). 38 | # Thus, using the RLE can result in substantial computational savings. 39 | # 40 | # The following API functions are defined: 41 | # encode - Encode binary masks using RLE. 42 | # decode - Decode binary masks encoded via RLE. 43 | # merge - Compute union or intersection of encoded masks. 44 | # iou - Compute intersection over union between masks. 45 | # area - Compute area of encoded masks. 46 | # toBbox - Get bounding boxes surrounding encoded masks. 47 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 48 | # 49 | # Usage: 50 | # Rs = encode( masks ) 51 | # masks = decode( Rs ) 52 | # R = merge( Rs, intersect=false ) 53 | # o = iou( dt, gt, iscrowd ) 54 | # a = area( Rs ) 55 | # bbs = toBbox( Rs ) 56 | # Rs = frPyObjects( [pyObjects], h, w ) 57 | # 58 | # In the API the following formats are used: 59 | # Rs - [dict] Run-length encoding of binary masks 60 | # R - dict Run-length encoding of binary mask 61 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 62 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 63 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 64 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 65 | # dt,gt - May be either bounding boxes or encoded masks 66 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 67 | # 68 | # Finally, a note about the intersection over union (iou) computation. 69 | # The standard iou of a ground truth (gt) and detected (dt) object is 70 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 71 | # For "crowd" regions, we use a modified criteria. If a gt object is 72 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 73 | # Choosing gt' in the crowd gt that best matches the dt can be done using 74 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 75 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 76 | # For crowd gt regions we use this modified criteria above for the iou. 77 | # 78 | # To compile run "python setup.py build_ext --inplace" 79 | # Please do not contact us for help with compiling. 80 | # 81 | # Microsoft COCO Toolbox. version 2.0 82 | # Data, paper, and tutorials available at: http://mscoco.org/ 83 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 84 | # Licensed under the Simplified BSD License [see coco/license.txt] 85 | 86 | iou = _mask.iou 87 | merge = _mask.merge 88 | frPyObjects = _mask.frPyObjects 89 | 90 | def encode(bimask): 91 | if len(bimask.shape) == 3: 92 | return _mask.encode(bimask) 93 | elif len(bimask.shape) == 2: 94 | h, w = bimask.shape 95 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 96 | 97 | def decode(rleObjs): 98 | if type(rleObjs) == list: 99 | return _mask.decode(rleObjs) 100 | else: 101 | return _mask.decode([rleObjs])[:,:,0] 102 | 103 | def area(rleObjs): 104 | if type(rleObjs) == list: 105 | return _mask.area(rleObjs) 106 | else: 107 | return _mask.area([rleObjs])[0] 108 | 109 | def toBbox(rleObjs): 110 | if type(rleObjs) == list: 111 | return _mask.toBbox(rleObjs) 112 | else: 113 | return _mask.toBbox([rleObjs])[0] -------------------------------------------------------------------------------- /tf_pose/pystopwatch.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | 4 | 5 | class StopWatchManager: 6 | def __init__(self): 7 | self.watches = defaultdict(StopWatch) 8 | 9 | def get(self, name): 10 | return self.watches[name] 11 | 12 | def start(self, name): 13 | self.get(name).start() 14 | 15 | def stop(self, name): 16 | self.get(name).stop() 17 | 18 | def reset(self, name): 19 | self.get(name).reset() 20 | 21 | def get_elapsed(self, name): 22 | return self.get(name).get_elapsed() 23 | 24 | def __repr__(self): 25 | return '\n'.join(['%s: %.8f' % (k, v.elapsed_accumulated) for k, v in self.watches.items()]) 26 | 27 | 28 | class StopWatch: 29 | def __init__(self): 30 | self.elapsed_accumulated = 0.0 31 | self.started_at = time.time() 32 | 33 | def start(self): 34 | self.started_at = time.time() 35 | 36 | def stop(self): 37 | self.elapsed_accumulated += time.time() - self.started_at 38 | 39 | def reset(self): 40 | self.elapsed_accumulated = 0.0 41 | 42 | def get_elapsed(self): 43 | return self.elapsed_accumulated -------------------------------------------------------------------------------- /tf_pose/runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import os 4 | import cv2 5 | import base64 6 | 7 | from tf_pose import common 8 | from tf_pose.estimator import TfPoseEstimator 9 | from tf_pose.networks import get_graph_path, model_wh 10 | 11 | Estimator = TfPoseEstimator 12 | 13 | 14 | def get_estimator(model='cmu', resize='0x0', resize_out_ratio=4.0): 15 | w, h = model_wh(resize) 16 | if w == 0 or h == 0: 17 | e = TfPoseEstimator(get_graph_path(model), target_size=(432, 368)) 18 | else: 19 | e = TfPoseEstimator(get_graph_path(model), target_size=(w, h)) 20 | 21 | return e 22 | 23 | 24 | def infer(image, model='cmu', resize='0x0', resize_out_ratio=4.0): 25 | """ 26 | 27 | :param image: 28 | :param model: 29 | :param resize: 30 | :param resize_out_ratio: 31 | :return: 32 | """ 33 | w, h = model_wh(resize) 34 | if w == 0 or h == 0: 35 | e = TfPoseEstimator(get_graph_path(model), target_size=(432, 368)) 36 | else: 37 | e = TfPoseEstimator(get_graph_path(model), target_size=(w, h)) 38 | 39 | # estimate human poses from a single image ! 40 | image = common.read_imgfile(image, None, None) 41 | if image is None: 42 | raise Exception('Image can not be read, path=%s' % image) 43 | humans = e.inference(image, resize_to_default=(w > 0 and h > 0), upsample_size=resize_out_ratio) 44 | 45 | if "TERM_PROGRAM" in os.environ and 'iTerm' in os.environ["TERM_PROGRAM"]: 46 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 47 | image_str = cv2.imencode(".jpg", image)[1].tostring() 48 | print("\033]1337;File=name=;inline=1:" + base64.b64encode(image_str).decode("utf-8") + "\a") 49 | 50 | return humans 51 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/ArrayUtils.py: -------------------------------------------------------------------------------- 1 | import math, mmap, tempfile 2 | import numpy as np 3 | import psutil 4 | 5 | def _requiredSize(shape, dtype): 6 | """ 7 | Determines the number of bytes required to store a NumPy array with 8 | the specified shape and datatype. 9 | """ 10 | return math.floor(np.prod(np.asarray(shape, dtype=np.uint64)) * np.dtype(dtype).itemsize) 11 | 12 | 13 | class TempfileBackedArray(np.ndarray): 14 | """ 15 | A NumPy ndarray that uses a memory-mapped temp file as its backing 16 | """ 17 | 18 | def __new__(subtype, shape, dtype=float, buffer=None, offset=0, strides=None, order=None, info=None): 19 | 20 | # Determine the size in bytes required to hold the array 21 | numBytes = _requiredSize(shape, dtype) 22 | 23 | # Create the temporary file, resize it, and map it into memory 24 | tempFile = tempfile.TemporaryFile() 25 | tempFile.truncate(numBytes) 26 | buf = mmap.mmap(tempFile.fileno(), numBytes, access=mmap.ACCESS_WRITE) 27 | 28 | # Create the ndarray with the memory map as the underlying buffer 29 | obj = super(TempfileBackedArray, subtype).__new__(subtype, shape, dtype, buf, 0, None, order) 30 | 31 | # Attach the file reference to the ndarray object 32 | obj._file = tempFile 33 | return obj 34 | 35 | def __array_finalize__(self, obj): 36 | if obj is None: return 37 | self._file = getattr(obj, '_file', None) 38 | 39 | 40 | def arrayFactory(shape, dtype=float): 41 | """ 42 | Creates a new ndarray of the specified shape and datatype, storing 43 | it in memory if there is sufficient available space or else using 44 | a memory-mapped temporary file to provide the underlying buffer. 45 | """ 46 | 47 | # Determine the number of bytes required to store the array 48 | requiredBytes = _requiredSize(shape, dtype) 49 | 50 | # Determine if there is sufficient available memory 51 | vmem = psutil.virtual_memory() 52 | if vmem.available > requiredBytes: 53 | return np.ndarray(shape=shape, dtype=dtype) 54 | else: 55 | return TempfileBackedArray(shape=shape, dtype=dtype) 56 | 57 | 58 | def zerosFactory(shape, dtype=float): 59 | """ 60 | Creates a new NumPy array using `arrayFactory()` and fills it with zeros. 61 | """ 62 | arr = arrayFactory(shape=shape, dtype=dtype) 63 | arr.fill(0) 64 | return arr 65 | 66 | 67 | def arrayCast(source, dtype): 68 | """ 69 | Casts a NumPy array to the specified datatype, storing the copy 70 | in memory if there is sufficient available space or else using a 71 | memory-mapped temporary file to provide the underlying buffer. 72 | """ 73 | 74 | # Determine the number of bytes required to store the array 75 | requiredBytes = _requiredSize(source.shape, dtype) 76 | 77 | # Determine if there is sufficient available memory 78 | vmem = psutil.virtual_memory() 79 | if vmem.available > requiredBytes: 80 | return source.astype(dtype, subok=False) 81 | else: 82 | dest = arrayFactory(source.shape, dtype) 83 | np.copyto(dest, source, casting='unsafe') 84 | return dest 85 | 86 | 87 | def determineMaxWindowSize(dtype, limit=None): 88 | """ 89 | Determines the largest square window size that can be used, based on 90 | the specified datatype and amount of currently available system memory. 91 | 92 | If `limit` is specified, then this value will be returned in the event 93 | that it is smaller than the maximum computed size. 94 | """ 95 | vmem = psutil.virtual_memory() 96 | maxSize = math.floor(math.sqrt(vmem.available / np.dtype(dtype).itemsize)) 97 | if limit is None or limit >= maxSize: 98 | return maxSize 99 | else: 100 | return limit 101 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/Batching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def batchWindows(windows, batchSize): 4 | """ 5 | Splits a list of windows into a series of batches. 6 | """ 7 | return np.array_split(np.array(windows), len(windows) // batchSize) 8 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/Merging.py: -------------------------------------------------------------------------------- 1 | from .SlidingWindow import generate 2 | from .Batching import batchWindows 3 | import numpy as np 4 | 5 | def mergeWindows(data, dimOrder, maxWindowSize, overlapPercent, batchSize, transform, progressCallback = None): 6 | """ 7 | Generates sliding windows for the specified dataset and applies the specified 8 | transformation function to each window. Where multiple overlapping windows 9 | include an element of the input dataset, the overlap is resolved by computing 10 | the mean transform result value for that element. 11 | 12 | Irrespective of the order of the dimensions of the input dataset, the 13 | transformation function should return a NumPy array with dimensions 14 | [batch, height, width, resultChannels]. 15 | 16 | If a progress callback is supplied, it will be called immediately before 17 | applying the transformation function to each batch of windows. The callback 18 | should accept the current batch index and number of batches as arguments. 19 | """ 20 | 21 | # Determine the dimensions of the input data 22 | sourceWidth = data.shape[dimOrder.index('w')] 23 | sourceHeight = data.shape[dimOrder.index('h')] 24 | 25 | # Generate the sliding windows and group them into batches 26 | windows = generate(data, dimOrder, maxWindowSize, overlapPercent) 27 | batches = batchWindows(windows, batchSize) 28 | 29 | # Apply the transform to the first batch of windows and determine the result dimensionality 30 | exemplarResult = transform(data, batches[0]) 31 | resultDimensions = exemplarResult.shape[ len(exemplarResult.shape) - 1 ] 32 | 33 | # Create the matrices to hold the sums and counts for the transform result values 34 | sums = np.zeros((sourceHeight, sourceWidth, resultDimensions), dtype=np.float) 35 | counts = np.zeros((sourceHeight, sourceWidth), dtype=np.uint32) 36 | 37 | # Iterate over the batches and apply the transformation function to each batch 38 | for batchNum, batch in enumerate(batches): 39 | 40 | # If a progress callback was supplied, call it 41 | if progressCallback != None: 42 | progressCallback(batchNum, len(batches)) 43 | 44 | # Apply the transformation function to the current batch 45 | batchResult = transform(data, batch) 46 | 47 | # Iterate over the windows in the batch and update the sums matrix 48 | for windowNum, window in enumerate(batch): 49 | 50 | # Create views into the larger matrices that correspond to the current window 51 | windowIndices = window.indices(False) 52 | sumsView = sums[windowIndices] 53 | countsView = counts[windowIndices] 54 | 55 | # Update the result sums for each of the dataset elements in the window 56 | sumsView[:] += batchResult[windowNum] 57 | countsView[:] += 1 58 | 59 | # Use the sums and the counts to compute the mean values 60 | for dim in range(0, resultDimensions): 61 | sums[:,:,dim] /= counts 62 | 63 | # Return the mean values 64 | return sums 65 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/RectangleUtils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | def cropRect(rect, cropTop, cropBottom, cropLeft, cropRight): 5 | """ 6 | Crops a rectangle by the specified number of pixels on each side. 7 | 8 | The input rectangle and return value are both a tuple of (x,y,w,h). 9 | """ 10 | 11 | # Unpack the rectangle 12 | x, y, w, h = rect 13 | 14 | # Crop by the specified value 15 | x += cropLeft 16 | y += cropTop 17 | w -= (cropLeft + cropRight) 18 | h -= (cropTop + cropBottom) 19 | 20 | # Re-pack the padded rect 21 | return (x,y,w,h) 22 | 23 | 24 | def padRect(rect, padTop, padBottom, padLeft, padRight, bounds, clipExcess = True): 25 | """ 26 | Pads a rectangle by the specified values on each individual side, 27 | ensuring the padded rectangle falls within the specified bounds. 28 | 29 | The input rectangle, bounds, and return value are all a tuple of (x,y,w,h). 30 | """ 31 | 32 | # Unpack the rectangle 33 | x, y, w, h = rect 34 | 35 | # Pad by the specified value 36 | x -= padLeft 37 | y -= padTop 38 | w += (padLeft + padRight) 39 | h += (padTop + padBottom) 40 | 41 | # Determine if we are clipping overflows/underflows or 42 | # shifting the centre of the rectangle to compensate 43 | if clipExcess == True: 44 | 45 | # Clip any underflows 46 | x = max(0, x) 47 | y = max(0, y) 48 | 49 | # Clip any overflows 50 | overflowY = max(0, (y + h) - bounds[0]) 51 | overflowX = max(0, (x + w) - bounds[1]) 52 | h -= overflowY 53 | w -= overflowX 54 | 55 | else: 56 | 57 | # Compensate for any underflows 58 | underflowX = max(0, 0 - x) 59 | underflowY = max(0, 0 - y) 60 | x += underflowX 61 | y += underflowY 62 | 63 | # Compensate for any overflows 64 | overflowY = max(0, (y + h) - bounds[0]) 65 | overflowX = max(0, (x + w) - bounds[1]) 66 | x -= overflowX 67 | w += overflowX 68 | y -= overflowY 69 | h += overflowY 70 | 71 | # If there are still overflows or underflows after our 72 | # modifications, we have no choice but to clip them 73 | x, y, w, h = padRect((x,y,w,h), 0, 0, 0, 0, bounds, True) 74 | 75 | # Re-pack the padded rect 76 | return (x,y,w,h) 77 | 78 | 79 | def cropRectEqually(rect, cropping): 80 | """ 81 | Crops a rectangle by the specified number of pixels on all sides. 82 | 83 | The input rectangle and return value are both a tuple of (x,y,w,h). 84 | """ 85 | return cropRect(rect, cropping, cropping, cropping, cropping) 86 | 87 | 88 | def padRectEqually(rect, padding, bounds, clipExcess = True): 89 | """ 90 | Applies equal padding to all sides of a rectangle, 91 | ensuring the padded rectangle falls within the specified bounds. 92 | 93 | The input rectangle, bounds, and return value are all a tuple of (x,y,w,h). 94 | """ 95 | return padRect(rect, padding, padding, padding, padding, bounds, clipExcess) 96 | 97 | 98 | def squareAspect(rect): 99 | """ 100 | Crops either the width or height, as necessary, to make a rectangle into a square. 101 | 102 | The input rectangle and return value are both a tuple of (x,y,w,h). 103 | """ 104 | 105 | # Determine which dimension needs to be cropped 106 | x,y,w,h = rect 107 | if w > h: 108 | cropX = (w - h) // 2 109 | return cropRect(rect, 0, 0, cropX, cropX) 110 | elif w < h: 111 | cropY = (h - w) // 2 112 | return cropRect(rect, cropY, cropY, 0, 0) 113 | 114 | # Already a square 115 | return rect 116 | 117 | 118 | def fitToSize(rect, targetWidth, targetHeight, bounds): 119 | """ 120 | Pads or crops a rectangle as necessary to achieve the target dimensions, 121 | ensuring the modified rectangle falls within the specified bounds. 122 | 123 | The input rectangle, bounds, and return value are all a tuple of (x,y,w,h). 124 | """ 125 | 126 | # Determine the difference between the current size and target size 127 | x,y,w,h = rect 128 | diffX = w - targetWidth 129 | diffY = h - targetHeight 130 | 131 | # Determine if we are cropping or padding the width 132 | if diffX > 0: 133 | cropLeft = math.floor(diffX / 2) 134 | cropRight = diffX - cropLeft 135 | x,y,w,h = cropRect((x,y,w,h), 0, 0, cropLeft, cropRight) 136 | elif diffX < 0: 137 | padLeft = math.floor(abs(diffX) / 2) 138 | padRight = abs(diffX) - padLeft 139 | x,y,w,h = padRect((x,y,w,h), 0, 0, padLeft, padRight, bounds, False) 140 | 141 | # Determine if we are cropping or padding the height 142 | if diffY > 0: 143 | cropTop = math.floor(diffY / 2) 144 | cropBottom = diffY - cropTop 145 | x,y,w,h = cropRect((x,y,w,h), cropTop, cropBottom, 0, 0) 146 | elif diffY < 0: 147 | padTop = math.floor(abs(diffY) / 2) 148 | padBottom = abs(diffY) - padTop 149 | x,y,w,h = padRect((x,y,w,h), padTop, padBottom, 0, 0, bounds, False) 150 | 151 | return (x,y,w,h) 152 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/SlidingWindow.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class DimOrder(object): 4 | """ 5 | Represents the order of the dimensions in a dataset's shape. 6 | """ 7 | ChannelHeightWidth = ['c', 'h', 'w'] 8 | HeightWidthChannel = ['h', 'w', 'c'] 9 | 10 | 11 | class SlidingWindow(object): 12 | """ 13 | Represents a single window into a larger dataset. 14 | """ 15 | 16 | def __init__(self, x, y, w, h, dimOrder, transform = None): 17 | """ 18 | Creates a new window with the specified dimensions and transform 19 | """ 20 | self.x = x 21 | self.y = y 22 | self.w = w 23 | self.h = h 24 | self.dimOrder = dimOrder 25 | self.transform = transform 26 | 27 | def apply(self, matrix): 28 | """ 29 | Slices the supplied matrix and applies any transform bound to this window 30 | """ 31 | view = matrix[ self.indices() ] 32 | return self.transform(view) if self.transform != None else view 33 | 34 | def getRect(self): 35 | """ 36 | Returns the window bounds as a tuple of (x,y,w,h) 37 | """ 38 | return (self.x, self.y, self.w, self.h) 39 | 40 | def setRect(self, rect): 41 | """ 42 | Sets the window bounds from a tuple of (x,y,w,h) 43 | """ 44 | self.x, self.y, self.w, self.h = rect 45 | 46 | def indices(self, includeChannel=True): 47 | """ 48 | Retrieves the indices for this window as a tuple of slices 49 | """ 50 | if self.dimOrder == DimOrder.HeightWidthChannel: 51 | 52 | # Equivalent to [self.y:self.y+self.h+1, self.x:self.x+self.w+1] 53 | return ( 54 | slice(self.y, self.y+self.h), 55 | slice(self.x, self.x+self.w) 56 | ) 57 | 58 | elif self.dimOrder == DimOrder.ChannelHeightWidth: 59 | 60 | if includeChannel is True: 61 | 62 | # Equivalent to [:, self.y:self.y+self.h+1, self.x:self.x+self.w+1] 63 | return ( 64 | slice(None, None), 65 | slice(self.y, self.y+self.h), 66 | slice(self.x, self.x+self.w) 67 | ) 68 | 69 | else: 70 | 71 | # Equivalent to [self.y:self.y+self.h+1, self.x:self.x+self.w+1] 72 | return ( 73 | slice(self.y, self.y+self.h), 74 | slice(self.x, self.x+self.w) 75 | ) 76 | 77 | else: 78 | raise Error('Unsupported order of dimensions: ' + str(self.dimOrder)) 79 | 80 | def __str__(self): 81 | return '(' + str(self.x) + ',' + str(self.y) + ',' + str(self.w) + ',' + str(self.h) + ')' 82 | 83 | def __repr__(self): 84 | return self.__str__() 85 | 86 | 87 | def generate(data, dimOrder, maxWindowSizeW, maxWindowSizeH, overlapPercent, transforms = []): 88 | """ 89 | Generates a set of sliding windows for the specified dataset. 90 | """ 91 | 92 | # Determine the dimensions of the input data 93 | width = data.shape[dimOrder.index('w')] 94 | height = data.shape[dimOrder.index('h')] 95 | 96 | # Generate the windows 97 | return generateForSize(width, height, dimOrder, maxWindowSizeW, maxWindowSizeH, overlapPercent, transforms) 98 | 99 | 100 | def generateForSize(width, height, dimOrder, maxWindowSizeW, maxWindowSizeH, overlapPercent, transforms = []): 101 | """ 102 | Generates a set of sliding windows for a dataset with the specified dimensions and order. 103 | """ 104 | 105 | # If the input data is smaller than the specified window size, 106 | # clip the window size to the input size on both dimensions 107 | windowSizeX = min(maxWindowSizeW, width) 108 | windowSizeY = min(maxWindowSizeH, height) 109 | 110 | # Compute the window overlap and step size 111 | windowOverlapX = int(math.floor(windowSizeX * overlapPercent)) 112 | windowOverlapY = int(math.floor(windowSizeY * overlapPercent)) 113 | stepSizeX = windowSizeX - windowOverlapX 114 | stepSizeY = windowSizeY - windowOverlapY 115 | 116 | # Determine how many windows we will need in order to cover the input data 117 | lastX = width - windowSizeX 118 | lastY = height - windowSizeY 119 | xOffsets = list(range(0, lastX+1, stepSizeX)) 120 | yOffsets = list(range(0, lastY+1, stepSizeY)) 121 | 122 | # Unless the input data dimensions are exact multiples of the step size, 123 | # we will need one additional row and column of windows to get 100% coverage 124 | if len(xOffsets) == 0 or xOffsets[-1] != lastX: 125 | xOffsets.append(lastX) 126 | if len(yOffsets) == 0 or yOffsets[-1] != lastY: 127 | yOffsets.append(lastY) 128 | 129 | # Generate the list of windows 130 | windows = [] 131 | for xOffset in xOffsets: 132 | for yOffset in yOffsets: 133 | for transform in [None] + transforms: 134 | windows.append(SlidingWindow( 135 | x=xOffset, 136 | y=yOffset, 137 | w=windowSizeX, 138 | h=windowSizeY, 139 | dimOrder=dimOrder, 140 | transform=transform 141 | )) 142 | 143 | return windows 144 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/WindowDistance.py: -------------------------------------------------------------------------------- 1 | from .ArrayUtils import * 2 | import numpy as np 3 | import math 4 | 5 | def generateDistanceMatrix(width, height): 6 | """ 7 | Generates a matrix specifying the distance of each point in a window to its centre. 8 | """ 9 | 10 | # Determine the coordinates of the exact centre of the window 11 | originX = width / 2 12 | originY = height / 2 13 | 14 | # Generate the distance matrix 15 | distances = zerosFactory((height,width), dtype=np.float) 16 | for index, val in np.ndenumerate(distances): 17 | y,x = index 18 | distances[(y,x)] = math.sqrt( math.pow(x - originX, 2) + math.pow(y - originY, 2) ) 19 | 20 | return distances 21 | -------------------------------------------------------------------------------- /tf_pose/slidingwindow/__init__.py: -------------------------------------------------------------------------------- 1 | from .SlidingWindow import DimOrder, SlidingWindow, generate, generateForSize 2 | from .WindowDistance import generateDistanceMatrix 3 | from .RectangleUtils import * 4 | from .ArrayUtils import * 5 | from .Batching import * 6 | from .Merging import * 7 | -------------------------------------------------------------------------------- /tf_pose/slim/WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/WORKSPACE -------------------------------------------------------------------------------- /tf_pose/slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/__init__.py -------------------------------------------------------------------------------- /tf_pose/slim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/slim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tf_pose/slim/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /tf_pose/slim/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | 26 | datasets_map = { 27 | 'cifar10': cifar10, 28 | 'flowers': flowers, 29 | 'imagenet': imagenet, 30 | 'mnist': mnist, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /tf_pose/slim/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | A TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_feature(values): 45 | """Returns a TF-Feature of bytes. 46 | 47 | Args: 48 | values: A string. 49 | 50 | Returns: 51 | A TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 54 | 55 | 56 | def float_feature(values): 57 | """Returns a TF-Feature of floats. 58 | 59 | Args: 60 | values: A scalar of list of values. 61 | 62 | Returns: 63 | A TF-Feature. 64 | """ 65 | if not isinstance(values, (tuple, list)): 66 | values = [values] 67 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 68 | 69 | 70 | def image_to_tfexample(image_data, image_format, height, width, class_id): 71 | return tf.train.Example(features=tf.train.Features(feature={ 72 | 'image/encoded': bytes_feature(image_data), 73 | 'image/format': bytes_feature(image_format), 74 | 'image/class/label': int64_feature(class_id), 75 | 'image/height': int64_feature(height), 76 | 'image/width': int64_feature(width), 77 | })) 78 | 79 | 80 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 81 | """Downloads the `tarball_url` and uncompresses it locally. 82 | 83 | Args: 84 | tarball_url: The URL of a tarball file. 85 | dataset_dir: The directory where the temporary files are stored. 86 | """ 87 | filename = tarball_url.split('/')[-1] 88 | filepath = os.path.join(dataset_dir, filename) 89 | 90 | def _progress(count, block_size, total_size): 91 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 92 | filename, float(count * block_size) / float(total_size) * 100.0)) 93 | sys.stdout.flush() 94 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 95 | print() 96 | statinfo = os.stat(filepath) 97 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 98 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 99 | 100 | 101 | def write_label_file(labels_to_class_names, dataset_dir, 102 | filename=LABELS_FILENAME): 103 | """Writes a file with the list of class names. 104 | 105 | Args: 106 | labels_to_class_names: A map of (integer) labels to class names. 107 | dataset_dir: The directory in which the labels file should be written. 108 | filename: The filename where the class names are written. 109 | """ 110 | labels_filename = os.path.join(dataset_dir, filename) 111 | with tf.gfile.Open(labels_filename, 'w') as f: 112 | for label in labels_to_class_names: 113 | class_name = labels_to_class_names[label] 114 | f.write('%d:%s\n' % (label, class_name)) 115 | 116 | 117 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 118 | """Specifies whether or not the dataset directory contains a label map file. 119 | 120 | Args: 121 | dataset_dir: The directory in which the labels file is found. 122 | filename: The filename where the class names are written. 123 | 124 | Returns: 125 | `True` if the labels file exists and `False` otherwise. 126 | """ 127 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 128 | 129 | 130 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 131 | """Reads the labels file and returns a mapping from ID to class name. 132 | 133 | Args: 134 | dataset_dir: The directory in which the labels file is found. 135 | filename: The filename where the class names are written. 136 | 137 | Returns: 138 | A map from a label (integer) to class name. 139 | """ 140 | labels_filename = os.path.join(dataset_dir, filename) 141 | with tf.gfile.Open(labels_filename, 'rb') as f: 142 | lines = f.read().decode() 143 | lines = lines.split('\n') 144 | lines = filter(None, lines) 145 | 146 | labels_to_class_names = {} 147 | for line in lines: 148 | index = line.index(':') 149 | labels_to_class_names[int(line[:index])] = line[index+1:] 150 | return labels_to_class_names 151 | -------------------------------------------------------------------------------- /tf_pose/slim/datasets/download_and_convert_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download and preprocess ImageNet Challenge 2012 18 | # training and validation data set. 19 | # 20 | # The final output of this script are sharded TFRecord files containing 21 | # serialized Example protocol buffers. See build_imagenet_data.py for 22 | # details of how the Example protocol buffers contain the ImageNet data. 23 | # 24 | # The final output of this script appears as such: 25 | # 26 | # data_dir/train-00000-of-01024 27 | # data_dir/train-00001-of-01024 28 | # ... 29 | # data_dir/train-00127-of-01024 30 | # 31 | # and 32 | # 33 | # data_dir/validation-00000-of-00128 34 | # data_dir/validation-00001-of-00128 35 | # ... 36 | # data_dir/validation-00127-of-00128 37 | # 38 | # Note that this script may take several hours to run to completion. The 39 | # conversion of the ImageNet data to TFRecords alone takes 2-3 hours depending 40 | # on the speed of your machine. Please be patient. 41 | # 42 | # **IMPORTANT** 43 | # To download the raw images, the user must create an account with image-net.org 44 | # and generate a username and access_key. The latter two are required for 45 | # downloading the raw images. 46 | # 47 | # usage: 48 | # cd research/slim 49 | # bazel build :download_and_convert_imagenet 50 | # ./bazel-bin/download_and_convert_imagenet.sh [data-dir] 51 | set -e 52 | 53 | if [ -z "$1" ]; then 54 | echo "usage download_and_convert_imagenet.sh [data dir]" 55 | exit 56 | fi 57 | 58 | # Create the output and temporary directories. 59 | DATA_DIR="${1%/}" 60 | SCRATCH_DIR="${DATA_DIR}/raw-data/" 61 | mkdir -p "${DATA_DIR}" 62 | mkdir -p "${SCRATCH_DIR}" 63 | WORK_DIR="$0.runfiles/__main__" 64 | 65 | # Download the ImageNet data. 66 | LABELS_FILE="${WORK_DIR}/datasets/imagenet_lsvrc_2015_synsets.txt" 67 | DOWNLOAD_SCRIPT="${WORK_DIR}/datasets/download_imagenet.sh" 68 | "${DOWNLOAD_SCRIPT}" "${SCRATCH_DIR}" "${LABELS_FILE}" 69 | 70 | # Note the locations of the train and validation data. 71 | TRAIN_DIRECTORY="${SCRATCH_DIR}train/" 72 | VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/" 73 | 74 | # Preprocess the validation data by moving the images into the appropriate 75 | # sub-directory based on the label (synset) of the image. 76 | echo "Organizing the validation data into sub-directories." 77 | PREPROCESS_VAL_SCRIPT="${WORK_DIR}/datasets/preprocess_imagenet_validation_data.py" 78 | VAL_LABELS_FILE="${WORK_DIR}/datasets/imagenet_2012_validation_synset_labels.txt" 79 | 80 | "${PREPROCESS_VAL_SCRIPT}" "${VALIDATION_DIRECTORY}" "${VAL_LABELS_FILE}" 81 | 82 | # Convert the XML files for bounding box annotations into a single CSV. 83 | echo "Extracting bounding box information from XML." 84 | BOUNDING_BOX_SCRIPT="${WORK_DIR}/datasets/process_bounding_boxes.py" 85 | BOUNDING_BOX_FILE="${SCRATCH_DIR}/imagenet_2012_bounding_boxes.csv" 86 | BOUNDING_BOX_DIR="${SCRATCH_DIR}bounding_boxes/" 87 | 88 | "${BOUNDING_BOX_SCRIPT}" "${BOUNDING_BOX_DIR}" "${LABELS_FILE}" \ 89 | | sort >"${BOUNDING_BOX_FILE}" 90 | echo "Finished downloading and preprocessing the ImageNet data." 91 | 92 | # Build the TFRecords version of the ImageNet data. 93 | BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data" 94 | OUTPUT_DIRECTORY="${DATA_DIR}" 95 | IMAGENET_METADATA_FILE="${WORK_DIR}/datasets/imagenet_metadata.txt" 96 | 97 | "${BUILD_SCRIPT}" \ 98 | --train_directory="${TRAIN_DIRECTORY}" \ 99 | --validation_directory="${VALIDATION_DIRECTORY}" \ 100 | --output_directory="${OUTPUT_DIRECTORY}" \ 101 | --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \ 102 | --labels_file="${LABELS_FILE}" \ 103 | --bounding_box_file="${BOUNDING_BOX_FILE}" 104 | -------------------------------------------------------------------------------- /tf_pose/slim/datasets/download_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download ImageNet Challenge 2012 training and validation data set. 18 | # 19 | # Downloads and decompresses raw images and bounding boxes. 20 | # 21 | # **IMPORTANT** 22 | # To download the raw images, the user must create an account with image-net.org 23 | # and generate a username and access_key. The latter two are required for 24 | # downloading the raw images. 25 | # 26 | # usage: 27 | # ./download_imagenet.sh [dirname] 28 | set -e 29 | 30 | if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then 31 | cat < ') 62 | sys.exit(-1) 63 | data_dir = sys.argv[1] 64 | validation_labels_file = sys.argv[2] 65 | 66 | # Read in the 50000 synsets associated with the validation data set. 67 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 68 | unique_labels = set(labels) 69 | 70 | # Make all sub-directories in the validation data dir. 71 | for label in unique_labels: 72 | labeled_data_dir = os.path.join(data_dir, label) 73 | os.makedirs(labeled_data_dir) 74 | 75 | # Move all of the image to the appropriate sub-directory. 76 | for i in xrange(len(labels)): 77 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 78 | original_filename = os.path.join(data_dir, basename) 79 | if not os.path.exists(original_filename): 80 | print('Failed to find: ' % original_filename) 81 | sys.exit(-1) 82 | new_filename = os.path.join(data_dir, labels[i], basename) 83 | os.rename(original_filename, new_filename) 84 | -------------------------------------------------------------------------------- /tf_pose/slim/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tf_pose/slim/download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 47 | None, 48 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 49 | 50 | tf.app.flags.DEFINE_string( 51 | 'dataset_dir', 52 | None, 53 | 'The directory where the output TFRecords and temporary files are saved.') 54 | 55 | 56 | def main(_): 57 | if not FLAGS.dataset_name: 58 | raise ValueError('You must supply the dataset name with --dataset_name') 59 | if not FLAGS.dataset_dir: 60 | raise ValueError('You must supply the dataset directory with --dataset_dir') 61 | 62 | if FLAGS.dataset_name == 'cifar10': 63 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 64 | elif FLAGS.dataset_name == 'flowers': 65 | download_and_convert_flowers.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'mnist': 67 | download_and_convert_mnist.run(FLAGS.dataset_dir) 68 | else: 69 | raise ValueError( 70 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) 71 | 72 | if __name__ == '__main__': 73 | tf.app.run() 74 | -------------------------------------------------------------------------------- /tf_pose/slim/export_inference_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Saves out a GraphDef containing the architecture of the model. 16 | 17 | To use it, run something like this, with a model name defined by slim: 18 | 19 | bazel build tensorflow_models/research/slim:export_inference_graph 20 | bazel-bin/tensorflow_models/research/slim/export_inference_graph \ 21 | --model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb 22 | 23 | If you then want to use the resulting model with your own or pretrained 24 | checkpoints as part of a mobile model, you can run freeze_graph to get a graph 25 | def with the variables inlined as constants using: 26 | 27 | bazel build tensorflow/python/tools:freeze_graph 28 | bazel-bin/tensorflow/python/tools/freeze_graph \ 29 | --input_graph=/tmp/inception_v3_inf_graph.pb \ 30 | --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ 31 | --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ 32 | --output_node_names=InceptionV3/Predictions/Reshape_1 33 | 34 | The output node names will vary depending on the model, but you can inspect and 35 | estimate them using the summarize_graph tool: 36 | 37 | bazel build tensorflow/tools/graph_transforms:summarize_graph 38 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ 39 | --in_graph=/tmp/inception_v3_inf_graph.pb 40 | 41 | To run the resulting graph in C++, you can look at the label_image sample code: 42 | 43 | bazel build tensorflow/examples/label_image:label_image 44 | bazel-bin/tensorflow/examples/label_image/label_image \ 45 | --image=${HOME}/Pictures/flowers.jpg \ 46 | --input_layer=input \ 47 | --output_layer=InceptionV3/Predictions/Reshape_1 \ 48 | --graph=/tmp/frozen_inception_v3.pb \ 49 | --labels=/tmp/imagenet_slim_labels.txt \ 50 | --input_mean=0 \ 51 | --input_std=255 52 | 53 | """ 54 | 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | from tensorflow.python.platform import gfile 62 | from datasets import dataset_factory 63 | from nets import nets_factory 64 | 65 | 66 | slim = tf.contrib.slim 67 | 68 | tf.app.flags.DEFINE_string( 69 | 'model_name', 'inception_v3', 'The name of the architecture to save.') 70 | 71 | tf.app.flags.DEFINE_boolean( 72 | 'is_training', False, 73 | 'Whether to save out a training-focused version of the model.') 74 | 75 | tf.app.flags.DEFINE_integer( 76 | 'image_size', None, 77 | 'The image size to use, otherwise use the model default_image_size.') 78 | 79 | tf.app.flags.DEFINE_integer( 80 | 'batch_size', None, 81 | 'Batch size for the exported model. Defaulted to "None" so batch size can ' 82 | 'be specified at model runtime.') 83 | 84 | tf.app.flags.DEFINE_string('dataset_name', 'imagenet', 85 | 'The name of the dataset to use with the model.') 86 | 87 | tf.app.flags.DEFINE_integer( 88 | 'labels_offset', 0, 89 | 'An offset for the labels in the dataset. This flag is primarily used to ' 90 | 'evaluate the VGG and ResNet architectures which do not use a background ' 91 | 'class for the ImageNet dataset.') 92 | 93 | tf.app.flags.DEFINE_string( 94 | 'output_file', '', 'Where to save the resulting file to.') 95 | 96 | tf.app.flags.DEFINE_string( 97 | 'dataset_dir', '', 'Directory to save intermediate dataset files to') 98 | 99 | FLAGS = tf.app.flags.FLAGS 100 | 101 | 102 | def main(_): 103 | if not FLAGS.output_file: 104 | raise ValueError('You must supply the path to save to with --output_file') 105 | tf.logging.set_verbosity(tf.logging.INFO) 106 | with tf.Graph().as_default() as graph: 107 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train', 108 | FLAGS.dataset_dir) 109 | network_fn = nets_factory.get_network_fn( 110 | FLAGS.model_name, 111 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 112 | is_training=FLAGS.is_training) 113 | image_size = FLAGS.image_size or network_fn.default_image_size 114 | placeholder = tf.placeholder(name='input', dtype=tf.float32, 115 | shape=[FLAGS.batch_size, image_size, 116 | image_size, 3]) 117 | network_fn(placeholder) 118 | graph_def = graph.as_graph_def() 119 | with gfile.GFile(FLAGS.output_file, 'wb') as f: 120 | f.write(graph_def.SerializeToString()) 121 | 122 | 123 | if __name__ == '__main__': 124 | tf.app.run() 125 | -------------------------------------------------------------------------------- /tf_pose/slim/export_inference_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for export_inference_graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import gfile 28 | import export_inference_graph 29 | 30 | 31 | class ExportInferenceGraphTest(tf.test.TestCase): 32 | 33 | def testExportInferenceGraph(self): 34 | tmpdir = self.get_temp_dir() 35 | output_file = os.path.join(tmpdir, 'inception_v3.pb') 36 | flags = tf.app.flags.FLAGS 37 | flags.output_file = output_file 38 | flags.model_name = 'inception_v3' 39 | flags.dataset_dir = tmpdir 40 | export_inference_graph.main(None) 41 | self.assertTrue(gfile.Exists(output_file)) 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/slim/nets/__pycache__/resnet_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/__pycache__/resnet_utils.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/slim/nets/__pycache__/resnet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/__pycache__/resnet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/slim/nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | from nets import dcgan 25 | 26 | 27 | class DCGANTest(tf.test.TestCase): 28 | 29 | def test_generator_run(self): 30 | tf.set_random_seed(1234) 31 | noise = tf.random_normal([100, 64]) 32 | image, _ = dcgan.generator(noise) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | image.eval() 36 | 37 | def test_generator_graph(self): 38 | tf.set_random_seed(1234) 39 | # Check graph construction for a number of image size/depths and batch 40 | # sizes. 41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 42 | tf.reset_default_graph() 43 | final_size = 2 ** i 44 | noise = tf.random_normal([batch_size, 64]) 45 | image, end_points = dcgan.generator( 46 | noise, 47 | depth=32, 48 | final_size=final_size) 49 | 50 | self.assertAllEqual([batch_size, final_size, final_size, 3], 51 | image.shape.as_list()) 52 | 53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 54 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 55 | 56 | # Check layer depths. 57 | for j in range(1, i): 58 | layer = end_points['deconv%i' % j] 59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 60 | 61 | def test_generator_invalid_input(self): 62 | wrong_dim_input = tf.zeros([5, 32, 32]) 63 | with self.assertRaises(ValueError): 64 | dcgan.generator(wrong_dim_input) 65 | 66 | correct_input = tf.zeros([3, 2]) 67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 68 | dcgan.generator(correct_input, final_size=30) 69 | 70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 71 | dcgan.generator(correct_input, final_size=4) 72 | 73 | def test_discriminator_run(self): 74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 75 | output, _ = dcgan.discriminator(image) 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | output.eval() 79 | 80 | def test_discriminator_graph(self): 81 | # Check graph construction for a number of image size/depths and batch 82 | # sizes. 83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 84 | tf.reset_default_graph() 85 | img_w = 2 ** i 86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 87 | output, end_points = dcgan.discriminator( 88 | image, 89 | depth=32) 90 | 91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 92 | 93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 94 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 95 | 96 | # Check layer depths. 97 | for j in range(1, i+1): 98 | layer = end_points['conv%i' % j] 99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 100 | 101 | def test_discriminator_invalid_input(self): 102 | wrong_dim_img = tf.zeros([5, 32, 32]) 103 | with self.assertRaises(ValueError): 104 | dcgan.discriminator(wrong_dim_img) 105 | 106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 107 | with self.assertRaises(ValueError): 108 | dcgan.discriminator(spatially_undefined_shape) 109 | 110 | not_square = tf.zeros([5, 32, 16, 3]) 111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 112 | dcgan.discriminator(not_square) 113 | 114 | not_power_2 = tf.zeros([5, 30, 30, 3]) 115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 116 | dcgan.discriminator(not_power_2) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu): 37 | """Defines the default arg scope for inception models. 38 | 39 | Args: 40 | weight_decay: The weight decay to use for regularizing the model. 41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 42 | batch_norm_decay: Decay for batch norm moving average. 43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 44 | in batch norm. 45 | activation_fn: Activation function for conv2d. 46 | 47 | Returns: 48 | An `arg_scope` to use for the inception models. 49 | """ 50 | batch_norm_params = { 51 | # Decay for the moving averages. 52 | 'decay': batch_norm_decay, 53 | # epsilon to prevent 0s in variance. 54 | 'epsilon': batch_norm_epsilon, 55 | # collection containing update_ops. 56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 57 | # use fused batch norm if possible. 58 | 'fused': None, 59 | } 60 | if use_batch_norm: 61 | normalizer_fn = slim.batch_norm 62 | normalizer_params = batch_norm_params 63 | else: 64 | normalizer_fn = None 65 | normalizer_params = {} 66 | # Set weight_decay for weights in Conv and FC layers. 67 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 68 | weights_regularizer=slim.l2_regularizer(weight_decay)): 69 | with slim.arg_scope( 70 | [slim.conv2d], 71 | weights_initializer=slim.variance_scaling_initializer(), 72 | activation_fn=activation_fn, 73 | normalizer_fn=normalizer_fn, 74 | normalizer_params=normalizer_params) as sc: 75 | return sc 76 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Example 51 | 52 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 53 | 54 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/mobilenet/__init__.py -------------------------------------------------------------------------------- /tf_pose/slim/nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /tf_pose/slim/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /tf_pose/slim/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/slim/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /tf_pose/slim/nets/mobilenet_v1_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import mobilenet_v1 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_string('master', '', 'Session master') 33 | flags.DEFINE_integer('batch_size', 250, 'Batch size') 34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate') 36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 38 | flags.DEFINE_bool('quantize', False, 'Quantize training') 39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints') 40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs') 41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def imagenet_input(is_training): 47 | """Data reader for imagenet. 48 | 49 | Reads in imagenet data and performs pre-processing on the images. 50 | 51 | Args: 52 | is_training: bool specifying if train or validation dataset is needed. 53 | Returns: 54 | A batch of images and labels. 55 | """ 56 | if is_training: 57 | dataset = dataset_factory.get_dataset('imagenet', 'train', 58 | FLAGS.dataset_dir) 59 | else: 60 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 61 | FLAGS.dataset_dir) 62 | 63 | provider = slim.dataset_data_provider.DatasetDataProvider( 64 | dataset, 65 | shuffle=is_training, 66 | common_queue_capacity=2 * FLAGS.batch_size, 67 | common_queue_min=FLAGS.batch_size) 68 | [image, label] = provider.get(['image', 'label']) 69 | 70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 71 | 'mobilenet_v1', is_training=is_training) 72 | 73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 74 | 75 | images, labels = tf.train.batch( 76 | tensors=[image, label], 77 | batch_size=FLAGS.batch_size, 78 | num_threads=4, 79 | capacity=5 * FLAGS.batch_size) 80 | return images, labels 81 | 82 | 83 | def metrics(logits, labels): 84 | """Specify the metrics for eval. 85 | 86 | Args: 87 | logits: Logits output from the graph. 88 | labels: Ground truth labels for inputs. 89 | 90 | Returns: 91 | Eval Op for the graph. 92 | """ 93 | labels = tf.squeeze(labels) 94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels), 96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5), 97 | }) 98 | for name, value in names_to_values.iteritems(): 99 | slim.summaries.add_scalar_summary( 100 | value, name, prefix='eval', print_summary=True) 101 | return names_to_updates.values() 102 | 103 | 104 | def build_model(): 105 | """Build the mobilenet_v1 model for evaluation. 106 | 107 | Returns: 108 | g: graph with rewrites after insertion of quantization ops and batch norm 109 | folding. 110 | eval_ops: eval ops for inference. 111 | variables_to_restore: List of variables to restore from checkpoint. 112 | """ 113 | g = tf.Graph() 114 | with g.as_default(): 115 | inputs, labels = imagenet_input(is_training=False) 116 | 117 | scope = mobilenet_v1.mobilenet_v1_arg_scope( 118 | is_training=False, weight_decay=0.0) 119 | with slim.arg_scope(scope): 120 | logits, _ = mobilenet_v1.mobilenet_v1( 121 | inputs, 122 | is_training=False, 123 | depth_multiplier=FLAGS.depth_multiplier, 124 | num_classes=FLAGS.num_classes) 125 | 126 | if FLAGS.quantize: 127 | tf.contrib.quantize.create_eval_graph() 128 | 129 | eval_ops = metrics(logits, labels) 130 | 131 | return g, eval_ops 132 | 133 | 134 | def eval_model(): 135 | """Evaluates mobilenet_v1.""" 136 | g, eval_ops = build_model() 137 | with g.as_default(): 138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 139 | slim.evaluation.evaluate_once( 140 | FLAGS.master, 141 | FLAGS.checkpoint_dir, 142 | logdir=FLAGS.eval_dir, 143 | num_evals=num_batches, 144 | eval_op=eval_ops) 145 | 146 | 147 | def main(unused_arg): 148 | eval_model() 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run(main) 153 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /tf_pose/slim/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in list(nets_factory.networks_map.keys())[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in list(nets_factory.networks_map.keys())[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /tf_pose/slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tf_pose/slim/preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /tf_pose/slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /tf_pose/slim/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'nasnet_mobile': inception_preprocessing, 58 | 'nasnet_large': inception_preprocessing, 59 | 'pnasnet_large': inception_preprocessing, 60 | 'resnet_v1_50': vgg_preprocessing, 61 | 'resnet_v1_101': vgg_preprocessing, 62 | 'resnet_v1_152': vgg_preprocessing, 63 | 'resnet_v1_200': vgg_preprocessing, 64 | 'resnet_v2_50': vgg_preprocessing, 65 | 'resnet_v2_101': vgg_preprocessing, 66 | 'resnet_v2_152': vgg_preprocessing, 67 | 'resnet_v2_200': vgg_preprocessing, 68 | 'vgg': vgg_preprocessing, 69 | 'vgg_a': vgg_preprocessing, 70 | 'vgg_16': vgg_preprocessing, 71 | 'vgg_19': vgg_preprocessing, 72 | } 73 | 74 | if name not in preprocessing_fn_map: 75 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 76 | 77 | def preprocessing_fn(image, output_height, output_width, **kwargs): 78 | return preprocessing_fn_map[name].preprocess_image( 79 | image, output_height, output_width, is_training=is_training, **kwargs) 80 | 81 | return preprocessing_fn 82 | -------------------------------------------------------------------------------- /tf_pose/slim/scripts/finetune_inception_resnet_v2_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an Inception Resnet V2 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_resnet_v2_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 31 | MODEL_NAME=inception_resnet_v2 32 | 33 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 34 | TRAIN_DIR=/tmp/flowers-models/${MODEL_NAME} 35 | 36 | # Where the dataset is saved to. 37 | DATASET_DIR=/tmp/flowers 38 | 39 | # Download the pre-trained checkpoint. 40 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 41 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 42 | fi 43 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 44 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 45 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 46 | mv inception_resnet_v2.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 47 | rm inception_resnet_v2_2016_08_30.tar.gz 48 | fi 49 | 50 | # Download the dataset 51 | python download_and_convert_data.py \ 52 | --dataset_name=flowers \ 53 | --dataset_dir=${DATASET_DIR} 54 | 55 | # Fine-tune only the new layers for 1000 steps. 56 | python train_image_classifier.py \ 57 | --train_dir=${TRAIN_DIR} \ 58 | --dataset_name=flowers \ 59 | --dataset_split_name=train \ 60 | --dataset_dir=${DATASET_DIR} \ 61 | --model_name=${MODEL_NAME} \ 62 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 63 | --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 64 | --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 65 | --max_number_of_steps=1000 \ 66 | --batch_size=32 \ 67 | --learning_rate=0.01 \ 68 | --learning_rate_decay_type=fixed \ 69 | --save_interval_secs=60 \ 70 | --save_summaries_secs=60 \ 71 | --log_every_n_steps=10 \ 72 | --optimizer=rmsprop \ 73 | --weight_decay=0.00004 74 | 75 | # Run evaluation. 76 | python eval_image_classifier.py \ 77 | --checkpoint_path=${TRAIN_DIR} \ 78 | --eval_dir=${TRAIN_DIR} \ 79 | --dataset_name=flowers \ 80 | --dataset_split_name=validation \ 81 | --dataset_dir=${DATASET_DIR} \ 82 | --model_name=${MODEL_NAME} 83 | 84 | # Fine-tune all the new layers for 500 steps. 85 | python train_image_classifier.py \ 86 | --train_dir=${TRAIN_DIR}/all \ 87 | --dataset_name=flowers \ 88 | --dataset_split_name=train \ 89 | --dataset_dir=${DATASET_DIR} \ 90 | --model_name=${MODEL_NAME} \ 91 | --checkpoint_path=${TRAIN_DIR} \ 92 | --max_number_of_steps=500 \ 93 | --batch_size=32 \ 94 | --learning_rate=0.0001 \ 95 | --learning_rate_decay_type=fixed \ 96 | --save_interval_secs=60 \ 97 | --save_summaries_secs=60 \ 98 | --log_every_n_steps=10 \ 99 | --optimizer=rmsprop \ 100 | --weight_decay=0.00004 101 | 102 | # Run evaluation. 103 | python eval_image_classifier.py \ 104 | --checkpoint_path=${TRAIN_DIR}/all \ 105 | --eval_dir=${TRAIN_DIR}/all \ 106 | --dataset_name=flowers \ 107 | --dataset_split_name=validation \ 108 | --dataset_dir=${DATASET_DIR} \ 109 | --model_name=${MODEL_NAME} 110 | -------------------------------------------------------------------------------- /tf_pose/slim/scripts/finetune_inception_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v1 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 42 | tar -xvf inception_v1_2016_08_28.tar.gz 43 | mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 44 | rm inception_v1_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 2000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v1 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV1/Logits \ 61 | --trainable_scopes=InceptionV1/Logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=inception_v1 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=inception_v1 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=inception_v1 105 | -------------------------------------------------------------------------------- /tf_pose/slim/scripts/finetune_inception_v3_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV3 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v3_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV3 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v3 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 42 | tar -xvf inception_v3_2016_08_28.tar.gz 43 | mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt 44 | rm inception_v3_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 1000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v3 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 61 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 62 | --max_number_of_steps=1000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --learning_rate_decay_type=fixed \ 66 | --save_interval_secs=60 \ 67 | --save_summaries_secs=60 \ 68 | --log_every_n_steps=100 \ 69 | --optimizer=rmsprop \ 70 | --weight_decay=0.00004 71 | 72 | # Run evaluation. 73 | python eval_image_classifier.py \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --eval_dir=${TRAIN_DIR} \ 76 | --dataset_name=flowers \ 77 | --dataset_split_name=validation \ 78 | --dataset_dir=${DATASET_DIR} \ 79 | --model_name=inception_v3 80 | 81 | # Fine-tune all the new layers for 500 steps. 82 | python train_image_classifier.py \ 83 | --train_dir=${TRAIN_DIR}/all \ 84 | --dataset_name=flowers \ 85 | --dataset_split_name=train \ 86 | --dataset_dir=${DATASET_DIR} \ 87 | --model_name=inception_v3 \ 88 | --checkpoint_path=${TRAIN_DIR} \ 89 | --max_number_of_steps=500 \ 90 | --batch_size=32 \ 91 | --learning_rate=0.0001 \ 92 | --learning_rate_decay_type=fixed \ 93 | --save_interval_secs=60 \ 94 | --save_summaries_secs=60 \ 95 | --log_every_n_steps=10 \ 96 | --optimizer=rmsprop \ 97 | --weight_decay=0.00004 98 | 99 | # Run evaluation. 100 | python eval_image_classifier.py \ 101 | --checkpoint_path=${TRAIN_DIR}/all \ 102 | --eval_dir=${TRAIN_DIR}/all \ 103 | --dataset_name=flowers \ 104 | --dataset_split_name=validation \ 105 | --dataset_dir=${DATASET_DIR} \ 106 | --model_name=inception_v3 107 | -------------------------------------------------------------------------------- /tf_pose/slim/scripts/finetune_resnet_v1_50_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes a ResNetV1-50 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained ResNetV1-50 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/resnet_v1_50 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then 41 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 42 | tar -xvf resnet_v1_50_2016_08_28.tar.gz 43 | mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt 44 | rm resnet_v1_50_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 3000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=resnet_v1_50 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \ 60 | --checkpoint_exclude_scopes=resnet_v1_50/logits \ 61 | --trainable_scopes=resnet_v1_50/logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=resnet_v1_50 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=resnet_v1_50 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=resnet_v1_50 105 | -------------------------------------------------------------------------------- /tf_pose/slim/scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Cifar10 dataset 19 | # 2. Trains a CifarNet model on the Cifar10 training set. 20 | # 3. Evaluates the model on the Cifar10 testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./scripts/train_cifarnet_on_cifar10.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/cifarnet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/cifar10 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=cifar10 \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=cifar10 \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=cifarnet \ 45 | --preprocessing_name=cifarnet \ 46 | --max_number_of_steps=100000 \ 47 | --batch_size=128 \ 48 | --save_interval_secs=120 \ 49 | --save_summaries_secs=120 \ 50 | --log_every_n_steps=100 \ 51 | --optimizer=sgd \ 52 | --learning_rate=0.1 \ 53 | --learning_rate_decay_factor=0.1 \ 54 | --num_epochs_per_decay=200 \ 55 | --weight_decay=0.004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=cifar10 \ 62 | --dataset_split_name=test \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=cifarnet 65 | -------------------------------------------------------------------------------- /tf_pose/slim/scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the MNIST dataset 19 | # 2. Trains a LeNet model on the MNIST training set. 20 | # 3. Evaluates the model on the MNIST testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/train_lenet_on_mnist.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/lenet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/mnist 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=mnist \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=mnist \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=lenet \ 45 | --preprocessing_name=lenet \ 46 | --max_number_of_steps=20000 \ 47 | --batch_size=50 \ 48 | --learning_rate=0.01 \ 49 | --save_interval_secs=60 \ 50 | --save_summaries_secs=60 \ 51 | --log_every_n_steps=100 \ 52 | --optimizer=sgd \ 53 | --learning_rate_decay_type=fixed \ 54 | --weight_decay=0 55 | 56 | # Run evaluation. 57 | python eval_image_classifier.py \ 58 | --checkpoint_path=${TRAIN_DIR} \ 59 | --eval_dir=${TRAIN_DIR} \ 60 | --dataset_name=mnist \ 61 | --dataset_split_name=test \ 62 | --dataset_dir=${DATASET_DIR} \ 63 | --model_name=lenet 64 | -------------------------------------------------------------------------------- /tf_pose/slim/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup script for slim.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | setup( 22 | name='slim', 23 | version='0.1', 24 | include_package_data=True, 25 | packages=find_packages(), 26 | description='tf-slim', 27 | ) 28 | -------------------------------------------------------------------------------- /tf_pose/tensblur/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/tensblur/__init__.py -------------------------------------------------------------------------------- /tf_pose/tensblur/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/tensblur/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/tensblur/__pycache__/smoother.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satyaborg/pose-estimation-detection/f173348318876ebf4fc974575c40246bce71427f/tf_pose/tensblur/__pycache__/smoother.cpython-36.pyc -------------------------------------------------------------------------------- /tf_pose/tensblur/smoother.py: -------------------------------------------------------------------------------- 1 | # vim: sta:et:sw=2:ts=2:sts=2 2 | # Written by Antonio Loquercio 3 | 4 | import numpy as np 5 | import scipy.stats as st 6 | import pdb 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def layer(op): 12 | def layer_decorated(self, *args, **kwargs): 13 | # Automatically set a name if not provided. 14 | name = kwargs.setdefault('name', self.get_unique_name(op.__name__)) 15 | # Figure out the layer inputs. 16 | if len(self.terminals) == 0: 17 | raise RuntimeError('No input variables found for layer %s.' % name) 18 | elif len(self.terminals) == 1: 19 | layer_input = self.terminals[0] 20 | else: 21 | layer_input = list(self.terminals) 22 | # Perform the operation and get the output. 23 | layer_output = op(self, layer_input, *args, **kwargs) 24 | # Add to layer LUT. 25 | self.layers[name] = layer_output 26 | # This output is now the input for the next layer. 27 | self.feed(layer_output) 28 | # Return self for chained calls. 29 | return self 30 | 31 | return layer_decorated 32 | 33 | 34 | class Smoother(object): 35 | def __init__(self, inputs, filter_size, sigma): 36 | self.inputs = inputs 37 | self.terminals = [] 38 | self.layers = dict(inputs) 39 | self.filter_size = filter_size 40 | self.sigma = sigma 41 | self.setup() 42 | 43 | def setup(self): 44 | self.feed('data').conv(name='smoothing') 45 | 46 | def get_unique_name(self, prefix): 47 | ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1 48 | return '%s_%d' % (prefix, ident) 49 | 50 | def feed(self, *args): 51 | assert len(args) != 0 52 | self.terminals = [] 53 | for fed_layer in args: 54 | if isinstance(fed_layer, str): 55 | try: 56 | fed_layer = self.layers[fed_layer] 57 | except KeyError: 58 | raise KeyError('Unknown layer name fed: %s' % fed_layer) 59 | self.terminals.append(fed_layer) 60 | return self 61 | 62 | def gauss_kernel(self, kernlen=21, nsig=3, channels=1): 63 | interval = (2*nsig+1.)/(kernlen) 64 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) 65 | kern1d = np.diff(st.norm.cdf(x)) 66 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 67 | kernel = kernel_raw/kernel_raw.sum() 68 | out_filter = np.array(kernel, dtype = np.float32) 69 | out_filter = out_filter.reshape((kernlen, kernlen, 1, 1)) 70 | out_filter = np.repeat(out_filter, channels, axis = 2) 71 | return out_filter 72 | 73 | def make_gauss_var(self, name, size, sigma, c_i): 74 | # with tf.device("/cpu:0"): 75 | kernel = self.gauss_kernel(size, sigma, c_i) 76 | var = tf.Variable(tf.convert_to_tensor(kernel), name=name) 77 | return var 78 | 79 | def get_output(self): 80 | '''Returns the smoother output.''' 81 | return self.terminals[-1] 82 | 83 | @layer 84 | def conv(self, 85 | input, 86 | name, 87 | padding='SAME'): 88 | # Get the number of channels in the input 89 | c_i = input.get_shape().as_list()[3] 90 | # Convolution for a given input and kernel 91 | convolve = lambda i, k: tf.nn.depthwise_conv2d(i, k, [1, 1, 1, 1], padding=padding) 92 | with tf.variable_scope(name) as scope: 93 | kernel = self.make_gauss_var('gauss_weight', self.filter_size, self.sigma, c_i) 94 | output = convolve(input, kernel) 95 | return output 96 | --------------------------------------------------------------------------------