├── .dockerignore ├── .gitattributes ├── .gitignore ├── CMakeLists.txt ├── Dockerfile ├── LICENSE ├── README.md ├── docker ├── Dockerfile └── README.md ├── etcs ├── dance.mp4 ├── feature.md ├── inference_result2.png ├── loss_ll_heat.png ├── loss_ll_paf.png ├── openpose_macbook13_mobilenet2.gif ├── openpose_macbook_cmu.gif ├── openpose_macbook_mobilenet3.gif ├── openpose_p40_cmu.gif ├── openpose_p40_mobilenet.gif ├── openpose_tx2_mobilenet3.gif ├── ros.md └── training.md ├── images ├── apink1.jpg ├── apink1_crop.jpg ├── apink1_crop_s1.jpg ├── apink2.jpg ├── apink3.jpg ├── cat.jpg ├── golf.jpg ├── hand1.jpg ├── hand1_small.jpg ├── hand2.jpg ├── handsup1.jpg ├── p1.jpg ├── p2.jpg ├── p3.jpg ├── p3_dance.png ├── ski.jpg └── valid_person1.jpg ├── launch └── demo_video.launch ├── models ├── graph │ ├── cmu │ │ └── download.sh │ └── mobilenet_thin │ │ ├── graph.pb │ │ ├── graph_freeze.pb │ │ └── graph_opt.pb ├── numpy │ └── download.sh └── pretrained │ ├── mobilenet_v1_0.50_224_2017_06_14 │ └── download.sh │ ├── mobilenet_v1_0.75_224_2017_06_14 │ └── download.sh │ └── mobilenet_v1_1.0_224_2017_06_14 │ └── download.sh ├── msg ├── BodyPartElm.msg ├── Person.msg └── Persons.msg ├── package.xml ├── predict ├── predict.md └── predict.py ├── requirements.txt ├── scripts ├── broadcaster_ros.py └── visualization.py └── src ├── __init__.py ├── common.py ├── datum_pb2.py ├── estimator.py ├── lifting ├── __init__.py ├── config.py ├── draw.py ├── models │ └── prob_model_params.mat ├── prob_model.py └── upright_fast.py ├── network_base.py ├── network_cmu.py ├── network_dsconv.py ├── network_mobilenet.py ├── network_mobilenet_thin.py ├── networks.py ├── pose_augment.py ├── pose_datamaster.py ├── pose_dataset.py ├── pose_dataworker.py ├── pose_stats.py ├── run.py ├── run_checkpoint.py ├── run_directory.py ├── run_video.py ├── run_webcam.py ├── slim ├── __init__.py ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── cyclegan.py │ ├── cyclegan_test.py │ ├── dcgan.py │ ├── dcgan_test.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── mobilenet_v1.md │ ├── mobilenet_v1.png │ ├── mobilenet_v1.py │ ├── mobilenet_v1_test.py │ ├── nasnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nasnet.py │ │ ├── nasnet_test.py │ │ ├── nasnet_utils.py │ │ └── nasnet_utils_test.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── pix2pix.py │ ├── pix2pix_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── vgg.py │ └── vgg_test.py └── preprocessing │ ├── __init__.py │ ├── cifarnet_preprocessing.py │ ├── inception_preprocessing.py │ ├── lenet_preprocessing.py │ ├── preprocessing_factory.py │ └── vgg_preprocessing.py └── train.py /.dockerignore: -------------------------------------------------------------------------------- 1 | ./models 2 | ./models/* 3 | ./models/*/* 4 | ./models/*/*/* 5 | ./models/*/*/*/* 6 | models 7 | models/* 8 | models/*/* 9 | models/*/*/* 10 | models/*/*/*/* 11 | *.meta 12 | *.index 13 | *.data-* 14 | *.ckpt.* 15 | 16 | ./tests 17 | ./tests/* 18 | tests 19 | graph*.pb 20 | chk*.meta 21 | lifting 22 | slim 23 | 24 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | models/numpy/*.npy filter=lfs diff=lfs merge=lfs -text 2 | *.ckpt* filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # models 104 | *ckpt* 105 | *.npy 106 | timeline*.json 107 | 108 | tmp 109 | models/graph/cmu*/*.pb 110 | models/trained/*/checkpoint 111 | models/trained/*/*.pb 112 | models/trained/*/model-*.data-* 113 | models/trained/*/model-*.index 114 | models/trained/*/model-*.meta -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(tfpose_ros) 3 | 4 | ## Add support for C++11, supported in ROS Kinetic and newer 5 | add_definitions(-std=c++11) 6 | 7 | find_package(catkin REQUIRED COMPONENTS 8 | roscpp 9 | rospy 10 | std_msgs 11 | message_generation 12 | ) 13 | 14 | # Generate messages in the 'msg' folder 15 | add_message_files( 16 | FILES 17 | BodyPartElm.msg 18 | Person.msg 19 | Persons.msg 20 | ) 21 | 22 | generate_messages( 23 | DEPENDENCIES std_msgs 24 | ) 25 | 26 | catkin_package( 27 | CATKIN_DEPENDS rospy message_generation message_runtime 28 | ) 29 | 30 | install() 31 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess 2 | 3 | COPY ./*.py /root/tf-openpose/ 4 | WORKDIR /root/tf-openpose/ 5 | 6 | RUN cd /root/tf-openpose/ && pip3 install -r requirements.txt 7 | 8 | ENTRYPOINT ["python3", "pose_dataworker.py"] 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-pose-estimation 2 | 3 | 'Openpose' for human pose estimation have been implemented using Tensorflow. It also provides several variants that have made some changes to the network structure for **real-time processing on the CPU or low-power embedded devices.** 4 | 5 | ## ------------------------------ 20180326 更新 ------------------------------ 6 | #### 预测 7 | 这边原作开源了自己训练的模型,这边可以直接借用,所以自己抽取了部分内容,写了预测函数[predict.md](https://github.com/mattzheng/tf-pose-estimation-applied/blob/master/predict/predict.md) 8 | 9 | 10 | 11 | ---------- 12 | 13 | 14 | **You can even run this on your macbook with descent FPS!** 15 | 16 | Original Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose 17 | 18 | | CMU's Original Model
on Macbook Pro 15" | Mobilenet Variant
on Macbook Pro 15" | Mobilenet Variant
on Jetson TX2 | 19 | |:---------|:--------------------|:----------------| 20 | | ![cmu-model](/etcs/openpose_macbook_cmu.gif) | ![mb-model-macbook](/etcs/openpose_macbook_mobilenet3.gif) | ![mb-model-tx2](/etcs/openpose_tx2_mobilenet3.gif) | 21 | | **~0.6 FPS** | **~4.2 FPS** @ 368x368 | **~10 FPS** @ 368x368 | 22 | | 2.8GHz Quad-core i7 | 2.8GHz Quad-core i7 | Jetson TX2 Embedded Board | 23 | 24 | Implemented features are listed here : [features](./etcs/feature.md) 25 | 26 | ## Important Updates 27 | 28 | 2018.2.7 Arguments in run.py script changed. Support dynamic input size. 29 | 30 | ## Install 31 | 32 | ### Dependencies 33 | 34 | You need dependencies below. 35 | 36 | - python3 37 | - tensorflow 1.4.1+ 38 | - opencv3, protobuf, python3-tk 39 | 40 | ### Install 41 | 42 | ```bash 43 | $ git clone https://www.github.com/ildoonet/tf-openpose 44 | $ cd tf-openpose 45 | $ pip3 install -r requirements.txt 46 | ``` 47 | 48 | ## Models 49 | 50 | I have tried multiple variations of models to find optmized network architecture. Some of them are below and checkpoint files are provided for research purpose. 51 | 52 | - cmu 53 | - the model based VGG pretrained network which described in the original paper. 54 | - I converted Weights in Caffe format to use in tensorflow. 55 | - [pretrained weight download](https://www.dropbox.com/s/xh5s7sb7remu8tx/openpose_coco.npy?dl=0) 56 | 57 | - dsconv 58 | - Same architecture as the cmu version except for the **depthwise separable convolution** of mobilenet. 59 | - I trained it using 'transfer learning', but it provides not-enough speed and accuracy. 60 | 61 | - mobilenet 62 | - Based on the mobilenet paper, 12 convolutional layers are used as feature-extraction layers. 63 | - To improve on small person, **minor modification** on the architecture have been made. 64 | - Three models were learned according to network size parameters. 65 | - mobilenet 66 | - 368x368 : [checkpoint weight download](https://www.dropbox.com/s/09xivpuboecge56/mobilenet_0.75_0.50_model-388003.zip?dl=0) 67 | - mobilenet_fast 68 | - mobilenet_accurate 69 | - I published models which is not the best ones, but you can test them before you trained a model from the scratch. 70 | 71 | ### Download Tensorflow Graph File(pb file) 72 | 73 | Before running demo, you should download graph files. You can deploy this graph on your mobile or other platforms. 74 | 75 | - cmu (trained in 656x368) 76 | - mobilenet_thin (trained in 432x368) 77 | 78 | CMU's model graphs are too large for git, so I uploaded them on an external cloud. You should download them if you want to use cmu's original model. Download scripts are provided in the model folder. 79 | 80 | ``` 81 | $ cd models/graph/cmu 82 | $ bash download.sh 83 | ``` 84 | 85 | 86 | 87 | ### Inference Time 88 | 89 | | Dataset | Model | Inference Time
Macbook Pro i5 3.1G | Inference Time
Jetson TX2 | 90 | |---------|--------------------|----------------:|----------------:| 91 | | Coco | cmu | 10.0s @ 368x368 | OOM @ 368x368
5.5s @ 320x240| 92 | | Coco | dsconv | 1.10s @ 368x368 | 93 | | Coco | mobilenet_accurate | 0.40s @ 368x368 | 0.18s @ 368x368 | 94 | | Coco | mobilenet | 0.24s @ 368x368 | 0.10s @ 368x368 | 95 | | Coco | mobilenet_fast | 0.16s @ 368x368 | 0.07s @ 368x368 | 96 | 97 | ## Demo 98 | 99 | ### Test Inference 100 | 101 | You can test the inference feature with a single image. 102 | 103 | ``` 104 | $ python3 run.py --model=mobilenet_thin --resolution=432x368 --image=... 105 | ``` 106 | 107 | The image flag MUST be relative to the src folder with no "~", i.e: 108 | ``` 109 | --image ../../Desktop 110 | ``` 111 | 112 | Then you will see the screen as below with pafmap, heatmap, result and etc. 113 | 114 | ![inferent_result](./etcs/inference_result2.png) 115 | 116 | ### Realtime Webcam 117 | 118 | ``` 119 | $ python3 run_webcam.py --model=mobilenet_thin --resolution=432x368 --camera=0 120 | ``` 121 | 122 | Then you will see the realtime webcam screen with estimated poses as below. This [Realtime Result](./etcs/openpose_macbook13_mobilenet2.gif) was recored on macbook pro 13" with 3.1Ghz Dual-Core CPU. 123 | 124 | ## Python Usage 125 | 126 | This pose estimator provides simple python classes that you can use in your applications. 127 | 128 | See [run.py](run.py) or [run_webcam.py](run_webcam.py) as references. 129 | 130 | ```python 131 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 132 | humans = e.inference(image) 133 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 134 | ``` 135 | 136 | ## ROS Support 137 | 138 | See : [etcs/ros.md](./etcs/ros.md) 139 | 140 | ## Training 141 | 142 | See : [etcs/training.md](./etcs/training.md) 143 | 144 | ## References 145 | 146 | ### OpenPose 147 | 148 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose 149 | 150 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation 151 | 152 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 153 | 154 | [4] Keras Openpose : https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation 155 | 156 | ### Lifting from the deep 157 | 158 | [1] Arxiv Paper : https://arxiv.org/abs/1701.00295 159 | 160 | [2] https://github.com/DenisTome/Lifting-from-the-Deep-release 161 | 162 | ### Mobilenet 163 | 164 | [1] Original Paper : https://arxiv.org/abs/1704.04861 165 | 166 | [2] Pretrained model : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md 167 | 168 | ### Libraries 169 | 170 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack 171 | 172 | ### Tensorflow Tips 173 | 174 | [1] Freeze graph : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py 175 | 176 | [2] Optimize graph : https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2 177 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | #ENV http_proxy=http://10.41.249.28:8080 https_proxy=http://10.41.249.28:8080 4 | 5 | RUN apt-get update -yq && apt-get install -yq build-essential cmake git pkg-config wget zip && \ 6 | apt-get install -yq libjpeg8-dev libtiff5-dev libjasper-dev libpng12-dev && \ 7 | apt-get install -yq libavcodec-dev libavformat-dev libswscale-dev libv4l-dev && \ 8 | apt-get install -yq libgtk2.0-dev && \ 9 | apt-get install -yq libatlas-base-dev gfortran && \ 10 | apt-get install -yq python3 python3-dev python3-pip python3-setuptools python3-tk git && \ 11 | apt-get remove -yq python-pip python3-pip && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && \ 12 | pip3 install numpy && \ 13 | cd ~ && git clone https://github.com/Itseez/opencv.git && \ 14 | cd opencv && mkdir build && cd build && \ 15 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 16 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 17 | -D INSTALL_PYTHON_EXAMPLES=ON \ 18 | -D BUILD_opencv_python3=yes -D PYTHON_EXECUTABLE=/usr/bin/python3 .. && \ 19 | make -j8 && make install && rm -rf /root/opencv/ && \ 20 | mkdir -p /root/tf-openpose && \ 21 | rm -rf /tmp/*.tar.gz && \ 22 | apt-get clean && rm -rf /tmp/* /var/tmp* /var/lib/apt/lists/* && \ 23 | rm -f /etc/ssh/ssh_host_* && rm -rf /usr/share/man?? /usr/share/man/??_* 24 | 25 | COPY . /root/tf-openpose/ 26 | WORKDIR /root/tf-openpose/ 27 | 28 | RUN cd /root/tf-openpose/ && pip3 install -U setuptools && \ 29 | pip3 install tensorflow && pip3 install -r requirements.txt 30 | 31 | RUN cd /root && git clone https://github.com/cocodataset/cocoapi && \ 32 | pip3 install cython && \ 33 | cd cocoapi/PythonAPI && python3 setup.py build_ext --inplace && python3 setup.py build_ext install && \ 34 | mkdir /coco && cd /coco && wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip && \ 35 | unzip annotations_trainval2017.zip && rm -rf annotations_trainval2017.zip 36 | 37 | ENTRYPOINT ["python3", "pose_dataworker.py"] 38 | 39 | #ENV http_proxy= https_proxy= 40 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | $ docker build --tag idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess -f ./docker/full/ . 3 | ``` 4 | 5 | ``` 6 | $ docker build --tag idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess -f ./docker/update/Dockerfile . 7 | ``` -------------------------------------------------------------------------------- /etcs/dance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/dance.mp4 -------------------------------------------------------------------------------- /etcs/feature.md: -------------------------------------------------------------------------------- 1 | ## Features 2 | 3 | - [x] CMU's original network architecture and weights. 4 | 5 | - [x] Transfer Original Weights to Tensorflow 6 | 7 | - [x] Training Code with multi-gpus 8 | 9 | - [x] Evaluate with test dataset 10 | 11 | - [ ] Inference 12 | 13 | - [x] Post processing from network output. 14 | 15 | - [x] Faster post-processing 16 | 17 | - [ ] Multi-Scale Inference 18 | 19 | - [x] Faster network variants using custom mobilenet architecture. 20 | 21 | - [x] Depthwise Separable Convolution Version 22 | 23 | - [x] Mobilenet Version 24 | 25 | - [ ] Demos 26 | 27 | - [x] Realtime Webcam Demo 28 | 29 | - [x] Image File Demo 30 | 31 | - [ ] Video File Demo 32 | 33 | - [x] ROS Support. See [./etcs/ros.md](ros readme). -------------------------------------------------------------------------------- /etcs/inference_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/inference_result2.png -------------------------------------------------------------------------------- /etcs/loss_ll_heat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/loss_ll_heat.png -------------------------------------------------------------------------------- /etcs/loss_ll_paf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/loss_ll_paf.png -------------------------------------------------------------------------------- /etcs/openpose_macbook13_mobilenet2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/openpose_macbook13_mobilenet2.gif -------------------------------------------------------------------------------- /etcs/openpose_macbook_cmu.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/openpose_macbook_cmu.gif -------------------------------------------------------------------------------- /etcs/openpose_macbook_mobilenet3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/openpose_macbook_mobilenet3.gif -------------------------------------------------------------------------------- /etcs/openpose_p40_cmu.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/openpose_p40_cmu.gif -------------------------------------------------------------------------------- /etcs/openpose_p40_mobilenet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/openpose_p40_mobilenet.gif -------------------------------------------------------------------------------- /etcs/openpose_tx2_mobilenet3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/etcs/openpose_tx2_mobilenet3.gif -------------------------------------------------------------------------------- /etcs/ros.md: -------------------------------------------------------------------------------- 1 | # tf-pose-estimation for ROS 2 | 3 | Human pose estimation is expected to use on mobile robots which need human interactions. 4 | 5 | ## Installation 6 | 7 | Cloning this repository under src folder in your ros workstation. And the same should be carried out as [README.md](README.md). 8 | 9 | ``` 10 | $ cd $(ros-workspace) 11 | $ cd src 12 | $ git clone https://github.com/ildoonet/tf-pose-estimation 13 | $ pip install -r tf-pose-estimation/requirements.txt 14 | ``` 15 | 16 | There are dependencies to launch demo, 17 | 18 | - video_stream_opencv 19 | - image_view 20 | - ros_video_recorder : https://github.com/ildoonet/ros-video-recorder 21 | 22 | ## Video/camera demo 23 | 24 | | CMU
640x360 | Mobilenet_Thin
432x368 | 25 | |:----------------|:---------------------------| 26 | | ![cmu-model](/etcs/openpose_p40_cmu.gif) | ![cmu-model](/etcs/openpose_p40_mobilenet.gif) | 27 | 28 | Above tests were run on a P40 gpu. Latency between current video frames and processed frames is much lower on mobilenet version. 29 | 30 | Source : https://www.youtube.com/watch?v=rSZnyBuc6tc 31 | 32 | ``` 33 | $ roslaunch tfpose_ros demo_video.launch 34 | ``` 35 | 36 | You can specify 'video' arguments to launch realtime video demo using your camera. See [./launch/demo_video.launch](ros launch file). -------------------------------------------------------------------------------- /etcs/training.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | 3 | ### Coco Dataset 4 | 5 | You should download COCO Dataset from http://cocodataset.org/#download 6 | 7 | Also, you need to install cocoapi for easy parsing : https://github.com/cocodataset/cocoapi 8 | 9 | 10 | ``` 11 | $ git clone https://github.com/cocodataset/cocoapi 12 | $ cd cocoapi/PythonAPI 13 | $ python3 setup.py build_ext --inplace 14 | $ python3 setup.py build_ext install 15 | ``` 16 | 17 | ### Augmentation 18 | 19 | CMU Perceptual Computing Lab has modified Caffe to provide data augmentation. See : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 20 | 21 | I implemented the augmentation codes as the way of the original version, See [pose_dataset.py](pose_dataset.py) and [pose_augment.py](pose_augment.py). This includes scaling, rotation, flip, cropping. 22 | 23 | This process can be a bottleneck for training, so if you have enough computing resources, please see [Run for Faster Training]() Section 24 | 25 | ### Run 26 | 27 | ``` 28 | $ python3 train.py --model=cmu --datapath={datapath} --batchsize=64 --lr=0.001 --modelpath={path-to-save} 29 | 30 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 31 | ``` 32 | 33 | If you want to reproduce the original paper's result, the following setting is recommended. 34 | 35 | - model : vgg 36 | - lr : 0.0001 or 0.00004 37 | - input-width = input-height = 368x368 or 432x368 38 | - batchsize : 10 (I trained with batchsizes up to 128, they are trained well) 39 | 40 | | Heatmap Loss | PAFmap(Part Affinity Field) Loss | 41 | |-------------------------------------------|------------------------------------------| 42 | | ![train_loss_cmu](/etcs/loss_ll_heat.png) | ![train_loss_cmu](/etcs/loss_ll_paf.png) | 43 | 44 | As you can see from the table above, training loss was converged at the almost same trends with the original paper. 45 | 46 | The mobilenet versions has slightly poor loss value compared to the original one. Training losses are 3 to 8% larger, though validation losses are 5 to 14% larger. 47 | 48 | 49 | ### Run for Faster Training 50 | 51 | If you have enough computing resources in multiple nodes, you can launch multiple workers on nodes to help data preparation. 52 | 53 | ``` 54 | worker-node1$ python3 pose_dataworker.py --master=tcp://host:port 55 | worker-node2$ python3 pose_dataworker.py --master=tcp://host:port 56 | worker-node3$ python3 pose_dataworker.py --master=tcp://host:port 57 | ... 58 | ``` 59 | 60 | After above preparation, you can launch training script with 'remote-data' arguments. 61 | 62 | ``` 63 | $ python3 train.py --remote-data=tcp://0.0.0.0:port 64 | 65 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 66 | ``` 67 | 68 | Also, You can quickly train with multiple gpus. This automatically splits batch into multiple gpus for forward/backward computations. 69 | 70 | ``` 71 | $ python3 train.py --remote-data=tcp://0.0.0.0:port --gpus=8 72 | 73 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 74 | ``` 75 | 76 | I trained models within a day with 8 gpus and multiple pre-processing nodes with 48 core cpus. 77 | 78 | ### Model Optimization for Inference 79 | 80 | After trained a model, I optimized models by folding batch normalization to convolutional layers and removing redundant operations. 81 | 82 | Firstly, the model should be frozen. 83 | 84 | ```bash 85 | $ python3 -m tensorflow.python.tools.freeze_graph \ 86 | --input_graph=... \ 87 | --output_graph=... \ 88 | --input_checkpoint=... \ 89 | --output_node_names="Openpose/concat_stage7" 90 | ``` 91 | 92 | And the optimization can be performed on the frozen model via graph transform provided by tensorflow. 93 | 94 | ```bash 95 | $ bazel build tensorflow/tools/graph_transforms:transform_graph 96 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 97 | --in_graph=... \ 98 | --out_graph=... \ 99 | --inputs='image:0' \ 100 | --outputs='Openpose/concat_stage7:0' \ 101 | --transforms=' 102 | strip_unused_nodes(type=float, shape="1,368,368,3") 103 | remove_nodes(op=Identity, op=CheckNumerics) 104 | fold_constants(ignoreError=False) 105 | fold_old_batch_norms 106 | fold_batch_norms' 107 | ``` 108 | 109 | Also, It is promising to quantize neural network in 8 bit to get futher improvement for speed. In my case, this will make inference less accurate and take more time on Intel's CPUs. 110 | 111 | ``` 112 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 113 | --in_graph=/Users/ildoonet/repos/tf-openpose/tmp/cmu_640x480/graph_opt.pb \ 114 | --out_graph=/Users/ildoonet/repos/tf-openpose/tmp/cmu_640x480/graph_q.pb \ 115 | --inputs='image' \ 116 | --outputs='Openpose/concat_stage7:0' \ 117 | --transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,360,640,3") 118 | remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) 119 | fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes 120 | strip_unused_nodes sort_by_execution_order' 121 | ``` 122 | -------------------------------------------------------------------------------- /images/apink1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/apink1.jpg -------------------------------------------------------------------------------- /images/apink1_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/apink1_crop.jpg -------------------------------------------------------------------------------- /images/apink1_crop_s1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/apink1_crop_s1.jpg -------------------------------------------------------------------------------- /images/apink2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/apink2.jpg -------------------------------------------------------------------------------- /images/apink3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/apink3.jpg -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/cat.jpg -------------------------------------------------------------------------------- /images/golf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/golf.jpg -------------------------------------------------------------------------------- /images/hand1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/hand1.jpg -------------------------------------------------------------------------------- /images/hand1_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/hand1_small.jpg -------------------------------------------------------------------------------- /images/hand2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/hand2.jpg -------------------------------------------------------------------------------- /images/handsup1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/handsup1.jpg -------------------------------------------------------------------------------- /images/p1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/p1.jpg -------------------------------------------------------------------------------- /images/p2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/p2.jpg -------------------------------------------------------------------------------- /images/p3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/p3.jpg -------------------------------------------------------------------------------- /images/p3_dance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/p3_dance.png -------------------------------------------------------------------------------- /images/ski.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/ski.jpg -------------------------------------------------------------------------------- /images/valid_person1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/images/valid_person1.jpg -------------------------------------------------------------------------------- /launch/demo_video.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /models/graph/cmu/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "download model graph : cmu" 4 | 5 | extract_download_url() { 6 | 7 | url=$( wget -q -O - $1 | grep -o 'http*://download[^"]*' | tail -n 1 ) 8 | echo "$url" 9 | 10 | } 11 | 12 | wget $( extract_download_url http://www.mediafire.com/file/1pyjsjl0p93x27c/graph_freeze.pb ) -O graph_freeze.pb 13 | wget $( extract_download_url http://www.mediafire.com/file/qlzzr20mpocnpa3/graph_opt.pb ) -O graph_opt.pb 14 | wget $( extract_download_url http://www.mediafire.com/file/i72ll9k5i7x6qfh/graph.pb ) -O graph.pb 15 | -------------------------------------------------------------------------------- /models/graph/mobilenet_thin/graph_freeze.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/models/graph/mobilenet_thin/graph_freeze.pb -------------------------------------------------------------------------------- /models/graph/mobilenet_thin/graph_opt.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/models/graph/mobilenet_thin/graph_opt.pb -------------------------------------------------------------------------------- /models/numpy/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/ropayv77vklvf56/openpose_coco.npy 4 | wget http://www.mediafire.com/file/7e73ddj31rzw6qq/openpose_vgg16.npy 5 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.50_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/meu73iq8rxlsd3g/mobilenet_v1_0.50_224.ckpt.data-00000-of-00001 4 | wget http://www.mediafire.com/file/7u6iupfkcaxk5hx/mobilenet_v1_0.50_224.ckpt.index 5 | wget http://www.mediafire.com/file/zp8y4d0ytzharzz/mobilenet_v1_0.50_224.ckpt.meta 6 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.75_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/kibz0x9e7h11ueb/mobilenet_v1_0.75_224.ckpt.data-00000-of-00001 4 | wget http://www.mediafire.com/file/t8909eaikvc6ea2/mobilenet_v1_0.75_224.ckpt.index 5 | wget http://www.mediafire.com/file/6jjnbn1aged614x/mobilenet_v1_0.75_224.ckpt.meta 6 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_1.0_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/oh6njnz9lgoqwdj/mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 4 | wget http://www.mediafire.com/file/61qln0tbac4ny9o/mobilenet_v1_1.0_224.ckpt.meta 5 | wget http://www.mediafire.com/file/2111rh6tb5fl1lr/mobilenet_v1_1.0_224.ckpt.index 6 | -------------------------------------------------------------------------------- /msg/BodyPartElm.msg: -------------------------------------------------------------------------------- 1 | uint32 part_id 2 | float32 x 3 | float32 y 4 | float32 confidence -------------------------------------------------------------------------------- /msg/Person.msg: -------------------------------------------------------------------------------- 1 | BodyPartElm[] body_part -------------------------------------------------------------------------------- /msg/Persons.msg: -------------------------------------------------------------------------------- 1 | Person[] persons 2 | uint32 image_w 3 | uint32 image_h 4 | Header header -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | tfpose_ros 4 | 0.1.0 5 | ROS Package for Human Pose Estimation Implemented in Tensorflow 6 | 7 | Curtis Kim 8 | 9 | GNU2.0 10 | 11 | catkin 12 | 13 | rospy 14 | message_generation 15 | std_msgs 16 | cv_bridge 17 | 18 | message_generation 19 | message_runtime 20 | std_msgs 21 | rospy 22 | 23 | -------------------------------------------------------------------------------- /predict/predict.md: -------------------------------------------------------------------------------- 1 | ## 一、准备阶段 2 | 该模型是由openpose改编的,事先需要下载: 3 | 4 | - 1、训练好的pretrained network模型文件([pretrained weight 5 | download](https://www.dropbox.com/s/xh5s7sb7remu8tx/openpose_coco.npy?dl=0)):`openpose_coco.npy 6 | · 199.56 MB` 7 | - 2、Tensorflow Graph File图文件(CMU's model): 8 | 9 | ``` 10 | $ cd models/graph/cmu 11 | $ bash download.sh 12 | ``` 13 | 当然如果要批量处理起来可能需要另外写一个。 14 | 15 | ## 二、训练与预测 16 | #### 2.1 关于训练 17 | 一般是coco数据集训练而得,有18个关键点。训练的过程在原内容中已经封装的非常好,所以笔者最开始想借鉴,但是发现封的太死,要命了! 18 | 后来发现有一个小哥在参加完 AI challenger挑战赛,把训练过程简单写了出来,可参考:[galaxy-fangfang/AI-challenger-Realtime_Multi-Person_Pose_Estimation-training](https://github.com/galaxy-fangfang/AI-challenger-Realtime_Multi-Person_Pose_Estimation-training) 19 | 20 | 关键点的关联可见下图: 21 | ![这里写图片描述](https://camo.githubusercontent.com/5833ee83e638a1b622a16fd6447d64a9668efcf5/687474703a2f2f696d672e626c6f672e6373646e2e6e65742f32303137303930383134343233383631303f77617465726d61726b2f322f746578742f6148523063446f764c324a736232637559334e6b626935755a5851766147467763486c6f62334a70656d6c7662673d3d2f666f6e742f3561364c354c32542f666f6e7473697a652f3430302f66696c6c2f49304a42516b46434d413d3d2f646973736f6c76652f37302f677261766974792f536f75746845617374) 22 | 23 | #### 2.2 关于预测 24 | 预测使用的命令行模式: 25 | 26 | ``` 27 | python3 run.py --model=mobilenet_thin --resolution=432x368 --image=... 28 | ``` 29 | 笔者简单的从原作者的预测命令中提取出了相关预测信息,自己简单写了两个函数:`get_keypoint、PoseEstimatorPredict`在[predict.py](https://github.com/mattzheng/tf-pose-estimation-applied/blob/master/predict/predict.py)之中 30 | 31 | 目前支持两种模型:`mobilenet_thin`以及`cmu` 32 | 其中model的类型一共六种,具体可见文档:`/src/network.py` 33 | 34 | 其中,`PoseEstimatorPredict`两种输出方式: 35 | 36 | - plot=False,返回一个内容:关键点信息 37 | - plot=True,返回两个内容:关键点信息+标点图片matrix 38 | 39 | 40 | 41 | 42 | 43 | **预测效率,用titanxP + CMU模型效率貌似也不高的样子:** 44 | 45 | ``` 46 | CPU times: user 1.42 s, sys: 5.24 s, total: 6.66 s 47 | Wall time: 6.25 s 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /predict/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import sys 5 | import os 6 | sys.path.append('tf_pose_estimation/tf-openpose/src/') 7 | 8 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 9 | 10 | # import tensorflow as tf 11 | # from keras.backend.tensorflow_backend import set_session 12 | # config = tf.ConfigProto() 13 | # config.gpu_options.per_process_gpu_memory_fraction = 0.8 14 | # set_session(tf.Session(config=config)) 15 | 16 | import argparse 17 | import logging 18 | import time 19 | import ast 20 | 21 | import common 22 | import cv2 23 | import numpy as np 24 | from estimator import TfPoseEstimator 25 | from networks import get_graph_path, model_wh 26 | 27 | from lifting.prob_model import Prob3dPose 28 | from lifting.draw import plot_pose 29 | import matplotlib.pyplot as plt 30 | %matplotlib inline 31 | from scipy import misc 32 | 33 | def get_keypoint(image,humans): 34 | ''' 35 | 输入: 36 | image,矩阵格式 37 | humans,是关键点信息,预测出来的结果 38 | 39 | 输出: 40 | centers,关键点信息,格式: 41 | 点信息,(x,y),概率 42 | {1: [(142, 303), 10.328406150326954], 43 | 2: [(154, 303), 6.621983647346497], 44 | 3: [(154, 323), 7.118330538272858]} 45 | 46 | ''' 47 | image_h, image_w = image.shape[:2] 48 | centers = {} 49 | for n,human in enumerate(humans): 50 | center_tmp = {} 51 | for i in range(common.CocoPart.Background.value): # range(common.CocoPart.Background.value) = range(18) 52 | if i not in human.body_parts.keys(): 53 | continue 54 | body_part = human.body_parts[i] 55 | # print(human.body_parts[i]) 56 | # human.body_parts[i].score 57 | # human.body_parts[i].x 58 | # human.body_parts[i].y 59 | center = (int(body_part.x * image_w + 0.5), int(body_part.y * image_h + 0.5)) 60 | center_tmp[i] = [center,human.body_parts[i].score] 61 | centers[n] = center_tmp 62 | # print(center) 63 | return centers 64 | 65 | def PoseEstimatorPredict(image_path,plot = False,resolution ='432x368', scales = '[None]',model = 'mobilenet_thin'): 66 | ''' 67 | input: 68 | image_path,图片路径,jpg 69 | plot = False,是否画图,如果True,两样内容,关键点信息+标点图片matrix 70 | resolution ='432x368', 规格 71 | scales = '[None]', 72 | model = 'mobilenet_thin',模型选择 73 | 74 | output: 75 | plot为false,返回一个内容:关键点信息 76 | plot为true,返回两个内容:关键点信息+标点图片matrix 77 | ''' 78 | w, h = model_wh(resolution) 79 | e = TfPoseEstimator(get_graph_path(model), target_size=(w, h)) 80 | image = common.read_imgfile(image_path, None, None) 81 | t = time.time() 82 | humans = e.inference(image, scales=scales) # 主要的预测函数 83 | elapsed = time.time() - t 84 | 85 | logger.info('inference image: %s in %.4f seconds.' % (image_path, elapsed)) 86 | centers = get_keypoint(image,humans) # 拿上关键点信息 87 | 88 | if plot: 89 | # 画图的情况下: 90 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) # 画图函数 91 | fig = plt.figure() 92 | a = fig.add_subplot(2, 2, 1) 93 | a.set_title('Result') 94 | plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 95 | return centers,image 96 | else: 97 | # 不画图的情况下: 98 | return centers 99 | 100 | if __name__ == '__main__': 101 | 102 | # logger记录 103 | logger = logging.getLogger('TfPoseEstimator') 104 | logger.setLevel(logging.DEBUG) 105 | ch = logging.StreamHandler() 106 | ch.setLevel(logging.DEBUG) 107 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 108 | ch.setFormatter(formatter) 109 | logger.addHandler(ch) 110 | 111 | # 预测 112 | image_path = '/data1/research/matt/docker/tf_pose_estimation/tf-openpose/images/hand1.jpg' 113 | %time centers,images_p = PoseEstimatorPredict(image_path,plot = True,model = 'mobilenet_thin') 114 | ''' 115 | # 目前支持两种:mobilenet_thin以及cmu 116 | # model,还有那些model,一共六种,可见文档:/src/network.py 117 | ''' 118 | 119 | # 保存 120 | misc.imsave('/images/000000000569_1.jpg', images_p) 121 | 122 | 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | matplotlib 3 | scipy 4 | tqdm 5 | requests 6 | fire 7 | ast 8 | dill 9 | git+https://github.com/ppwwyyxx/tensorpack.git -------------------------------------------------------------------------------- /scripts/broadcaster_ros.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | import os 4 | import sys 5 | import ast 6 | 7 | from threading import Lock 8 | import rospy 9 | import rospkg 10 | from cv_bridge import CvBridge, CvBridgeError 11 | from std_msgs.msg import String 12 | from sensor_msgs.msg import Image 13 | from tfpose_ros.msg import Persons, Person, BodyPartElm 14 | 15 | from estimator import TfPoseEstimator 16 | from networks import model_wh, get_graph_path 17 | 18 | 19 | scales = [None] 20 | 21 | 22 | def humans_to_msg(humans): 23 | persons = Persons() 24 | 25 | for human in humans: 26 | person = Person() 27 | 28 | for k in human.body_parts: 29 | body_part = human.body_parts[k] 30 | 31 | body_part_msg = BodyPartElm() 32 | body_part_msg.part_id = body_part.part_idx 33 | body_part_msg.x = body_part.x 34 | body_part_msg.y = body_part.y 35 | body_part_msg.confidence = body_part.score 36 | person.body_part.append(body_part_msg) 37 | persons.persons.append(person) 38 | 39 | return persons 40 | 41 | 42 | def callback_image(data): 43 | # et = time.time() 44 | try: 45 | cv_image = cv_bridge.imgmsg_to_cv2(data, "bgr8") 46 | except CvBridgeError as e: 47 | rospy.logerr('[tf-pose-estimation] Converting Image Error. ' + str(e)) 48 | return 49 | 50 | acquired = tf_lock.acquire(False) 51 | if not acquired: 52 | return 53 | 54 | try: 55 | global scales 56 | humans = pose_estimator.inference(cv_image, scales) 57 | finally: 58 | tf_lock.release() 59 | 60 | msg = humans_to_msg(humans) 61 | msg.image_w = data.width 62 | msg.image_h = data.height 63 | msg.header = data.header 64 | 65 | pub_pose.publish(msg) 66 | # rospy.loginfo(time.time() - et) 67 | 68 | 69 | def callback_scales(data): 70 | global scales 71 | scales = ast.literal_eval(data.data) 72 | rospy.logdebug('[tf-pose-estimation] scale changed: ' + str(scales)) 73 | 74 | 75 | if __name__ == '__main__': 76 | global scales 77 | 78 | rospy.loginfo('initialization+') 79 | rospy.init_node('TfPoseEstimatorROS', anonymous=True, log_level=rospy.INFO) 80 | 81 | # parameters 82 | image_topic = rospy.get_param('~camera', '') 83 | model = rospy.get_param('~model', 'cmu') 84 | resolution = rospy.get_param('~resolution', '432x368') 85 | scales_str = rospy.get_param('~scales', '[None]') 86 | scales = ast.literal_eval(scales_str) 87 | tf_lock = Lock() 88 | 89 | rospy.loginfo('[TfPoseEstimatorROS] scales(%d)=%s' % (len(scales), str(scales))) 90 | 91 | if not image_topic: 92 | rospy.logerr('Parameter \'camera\' is not provided.') 93 | sys.exit(-1) 94 | 95 | try: 96 | w, h = model_wh(resolution) 97 | graph_path = get_graph_path(model) 98 | 99 | rospack = rospkg.RosPack() 100 | graph_path = os.path.join(rospack.get_path('tfpose_ros'), graph_path) 101 | except Exception as e: 102 | rospy.logerr('invalid model: %s, e=%s' % (model, e)) 103 | sys.exit(-1) 104 | 105 | pose_estimator = TfPoseEstimator(graph_path, target_size=(w, h)) 106 | cv_bridge = CvBridge() 107 | 108 | rospy.Subscriber(image_topic, Image, callback_image, queue_size=1, buff_size=2**24) 109 | rospy.Subscriber('~scales', String, callback_scales, queue_size=1) 110 | pub_pose = rospy.Publisher('~pose', Persons, queue_size=1) 111 | 112 | rospy.loginfo('start+') 113 | rospy.spin() 114 | rospy.loginfo('finished') 115 | -------------------------------------------------------------------------------- /scripts/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | import cv2 4 | import rospy 5 | from sensor_msgs.msg import Image 6 | from cv_bridge import CvBridge, CvBridgeError 7 | 8 | from tfpose_ros.msg import Persons, Person, BodyPartElm 9 | from estimator import Human, BodyPart, TfPoseEstimator 10 | 11 | 12 | class VideoFrames: 13 | """ 14 | Reference : ros-video-recorder 15 | https://github.com/ildoonet/ros-video-recorder/blob/master/scripts/recorder.py 16 | """ 17 | def __init__(self, image_topic): 18 | self.image_sub = rospy.Subscriber(image_topic, Image, self.callback_image, queue_size=1) 19 | self.bridge = CvBridge() 20 | self.frames = [] 21 | 22 | def callback_image(self, data): 23 | try: 24 | cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") 25 | except CvBridgeError as e: 26 | rospy.logerr('Converting Image Error. ' + str(e)) 27 | return 28 | 29 | self.frames.append((data.header.stamp, cv_image)) 30 | 31 | def get_latest(self, at_time, remove_older=True): 32 | fs = [x for x in self.frames if x[0] <= at_time] 33 | if len(fs) == 0: 34 | return None 35 | 36 | f = fs[-1] 37 | if remove_older: 38 | self.frames = self.frames[len(fs) - 1:] 39 | 40 | return f[1] 41 | 42 | 43 | def cb_pose(data): 44 | # get image with pose time 45 | t = data.header.stamp 46 | image = vf.get_latest(t, remove_older=True) 47 | if image is None: 48 | rospy.logwarn('No received images.') 49 | return 50 | 51 | h, w = image.shape[:2] 52 | if resize_ratio > 0: 53 | image = cv2.resize(image, (int(resize_ratio*w), int(resize_ratio*h)), interpolation=cv2.INTER_LINEAR) 54 | 55 | # ros topic to Person instance 56 | humans = [] 57 | for p_idx, person in enumerate(data.persons): 58 | human = Human([]) 59 | for body_part in person.body_part: 60 | part = BodyPart('', body_part.part_id, body_part.x, body_part.y, body_part.confidence) 61 | human.body_parts[body_part.part_id] = part 62 | 63 | humans.append(human) 64 | 65 | # draw 66 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 67 | pub_img.publish(cv_bridge.cv2_to_imgmsg(image, 'bgr8')) 68 | 69 | 70 | if __name__ == '__main__': 71 | rospy.loginfo('initialization+') 72 | rospy.init_node('TfPoseEstimatorROS-Visualization', anonymous=True) 73 | 74 | # topics params 75 | image_topic = rospy.get_param('~camera', '') 76 | pose_topic = rospy.get_param('~pose', '/pose_estimator/pose') 77 | 78 | resize_ratio = float(rospy.get_param('~resize_ratio', '-1')) 79 | 80 | # publishers 81 | pub_img = rospy.Publisher('~output', Image, queue_size=1) 82 | 83 | # initialization 84 | cv_bridge = CvBridge() 85 | vf = VideoFrames(image_topic) 86 | rospy.wait_for_message(image_topic, Image, timeout=30) 87 | 88 | # subscribers 89 | rospy.Subscriber(pose_topic, Persons, cb_pose, queue_size=1) 90 | 91 | # run 92 | rospy.spin() 93 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/src/__init__.py -------------------------------------------------------------------------------- /src/common.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import tensorflow as tf 4 | import cv2 5 | 6 | 7 | regularizer_conv = 0.004 8 | regularizer_dsconv = 0.0004 9 | batchnorm_fused = True 10 | activation_fn = tf.nn.relu 11 | 12 | 13 | class CocoPart(Enum): 14 | Nose = 0 15 | Neck = 1 16 | RShoulder = 2 17 | RElbow = 3 18 | RWrist = 4 19 | LShoulder = 5 20 | LElbow = 6 21 | LWrist = 7 22 | RHip = 8 23 | RKnee = 9 24 | RAnkle = 10 25 | LHip = 11 26 | LKnee = 12 27 | LAnkle = 13 28 | REye = 14 29 | LEye = 15 30 | REar = 16 31 | LEar = 17 32 | Background = 18 33 | 34 | 35 | class MPIIPart(Enum): 36 | RAnkle = 0 37 | RKnee = 1 38 | RHip = 2 39 | LHip = 3 40 | LKnee = 4 41 | LAnkle = 5 42 | RWrist = 6 43 | RElbow = 7 44 | RShoulder = 8 45 | LShoulder = 9 46 | LElbow = 10 47 | LWrist = 11 48 | Neck = 12 49 | Head = 13 50 | 51 | @staticmethod 52 | def from_coco(human): 53 | # t = { 54 | # MPIIPart.RAnkle: CocoPart.RAnkle, 55 | # MPIIPart.RKnee: CocoPart.RKnee, 56 | # MPIIPart.RHip: CocoPart.RHip, 57 | # MPIIPart.LHip: CocoPart.LHip, 58 | # MPIIPart.LKnee: CocoPart.LKnee, 59 | # MPIIPart.LAnkle: CocoPart.LAnkle, 60 | # MPIIPart.RWrist: CocoPart.RWrist, 61 | # MPIIPart.RElbow: CocoPart.RElbow, 62 | # MPIIPart.RShoulder: CocoPart.RShoulder, 63 | # MPIIPart.LShoulder: CocoPart.LShoulder, 64 | # MPIIPart.LElbow: CocoPart.LElbow, 65 | # MPIIPart.LWrist: CocoPart.LWrist, 66 | # MPIIPart.Neck: CocoPart.Neck, 67 | # MPIIPart.Nose: CocoPart.Nose, 68 | # } 69 | 70 | t = [ 71 | (MPIIPart.Head, CocoPart.Nose), 72 | (MPIIPart.Neck, CocoPart.Neck), 73 | (MPIIPart.RShoulder, CocoPart.RShoulder), 74 | (MPIIPart.RElbow, CocoPart.RElbow), 75 | (MPIIPart.RWrist, CocoPart.RWrist), 76 | (MPIIPart.LShoulder, CocoPart.LShoulder), 77 | (MPIIPart.LElbow, CocoPart.LElbow), 78 | (MPIIPart.LWrist, CocoPart.LWrist), 79 | (MPIIPart.RHip, CocoPart.RHip), 80 | (MPIIPart.RKnee, CocoPart.RKnee), 81 | (MPIIPart.RAnkle, CocoPart.RAnkle), 82 | (MPIIPart.LHip, CocoPart.LHip), 83 | (MPIIPart.LKnee, CocoPart.LKnee), 84 | (MPIIPart.LAnkle, CocoPart.LAnkle), 85 | ] 86 | 87 | pose_2d_mpii = [] 88 | visibilty = [] 89 | for mpi, coco in t: 90 | if coco.value not in human.body_parts.keys(): 91 | pose_2d_mpii.append((0, 0)) 92 | visibilty.append(False) 93 | continue 94 | pose_2d_mpii.append((human.body_parts[coco.value].x, human.body_parts[coco.value].y)) 95 | visibilty.append(True) 96 | return pose_2d_mpii, visibilty 97 | 98 | CocoPairs = [ 99 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11), 100 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17) 101 | ] # = 19 102 | CocoPairsRender = CocoPairs[:-2] 103 | CocoPairsNetwork = [ 104 | (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5), 105 | (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27) 106 | ] # = 19 107 | 108 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 109 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 110 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 111 | 112 | 113 | def read_imgfile(path, width, height): 114 | val_image = cv2.imread(path, cv2.IMREAD_COLOR) 115 | if width is not None and height is not None: 116 | val_image = cv2.resize(val_image, (width, height)) 117 | return val_image 118 | 119 | 120 | def get_sample_images(w, h): 121 | val_image = [ 122 | read_imgfile('./images/p1.jpg', w, h), 123 | read_imgfile('./images/p2.jpg', w, h), 124 | read_imgfile('./images/p3.jpg', w, h), 125 | read_imgfile('./images/golf.jpg', w, h), 126 | read_imgfile('./images/hand1.jpg', w, h), 127 | read_imgfile('./images/hand2.jpg', w, h), 128 | read_imgfile('./images/apink1_crop.jpg', w, h), 129 | read_imgfile('./images/ski.jpg', w, h), 130 | read_imgfile('./images/apink2.jpg', w, h), 131 | read_imgfile('./images/apink3.jpg', w, h), 132 | read_imgfile('./images/handsup1.jpg', w, h), 133 | read_imgfile('./images/p3_dance.png', w, h), 134 | ] 135 | return val_image 136 | -------------------------------------------------------------------------------- /src/datum_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: datum.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='datum.proto', 20 | package='', 21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _DATUM = _descriptor.Descriptor( 29 | name='Datum', 30 | full_name='Datum', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='channels', full_name='Datum.channels', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='Datum.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='width', full_name='Datum.width', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='data', full_name='Datum.data', index=3, 58 | number=4, type=12, cpp_type=9, label=1, 59 | has_default_value=False, default_value=_b(""), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='label', full_name='Datum.label', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='float_data', full_name='Datum.float_data', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='encoded', full_name='Datum.encoded', index=6, 79 | number=7, type=8, cpp_type=7, label=1, 80 | has_default_value=True, default_value=False, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | options=None, 91 | is_extendable=False, 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=16, 96 | serialized_end=145, 97 | ) 98 | 99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM 100 | 101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict( 102 | DESCRIPTOR = _DATUM, 103 | __module__ = 'datum_pb2' 104 | # @@protoc_insertion_point(class_scope:Datum) 105 | )) 106 | _sym_db.RegisterMessage(Datum) 107 | 108 | 109 | # @@protoc_insertion_point(module_scope) 110 | -------------------------------------------------------------------------------- /src/lifting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/src/lifting/__init__.py -------------------------------------------------------------------------------- /src/lifting/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mar 23 11:57 2017 4 | 5 | @author: Denis Tome' 6 | """ 7 | 8 | __all__ = [ 9 | 'VISIBLE_PART', 10 | 'MIN_NUM_JOINTS', 11 | 'CENTER_TR', 12 | 'SIGMA', 13 | 'STRIDE', 14 | 'SIGMA_CENTER', 15 | 'INPUT_SIZE', 16 | 'OUTPUT_SIZE', 17 | 'NUM_JOINTS', 18 | 'NUM_OUTPUT', 19 | 'H36M_NUM_JOINTS', 20 | 'JOINT_DRAW_SIZE', 21 | 'LIMB_DRAW_SIZE' 22 | ] 23 | 24 | # threshold 25 | VISIBLE_PART = 1e-3 26 | MIN_NUM_JOINTS = 5 27 | CENTER_TR = 0.4 28 | 29 | # net attributes 30 | SIGMA = 7 31 | STRIDE = 8 32 | SIGMA_CENTER = 21 33 | INPUT_SIZE = 368 34 | OUTPUT_SIZE = 46 35 | NUM_JOINTS = 14 36 | NUM_OUTPUT = NUM_JOINTS + 1 37 | H36M_NUM_JOINTS = 17 38 | 39 | # draw options 40 | JOINT_DRAW_SIZE = 3 41 | LIMB_DRAW_SIZE = 2 42 | NORMALISATION_COEFFICIENT = 1280*720 43 | -------------------------------------------------------------------------------- /src/lifting/draw.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mar 23 15:04 2017 4 | 5 | @author: Denis Tome' 6 | """ 7 | import cv2 8 | import numpy as np 9 | from .config import JOINT_DRAW_SIZE 10 | from .config import LIMB_DRAW_SIZE 11 | from .config import NORMALISATION_COEFFICIENT 12 | import matplotlib.pyplot as plt 13 | import math 14 | 15 | __all__ = [ 16 | 'draw_limbs', 17 | 'plot_pose' 18 | ] 19 | 20 | 21 | def draw_limbs(image, pose_2d, visible): 22 | """Draw the 2D pose without the occluded/not visible joints.""" 23 | 24 | _COLORS = [ 25 | [0, 0, 255], [0, 170, 255], [0, 255, 170], [0, 255, 0], 26 | [170, 255, 0], [255, 170, 0], [255, 0, 0], [255, 0, 170], 27 | [170, 0, 255] 28 | ] 29 | _LIMBS = np.array([0, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8, 9, 30 | 9, 10, 11, 12, 12, 13]).reshape((-1, 2)) 31 | 32 | _NORMALISATION_FACTOR = int(math.floor(math.sqrt(image.shape[0] * image.shape[1] / NORMALISATION_COEFFICIENT))) 33 | 34 | for oid in range(pose_2d.shape[0]): 35 | for lid, (p0, p1) in enumerate(_LIMBS): 36 | if not (visible[oid][p0] and visible[oid][p1]): 37 | continue 38 | y0, x0 = pose_2d[oid][p0] 39 | y1, x1 = pose_2d[oid][p1] 40 | cv2.circle(image, (x0, y0), JOINT_DRAW_SIZE *_NORMALISATION_FACTOR , _COLORS[lid], -1) 41 | cv2.circle(image, (x1, y1), JOINT_DRAW_SIZE*_NORMALISATION_FACTOR , _COLORS[lid], -1) 42 | cv2.line(image, (x0, y0), (x1, y1), 43 | _COLORS[lid], LIMB_DRAW_SIZE*_NORMALISATION_FACTOR , 16) 44 | 45 | 46 | def plot_pose(pose): 47 | """Plot the 3D pose showing the joint connections.""" 48 | import mpl_toolkits.mplot3d.axes3d as p3 49 | 50 | _CONNECTION = [ 51 | [0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], 52 | [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], 53 | [15, 16]] 54 | 55 | def joint_color(j): 56 | """ 57 | TODO: 'j' shadows name 'j' from outer scope 58 | """ 59 | 60 | colors = [(0, 0, 0), (255, 0, 255), (0, 0, 255), 61 | (0, 255, 255), (255, 0, 0), (0, 255, 0)] 62 | _c = 0 63 | if j in range(1, 4): 64 | _c = 1 65 | if j in range(4, 7): 66 | _c = 2 67 | if j in range(9, 11): 68 | _c = 3 69 | if j in range(11, 14): 70 | _c = 4 71 | if j in range(14, 17): 72 | _c = 5 73 | return colors[_c] 74 | 75 | assert (pose.ndim == 2) 76 | assert (pose.shape[0] == 3) 77 | fig = plt.figure() 78 | ax = fig.gca(projection='3d') 79 | for c in _CONNECTION: 80 | col = '#%02x%02x%02x' % joint_color(c[0]) 81 | ax.plot([pose[0, c[0]], pose[0, c[1]]], 82 | [pose[1, c[0]], pose[1, c[1]]], 83 | [pose[2, c[0]], pose[2, c[1]]], c=col) 84 | for j in range(pose.shape[1]): 85 | col = '#%02x%02x%02x' % joint_color(j) 86 | ax.scatter(pose[0, j], pose[1, j], pose[2, j], 87 | c=col, marker='o', edgecolor=col) 88 | smallest = pose.min() 89 | largest = pose.max() 90 | ax.set_xlim3d(smallest, largest) 91 | ax.set_ylim3d(smallest, largest) 92 | ax.set_zlim3d(smallest, largest) 93 | 94 | return fig 95 | 96 | 97 | -------------------------------------------------------------------------------- /src/lifting/models/prob_model_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/src/lifting/models/prob_model_params.mat -------------------------------------------------------------------------------- /src/network_cmu.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | import tensorflow as tf 3 | 4 | 5 | class CmuNetwork(network_base.BaseNetwork): 6 | def setup(self): 7 | (self.feed('image') 8 | .normalize_vgg(name='preprocess') 9 | .conv(3, 3, 64, 1, 1, name='conv1_1') 10 | .conv(3, 3, 64, 1, 1, name='conv1_2') 11 | .max_pool(2, 2, 2, 2, name='pool1_stage1') 12 | .conv(3, 3, 128, 1, 1, name='conv2_1') 13 | .conv(3, 3, 128, 1, 1, name='conv2_2') 14 | .max_pool(2, 2, 2, 2, name='pool2_stage1') 15 | .conv(3, 3, 256, 1, 1, name='conv3_1') 16 | .conv(3, 3, 256, 1, 1, name='conv3_2') 17 | .conv(3, 3, 256, 1, 1, name='conv3_3') 18 | .conv(3, 3, 256, 1, 1, name='conv3_4') 19 | .max_pool(2, 2, 2, 2, name='pool3_stage1') 20 | .conv(3, 3, 512, 1, 1, name='conv4_1') 21 | .conv(3, 3, 512, 1, 1, name='conv4_2') 22 | .conv(3, 3, 256, 1, 1, name='conv4_3_CPM') 23 | .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # ***** 24 | 25 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1') 26 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1') 27 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1') 28 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 29 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 30 | 31 | (self.feed('conv4_4_CPM') 32 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2') 33 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2') 34 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2') 35 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 36 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 37 | 38 | (self.feed('conv5_5_CPM_L1', 39 | 'conv5_5_CPM_L2', 40 | 'conv4_4_CPM') 41 | .concat(3, name='concat_stage2') 42 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1') 43 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1') 44 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1') 45 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1') 46 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1') 47 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 48 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 49 | 50 | (self.feed('concat_stage2') 51 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2') 52 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2') 53 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2') 54 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2') 55 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2') 56 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 57 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 58 | 59 | (self.feed('Mconv7_stage2_L1', 60 | 'Mconv7_stage2_L2', 61 | 'conv4_4_CPM') 62 | .concat(3, name='concat_stage3') 63 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1') 64 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1') 65 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1') 66 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1') 67 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1') 68 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 69 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 70 | 71 | (self.feed('concat_stage3') 72 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2') 73 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2') 74 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2') 75 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2') 76 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2') 77 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 78 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 79 | 80 | (self.feed('Mconv7_stage3_L1', 81 | 'Mconv7_stage3_L2', 82 | 'conv4_4_CPM') 83 | .concat(3, name='concat_stage4') 84 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1') 85 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1') 86 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1') 87 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1') 88 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1') 89 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 90 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 91 | 92 | (self.feed('concat_stage4') 93 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2') 94 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2') 95 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2') 96 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2') 97 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2') 98 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 99 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 100 | 101 | (self.feed('Mconv7_stage4_L1', 102 | 'Mconv7_stage4_L2', 103 | 'conv4_4_CPM') 104 | .concat(3, name='concat_stage5') 105 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1') 106 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1') 107 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1') 108 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1') 109 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1') 110 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 111 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 112 | 113 | (self.feed('concat_stage5') 114 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2') 115 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2') 116 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2') 117 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2') 118 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2') 119 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 120 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 121 | 122 | (self.feed('Mconv7_stage5_L1', 123 | 'Mconv7_stage5_L2', 124 | 'conv4_4_CPM') 125 | .concat(3, name='concat_stage6') 126 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1') 127 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1') 128 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1') 129 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1') 130 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1') 131 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 132 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 133 | 134 | (self.feed('concat_stage6') 135 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2') 136 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2') 137 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2') 138 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2') 139 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2') 140 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 141 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 142 | 143 | with tf.variable_scope('Openpose'): 144 | (self.feed('Mconv7_stage6_L2', 145 | 'Mconv7_stage6_L1') 146 | .concat(3, name='concat_stage7')) 147 | 148 | def loss_l1_l2(self): 149 | l1s = [] 150 | l2s = [] 151 | for layer_name in self.layers.keys(): 152 | if 'Mconv7' in layer_name and '_L1' in layer_name: 153 | l1s.append(self.layers[layer_name]) 154 | if 'Mconv7' in layer_name and '_L2' in layer_name: 155 | l2s.append(self.layers[layer_name]) 156 | 157 | return l1s, l2s 158 | 159 | def loss_last(self): 160 | return self.get_output('Mconv7_stage6_L1'), self.get_output('Mconv7_stage6_L2') 161 | 162 | def restorable_variables(self): 163 | return None -------------------------------------------------------------------------------- /src/network_dsconv.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | 3 | 4 | class DSConvNetwork(network_base.BaseNetwork): 5 | def __init__(self, inputs, trainable=True, conv_width=1.0): 6 | self.conv_width = conv_width 7 | network_base.BaseNetwork.__init__(self, inputs, trainable) 8 | 9 | def setup(self): 10 | (self.feed('image') 11 | .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False) 12 | # .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=True) # TODO 13 | .separable_conv(3, 3, round(self.conv_width * 64), 2, name='conv1_2') 14 | # .max_pool(2, 2, 2, 2, name='pool1_stage1') 15 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv2_1') 16 | .separable_conv(3, 3, round(self.conv_width * 128), 2, name='conv2_2') 17 | # .max_pool(2, 2, 2, 2, name='pool2_stage1') 18 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_1') 19 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_2') 20 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_3') 21 | .separable_conv(3, 3, round(self.conv_width * 256), 2, name='conv3_4') 22 | # .max_pool(2, 2, 2, 2, name='pool3_stage1') 23 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_1') 24 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_2') 25 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv4_3_CPM') 26 | .separable_conv(3, 3, 128, 1, name='conv4_4_CPM') 27 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L1') 28 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L1') 29 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L1') 30 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 31 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 32 | 33 | (self.feed('conv4_4_CPM') 34 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L2') 35 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L2') 36 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L2') 37 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 38 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 39 | 40 | (self.feed('conv5_5_CPM_L1', 41 | 'conv5_5_CPM_L2', 42 | 'conv4_4_CPM') 43 | .concat(3, name='concat_stage2') 44 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L1') 45 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L1') 46 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L1') 47 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L1') 48 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L1') 49 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 50 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 51 | 52 | (self.feed('concat_stage2') 53 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L2') 54 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L2') 55 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L2') 56 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L2') 57 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L2') 58 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 59 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 60 | 61 | (self.feed('Mconv7_stage2_L1', 62 | 'Mconv7_stage2_L2', 63 | 'conv4_4_CPM') 64 | .concat(3, name='concat_stage3') 65 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L1') 66 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L1') 67 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L1') 68 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L1') 69 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L1') 70 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 71 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 72 | 73 | (self.feed('concat_stage3') 74 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L2') 75 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L2') 76 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L2') 77 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L2') 78 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L2') 79 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 80 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 81 | 82 | (self.feed('Mconv7_stage3_L1', 83 | 'Mconv7_stage3_L2', 84 | 'conv4_4_CPM') 85 | .concat(3, name='concat_stage4') 86 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L1') 87 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L1') 88 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L1') 89 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L1') 90 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L1') 91 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 92 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 93 | 94 | (self.feed('concat_stage4') 95 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L2') 96 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L2') 97 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L2') 98 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L2') 99 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L2') 100 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 101 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 102 | 103 | (self.feed('Mconv7_stage4_L1', 104 | 'Mconv7_stage4_L2', 105 | 'conv4_4_CPM') 106 | .concat(3, name='concat_stage5') 107 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L1') 108 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L1') 109 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L1') 110 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L1') 111 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L1') 112 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 113 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 114 | 115 | (self.feed('concat_stage5') 116 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L2') 117 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L2') 118 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L2') 119 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L2') 120 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L2') 121 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 122 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 123 | 124 | (self.feed('Mconv7_stage5_L1', 125 | 'Mconv7_stage5_L2', 126 | 'conv4_4_CPM') 127 | .concat(3, name='concat_stage6') 128 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L1') 129 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L1') 130 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L1') 131 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L1') 132 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L1') 133 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 134 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 135 | 136 | (self.feed('concat_stage6') 137 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L2') 138 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L2') 139 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L2') 140 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L2') 141 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L2') 142 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 143 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 144 | 145 | (self.feed('Mconv7_stage6_L2', 146 | 'Mconv7_stage6_L1') 147 | .concat(3, name='concat_stage7')) 148 | -------------------------------------------------------------------------------- /src/network_mobilenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import network_base 4 | 5 | 6 | class MobilenetNetwork(network_base.BaseNetwork): 7 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 8 | self.conv_width = conv_width 9 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 10 | self.num_refine = 4 11 | network_base.BaseNetwork.__init__(self, inputs, trainable) 12 | 13 | def setup(self): 14 | min_depth = 8 15 | depth = lambda d: max(int(d * self.conv_width), min_depth) 16 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 17 | 18 | with tf.variable_scope(None, 'MobilenetV1'): 19 | (self.feed('image') 20 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 21 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 22 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 23 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 24 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 25 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 29 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 30 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 31 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 32 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 33 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 34 | ) 35 | 36 | (self.feed('Conv2d_1').max_pool(2, 2, 2, 2, name='Conv2d_1_pool')) 37 | (self.feed('Conv2d_7').upsample(2, name='Conv2d_7_upsample')) 38 | 39 | (self.feed('Conv2d_1_pool', 'Conv2d_3', 'Conv2d_7_upsample') 40 | .concat(3, name='feat_concat')) 41 | 42 | feature_lv = 'feat_concat' 43 | with tf.variable_scope(None, 'Openpose'): 44 | prefix = 'MConv_Stage1' 45 | (self.feed(feature_lv) 46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 47 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 48 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 49 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 50 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 51 | 52 | (self.feed(feature_lv) 53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 54 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 55 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 56 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 57 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 58 | 59 | for stage_id in range(self.num_refine): 60 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 61 | prefix = 'MConv_Stage%d' % (stage_id + 2) 62 | (self.feed(prefix_prev + '_L1_5', 63 | prefix_prev + '_L2_5', 64 | feature_lv) 65 | .concat(3, name=prefix + '_concat') 66 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_1') 67 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_2') 68 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_3') 69 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 70 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 71 | 72 | (self.feed(prefix + '_concat') 73 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_1') 74 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_2') 75 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_3') 76 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 77 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 78 | 79 | # final result 80 | (self.feed('MConv_Stage%d_L2_5' % self.get_refine_num(), 81 | 'MConv_Stage%d_L1_5' % self.get_refine_num()) 82 | .concat(3, name='concat_stage7')) 83 | 84 | def loss_l1_l2(self): 85 | l1s = [] 86 | l2s = [] 87 | for layer_name in sorted(self.layers.keys()): 88 | if '_L1_5' in layer_name: 89 | l1s.append(self.layers[layer_name]) 90 | if '_L2_5' in layer_name: 91 | l2s.append(self.layers[layer_name]) 92 | 93 | return l1s, l2s 94 | 95 | def loss_last(self): 96 | return self.get_output('MConv_Stage%d_L1_5' % self.get_refine_num()), \ 97 | self.get_output('MConv_Stage%d_L2_5' % self.get_refine_num()) 98 | 99 | def restorable_variables(self): 100 | vs = {v.op.name: v for v in tf.global_variables() if 101 | 'MobilenetV1/Conv2d' in v.op.name and 102 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 'Ada' not in v.op.name 103 | } 104 | return vs 105 | 106 | def get_refine_num(self): 107 | return self.num_refine + 1 108 | -------------------------------------------------------------------------------- /src/network_mobilenet_thin.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import network_base 4 | 5 | 6 | class MobilenetNetworkThin(network_base.BaseNetwork): 7 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 8 | self.conv_width = conv_width 9 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 10 | network_base.BaseNetwork.__init__(self, inputs, trainable) 11 | 12 | def setup(self): 13 | min_depth = 8 14 | depth = lambda d: max(int(d * self.conv_width), min_depth) 15 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 16 | 17 | with tf.variable_scope(None, 'MobilenetV1'): 18 | (self.feed('image') 19 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 20 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 21 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 22 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 23 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 24 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 25 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 29 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 30 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 31 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 32 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 33 | ) 34 | 35 | (self.feed('Conv2d_3').max_pool(2, 2, 2, 2, name='Conv2d_3_pool')) 36 | 37 | (self.feed('Conv2d_3_pool', 'Conv2d_7', 'Conv2d_11') 38 | .concat(3, name='feat_concat')) 39 | 40 | feature_lv = 'feat_concat' 41 | with tf.variable_scope(None, 'Openpose'): 42 | prefix = 'MConv_Stage1' 43 | (self.feed(feature_lv) 44 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 45 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 47 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 48 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 49 | 50 | (self.feed(feature_lv) 51 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 52 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 54 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 55 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 56 | 57 | for stage_id in range(5): 58 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 59 | prefix = 'MConv_Stage%d' % (stage_id + 2) 60 | (self.feed(prefix_prev + '_L1_5', 61 | prefix_prev + '_L2_5', 62 | feature_lv) 63 | .concat(3, name=prefix + '_concat') 64 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 65 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 66 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 67 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 68 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 69 | 70 | (self.feed(prefix + '_concat') 71 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 72 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 73 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 74 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 75 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 76 | 77 | # final result 78 | (self.feed('MConv_Stage6_L2_5', 79 | 'MConv_Stage6_L1_5') 80 | .concat(3, name='concat_stage7')) 81 | 82 | def loss_l1_l2(self): 83 | l1s = [] 84 | l2s = [] 85 | for layer_name in sorted(self.layers.keys()): 86 | if '_L1_5' in layer_name: 87 | l1s.append(self.layers[layer_name]) 88 | if '_L2_5' in layer_name: 89 | l2s.append(self.layers[layer_name]) 90 | 91 | return l1s, l2s 92 | 93 | def loss_last(self): 94 | return self.get_output('MConv_Stage6_L1_5'), self.get_output('MConv_Stage6_L2_5') 95 | 96 | def restorable_variables(self): 97 | vs = {v.op.name: v for v in tf.global_variables() if 98 | 'MobilenetV1/Conv2d' in v.op.name and 99 | # 'global_step' not in v.op.name and 100 | # 'beta1_power' not in v.op.name and 'beta2_power' not in v.op.name and 101 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 102 | 'Ada' not in v.op.name and 'Adam' not in v.op.name 103 | } 104 | return vs 105 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | from network_mobilenet import MobilenetNetwork 5 | from network_mobilenet_thin import MobilenetNetworkThin 6 | 7 | from network_cmu import CmuNetwork 8 | 9 | 10 | def _get_base_path(): 11 | if not os.environ.get('OPENPOSE_MODEL', ''): 12 | return './models' 13 | return os.environ.get('OPENPOSE_MODEL') 14 | 15 | 16 | def get_network(type, placeholder_input, sess_for_load=None, trainable=True): 17 | if type == 'mobilenet': 18 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable) 19 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 20 | last_layer = 'MConv_Stage6_L{aux}_5' 21 | elif type == 'mobilenet_fast': 22 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.5, conv_width2=0.5, trainable=trainable) 23 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 24 | last_layer = 'MConv_Stage6_L{aux}_5' 25 | elif type == 'mobilenet_accurate': 26 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=1.00, conv_width2=1.00, trainable=trainable) 27 | pretrain_path = 'pretrained/mobilenet_v1_1.0_224_2017_06_14/mobilenet_v1_1.0_224.ckpt' 28 | last_layer = 'MConv_Stage6_L{aux}_5' 29 | 30 | elif type == 'mobilenet_thin': 31 | net = MobilenetNetworkThin({'image': placeholder_input}, conv_width=0.75, conv_width2=0.50, trainable=trainable) 32 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_1.0_224.ckpt' 33 | last_layer = 'MConv_Stage6_L{aux}_5' 34 | 35 | elif type == 'cmu': 36 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable) 37 | pretrain_path = 'numpy/openpose_coco.npy' 38 | last_layer = 'Mconv7_stage6_L{aux}' 39 | elif type == 'vgg': 40 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable) 41 | pretrain_path = 'numpy/openpose_vgg16.npy' 42 | last_layer = 'Mconv7_stage6_L{aux}' 43 | else: 44 | raise Exception('Invalid Mode.') 45 | 46 | pretrain_path_full = os.path.join(_get_base_path(), pretrain_path) 47 | if sess_for_load is not None: 48 | if type == 'cmu' or type == 'vgg': 49 | if not os.path.isfile(pretrain_path_full): 50 | raise Exception('Model file doesn\'t exist, path=%s' % pretrain_path_full) 51 | net.load(os.path.join(_get_base_path(), pretrain_path), sess_for_load) 52 | else: 53 | s = '%dx%d' % (placeholder_input.shape[2], placeholder_input.shape[1]) 54 | ckpts = { 55 | 'mobilenet': 'trained/mobilenet_%s/model-246038' % s, 56 | 'mobilenet_thin': 'trained/mobilenet_thin_%s/model-449003' % s, 57 | 'mobilenet_fast': 'trained/mobilenet_fast_%s/model-189000' % s, 58 | 'mobilenet_accurate': 'trained/mobilenet_accurate/model-170000' 59 | } 60 | ckpt_path = os.path.join(_get_base_path(), ckpts[type]) 61 | loader = tf.train.Saver() 62 | try: 63 | loader.restore(sess_for_load, ckpt_path) 64 | except Exception as e: 65 | raise Exception('Fail to load model files. \npath=%s\nerr=%s' % (ckpt_path, str(e))) 66 | 67 | return net, pretrain_path_full, last_layer 68 | 69 | 70 | def get_graph_path(model_name): 71 | dyn_graph_path = { 72 | 'cmu': './models/graph/cmu/graph_opt.pb', 73 | 'mobilenet_thin': './models/graph/mobilenet_thin/graph_opt.pb' 74 | } 75 | graph_path = dyn_graph_path[model_name] 76 | for path in (graph_path, os.path.join(os.path.dirname(os.path.abspath(__file__)), graph_path), os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', graph_path)): 77 | if not os.path.isfile(path): 78 | continue 79 | return path 80 | raise Exception('Graph file doesn\'t exist, path=%s' % graph_path) 81 | 82 | 83 | def model_wh(resolution_str): 84 | width, height = map(int, resolution_str.split('x')) 85 | if width % 16 != 0 or height % 16 != 0: 86 | raise Exception('Width and height should be multiples of 16. w=%d, h=%d' % (width, height)) 87 | return int(width), int(height) 88 | -------------------------------------------------------------------------------- /src/pose_augment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | from tensorpack.dataflow.imgaug.geometry import RotationAndCropValid 7 | 8 | from common import CocoPart 9 | 10 | _network_w = 368 11 | _network_h = 368 12 | _scale = 2 13 | 14 | 15 | def set_network_input_wh(w, h): 16 | global _network_w, _network_h 17 | _network_w, _network_h = w, h 18 | 19 | 20 | def set_network_scale(scale): 21 | global _scale 22 | _scale = scale 23 | 24 | 25 | def pose_random_scale(meta): 26 | scalew = random.uniform(0.8, 1.2) 27 | scaleh = random.uniform(0.8, 1.2) 28 | neww = int(meta.width * scalew) 29 | newh = int(meta.height * scaleh) 30 | dst = cv2.resize(meta.img, (neww, newh), interpolation=cv2.INTER_AREA) 31 | 32 | # adjust meta data 33 | adjust_joint_list = [] 34 | for joint in meta.joint_list: 35 | adjust_joint = [] 36 | for point in joint: 37 | if point[0] < -100 or point[1] < -100: 38 | adjust_joint.append((-1000, -1000)) 39 | continue 40 | # if point[0] <= 0 or point[1] <= 0 or int(point[0] * scalew + 0.5) > neww or int( 41 | # point[1] * scaleh + 0.5) > newh: 42 | # adjust_joint.append((-1, -1)) 43 | # continue 44 | adjust_joint.append((int(point[0] * scalew + 0.5), int(point[1] * scaleh + 0.5))) 45 | adjust_joint_list.append(adjust_joint) 46 | 47 | meta.joint_list = adjust_joint_list 48 | meta.width, meta.height = neww, newh 49 | meta.img = dst 50 | return meta 51 | 52 | 53 | def pose_resize_shortestedge_fixed(meta): 54 | ratio_w = _network_w / meta.width 55 | ratio_h = _network_h / meta.height 56 | ratio = max(ratio_w, ratio_h) 57 | return pose_resize_shortestedge(meta, int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5))) 58 | 59 | 60 | def pose_resize_shortestedge_random(meta): 61 | ratio_w = _network_w / meta.width 62 | ratio_h = _network_h / meta.height 63 | ratio = min(ratio_w, ratio_h) 64 | target_size = int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5)) 65 | target_size = int(target_size * random.uniform(0.95, 1.6)) 66 | # target_size = int(min(_network_w, _network_h) * random.uniform(0.7, 1.5)) 67 | return pose_resize_shortestedge(meta, target_size) 68 | 69 | 70 | def pose_resize_shortestedge(meta, target_size): 71 | global _network_w, _network_h 72 | img = meta.img 73 | 74 | # adjust image 75 | scale = target_size / min(meta.height, meta.width) 76 | if meta.height < meta.width: 77 | newh, neww = target_size, int(scale * meta.width + 0.5) 78 | else: 79 | newh, neww = int(scale * meta.height + 0.5), target_size 80 | 81 | dst = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA) 82 | 83 | pw = ph = 0 84 | if neww < _network_w or newh < _network_h: 85 | pw = max(0, (_network_w - neww) // 2) 86 | ph = max(0, (_network_h - newh) // 2) 87 | mw = (_network_w - neww) % 2 88 | mh = (_network_h - newh) % 2 89 | color = random.randint(0, 255) 90 | dst = cv2.copyMakeBorder(dst, ph, ph+mh, pw, pw+mw, cv2.BORDER_CONSTANT, value=(color, 0, 0)) 91 | 92 | # adjust meta data 93 | adjust_joint_list = [] 94 | for joint in meta.joint_list: 95 | adjust_joint = [] 96 | for point in joint: 97 | if point[0] < -100 or point[1] < -100: 98 | adjust_joint.append((-1000, -1000)) 99 | continue 100 | # if point[0] <= 0 or point[1] <= 0 or int(point[0]*scale+0.5) > neww or int(point[1]*scale+0.5) > newh: 101 | # adjust_joint.append((-1, -1)) 102 | # continue 103 | adjust_joint.append((int(point[0]*scale+0.5) + pw, int(point[1]*scale+0.5) + ph)) 104 | adjust_joint_list.append(adjust_joint) 105 | 106 | meta.joint_list = adjust_joint_list 107 | meta.width, meta.height = neww + pw * 2, newh + ph * 2 108 | meta.img = dst 109 | return meta 110 | 111 | 112 | def pose_crop_center(meta): 113 | global _network_w, _network_h 114 | target_size = (_network_w, _network_h) 115 | x = (meta.width - target_size[0]) // 2 if meta.width > target_size[0] else 0 116 | y = (meta.height - target_size[1]) // 2 if meta.height > target_size[1] else 0 117 | 118 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 119 | 120 | 121 | def pose_crop_random(meta): 122 | global _network_w, _network_h 123 | target_size = (_network_w, _network_h) 124 | 125 | for _ in range(50): 126 | x = random.randrange(0, meta.width - target_size[0]) if meta.width > target_size[0] else 0 127 | y = random.randrange(0, meta.height - target_size[1]) if meta.height > target_size[1] else 0 128 | 129 | # check whether any face is inside the box to generate a reasonably-balanced datasets 130 | for joint in meta.joint_list: 131 | if x <= joint[CocoPart.Nose.value][0] < x + target_size[0] and y <= joint[CocoPart.Nose.value][1] < y + target_size[1]: 132 | break 133 | 134 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 135 | 136 | 137 | def pose_crop(meta, x, y, w, h): 138 | # adjust image 139 | target_size = (w, h) 140 | 141 | img = meta.img 142 | resized = img[y:y+target_size[1], x:x+target_size[0], :] 143 | 144 | # adjust meta data 145 | adjust_joint_list = [] 146 | for joint in meta.joint_list: 147 | adjust_joint = [] 148 | for point in joint: 149 | if point[0] < -100 or point[1] < -100: 150 | adjust_joint.append((-1000, -1000)) 151 | continue 152 | # if point[0] <= 0 or point[1] <= 0: 153 | # adjust_joint.append((-1000, -1000)) 154 | # continue 155 | new_x, new_y = point[0] - x, point[1] - y 156 | # if new_x <= 0 or new_y <= 0 or new_x > target_size[0] or new_y > target_size[1]: 157 | # adjust_joint.append((-1, -1)) 158 | # continue 159 | adjust_joint.append((new_x, new_y)) 160 | adjust_joint_list.append(adjust_joint) 161 | 162 | meta.joint_list = adjust_joint_list 163 | meta.width, meta.height = target_size 164 | meta.img = resized 165 | return meta 166 | 167 | 168 | def pose_flip(meta): 169 | r = random.uniform(0, 1.0) 170 | if r > 0.5: 171 | return meta 172 | 173 | img = meta.img 174 | img = cv2.flip(img, 1) 175 | 176 | # flip meta 177 | flip_list = [CocoPart.Nose, CocoPart.Neck, CocoPart.LShoulder, CocoPart.LElbow, CocoPart.LWrist, CocoPart.RShoulder, CocoPart.RElbow, CocoPart.RWrist, 178 | CocoPart.LHip, CocoPart.LKnee, CocoPart.LAnkle, CocoPart.RHip, CocoPart.RKnee, CocoPart.RAnkle, 179 | CocoPart.LEye, CocoPart.REye, CocoPart.LEar, CocoPart.REar, CocoPart.Background] 180 | adjust_joint_list = [] 181 | for joint in meta.joint_list: 182 | adjust_joint = [] 183 | for cocopart in flip_list: 184 | point = joint[cocopart.value] 185 | if point[0] < -100 or point[1] < -100: 186 | adjust_joint.append((-1000, -1000)) 187 | continue 188 | # if point[0] <= 0 or point[1] <= 0: 189 | # adjust_joint.append((-1, -1)) 190 | # continue 191 | adjust_joint.append((meta.width - point[0], point[1])) 192 | adjust_joint_list.append(adjust_joint) 193 | 194 | meta.joint_list = adjust_joint_list 195 | 196 | meta.img = img 197 | return meta 198 | 199 | 200 | def pose_rotation(meta): 201 | deg = random.uniform(-15.0, 15.0) 202 | img = meta.img 203 | 204 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5) # x, y 205 | rot_m = cv2.getRotationMatrix2D((int(center[0]), int(center[1])), deg, 1) 206 | ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], flags=cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT) 207 | if img.ndim == 3 and ret.ndim == 2: 208 | ret = ret[:, :, np.newaxis] 209 | neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) 210 | neww = min(neww, ret.shape[1]) 211 | newh = min(newh, ret.shape[0]) 212 | newx = int(center[0] - neww * 0.5) 213 | newy = int(center[1] - newh * 0.5) 214 | # print(ret.shape, deg, newx, newy, neww, newh) 215 | img = ret[newy:newy + newh, newx:newx + neww] 216 | 217 | # adjust meta data 218 | adjust_joint_list = [] 219 | for joint in meta.joint_list: 220 | adjust_joint = [] 221 | for point in joint: 222 | if point[0] < -100 or point[1] < -100: 223 | adjust_joint.append((-1000, -1000)) 224 | continue 225 | # if point[0] <= 0 or point[1] <= 0: 226 | # adjust_joint.append((-1, -1)) 227 | # continue 228 | x, y = _rotate_coord((meta.width, meta.height), (newx, newy), point, deg) 229 | adjust_joint.append((x, y)) 230 | adjust_joint_list.append(adjust_joint) 231 | 232 | meta.joint_list = adjust_joint_list 233 | meta.width, meta.height = neww, newh 234 | meta.img = img 235 | 236 | return meta 237 | 238 | 239 | def _rotate_coord(shape, newxy, point, angle): 240 | angle = -1 * angle / 180.0 * math.pi 241 | 242 | ox, oy = shape 243 | px, py = point 244 | 245 | ox /= 2 246 | oy /= 2 247 | 248 | qx = math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) 249 | qy = math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) 250 | 251 | new_x, new_y = newxy 252 | 253 | qx += ox - new_x 254 | qy += oy - new_y 255 | 256 | return int(qx + 0.5), int(qy + 0.5) 257 | 258 | 259 | def pose_to_img(meta_l): 260 | global _network_w, _network_h, _scale 261 | return [ 262 | meta_l[0].img.astype(np.float16), 263 | meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)), 264 | meta_l[0].get_vectormap(target_size=(_network_w // _scale, _network_h // _scale)) 265 | ] 266 | -------------------------------------------------------------------------------- /src/pose_datamaster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | from tensorpack.dataflow.remote import RemoteDataZMQ 6 | 7 | from pose_dataset import CocoPose 8 | 9 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 10 | 11 | if __name__ == '__main__': 12 | """ 13 | Speed Test for Getting Input batches from other nodes 14 | """ 15 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 16 | parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027') 17 | parser.add_argument('--show', type=bool, default=False) 18 | args = parser.parse_args() 19 | 20 | df = RemoteDataZMQ(args.listen) 21 | 22 | logging.info('tcp queue start') 23 | df.reset_state() 24 | t = time.time() 25 | for i, dp in enumerate(df.get_data()): 26 | if i == 100: 27 | break 28 | logging.info('Input batch %d received.' % i) 29 | if i == 0: 30 | for d in dp: 31 | logging.info('%d dp shape={}'.format(d.shape)) 32 | 33 | if args.show: 34 | CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0]) 35 | 36 | logging.info('Speed Test Done for 100 Batches in %f seconds.' % (time.time() - t)) 37 | -------------------------------------------------------------------------------- /src/pose_dataworker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tensorpack.dataflow.remote import send_dataflow_zmq 4 | 5 | from pose_dataset import get_dataflow_batch 6 | from pose_augment import set_network_input_wh, set_network_scale 7 | 8 | if __name__ == '__main__': 9 | """ 10 | OpenPose Data Preparation might be a bottleneck for training. 11 | You can run multiple workers to generate input batches in multi-nodes to make training process faster. 12 | """ 13 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 14 | parser.add_argument('--datapath', type=str, default='/coco/annotations/') 15 | parser.add_argument('--imgpath', type=str, default='/coco/') 16 | parser.add_argument('--batchsize', type=int, default=64) 17 | parser.add_argument('--train', type=bool, default=True) 18 | parser.add_argument('--master', type=str, default='tcp://csi-cluster-gpu20.dakao.io:1027') 19 | parser.add_argument('--input-width', type=int, default=368) 20 | parser.add_argument('--input-height', type=int, default=368) 21 | parser.add_argument('--scale-factor', type=int, default=2) 22 | args = parser.parse_args() 23 | 24 | set_network_input_wh(args.input_width, args.input_height) 25 | set_network_scale(args.scale_factor) 26 | 27 | df = get_dataflow_batch(args.datapath, args.train, args.batchsize, args.imgpath) 28 | 29 | send_dataflow_zmq(df, args.master, hwm=10) 30 | -------------------------------------------------------------------------------- /src/pose_stats.py: -------------------------------------------------------------------------------- 1 | from pose_dataset import CocoPose 2 | from tensorpack import imgaug 3 | from tensorpack.dataflow.common import MapDataComponent, MapData 4 | from tensorpack.dataflow.image import AugmentImageComponent 5 | 6 | from pose_augment import * 7 | 8 | 9 | def get_idx_hands_up(): 10 | from src.pose_augment import set_network_input_wh 11 | set_network_input_wh(368, 368) 12 | 13 | show_sample = True 14 | db = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train=True, decode_img=show_sample) 15 | db.reset_state() 16 | total_cnt = 0 17 | handup_cnt = 0 18 | for idx, metas in enumerate(db.get_data()): 19 | meta = metas[0] 20 | if len(meta.joint_list) <= 0: 21 | continue 22 | body = meta.joint_list[0] 23 | if body[CocoPart.Neck.value][1] <= 0: 24 | continue 25 | if body[CocoPart.LWrist.value][1] <= 0: 26 | continue 27 | if body[CocoPart.RWrist.value][1] <= 0: 28 | continue 29 | 30 | if body[CocoPart.Neck.value][1] > body[CocoPart.LWrist.value][1] or body[CocoPart.Neck.value][1] > body[CocoPart.RWrist.value][1]: 31 | print(meta.idx) 32 | handup_cnt += 1 33 | 34 | if show_sample: 35 | l1, l2, l3 = pose_to_img(metas) 36 | CocoPose.display_image(l1, l2, l3) 37 | 38 | total_cnt += 1 39 | 40 | print('%d / %d' % (handup_cnt, total_cnt)) 41 | 42 | 43 | def sample_augmentations(): 44 | ds = CocoPose('/data/public/rw/coco-pose-estimation-lmdb/', is_train=False, only_idx=0) 45 | ds = MapDataComponent(ds, pose_random_scale) 46 | ds = MapDataComponent(ds, pose_rotation) 47 | ds = MapDataComponent(ds, pose_flip) 48 | ds = MapDataComponent(ds, pose_resize_shortestedge_random) 49 | ds = MapDataComponent(ds, pose_crop_random) 50 | ds = MapData(ds, pose_to_img) 51 | augs = [ 52 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([ 53 | imgaug.GaussianBlur(3), 54 | imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 55 | imgaug.RandomOrderAug([ 56 | imgaug.BrightnessScale((0.8, 1.2), clip=False), 57 | imgaug.Contrast((0.8, 1.2), clip=False), 58 | # imgaug.Saturation(0.4, rgb=True), 59 | ]), 60 | ]), 0.7), 61 | ] 62 | ds = AugmentImageComponent(ds, augs) 63 | 64 | ds.reset_state() 65 | for l1, l2, l3 in ds.get_data(): 66 | CocoPose.display_image(l1, l2, l3) 67 | 68 | 69 | if __name__ == '__main__': 70 | # codes for tests 71 | # get_idx_hands_up() 72 | 73 | # show augmentation samples 74 | sample_augmentations() 75 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | import ast 5 | 6 | import common 7 | import cv2 8 | import numpy as np 9 | from estimator import TfPoseEstimator 10 | from networks import get_graph_path, model_wh 11 | 12 | from lifting.prob_model import Prob3dPose 13 | from lifting.draw import plot_pose 14 | 15 | logger = logging.getLogger('TfPoseEstimator') 16 | logger.setLevel(logging.DEBUG) 17 | ch = logging.StreamHandler() 18 | ch.setLevel(logging.DEBUG) 19 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 20 | ch.setFormatter(formatter) 21 | logger.addHandler(ch) 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser(description='tf-pose-estimation run') 26 | parser.add_argument('--image', type=str, default='./images/p1.jpg') 27 | parser.add_argument('--resolution', type=str, default='432x368', help='network input resolution. default=432x368') 28 | parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet_thin') 29 | parser.add_argument('--scales', type=str, default='[None]', help='for multiple scales, eg. [1.0, (1.1, 0.05)]') 30 | args = parser.parse_args() 31 | scales = ast.literal_eval(args.scales) 32 | 33 | w, h = model_wh(args.resolution) 34 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 35 | 36 | # estimate human poses from a single image ! 37 | image = common.read_imgfile(args.image, None, None) 38 | # image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21) 39 | t = time.time() 40 | humans = e.inference(image, scales=scales) 41 | elapsed = time.time() - t 42 | 43 | logger.info('inference image: %s in %.4f seconds.' % (args.image, elapsed)) 44 | 45 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 46 | # cv2.imshow('tf-pose-estimation result', image) 47 | # cv2.waitKey() 48 | 49 | import matplotlib.pyplot as plt 50 | 51 | fig = plt.figure() 52 | a = fig.add_subplot(2, 2, 1) 53 | a.set_title('Result') 54 | plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 55 | 56 | bgimg = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2RGB) 57 | bgimg = cv2.resize(bgimg, (e.heatMat.shape[1], e.heatMat.shape[0]), interpolation=cv2.INTER_AREA) 58 | 59 | # show network output 60 | a = fig.add_subplot(2, 2, 2) 61 | plt.imshow(bgimg, alpha=0.5) 62 | tmp = np.amax(e.heatMat[:, :, :-1], axis=2) 63 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5) 64 | plt.colorbar() 65 | 66 | tmp2 = e.pafMat.transpose((2, 0, 1)) 67 | tmp2_odd = np.amax(np.absolute(tmp2[::2, :, :]), axis=0) 68 | tmp2_even = np.amax(np.absolute(tmp2[1::2, :, :]), axis=0) 69 | 70 | a = fig.add_subplot(2, 2, 3) 71 | a.set_title('Vectormap-x') 72 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 73 | plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5) 74 | plt.colorbar() 75 | 76 | a = fig.add_subplot(2, 2, 4) 77 | a.set_title('Vectormap-y') 78 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 79 | plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5) 80 | plt.colorbar() 81 | plt.show() 82 | 83 | import sys 84 | sys.exit(0) 85 | 86 | logger.info('3d lifting initialization.') 87 | poseLifting = Prob3dPose('./src/lifting/models/prob_model_params.mat') 88 | 89 | image_h, image_w = image.shape[:2] 90 | standard_w = 640 91 | standard_h = 480 92 | 93 | pose_2d_mpiis = [] 94 | visibilities = [] 95 | for human in humans: 96 | pose_2d_mpii, visibility = common.MPIIPart.from_coco(human) 97 | pose_2d_mpiis.append([(int(x * standard_w + 0.5), int(y * standard_h + 0.5)) for x, y in pose_2d_mpii]) 98 | visibilities.append(visibility) 99 | 100 | pose_2d_mpiis = np.array(pose_2d_mpiis) 101 | visibilities = np.array(visibilities) 102 | transformed_pose2d, weights = poseLifting.transform_joints(pose_2d_mpiis, visibilities) 103 | pose_3d = poseLifting.compute_3d(transformed_pose2d, weights) 104 | 105 | for i, single_3d in enumerate(pose_3d): 106 | plot_pose(single_3d) 107 | plt.show() 108 | 109 | pass 110 | -------------------------------------------------------------------------------- /src/run_checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import tensorflow as tf 5 | 6 | from networks import get_network 7 | 8 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 9 | 10 | config = tf.ConfigProto() 11 | config.gpu_options.allocator_type = 'BFC' 12 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 13 | config.gpu_options.allow_growth = True 14 | 15 | 16 | if __name__ == '__main__': 17 | """ 18 | Use this script to just save graph and checkpoint. 19 | While training, checkpoints are saved. You can test them with this python code. 20 | """ 21 | parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor') 22 | parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet / mobilenet_thin') 23 | args = parser.parse_args() 24 | 25 | input_node = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='image') 26 | 27 | with tf.Session(config=config) as sess: 28 | net, _, last_layer = get_network(args.model, input_node, sess, trainable=False) 29 | 30 | tf.train.write_graph(sess.graph_def, './tmp', 'graph.pb', as_text=True) 31 | 32 | graph = tf.get_default_graph() 33 | dir(graph) 34 | for n in tf.get_default_graph().as_graph_def().node: 35 | if 'concat_stage' not in n.name: 36 | continue 37 | print(n.name) 38 | 39 | saver = tf.train.Saver(max_to_keep=100) 40 | saver.save(sess, '/Users/ildoonet/repos/tf-openpose/tmp/chk', global_step=1) 41 | -------------------------------------------------------------------------------- /src/run_directory.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | import glob 5 | import ast 6 | import os 7 | import dill 8 | 9 | import common 10 | import cv2 11 | import numpy as np 12 | from estimator import TfPoseEstimator 13 | from networks import get_graph_path, model_wh 14 | 15 | from lifting.prob_model import Prob3dPose 16 | from lifting.draw import plot_pose 17 | 18 | logger = logging.getLogger('TfPoseEstimator') 19 | logger.setLevel(logging.DEBUG) 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.DEBUG) 22 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 23 | ch.setFormatter(formatter) 24 | logger.addHandler(ch) 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser(description='tf-pose-estimation run by folder') 29 | parser.add_argument('--folder', type=str, default='./images/') 30 | parser.add_argument('--resolution', type=str, default='432x368', help='network input resolution. default=432x368') 31 | parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet_thin') 32 | parser.add_argument('--scales', type=str, default='[None]', help='for multiple scales, eg. [1.0, (1.1, 0.05)]') 33 | args = parser.parse_args() 34 | scales = ast.literal_eval(args.scales) 35 | 36 | w, h = model_wh(args.resolution) 37 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 38 | 39 | files_grabbed = glob.glob(os.path.join(args.folder, '*.jpg')) 40 | all_humans = dict() 41 | for i, file in enumerate(files_grabbed): 42 | # estimate human poses from a single image ! 43 | image = common.read_imgfile(file, None, None) 44 | t = time.time() 45 | humans = e.inference(image, scales=scales) 46 | elapsed = time.time() - t 47 | 48 | logger.info('inference image #%d: %s in %.4f seconds.' % (i, file, elapsed)) 49 | 50 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 51 | cv2.imshow('tf-pose-estimation result', image) 52 | cv2.waitKey(5) 53 | 54 | all_humans[file.replace(args.folder, '')] = humans 55 | 56 | with open(os.path.join(args.folder, 'pose.dil'), 'wb') as f: 57 | dill.dump(all_humans, f, protocol=dill.HIGHEST_PROTOCOL) 58 | -------------------------------------------------------------------------------- /src/run_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from estimator import TfPoseEstimator 9 | from networks import get_graph_path, model_wh 10 | 11 | logger = logging.getLogger('TfPoseEstimator-Video') 12 | logger.setLevel(logging.DEBUG) 13 | ch = logging.StreamHandler() 14 | ch.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | 19 | fps_time = 0 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='tf-pose-estimation Video') 24 | parser.add_argument('--video', type=str, default='') 25 | parser.add_argument('--zoom', type=float, default=1.0) 26 | parser.add_argument('--resolution', type=str, default='432x368', help='network input resolution. default=432x368') 27 | parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet_thin') 28 | parser.add_argument('--show-process', type=bool, default=False, 29 | help='for debug purpose, if enabled, speed for inference is dropped.') 30 | args = parser.parse_args() 31 | 32 | logger.debug('initialization %s : %s' % (args.model, get_graph_path(args.model))) 33 | w, h = model_wh(args.resolution) 34 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 35 | #logger.debug('cam read+') 36 | #cam = cv2.VideoCapture(args.camera) 37 | cap = cv2.VideoCapture(args.video) 38 | #ret_val, image = cap.read() 39 | #logger.info('cam image=%dx%d' % (image.shape[1], image.shape[0])) 40 | if (cap.isOpened()== False): 41 | print("Error opening video stream or file") 42 | while(cap.isOpened()): 43 | ret_val, image = cap.read() 44 | 45 | 46 | humans = e.inference(image) 47 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 48 | 49 | #logger.debug('show+') 50 | cv2.putText(image, 51 | "FPS: %f" % (1.0 / (time.time() - fps_time)), 52 | (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 53 | (0, 255, 0), 2) 54 | cv2.imshow('tf-pose-estimation result', image) 55 | fps_time = time.time() 56 | if cv2.waitKey(1) == 27: 57 | break 58 | 59 | 60 | cv2.destroyAllWindows() 61 | logger.debug('finished+') 62 | -------------------------------------------------------------------------------- /src/run_webcam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from estimator import TfPoseEstimator 9 | from networks import get_graph_path, model_wh 10 | 11 | logger = logging.getLogger('TfPoseEstimator-WebCam') 12 | logger.setLevel(logging.DEBUG) 13 | ch = logging.StreamHandler() 14 | ch.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | 19 | fps_time = 0 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='tf-pose-estimation realtime webcam') 24 | parser.add_argument('--camera', type=int, default=0) 25 | parser.add_argument('--zoom', type=float, default=1.0) 26 | parser.add_argument('--resolution', type=str, default='432x368', help='network input resolution. default=432x368') 27 | parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet_thin') 28 | parser.add_argument('--show-process', type=bool, default=False, 29 | help='for debug purpose, if enabled, speed for inference is dropped.') 30 | args = parser.parse_args() 31 | 32 | logger.debug('initialization %s : %s' % (args.model, get_graph_path(args.model))) 33 | w, h = model_wh(args.resolution) 34 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 35 | logger.debug('cam read+') 36 | cam = cv2.VideoCapture(args.camera) 37 | ret_val, image = cam.read() 38 | logger.info('cam image=%dx%d' % (image.shape[1], image.shape[0])) 39 | 40 | while True: 41 | ret_val, image = cam.read() 42 | 43 | logger.debug('image preprocess+') 44 | if args.zoom < 1.0: 45 | canvas = np.zeros_like(image) 46 | img_scaled = cv2.resize(image, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR) 47 | dx = (canvas.shape[1] - img_scaled.shape[1]) // 2 48 | dy = (canvas.shape[0] - img_scaled.shape[0]) // 2 49 | canvas[dy:dy + img_scaled.shape[0], dx:dx + img_scaled.shape[1]] = img_scaled 50 | image = canvas 51 | elif args.zoom > 1.0: 52 | img_scaled = cv2.resize(image, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR) 53 | dx = (img_scaled.shape[1] - image.shape[1]) // 2 54 | dy = (img_scaled.shape[0] - image.shape[0]) // 2 55 | image = img_scaled[dy:image.shape[0], dx:image.shape[1]] 56 | 57 | logger.debug('image process+') 58 | humans = e.inference(image) 59 | 60 | logger.debug('postprocess+') 61 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 62 | 63 | logger.debug('show+') 64 | cv2.putText(image, 65 | "FPS: %f" % (1.0 / (time.time() - fps_time)), 66 | (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 67 | (0, 255, 0), 2) 68 | cv2.imshow('tf-pose-estimation result', image) 69 | fps_time = time.time() 70 | if cv2.waitKey(1) == 27: 71 | break 72 | logger.debug('finished+') 73 | 74 | cv2.destroyAllWindows() 75 | -------------------------------------------------------------------------------- /src/slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/src/slim/__init__.py -------------------------------------------------------------------------------- /src/slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/slim/nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /src/slim/nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 300, 400 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 224, 224 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 70 | expected_names = ['alexnet_v2/conv1', 71 | 'alexnet_v2/pool1', 72 | 'alexnet_v2/conv2', 73 | 'alexnet_v2/pool2', 74 | 'alexnet_v2/conv3', 75 | 'alexnet_v2/conv4', 76 | 'alexnet_v2/conv5', 77 | 'alexnet_v2/pool5', 78 | 'alexnet_v2/fc6', 79 | 'alexnet_v2/fc7', 80 | 'alexnet_v2/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 224, 224 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes) 91 | expected_names = ['alexnet_v2/conv1', 92 | 'alexnet_v2/pool1', 93 | 'alexnet_v2/conv2', 94 | 'alexnet_v2/pool2', 95 | 'alexnet_v2/conv3', 96 | 'alexnet_v2/conv4', 97 | 'alexnet_v2/conv5', 98 | 'alexnet_v2/pool5', 99 | 'alexnet_v2/fc6', 100 | 'alexnet_v2/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7')) 104 | self.assertListEqual(net.get_shape().as_list(), 105 | [batch_size, 1, 1, 4096]) 106 | 107 | def testModelVariables(self): 108 | batch_size = 5 109 | height, width = 224, 224 110 | num_classes = 1000 111 | with self.test_session(): 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | alexnet.alexnet_v2(inputs, num_classes) 114 | expected_names = ['alexnet_v2/conv1/weights', 115 | 'alexnet_v2/conv1/biases', 116 | 'alexnet_v2/conv2/weights', 117 | 'alexnet_v2/conv2/biases', 118 | 'alexnet_v2/conv3/weights', 119 | 'alexnet_v2/conv3/biases', 120 | 'alexnet_v2/conv4/weights', 121 | 'alexnet_v2/conv4/biases', 122 | 'alexnet_v2/conv5/weights', 123 | 'alexnet_v2/conv5/biases', 124 | 'alexnet_v2/fc6/weights', 125 | 'alexnet_v2/fc6/biases', 126 | 'alexnet_v2/fc7/weights', 127 | 'alexnet_v2/fc7/biases', 128 | 'alexnet_v2/fc8/weights', 129 | 'alexnet_v2/fc8/biases', 130 | ] 131 | model_variables = [v.op.name for v in slim.get_model_variables()] 132 | self.assertSetEqual(set(model_variables), set(expected_names)) 133 | 134 | def testEvaluation(self): 135 | batch_size = 2 136 | height, width = 224, 224 137 | num_classes = 1000 138 | with self.test_session(): 139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 141 | self.assertListEqual(logits.get_shape().as_list(), 142 | [batch_size, num_classes]) 143 | predictions = tf.argmax(logits, 1) 144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 145 | 146 | def testTrainEvalWithReuse(self): 147 | train_batch_size = 2 148 | eval_batch_size = 1 149 | train_height, train_width = 224, 224 150 | eval_height, eval_width = 300, 400 151 | num_classes = 1000 152 | with self.test_session(): 153 | train_inputs = tf.random_uniform( 154 | (train_batch_size, train_height, train_width, 3)) 155 | logits, _ = alexnet.alexnet_v2(train_inputs) 156 | self.assertListEqual(logits.get_shape().as_list(), 157 | [train_batch_size, num_classes]) 158 | tf.get_variable_scope().reuse_variables() 159 | eval_inputs = tf.random_uniform( 160 | (eval_batch_size, eval_height, eval_width, 3)) 161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 162 | spatial_squeeze=False) 163 | self.assertListEqual(logits.get_shape().as_list(), 164 | [eval_batch_size, 4, 7, num_classes]) 165 | logits = tf.reduce_mean(logits, [1, 2]) 166 | predictions = tf.argmax(logits, 1) 167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 168 | 169 | def testForward(self): 170 | batch_size = 1 171 | height, width = 224, 224 172 | with self.test_session() as sess: 173 | inputs = tf.random_uniform((batch_size, height, width, 3)) 174 | logits, _ = alexnet.alexnet_v2(inputs) 175 | sess.run(tf.global_variables_initializer()) 176 | output = sess.run(logits) 177 | self.assertTrue(output.any()) 178 | 179 | if __name__ == '__main__': 180 | tf.test.main() 181 | -------------------------------------------------------------------------------- /src/slim/nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /src/slim/nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /src/slim/nets/dcgan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from math import log 21 | 22 | import tensorflow as tf 23 | slim = tf.contrib.slim 24 | 25 | 26 | def _validate_image_inputs(inputs): 27 | inputs.get_shape().assert_has_rank(4) 28 | inputs.get_shape()[1:3].assert_is_fully_defined() 29 | if inputs.get_shape()[1] != inputs.get_shape()[2]: 30 | raise ValueError('Input tensor does not have equal width and height: ', 31 | inputs.get_shape()[1:3]) 32 | width = inputs.get_shape().as_list()[1] 33 | if log(width, 2) != int(log(width, 2)): 34 | raise ValueError('Input tensor `width` is not a power of 2: ', width) 35 | 36 | 37 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 38 | # setups need the gradient of gradient FusedBatchNormGrad. 39 | def discriminator(inputs, 40 | depth=64, 41 | is_training=True, 42 | reuse=None, 43 | scope='Discriminator', 44 | fused_batch_norm=False): 45 | """Discriminator network for DCGAN. 46 | 47 | Construct discriminator network from inputs to the final endpoint. 48 | 49 | Args: 50 | inputs: A tensor of size [batch_size, height, width, channels]. Must be 51 | floating point. 52 | depth: Number of channels in first convolution layer. 53 | is_training: Whether the network is for training or not. 54 | reuse: Whether or not the network variables should be reused. `scope` 55 | must be given to be reused. 56 | scope: Optional variable_scope. 57 | fused_batch_norm: If `True`, use a faster, fused implementation of 58 | batch norm. 59 | 60 | Returns: 61 | logits: The pre-softmax activations, a tensor of size [batch_size, 1] 62 | end_points: a dictionary from components of the network to their activation. 63 | 64 | Raises: 65 | ValueError: If the input image shape is not 4-dimensional, if the spatial 66 | dimensions aren't defined at graph construction time, if the spatial 67 | dimensions aren't square, or if the spatial dimensions aren't a power of 68 | two. 69 | """ 70 | 71 | normalizer_fn = slim.batch_norm 72 | normalizer_fn_args = { 73 | 'is_training': is_training, 74 | 'zero_debias_moving_mean': True, 75 | 'fused': fused_batch_norm, 76 | } 77 | 78 | _validate_image_inputs(inputs) 79 | inp_shape = inputs.get_shape().as_list()[1] 80 | 81 | end_points = {} 82 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 83 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 84 | with slim.arg_scope([slim.conv2d], 85 | stride=2, 86 | kernel_size=4, 87 | activation_fn=tf.nn.leaky_relu): 88 | net = inputs 89 | for i in xrange(int(log(inp_shape, 2))): 90 | scope = 'conv%i' % (i + 1) 91 | current_depth = depth * 2**i 92 | normalizer_fn_ = None if i == 0 else normalizer_fn 93 | net = slim.conv2d( 94 | net, current_depth, normalizer_fn=normalizer_fn_, scope=scope) 95 | end_points[scope] = net 96 | 97 | logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID', 98 | normalizer_fn=None, activation_fn=None) 99 | logits = tf.reshape(logits, [-1, 1]) 100 | end_points['logits'] = logits 101 | 102 | return logits, end_points 103 | 104 | 105 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 106 | # setups need the gradient of gradient FusedBatchNormGrad. 107 | def generator(inputs, 108 | depth=64, 109 | final_size=32, 110 | num_outputs=3, 111 | is_training=True, 112 | reuse=None, 113 | scope='Generator', 114 | fused_batch_norm=False): 115 | """Generator network for DCGAN. 116 | 117 | Construct generator network from inputs to the final endpoint. 118 | 119 | Args: 120 | inputs: A tensor with any size N. [batch_size, N] 121 | depth: Number of channels in last deconvolution layer. 122 | final_size: The shape of the final output. 123 | num_outputs: Number of output features. For images, this is the number of 124 | channels. 125 | is_training: whether is training or not. 126 | reuse: Whether or not the network has its variables should be reused. scope 127 | must be given to be reused. 128 | scope: Optional variable_scope. 129 | fused_batch_norm: If `True`, use a faster, fused implementation of 130 | batch norm. 131 | 132 | Returns: 133 | logits: the pre-softmax activations, a tensor of size 134 | [batch_size, 32, 32, channels] 135 | end_points: a dictionary from components of the network to their activation. 136 | 137 | Raises: 138 | ValueError: If `inputs` is not 2-dimensional. 139 | ValueError: If `final_size` isn't a power of 2 or is less than 8. 140 | """ 141 | normalizer_fn = slim.batch_norm 142 | normalizer_fn_args = { 143 | 'is_training': is_training, 144 | 'zero_debias_moving_mean': True, 145 | 'fused': fused_batch_norm, 146 | } 147 | 148 | inputs.get_shape().assert_has_rank(2) 149 | if log(final_size, 2) != int(log(final_size, 2)): 150 | raise ValueError('`final_size` (%i) must be a power of 2.' % final_size) 151 | if final_size < 8: 152 | raise ValueError('`final_size` (%i) must be greater than 8.' % final_size) 153 | 154 | end_points = {} 155 | num_layers = int(log(final_size, 2)) - 1 156 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 157 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 158 | with slim.arg_scope([slim.conv2d_transpose], 159 | normalizer_fn=normalizer_fn, 160 | stride=2, 161 | kernel_size=4): 162 | net = tf.expand_dims(tf.expand_dims(inputs, 1), 1) 163 | 164 | # First upscaling is different because it takes the input vector. 165 | current_depth = depth * 2 ** (num_layers - 1) 166 | scope = 'deconv1' 167 | net = slim.conv2d_transpose( 168 | net, current_depth, stride=1, padding='VALID', scope=scope) 169 | end_points[scope] = net 170 | 171 | for i in xrange(2, num_layers): 172 | scope = 'deconv%i' % (i) 173 | current_depth = depth * 2 ** (num_layers - i) 174 | net = slim.conv2d_transpose(net, current_depth, scope=scope) 175 | end_points[scope] = net 176 | 177 | # Last layer has different normalizer and activation. 178 | scope = 'deconv%i' % (num_layers) 179 | net = slim.conv2d_transpose( 180 | net, depth, normalizer_fn=None, activation_fn=None, scope=scope) 181 | end_points[scope] = net 182 | 183 | # Convert to proper channels. 184 | scope = 'logits' 185 | logits = slim.conv2d( 186 | net, 187 | num_outputs, 188 | normalizer_fn=None, 189 | activation_fn=None, 190 | kernel_size=1, 191 | stride=1, 192 | padding='VALID', 193 | scope=scope) 194 | end_points[scope] = logits 195 | 196 | logits.get_shape().assert_has_rank(4) 197 | logits.get_shape().assert_is_compatible_with( 198 | [None, final_size, final_size, num_outputs]) 199 | 200 | return logits, end_points 201 | -------------------------------------------------------------------------------- /src/slim/nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import dcgan 23 | 24 | 25 | class DCGANTest(tf.test.TestCase): 26 | 27 | def test_generator_run(self): 28 | tf.set_random_seed(1234) 29 | noise = tf.random_normal([100, 64]) 30 | image, _ = dcgan.generator(noise) 31 | with self.test_session() as sess: 32 | sess.run(tf.global_variables_initializer()) 33 | image.eval() 34 | 35 | def test_generator_graph(self): 36 | tf.set_random_seed(1234) 37 | # Check graph construction for a number of image size/depths and batch 38 | # sizes. 39 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 40 | tf.reset_default_graph() 41 | final_size = 2 ** i 42 | noise = tf.random_normal([batch_size, 64]) 43 | image, end_points = dcgan.generator( 44 | noise, 45 | depth=32, 46 | final_size=final_size) 47 | 48 | self.assertAllEqual([batch_size, final_size, final_size, 3], 49 | image.shape.as_list()) 50 | 51 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 52 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 53 | 54 | # Check layer depths. 55 | for j in range(1, i): 56 | layer = end_points['deconv%i' % j] 57 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 58 | 59 | def test_generator_invalid_input(self): 60 | wrong_dim_input = tf.zeros([5, 32, 32]) 61 | with self.assertRaises(ValueError): 62 | dcgan.generator(wrong_dim_input) 63 | 64 | correct_input = tf.zeros([3, 2]) 65 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 66 | dcgan.generator(correct_input, final_size=30) 67 | 68 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 69 | dcgan.generator(correct_input, final_size=4) 70 | 71 | def test_discriminator_run(self): 72 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 73 | output, _ = dcgan.discriminator(image) 74 | with self.test_session() as sess: 75 | sess.run(tf.global_variables_initializer()) 76 | output.eval() 77 | 78 | def test_discriminator_graph(self): 79 | # Check graph construction for a number of image size/depths and batch 80 | # sizes. 81 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 82 | tf.reset_default_graph() 83 | img_w = 2 ** i 84 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 85 | output, end_points = dcgan.discriminator( 86 | image, 87 | depth=32) 88 | 89 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 90 | 91 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 92 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 93 | 94 | # Check layer depths. 95 | for j in range(1, i+1): 96 | layer = end_points['conv%i' % j] 97 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 98 | 99 | def test_discriminator_invalid_input(self): 100 | wrong_dim_img = tf.zeros([5, 32, 32]) 101 | with self.assertRaises(ValueError): 102 | dcgan.discriminator(wrong_dim_img) 103 | 104 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 105 | with self.assertRaises(ValueError): 106 | dcgan.discriminator(spatially_undefined_shape) 107 | 108 | not_square = tf.zeros([5, 32, 16, 3]) 109 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 110 | dcgan.discriminator(not_square) 111 | 112 | not_power_2 = tf.zeros([5, 30, 30, 3]) 113 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 114 | dcgan.discriminator(not_power_2) 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /src/slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /src/slim/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu): 37 | """Defines the default arg scope for inception models. 38 | 39 | Args: 40 | weight_decay: The weight decay to use for regularizing the model. 41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 42 | batch_norm_decay: Decay for batch norm moving average. 43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 44 | in batch norm. 45 | activation_fn: Activation function for conv2d. 46 | 47 | Returns: 48 | An `arg_scope` to use for the inception models. 49 | """ 50 | batch_norm_params = { 51 | # Decay for the moving averages. 52 | 'decay': batch_norm_decay, 53 | # epsilon to prevent 0s in variance. 54 | 'epsilon': batch_norm_epsilon, 55 | # collection containing update_ops. 56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 57 | # use fused batch norm if possible. 58 | 'fused': None, 59 | } 60 | if use_batch_norm: 61 | normalizer_fn = slim.batch_norm 62 | normalizer_params = batch_norm_params 63 | else: 64 | normalizer_fn = None 65 | normalizer_params = {} 66 | # Set weight_decay for weights in Conv and FC layers. 67 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 68 | weights_regularizer=slim.l2_regularizer(weight_decay)): 69 | with slim.arg_scope( 70 | [slim.conv2d], 71 | weights_initializer=slim.variance_scaling_initializer(), 72 | activation_fn=activation_fn, 73 | normalizer_fn=normalizer_fn, 74 | normalizer_params=normalizer_params) as sc: 75 | return sc 76 | -------------------------------------------------------------------------------- /src/slim/nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /src/slim/nets/mobilenet_v1.md: -------------------------------------------------------------------------------- 1 | # MobileNet_v1 2 | 3 | [MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 4 | 5 | MobileNets trade off between latency, size and accuracy while comparing favorably with popular models from the literature. 6 | 7 | ![alt text](mobilenet_v1.png "MobileNet Graph") 8 | 9 | # Pre-trained Models 10 | 11 | Choose the right MobileNet model to fit your latency and size budget. The size of the network in memory and on disk is proportional to the number of parameters. The latency and power usage of the network scales with the number of Multiply-Accumulates (MACs) which measures the number of fused Multiplication and Addition operations. These MobileNet models have been trained on the 12 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 13 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 14 | 15 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 16 | :----:|:------------:|:----------:|:-------:|:-------:| 17 | [MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|569|4.24|70.7|89.5| 18 | [MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_1.0_192_2017_06_14.tar.gz)|418|4.24|69.3|88.9| 19 | [MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_1.0_160_2017_06_14.tar.gz)|291|4.24|67.2|87.5| 20 | [MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_1.0_128_2017_06_14.tar.gz)|186|4.24|64.1|85.3| 21 | [MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_0.75_224_2017_06_14.tar.gz)|317|2.59|68.4|88.2| 22 | [MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_0.75_192_2017_06_14.tar.gz)|233|2.59|67.4|87.3| 23 | [MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_0.75_160_2017_06_14.tar.gz)|162|2.59|65.2|86.1| 24 | [MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_0.75_128_2017_06_14.tar.gz)|104|2.59|61.8|83.6| 25 | [MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_0.50_224_2017_06_14.tar.gz)|150|1.34|64.0|85.4| 26 | [MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_0.50_192_2017_06_14.tar.gz)|110|1.34|62.1|84.0| 27 | [MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|77|1.34|59.9|82.5| 28 | [MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_0.50_128_2017_06_14.tar.gz)|49|1.34|56.2|79.6| 29 | [MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_0.25_224_2017_06_14.tar.gz)|41|0.47|50.6|75.0| 30 | [MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_0.25_192_2017_06_14.tar.gz)|34|0.47|49.0|73.6| 31 | [MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_0.25_160_2017_06_14.tar.gz)|21|0.47|46.0|70.7| 32 | [MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|14|0.47|41.3|66.2| 33 | 34 | 35 | Here is an example of how to download the MobileNet_v1_1.0_224 checkpoint: 36 | 37 | ```shell 38 | $ CHECKPOINT_DIR=/tmp/checkpoints 39 | $ mkdir ${CHECKPOINT_DIR} 40 | $ wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 41 | $ tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz 42 | $ mv mobilenet_v1_1.0_224.ckpt.* ${CHECKPOINT_DIR} 43 | $ rm mobilenet_v1_1.0_224_2017_06_14.tar.gz 44 | ``` 45 | More information on integrating MobileNets into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 46 | 47 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 48 | -------------------------------------------------------------------------------- /src/slim/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattzheng/tf-pose-estimation-applied/a10de95853de9ec3b5b13efc7a141f12e0cacf3d/src/slim/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /src/slim/nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 \ 48 | --moving_average_decay=0.9999 49 | ``` 50 | 51 | Run eval with the NASNet-A large ImageNet model 52 | 53 | ```shell 54 | DATASET_DIR=/tmp/imagenet 55 | EVAL_DIR=/tmp/tfmodel/eval 56 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 57 | python tensorflow_models/research/slim/eval_image_classifier \ 58 | --checkpoint_path=${CHECKPOINT_DIR} \ 59 | --eval_dir=${EVAL_DIR} \ 60 | --dataset_dir=${DATASET_DIR} \ 61 | --dataset_name=imagenet \ 62 | --dataset_split_name=validation \ 63 | --model_name=nasnet_large \ 64 | --eval_image_size=331 \ 65 | --moving_average_decay=0.9999 66 | ``` 67 | -------------------------------------------------------------------------------- /src/slim/nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/slim/nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /src/slim/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.nasnet import nasnet 34 | 35 | slim = tf.contrib.slim 36 | 37 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 38 | 'cifarnet': cifarnet.cifarnet, 39 | 'overfeat': overfeat.overfeat, 40 | 'vgg_a': vgg.vgg_a, 41 | 'vgg_16': vgg.vgg_16, 42 | 'vgg_19': vgg.vgg_19, 43 | 'inception_v1': inception.inception_v1, 44 | 'inception_v2': inception.inception_v2, 45 | 'inception_v3': inception.inception_v3, 46 | 'inception_v4': inception.inception_v4, 47 | 'inception_resnet_v2': inception.inception_resnet_v2, 48 | 'lenet': lenet.lenet, 49 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 50 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 51 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 52 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 53 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 54 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 55 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 56 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 57 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 58 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 59 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 60 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 61 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 62 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 63 | 'nasnet_large': nasnet.build_nasnet_large, 64 | } 65 | 66 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 67 | 'cifarnet': cifarnet.cifarnet_arg_scope, 68 | 'overfeat': overfeat.overfeat_arg_scope, 69 | 'vgg_a': vgg.vgg_arg_scope, 70 | 'vgg_16': vgg.vgg_arg_scope, 71 | 'vgg_19': vgg.vgg_arg_scope, 72 | 'inception_v1': inception.inception_v3_arg_scope, 73 | 'inception_v2': inception.inception_v3_arg_scope, 74 | 'inception_v3': inception.inception_v3_arg_scope, 75 | 'inception_v4': inception.inception_v4_arg_scope, 76 | 'inception_resnet_v2': 77 | inception.inception_resnet_v2_arg_scope, 78 | 'lenet': lenet.lenet_arg_scope, 79 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 80 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 81 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 82 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 83 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 84 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 85 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 86 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 87 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 88 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 89 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 90 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 91 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 92 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 93 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 94 | } 95 | 96 | 97 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 98 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 99 | 100 | Args: 101 | name: The name of the network. 102 | num_classes: The number of classes to use for classification. If 0 or None, 103 | the logits layer is omitted and its input features are returned instead. 104 | weight_decay: The l2 coefficient for the model weights. 105 | is_training: `True` if the model is being used for training and `False` 106 | otherwise. 107 | 108 | Returns: 109 | network_fn: A function that applies the model to a batch of images. It has 110 | the following signature: 111 | net, end_points = network_fn(images) 112 | The `images` input is a tensor of shape [batch_size, height, width, 3] 113 | with height = width = network_fn.default_image_size. (The permissibility 114 | and treatment of other sizes depends on the network_fn.) 115 | The returned `end_points` are a dictionary of intermediate activations. 116 | The returned `net` is the topmost layer, depending on `num_classes`: 117 | If `num_classes` was a non-zero integer, `net` is a logits tensor 118 | of shape [batch_size, num_classes]. 119 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 120 | to the logits layer of shape [batch_size, 1, 1, num_features] or 121 | [batch_size, num_features]. Dropout has not been applied to this 122 | (even if the network's original classification does); it remains for 123 | the caller to do this or not. 124 | 125 | Raises: 126 | ValueError: If network `name` is not recognized. 127 | """ 128 | if name not in networks_map: 129 | raise ValueError('Name of network unknown %s' % name) 130 | func = networks_map[name] 131 | @functools.wraps(func) 132 | def network_fn(images, **kwargs): 133 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 134 | with slim.arg_scope(arg_scope): 135 | return func(images, num_classes, is_training=is_training, **kwargs) 136 | if hasattr(func, 'default_image_size'): 137 | network_fn.default_image_size = func.default_image_size 138 | 139 | return network_fn 140 | -------------------------------------------------------------------------------- /src/slim/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map.keys()[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in nets_factory.networks_map.keys()[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /src/slim/nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /src/slim/nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /src/slim/nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def test_nonsquare_inputs_raise_exception(self): 28 | batch_size = 2 29 | height, width = 240, 320 30 | num_outputs = 4 31 | 32 | images = tf.ones((batch_size, height, width, 3)) 33 | 34 | with self.assertRaises(ValueError): 35 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 36 | pix2pix.pix2pix_generator( 37 | images, num_outputs, upsample_method='nn_upsample_conv') 38 | 39 | def _reduced_default_blocks(self): 40 | """Returns the default blocks, scaled down to make test run faster.""" 41 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 42 | for b in pix2pix._default_generator_blocks()] 43 | 44 | def test_output_size_nn_upsample_conv(self): 45 | batch_size = 2 46 | height, width = 256, 256 47 | num_outputs = 4 48 | 49 | images = tf.ones((batch_size, height, width, 3)) 50 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 51 | logits, _ = pix2pix.pix2pix_generator( 52 | images, num_outputs, blocks=self._reduced_default_blocks(), 53 | upsample_method='nn_upsample_conv') 54 | 55 | with self.test_session() as session: 56 | session.run(tf.global_variables_initializer()) 57 | np_outputs = session.run(logits) 58 | self.assertListEqual([batch_size, height, width, num_outputs], 59 | list(np_outputs.shape)) 60 | 61 | def test_output_size_conv2d_transpose(self): 62 | batch_size = 2 63 | height, width = 256, 256 64 | num_outputs = 4 65 | 66 | images = tf.ones((batch_size, height, width, 3)) 67 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 68 | logits, _ = pix2pix.pix2pix_generator( 69 | images, num_outputs, blocks=self._reduced_default_blocks(), 70 | upsample_method='conv2d_transpose') 71 | 72 | with self.test_session() as session: 73 | session.run(tf.global_variables_initializer()) 74 | np_outputs = session.run(logits) 75 | self.assertListEqual([batch_size, height, width, num_outputs], 76 | list(np_outputs.shape)) 77 | 78 | def test_block_number_dictates_number_of_layers(self): 79 | batch_size = 2 80 | height, width = 256, 256 81 | num_outputs = 4 82 | 83 | images = tf.ones((batch_size, height, width, 3)) 84 | blocks = [ 85 | pix2pix.Block(64, 0.5), 86 | pix2pix.Block(128, 0), 87 | ] 88 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 89 | _, end_points = pix2pix.pix2pix_generator( 90 | images, num_outputs, blocks) 91 | 92 | num_encoder_layers = 0 93 | num_decoder_layers = 0 94 | for end_point in end_points: 95 | if end_point.startswith('encoder'): 96 | num_encoder_layers += 1 97 | elif end_point.startswith('decoder'): 98 | num_decoder_layers += 1 99 | 100 | self.assertEqual(num_encoder_layers, len(blocks)) 101 | self.assertEqual(num_decoder_layers, len(blocks)) 102 | 103 | 104 | class DiscriminatorTest(tf.test.TestCase): 105 | 106 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 107 | return (input_size + pad * 2 - kernel_size) // stride + 1 108 | 109 | def test_four_layers(self): 110 | batch_size = 2 111 | input_size = 256 112 | 113 | output_size = self._layer_output_size(input_size) 114 | output_size = self._layer_output_size(output_size) 115 | output_size = self._layer_output_size(output_size) 116 | output_size = self._layer_output_size(output_size, stride=1) 117 | output_size = self._layer_output_size(output_size, stride=1) 118 | 119 | images = tf.ones((batch_size, input_size, input_size, 3)) 120 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 121 | logits, end_points = pix2pix.pix2pix_discriminator( 122 | images, num_filters=[64, 128, 256, 512]) 123 | self.assertListEqual([batch_size, output_size, output_size, 1], 124 | logits.shape.as_list()) 125 | self.assertListEqual([batch_size, output_size, output_size, 1], 126 | end_points['predictions'].shape.as_list()) 127 | 128 | def test_four_layers_no_padding(self): 129 | batch_size = 2 130 | input_size = 256 131 | 132 | output_size = self._layer_output_size(input_size, pad=0) 133 | output_size = self._layer_output_size(output_size, pad=0) 134 | output_size = self._layer_output_size(output_size, pad=0) 135 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 136 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 137 | 138 | images = tf.ones((batch_size, input_size, input_size, 3)) 139 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 140 | logits, end_points = pix2pix.pix2pix_discriminator( 141 | images, num_filters=[64, 128, 256, 512], padding=0) 142 | self.assertListEqual([batch_size, output_size, output_size, 1], 143 | logits.shape.as_list()) 144 | self.assertListEqual([batch_size, output_size, output_size, 1], 145 | end_points['predictions'].shape.as_list()) 146 | 147 | def test_four_layers_wrog_paddig(self): 148 | batch_size = 2 149 | input_size = 256 150 | 151 | images = tf.ones((batch_size, input_size, input_size, 3)) 152 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 153 | with self.assertRaises(TypeError): 154 | pix2pix.pix2pix_discriminator( 155 | images, num_filters=[64, 128, 256, 512], padding=1.5) 156 | 157 | def test_four_layers_negative_padding(self): 158 | batch_size = 2 159 | input_size = 256 160 | 161 | images = tf.ones((batch_size, input_size, input_size, 3)) 162 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 163 | with self.assertRaises(ValueError): 164 | pix2pix.pix2pix_discriminator( 165 | images, num_filters=[64, 128, 256, 512], padding=-1) 166 | 167 | if __name__ == '__main__': 168 | tf.test.main() 169 | -------------------------------------------------------------------------------- /src/slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/slim/preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /src/slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /src/slim/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'nasnet_mobile': inception_preprocessing, 58 | 'nasnet_large': inception_preprocessing, 59 | 'resnet_v1_50': vgg_preprocessing, 60 | 'resnet_v1_101': vgg_preprocessing, 61 | 'resnet_v1_152': vgg_preprocessing, 62 | 'resnet_v1_200': vgg_preprocessing, 63 | 'resnet_v2_50': vgg_preprocessing, 64 | 'resnet_v2_101': vgg_preprocessing, 65 | 'resnet_v2_152': vgg_preprocessing, 66 | 'resnet_v2_200': vgg_preprocessing, 67 | 'vgg': vgg_preprocessing, 68 | 'vgg_a': vgg_preprocessing, 69 | 'vgg_16': vgg_preprocessing, 70 | 'vgg_19': vgg_preprocessing, 71 | } 72 | 73 | if name not in preprocessing_fn_map: 74 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 75 | 76 | def preprocessing_fn(image, output_height, output_width, **kwargs): 77 | return preprocessing_fn_map[name].preprocess_image( 78 | image, output_height, output_width, is_training=is_training, **kwargs) 79 | 80 | return preprocessing_fn 81 | --------------------------------------------------------------------------------