├── .dockerignore ├── .gitattributes ├── .gitignore ├── CMakeLists.txt ├── Dockerfile ├── LICENSE ├── README.md ├── docker ├── Dockerfile └── README.md ├── etcs ├── dance.mp4 ├── feature.md ├── inference_result2.png ├── loss_ll_heat.png ├── loss_ll_paf.png ├── openpose_macbook13_mobilenet2.gif ├── openpose_macbook_cmu.gif ├── openpose_macbook_mobilenet3.gif ├── openpose_p40_cmu.gif ├── openpose_p40_mobilenet.gif ├── openpose_tx2_mobilenet3.gif ├── ros.md └── training.md ├── images ├── apink1.jpg ├── apink1_crop.jpg ├── apink1_crop_s1.jpg ├── apink2.jpg ├── apink3.jpg ├── cat.jpg ├── golf.jpg ├── hand1.jpg ├── hand1_small.jpg ├── hand2.jpg ├── handsup1.jpg ├── p1.jpg ├── p2.jpg ├── p3.jpg ├── p3_dance.png ├── ski.jpg └── valid_person1.jpg ├── launch └── demo_video.launch ├── models ├── graph │ ├── cmu_640x360 │ │ └── download.sh │ ├── cmu_640x480 │ │ └── download.sh │ └── mobilenet_thin_432x368 │ │ ├── graph.pb │ │ ├── graph_freeze.pb │ │ └── graph_opt.pb ├── numpy │ └── download.sh └── pretrained │ ├── mobilenet_v1_0.50_224_2017_06_14 │ └── download.sh │ ├── mobilenet_v1_0.75_224_2017_06_14 │ └── download.sh │ └── mobilenet_v1_1.0_224_2017_06_14 │ └── download.sh ├── msg ├── BodyPartElm.msg ├── Person.msg └── Persons.msg ├── package.xml ├── requirements.txt ├── scripts ├── broadcaster_ros.py └── visualization.py └── src ├── __init__.py ├── common.py ├── datum_pb2.py ├── estimator.py ├── lifting ├── __init__.py ├── config.py ├── draw.py ├── models │ └── prob_model_params.mat ├── prob_model.py └── upright_fast.py ├── network_base.py ├── network_cmu.py ├── network_dsconv.py ├── network_mobilenet.py ├── network_mobilenet_thin.py ├── networks.py ├── pose_augment.py ├── pose_datamaster.py ├── pose_dataset.py ├── pose_dataworker.py ├── pose_stats.py ├── run.py ├── run_checkpoint.py ├── run_webcam.py ├── slim ├── __init__.py ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── cyclegan.py │ ├── cyclegan_test.py │ ├── dcgan.py │ ├── dcgan_test.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── mobilenet_v1.md │ ├── mobilenet_v1.png │ ├── mobilenet_v1.py │ ├── mobilenet_v1_test.py │ ├── nasnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nasnet.py │ │ ├── nasnet_test.py │ │ ├── nasnet_utils.py │ │ └── nasnet_utils_test.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── pix2pix.py │ ├── pix2pix_test.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── vgg.py │ └── vgg_test.py └── preprocessing │ ├── __init__.py │ ├── cifarnet_preprocessing.py │ ├── inception_preprocessing.py │ ├── lenet_preprocessing.py │ ├── preprocessing_factory.py │ └── vgg_preprocessing.py └── train.py /.dockerignore: -------------------------------------------------------------------------------- 1 | ./models 2 | ./models/* 3 | ./models/*/* 4 | ./models/*/*/* 5 | ./models/*/*/*/* 6 | models 7 | models/* 8 | models/*/* 9 | models/*/*/* 10 | models/*/*/*/* 11 | *.meta 12 | *.index 13 | *.data-* 14 | *.ckpt.* 15 | 16 | ./tests 17 | ./tests/* 18 | tests 19 | graph*.pb 20 | chk*.meta 21 | lifting 22 | slim 23 | 24 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | models/numpy/*.npy filter=lfs diff=lfs merge=lfs -text 2 | *.ckpt* filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # models 104 | *ckpt* 105 | *.npy 106 | timeline*.json 107 | 108 | tmp 109 | models/graph/cmu_*/*.pb 110 | models/trained/*/checkpoint 111 | models/trained/*/*.pb 112 | models/trained/*/model-*.data-* 113 | models/trained/*/model-*.index 114 | models/trained/*/model-*.meta -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(tfpose_ros) 3 | 4 | ## Add support for C++11, supported in ROS Kinetic and newer 5 | add_definitions(-std=c++11) 6 | 7 | find_package(catkin REQUIRED COMPONENTS 8 | roscpp 9 | rospy 10 | std_msgs 11 | message_generation 12 | ) 13 | 14 | # Generate messages in the 'msg' folder 15 | add_message_files( 16 | FILES 17 | BodyPartElm.msg 18 | Person.msg 19 | Persons.msg 20 | ) 21 | 22 | generate_messages( 23 | DEPENDENCIES std_msgs 24 | ) 25 | 26 | catkin_package( 27 | CATKIN_DEPENDS rospy message_generation message_runtime 28 | ) 29 | 30 | install() 31 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess 2 | 3 | COPY ./*.py /root/tf-openpose/ 4 | WORKDIR /root/tf-openpose/ 5 | 6 | RUN cd /root/tf-openpose/ && pip3 install -r requirements.txt 7 | 8 | ENTRYPOINT ["python3", "pose_dataworker.py"] 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-pose-estimation 2 | 3 | 'Openpose' for human pose estimation have been implemented using Tensorflow. It also provides several variants that have made some changes to the network structure for **real-time processing on the CPU or low-power embedded devices.** 4 | 5 | 6 | **You can even run this on your macbook with descent FPS!** 7 | 8 | Original Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose 9 | 10 | | CMU's Original Model
on Macbook Pro 15" | Mobilenet Variant
on Macbook Pro 15" | Mobilenet Variant
on Jetson TX2 | 11 | |:---------|:--------------------|:----------------| 12 | | ![cmu-model](/etcs/openpose_macbook_cmu.gif) | ![mb-model-macbook](/etcs/openpose_macbook_mobilenet3.gif) | ![mb-model-tx2](/etcs/openpose_tx2_mobilenet3.gif) | 13 | | **~0.6 FPS** | **~4.2 FPS** @ 368x368 | **~10 FPS** @ 368x368 | 14 | | 2.8GHz Quad-core i7 | 2.8GHz Quad-core i7 | Jetson TX2 Embedded Board | 15 | 16 | Implemented features are listed here : [features](./etcs/feature.md) 17 | 18 | ## Install 19 | 20 | ### Dependencies 21 | 22 | You need dependencies below. 23 | 24 | - python3 25 | - tensorflow 1.4.1+ 26 | - opencv3, protobuf, python3-tk 27 | 28 | ### Install 29 | 30 | ```bash 31 | $ git clone https://www.github.com/ildoonet/tf-openpose 32 | $ cd tf-openpose 33 | $ pip3 install -r requirements.txt 34 | ``` 35 | 36 | ## Models 37 | 38 | I have tried multiple variations of models to find optmized network architecture. Some of them are below and checkpoint files are provided for research purpose. 39 | 40 | - cmu 41 | - the model based VGG pretrained network which described in the original paper. 42 | - I converted Weights in Caffe format to use in tensorflow. 43 | - [pretrained weight download](https://www.dropbox.com/s/xh5s7sb7remu8tx/openpose_coco.npy?dl=0) 44 | 45 | - dsconv 46 | - Same architecture as the cmu version except for the **depthwise separable convolution** of mobilenet. 47 | - I trained it using 'transfer learning', but it provides not-enough speed and accuracy. 48 | 49 | - mobilenet 50 | - Based on the mobilenet paper, 12 convolutional layers are used as feature-extraction layers. 51 | - To improve on small person, **minor modification** on the architecture have been made. 52 | - Three models were learned according to network size parameters. 53 | - mobilenet 54 | - 368x368 : [checkpoint weight download](https://www.dropbox.com/s/09xivpuboecge56/mobilenet_0.75_0.50_model-388003.zip?dl=0) 55 | - mobilenet_fast 56 | - mobilenet_accurate 57 | - I published models which is not the best ones, but you can test them before you trained a model from the scratch. 58 | 59 | ### Download Tensorflow Graph File(pb file) 60 | 61 | Before running demo, you should download graph files. You can deploy this graph on your mobile or other platforms. 62 | 63 | - cmu_640x360 64 | - cmu_640x480 65 | - mobilenet_thin_432x368 66 | 67 | CMU's model graphs are too large for git, so I uploaded them on dropbox. You should download them if you want to use cmu's original model. 68 | 69 | ``` 70 | $ cd models/graph/cmu_640x360 71 | $ bash download.sh 72 | $ cd models/graph/cmu_640x480 73 | $ bash download.sh 74 | ``` 75 | 76 | ### Inference Time 77 | 78 | | Dataset | Model | Inference Time
Macbook Pro i5 3.1G | Inference Time
Jetson TX2 | 79 | |---------|--------------------|----------------:|----------------:| 80 | | Coco | cmu | 10.0s @ 368x368 | OOM @ 368x368
5.5s @ 320x240| 81 | | Coco | dsconv | 1.10s @ 368x368 | 82 | | Coco | mobilenet_accurate | 0.40s @ 368x368 | 0.18s @ 368x368 | 83 | | Coco | mobilenet | 0.24s @ 368x368 | 0.10s @ 368x368 | 84 | | Coco | mobilenet_fast | 0.16s @ 368x368 | 0.07s @ 368x368 | 85 | 86 | ## Demo 87 | 88 | ### Test Inference 89 | 90 | You can test the inference feature with a single image. 91 | 92 | ``` 93 | $ python3 run.py --model=mobilenet_thin_432x368 --image=... 94 | ``` 95 | 96 | The image flag MUST be relative to the src folder with no "~", i.e: 97 | ``` 98 | --image ../../Desktop 99 | ``` 100 | 101 | Then you will see the screen as below with pafmap, heatmap, result and etc. 102 | 103 | ![inferent_result](./etcs/inference_result2.png) 104 | 105 | ### Realtime Webcam 106 | 107 | ``` 108 | $ python3 run_webcam.py --model=mobilenet_thin_432x368 --camera=0 109 | ``` 110 | 111 | Then you will see the realtime webcam screen with estimated poses as below. This [Realtime Result](./etcs/openpose_macbook13_mobilenet2.gif) was recored on macbook pro 13" with 3.1Ghz Dual-Core CPU. 112 | 113 | ## Python Usage 114 | 115 | This pose estimator provides simple python classes that you can use in your applications. 116 | 117 | See [run.py](run.py) or [run_webcam.py](run_webcam.py) as references. 118 | 119 | ```python 120 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 121 | humans = e.inference(image) 122 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 123 | ``` 124 | 125 | ## ROS Support 126 | 127 | See : [etcs/ros.md](./etcs/ros.md) 128 | 129 | ## Training 130 | 131 | See : [etcs/training.md](./etcs/training.md) 132 | 133 | ## References 134 | 135 | ### OpenPose 136 | 137 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose 138 | 139 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation 140 | 141 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 142 | 143 | [4] Keras Openpose : https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation 144 | 145 | ### Lifting from the deep 146 | 147 | [1] Arxiv Paper : https://arxiv.org/abs/1701.00295 148 | 149 | [2] https://github.com/DenisTome/Lifting-from-the-Deep-release 150 | 151 | ### Mobilenet 152 | 153 | [1] Original Paper : https://arxiv.org/abs/1704.04861 154 | 155 | [2] Pretrained model : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md 156 | 157 | ### Libraries 158 | 159 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack 160 | 161 | ### Tensorflow Tips 162 | 163 | [1] Freeze graph : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py 164 | 165 | [2] Optimize graph : https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2 166 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | #ENV http_proxy=http://10.41.249.28:8080 https_proxy=http://10.41.249.28:8080 4 | 5 | RUN apt-get update -yq && apt-get install -yq build-essential cmake git pkg-config wget zip && \ 6 | apt-get install -yq libjpeg8-dev libtiff5-dev libjasper-dev libpng12-dev && \ 7 | apt-get install -yq libavcodec-dev libavformat-dev libswscale-dev libv4l-dev && \ 8 | apt-get install -yq libgtk2.0-dev && \ 9 | apt-get install -yq libatlas-base-dev gfortran && \ 10 | apt-get install -yq python3 python3-dev python3-pip python3-setuptools python3-tk git && \ 11 | apt-get remove -yq python-pip python3-pip && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && \ 12 | pip3 install numpy && \ 13 | cd ~ && git clone https://github.com/Itseez/opencv.git && \ 14 | cd opencv && mkdir build && cd build && \ 15 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 16 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 17 | -D INSTALL_PYTHON_EXAMPLES=ON \ 18 | -D BUILD_opencv_python3=yes -D PYTHON_EXECUTABLE=/usr/bin/python3 .. && \ 19 | make -j8 && make install && rm -rf /root/opencv/ && \ 20 | mkdir -p /root/tf-openpose && \ 21 | rm -rf /tmp/*.tar.gz && \ 22 | apt-get clean && rm -rf /tmp/* /var/tmp* /var/lib/apt/lists/* && \ 23 | rm -f /etc/ssh/ssh_host_* && rm -rf /usr/share/man?? /usr/share/man/??_* 24 | 25 | COPY . /root/tf-openpose/ 26 | WORKDIR /root/tf-openpose/ 27 | 28 | RUN cd /root/tf-openpose/ && pip3 install -U setuptools && \ 29 | pip3 install tensorflow && pip3 install -r requirements.txt 30 | 31 | RUN cd /root && git clone https://github.com/cocodataset/cocoapi && \ 32 | pip3 install cython && \ 33 | cd cocoapi/PythonAPI && python3 setup.py build_ext --inplace && python3 setup.py build_ext install && \ 34 | mkdir /coco && cd /coco && wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip && \ 35 | unzip annotations_trainval2017.zip && rm -rf annotations_trainval2017.zip 36 | 37 | ENTRYPOINT ["python3", "pose_dataworker.py"] 38 | 39 | #ENV http_proxy= https_proxy= 40 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | $ docker build --tag idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess -f ./docker/full/ . 3 | ``` 4 | 5 | ``` 6 | $ docker build --tag idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess -f ./docker/update/Dockerfile . 7 | ``` -------------------------------------------------------------------------------- /etcs/dance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/dance.mp4 -------------------------------------------------------------------------------- /etcs/feature.md: -------------------------------------------------------------------------------- 1 | ## Features 2 | 3 | - [x] CMU's original network architecture and weights. 4 | 5 | - [x] Transfer Original Weights to Tensorflow 6 | 7 | - [x] Training Code with multi-gpus 8 | 9 | - [x] Evaluate with test dataset 10 | 11 | - [ ] Inference 12 | 13 | - [x] Post processing from network output. 14 | 15 | - [x] Faster post-processing 16 | 17 | - [ ] Multi-Scale Inference 18 | 19 | - [x] Faster network variants using custom mobilenet architecture. 20 | 21 | - [x] Depthwise Separable Convolution Version 22 | 23 | - [x] Mobilenet Version 24 | 25 | - [ ] Demos 26 | 27 | - [x] Realtime Webcam Demo 28 | 29 | - [x] Image File Demo 30 | 31 | - [ ] Video File Demo 32 | 33 | - [x] ROS Support. See [./etcs/ros.md](ros readme). -------------------------------------------------------------------------------- /etcs/inference_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/inference_result2.png -------------------------------------------------------------------------------- /etcs/loss_ll_heat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/loss_ll_heat.png -------------------------------------------------------------------------------- /etcs/loss_ll_paf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/loss_ll_paf.png -------------------------------------------------------------------------------- /etcs/openpose_macbook13_mobilenet2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_macbook13_mobilenet2.gif -------------------------------------------------------------------------------- /etcs/openpose_macbook_cmu.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_macbook_cmu.gif -------------------------------------------------------------------------------- /etcs/openpose_macbook_mobilenet3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_macbook_mobilenet3.gif -------------------------------------------------------------------------------- /etcs/openpose_p40_cmu.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_p40_cmu.gif -------------------------------------------------------------------------------- /etcs/openpose_p40_mobilenet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_p40_mobilenet.gif -------------------------------------------------------------------------------- /etcs/openpose_tx2_mobilenet3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_tx2_mobilenet3.gif -------------------------------------------------------------------------------- /etcs/ros.md: -------------------------------------------------------------------------------- 1 | # tf-pose-estimation for ROS 2 | 3 | Human pose estimation is expected to use on mobile robots which need human interactions. 4 | 5 | ## Installation 6 | 7 | Cloning this repository under src folder in your ros workstation. And the same should be carried out as [README.md](README.md). 8 | 9 | ``` 10 | $ cd $(ros-workspace) 11 | $ cd src 12 | $ git clone https://github.com/ildoonet/tf-pose-estimation 13 | $ pip install -r tf-pose-estimation/requirements.txt 14 | ``` 15 | 16 | There are dependencies to launch demo, 17 | 18 | - video_stream_opencv 19 | - image_view 20 | - ros_video_recorder : https://github.com/ildoonet/ros-video-recorder 21 | 22 | ## Video/camera demo 23 | 24 | | CMU
640x360 | Mobilenet_Thin
432x368 | 25 | |:----------------|:---------------------------| 26 | | ![cmu-model](/etcs/openpose_p40_cmu.gif) | ![cmu-model](/etcs/openpose_p40_mobilenet.gif) | 27 | 28 | Above tests were run on a P40 gpu. Latency between current video frames and processed frames is much lower on mobilenet version. 29 | 30 | Source : https://www.youtube.com/watch?v=rSZnyBuc6tc 31 | 32 | ``` 33 | $ roslaunch tfpose_ros demo_video.launch 34 | ``` 35 | 36 | You can specify 'video' arguments to launch realtime video demo using your camera. See [./launch/demo_video.launch](ros launch file). -------------------------------------------------------------------------------- /etcs/training.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | 3 | ### Coco Dataset 4 | 5 | You should download COCO Dataset from http://cocodataset.org/#download 6 | 7 | Also, you need to install cocoapi for easy parsing : https://github.com/cocodataset/cocoapi 8 | 9 | ''' 10 | $ git clone https://github.com/cocodataset/cocoapi 11 | $ cd cocoapi/PythonAPI 12 | $ python3 setup.py build_ext --inplace 13 | $ python3 setup.py build_ext install 14 | ''' 15 | 16 | ### Augmentation 17 | 18 | CMU Perceptual Computing Lab has modified Caffe to provide data augmentation. See : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train 19 | 20 | I implemented the augmentation codes as the way of the original version, See [pose_dataset.py](pose_dataset.py) and [pose_augment.py](pose_augment.py). This includes scaling, rotation, flip, cropping. 21 | 22 | This process can be a bottleneck for training, so if you have enough computing resources, please see [Run for Faster Training]() Section 23 | 24 | ### Run 25 | 26 | ``` 27 | $ python3 train.py --model=cmu --datapath={datapath} --batchsize=64 --lr=0.001 --modelpath={path-to-save} 28 | 29 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 30 | ``` 31 | 32 | If you want to reproduce the original paper's result, the following setting is recommended. 33 | 34 | - model : vgg 35 | - lr : 0.0001 or 0.00004 36 | - input-width = input-height = 368x368 or 432x368 37 | - batchsize : 10 (I trained with batchsizes up to 128, they are trained well) 38 | 39 | | Heatmap Loss | PAFmap(Part Affinity Field) Loss | 40 | |-------------------------------------------|------------------------------------------| 41 | | ![train_loss_cmu](/etcs/loss_ll_heat.png) | ![train_loss_cmu](/etcs/loss_ll_paf.png) | 42 | 43 | As you can see from the table above, training loss was converged at the almost same trends with the original paper. 44 | 45 | The mobilenet versions has slightly poor loss value compared to the original one. Training losses are 3~8% larger, though validation losses are 5~14% larger. 46 | 47 | 48 | ### Run for Faster Training 49 | 50 | If you have enough computing resources in multiple nodes, you can launch multiple workers on nodes to help data preparation. 51 | 52 | ``` 53 | worker-node1$ python3 pose_dataworker.py --master=tcp://host:port 54 | worker-node2$ python3 pose_dataworker.py --master=tcp://host:port 55 | worker-node3$ python3 pose_dataworker.py --master=tcp://host:port 56 | ... 57 | ``` 58 | 59 | After above preparation, you can launch training script with 'remote-data' arguments. 60 | 61 | ``` 62 | $ python3 train.py --remote-data=tcp://0.0.0.0:port 63 | 64 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 65 | ``` 66 | 67 | Also, You can quickly train with multiple gpus. This automatically splits batch into multiple gpus for forward/backward computations. 68 | 69 | ``` 70 | $ python3 train.py --remote-data=tcp://0.0.0.0:port --gpus=8 71 | 72 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights... 73 | ``` 74 | 75 | I trained models within a day with 8 gpus and multiple pre-processing nodes with 48 core cpus. 76 | 77 | ### Model Optimization for Inference 78 | 79 | After trained a model, I optimized models by folding batch normalization to convolutional layers and removing redundant operations. 80 | 81 | Firstly, the model should be frozen. 82 | 83 | ```bash 84 | $ python3 -m tensorflow.python.tools.freeze_graph \ 85 | --input_graph=... \ 86 | --output_graph=... \ 87 | --input_checkpoint=... \ 88 | --output_node_names="Openpose/concat_stage7" 89 | ``` 90 | 91 | And the optimization can be performed on the frozen model via graph transform provided by tensorflow. 92 | 93 | ```bash 94 | $ bazel build tensorflow/tools/graph_transforms:transform_graph 95 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 96 | --in_graph=... \ 97 | --out_graph=... \ 98 | --inputs='image:0' \ 99 | --outputs='Openpose/concat_stage7:0' \ 100 | --transforms=' 101 | strip_unused_nodes(type=float, shape="1,368,368,3") 102 | remove_nodes(op=Identity, op=CheckNumerics) 103 | fold_constants(ignoreError=False) 104 | fold_old_batch_norms 105 | fold_batch_norms' 106 | ``` 107 | 108 | Also, It is promising to quantize neural network in 8 bit to get futher improvement for speed. In my case, this will make inference less accurate and take more time on Intel's CPUs. 109 | 110 | ``` 111 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 112 | --in_graph=/Users/ildoonet/repos/tf-openpose/tmp/cmu_640x480/graph_opt.pb \ 113 | --out_graph=/Users/ildoonet/repos/tf-openpose/tmp/cmu_640x480/graph_q.pb \ 114 | --inputs='image' \ 115 | --outputs='Openpose/concat_stage7:0' \ 116 | --transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,360,640,3") 117 | remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) 118 | fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes 119 | strip_unused_nodes sort_by_execution_order' 120 | ``` -------------------------------------------------------------------------------- /images/apink1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink1.jpg -------------------------------------------------------------------------------- /images/apink1_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink1_crop.jpg -------------------------------------------------------------------------------- /images/apink1_crop_s1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink1_crop_s1.jpg -------------------------------------------------------------------------------- /images/apink2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink2.jpg -------------------------------------------------------------------------------- /images/apink3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink3.jpg -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/cat.jpg -------------------------------------------------------------------------------- /images/golf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/golf.jpg -------------------------------------------------------------------------------- /images/hand1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/hand1.jpg -------------------------------------------------------------------------------- /images/hand1_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/hand1_small.jpg -------------------------------------------------------------------------------- /images/hand2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/hand2.jpg -------------------------------------------------------------------------------- /images/handsup1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/handsup1.jpg -------------------------------------------------------------------------------- /images/p1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p1.jpg -------------------------------------------------------------------------------- /images/p2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p2.jpg -------------------------------------------------------------------------------- /images/p3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p3.jpg -------------------------------------------------------------------------------- /images/p3_dance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p3_dance.png -------------------------------------------------------------------------------- /images/ski.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/ski.jpg -------------------------------------------------------------------------------- /images/valid_person1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/valid_person1.jpg -------------------------------------------------------------------------------- /launch/demo_video.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /models/graph/cmu_640x360/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "download model graph : cmu_640x360" 4 | 5 | wget http://download1662.mediafire.com/9w41h1qg57og/fdhqpmdrdzafoc4/graph_freeze.pb -O graph_freeze.pb 6 | wget http://download1650.mediafire.com/35hcd7ukp3fg/n6qnqz00g1pjf7d/graph_opt.pb -O graph_opt.pb 7 | wget http://download1193.mediafire.com/eaoeszlwevfg/38hyjrwfdsyqsbq/graph_q.pb -O graph_q.pb 8 | wget http://download1477.mediafire.com/5mujvsj810xg/a2a0nc8i1oj5iam/graph.pb -O graph.pb 9 | -------------------------------------------------------------------------------- /models/graph/cmu_640x480/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "download model graph : cmu_640x480" 4 | 5 | wget http://download1515.mediafire.com/66p3hu7p6lfg/q7fbue8dqr3ytrb/graph_freeze.pb -O graph_freeze.pb 6 | wget http://download1640.mediafire.com/vqciqfcbz7qg/eolfk6t1t3yb191/graph_opt.pb -O graph_opt.pb 7 | wget http://download843.mediafire.com/zczmlmayrrng/s6d01qvmlkfxgzr/graph_q.pb -O graph_q.pb 8 | wget http://download938.mediafire.com/3mootio0u5ag/ae7hud583cx259z/graph.pb -O graph.pb 9 | -------------------------------------------------------------------------------- /models/graph/mobilenet_thin_432x368/graph_freeze.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/models/graph/mobilenet_thin_432x368/graph_freeze.pb -------------------------------------------------------------------------------- /models/graph/mobilenet_thin_432x368/graph_opt.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/models/graph/mobilenet_thin_432x368/graph_opt.pb -------------------------------------------------------------------------------- /models/numpy/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/ropayv77vklvf56/openpose_coco.npy 4 | wget http://www.mediafire.com/file/7e73ddj31rzw6qq/openpose_vgg16.npy 5 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.50_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/meu73iq8rxlsd3g/mobilenet_v1_0.50_224.ckpt.data-00000-of-00001 4 | wget http://www.mediafire.com/file/7u6iupfkcaxk5hx/mobilenet_v1_0.50_224.ckpt.index 5 | wget http://www.mediafire.com/file/zp8y4d0ytzharzz/mobilenet_v1_0.50_224.ckpt.meta 6 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_0.75_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/kibz0x9e7h11ueb/mobilenet_v1_0.75_224.ckpt.data-00000-of-00001 4 | wget http://www.mediafire.com/file/t8909eaikvc6ea2/mobilenet_v1_0.75_224.ckpt.index 5 | wget http://www.mediafire.com/file/6jjnbn1aged614x/mobilenet_v1_0.75_224.ckpt.meta 6 | -------------------------------------------------------------------------------- /models/pretrained/mobilenet_v1_1.0_224_2017_06_14/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.mediafire.com/file/oh6njnz9lgoqwdj/mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 4 | wget http://www.mediafire.com/file/61qln0tbac4ny9o/mobilenet_v1_1.0_224.ckpt.meta 5 | wget http://www.mediafire.com/file/2111rh6tb5fl1lr/mobilenet_v1_1.0_224.ckpt.index 6 | -------------------------------------------------------------------------------- /msg/BodyPartElm.msg: -------------------------------------------------------------------------------- 1 | uint32 part_id 2 | float32 x 3 | float32 y 4 | float32 confidence -------------------------------------------------------------------------------- /msg/Person.msg: -------------------------------------------------------------------------------- 1 | BodyPartElm[] body_part -------------------------------------------------------------------------------- /msg/Persons.msg: -------------------------------------------------------------------------------- 1 | Person[] persons 2 | uint32 image_w 3 | uint32 image_h 4 | Header header -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | tfpose_ros 4 | 0.1.0 5 | ROS Package for Human Pose Estimation Implemented in Tensorflow 6 | 7 | Curtis Kim 8 | 9 | GNU2.0 10 | 11 | catkin 12 | 13 | rospy 14 | message_generation 15 | std_msgs 16 | cv_bridge 17 | 18 | message_generation 19 | message_runtime 20 | std_msgs 21 | rospy 22 | 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | matplotlib 3 | scipy 4 | tqdm 5 | requests 6 | fire 7 | git+https://github.com/ppwwyyxx/tensorpack.git -------------------------------------------------------------------------------- /scripts/broadcaster_ros.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | import os 4 | import sys 5 | import ast 6 | 7 | from threading import Lock 8 | import rospy 9 | import rospkg 10 | from cv_bridge import CvBridge, CvBridgeError 11 | from sensor_msgs.msg import Image 12 | from tfpose_ros.msg import Persons, Person, BodyPartElm 13 | 14 | from estimator import TfPoseEstimator 15 | from networks import model_wh, get_graph_path 16 | 17 | 18 | def humans_to_msg(humans): 19 | persons = Persons() 20 | 21 | for human in humans: 22 | person = Person() 23 | 24 | for k in human.body_parts: 25 | body_part = human.body_parts[k] 26 | 27 | body_part_msg = BodyPartElm() 28 | body_part_msg.part_id = body_part.part_idx 29 | body_part_msg.x = body_part.x 30 | body_part_msg.y = body_part.y 31 | body_part_msg.confidence = body_part.score 32 | person.body_part.append(body_part_msg) 33 | persons.persons.append(person) 34 | 35 | return persons 36 | 37 | 38 | def callback_image(data): 39 | # et = time.time() 40 | try: 41 | cv_image = cv_bridge.imgmsg_to_cv2(data, "bgr8") 42 | except CvBridgeError as e: 43 | rospy.logerr('[ros-video-recorder][VideoFrames] Converting Image Error. ' + str(e)) 44 | return 45 | 46 | acquired = tf_lock.acquire(False) 47 | if not acquired: 48 | return 49 | 50 | try: 51 | humans = pose_estimator.inference(cv_image, scales) 52 | finally: 53 | tf_lock.release() 54 | 55 | msg = humans_to_msg(humans) 56 | msg.image_w = data.width 57 | msg.image_h = data.height 58 | msg.header = data.header 59 | 60 | pub_pose.publish(msg) 61 | # rospy.loginfo(time.time() - et) 62 | 63 | 64 | if __name__ == '__main__': 65 | rospy.loginfo('initialization+') 66 | rospy.init_node('TfPoseEstimatorROS', anonymous=True) 67 | 68 | # parameters 69 | image_topic = rospy.get_param('~camera', '') 70 | model = rospy.get_param('~model', 'cmu_640x480') 71 | scales = rospy.get_param('~scales', '[None]') 72 | scales = ast.literal_eval(scales) 73 | tf_lock = Lock() 74 | 75 | rospy.loginfo('[TfPoseEstimatorROS] scales(%d)=%s' % (len(scales), str(scales))) 76 | 77 | if not image_topic: 78 | rospy.logerr('Parameter \'camera\' is not provided.') 79 | sys.exit(-1) 80 | 81 | try: 82 | w, h = model_wh(model) 83 | graph_path = get_graph_path(model) 84 | 85 | rospack = rospkg.RosPack() 86 | graph_path = os.path.join(rospack.get_path('tfpose_ros'), graph_path) 87 | except Exception as e: 88 | rospy.logerr('invalid model: %s, e=%s' % (model, e)) 89 | sys.exit(-1) 90 | 91 | pose_estimator = TfPoseEstimator(graph_path, target_size=(w, h)) 92 | cv_bridge = CvBridge() 93 | 94 | rospy.Subscriber(image_topic, Image, callback_image, queue_size=1, buff_size=2**24) 95 | pub_pose = rospy.Publisher('~pose', Persons, queue_size=1) 96 | 97 | rospy.loginfo('start+') 98 | rospy.spin() 99 | rospy.loginfo('finished') 100 | -------------------------------------------------------------------------------- /scripts/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import time 3 | import cv2 4 | import rospy 5 | from sensor_msgs.msg import Image 6 | from cv_bridge import CvBridge, CvBridgeError 7 | 8 | from tfpose_ros.msg import Persons, Person, BodyPartElm 9 | from estimator import Human, BodyPart, TfPoseEstimator 10 | 11 | 12 | class VideoFrames: 13 | """ 14 | Reference : ros-video-recorder 15 | https://github.com/ildoonet/ros-video-recorder/blob/master/scripts/recorder.py 16 | """ 17 | def __init__(self, image_topic): 18 | self.image_sub = rospy.Subscriber(image_topic, Image, self.callback_image, queue_size=1) 19 | self.bridge = CvBridge() 20 | self.frames = [] 21 | 22 | def callback_image(self, data): 23 | try: 24 | cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") 25 | except CvBridgeError as e: 26 | rospy.logerr('Converting Image Error. ' + str(e)) 27 | return 28 | 29 | self.frames.append((data.header.stamp, cv_image)) 30 | 31 | def get_latest(self, at_time, remove_older=True): 32 | fs = [x for x in self.frames if x[0] <= at_time] 33 | if len(fs) == 0: 34 | return None 35 | 36 | f = fs[-1] 37 | if remove_older: 38 | self.frames = self.frames[len(fs) - 1:] 39 | 40 | return f[1] 41 | 42 | 43 | def cb_pose(data): 44 | # get image with pose time 45 | t = data.header.stamp 46 | image = vf.get_latest(t, remove_older=True) 47 | if image is None: 48 | rospy.logwarn('No received images.') 49 | return 50 | 51 | h, w = image.shape[:2] 52 | if resize_ratio > 0: 53 | image = cv2.resize(image, (int(resize_ratio*w), int(resize_ratio*h)), interpolation=cv2.INTER_LINEAR) 54 | 55 | # ros topic to Person instance 56 | humans = [] 57 | for p_idx, person in enumerate(data.persons): 58 | human = Human([]) 59 | for body_part in person.body_part: 60 | part = BodyPart('', body_part.part_id, body_part.x, body_part.y, body_part.confidence) 61 | human.body_parts[body_part.part_id] = part 62 | 63 | humans.append(human) 64 | 65 | # draw 66 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 67 | pub_img.publish(cv_bridge.cv2_to_imgmsg(image, 'bgr8')) 68 | 69 | 70 | if __name__ == '__main__': 71 | rospy.loginfo('initialization+') 72 | rospy.init_node('TfPoseEstimatorROS-Visualization', anonymous=True) 73 | 74 | # topics params 75 | image_topic = rospy.get_param('~camera', '') 76 | pose_topic = rospy.get_param('~pose', '/pose_estimator/pose') 77 | 78 | resize_ratio = float(rospy.get_param('~resize_ratio', '-1')) 79 | 80 | # publishers 81 | pub_img = rospy.Publisher('~output', Image, queue_size=1) 82 | 83 | # initialization 84 | cv_bridge = CvBridge() 85 | vf = VideoFrames(image_topic) 86 | rospy.wait_for_message(image_topic, Image, timeout=30) 87 | 88 | # subscribers 89 | rospy.Subscriber(pose_topic, Persons, cb_pose, queue_size=1) 90 | 91 | # run 92 | rospy.spin() 93 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/__init__.py -------------------------------------------------------------------------------- /src/common.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import tensorflow as tf 4 | import cv2 5 | 6 | 7 | regularizer_conv = 0.004 8 | regularizer_dsconv = 0.0004 9 | batchnorm_fused = True 10 | activation_fn = tf.nn.relu 11 | 12 | 13 | class CocoPart(Enum): 14 | Nose = 0 15 | Neck = 1 16 | RShoulder = 2 17 | RElbow = 3 18 | RWrist = 4 19 | LShoulder = 5 20 | LElbow = 6 21 | LWrist = 7 22 | RHip = 8 23 | RKnee = 9 24 | RAnkle = 10 25 | LHip = 11 26 | LKnee = 12 27 | LAnkle = 13 28 | REye = 14 29 | LEye = 15 30 | REar = 16 31 | LEar = 17 32 | Background = 18 33 | 34 | 35 | class MPIIPart(Enum): 36 | RAnkle = 0 37 | RKnee = 1 38 | RHip = 2 39 | LHip = 3 40 | LKnee = 4 41 | LAnkle = 5 42 | RWrist = 6 43 | RElbow = 7 44 | RShoulder = 8 45 | LShoulder = 9 46 | LElbow = 10 47 | LWrist = 11 48 | Neck = 12 49 | Head = 13 50 | 51 | @staticmethod 52 | def from_coco(human): 53 | # t = { 54 | # MPIIPart.RAnkle: CocoPart.RAnkle, 55 | # MPIIPart.RKnee: CocoPart.RKnee, 56 | # MPIIPart.RHip: CocoPart.RHip, 57 | # MPIIPart.LHip: CocoPart.LHip, 58 | # MPIIPart.LKnee: CocoPart.LKnee, 59 | # MPIIPart.LAnkle: CocoPart.LAnkle, 60 | # MPIIPart.RWrist: CocoPart.RWrist, 61 | # MPIIPart.RElbow: CocoPart.RElbow, 62 | # MPIIPart.RShoulder: CocoPart.RShoulder, 63 | # MPIIPart.LShoulder: CocoPart.LShoulder, 64 | # MPIIPart.LElbow: CocoPart.LElbow, 65 | # MPIIPart.LWrist: CocoPart.LWrist, 66 | # MPIIPart.Neck: CocoPart.Neck, 67 | # MPIIPart.Nose: CocoPart.Nose, 68 | # } 69 | 70 | t = [ 71 | (MPIIPart.Head, CocoPart.Nose), 72 | (MPIIPart.Neck, CocoPart.Neck), 73 | (MPIIPart.RShoulder, CocoPart.RShoulder), 74 | (MPIIPart.RElbow, CocoPart.RElbow), 75 | (MPIIPart.RWrist, CocoPart.RWrist), 76 | (MPIIPart.LShoulder, CocoPart.LShoulder), 77 | (MPIIPart.LElbow, CocoPart.LElbow), 78 | (MPIIPart.LWrist, CocoPart.LWrist), 79 | (MPIIPart.RHip, CocoPart.RHip), 80 | (MPIIPart.RKnee, CocoPart.RKnee), 81 | (MPIIPart.RAnkle, CocoPart.RAnkle), 82 | (MPIIPart.LHip, CocoPart.LHip), 83 | (MPIIPart.LKnee, CocoPart.LKnee), 84 | (MPIIPart.LAnkle, CocoPart.LAnkle), 85 | ] 86 | 87 | pose_2d_mpii = [] 88 | visibilty = [] 89 | for mpi, coco in t: 90 | if coco.value not in human.body_parts.keys(): 91 | pose_2d_mpii.append((0, 0)) 92 | visibilty.append(False) 93 | continue 94 | pose_2d_mpii.append((human.body_parts[coco.value].x, human.body_parts[coco.value].y)) 95 | visibilty.append(True) 96 | return pose_2d_mpii, visibilty 97 | 98 | CocoPairs = [ 99 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11), 100 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17) 101 | ] # = 19 102 | CocoPairsRender = CocoPairs[:-2] 103 | CocoPairsNetwork = [ 104 | (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5), 105 | (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27) 106 | ] # = 19 107 | 108 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 109 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 110 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 111 | 112 | 113 | def read_imgfile(path, width, height): 114 | val_image = cv2.imread(path, cv2.IMREAD_COLOR) 115 | if width is not None and height is not None: 116 | val_image = cv2.resize(val_image, (width, height)) 117 | return val_image 118 | 119 | 120 | def get_sample_images(w, h): 121 | val_image = [ 122 | read_imgfile('./images/p1.jpg', w, h), 123 | read_imgfile('./images/p2.jpg', w, h), 124 | read_imgfile('./images/p3.jpg', w, h), 125 | read_imgfile('./images/golf.jpg', w, h), 126 | read_imgfile('./images/hand1.jpg', w, h), 127 | read_imgfile('./images/hand2.jpg', w, h), 128 | read_imgfile('./images/apink1_crop.jpg', w, h), 129 | read_imgfile('./images/ski.jpg', w, h), 130 | read_imgfile('./images/apink2.jpg', w, h), 131 | read_imgfile('./images/apink3.jpg', w, h), 132 | read_imgfile('./images/handsup1.jpg', w, h), 133 | read_imgfile('./images/p3_dance.png', w, h), 134 | ] 135 | return val_image 136 | -------------------------------------------------------------------------------- /src/datum_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: datum.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='datum.proto', 20 | package='', 21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _DATUM = _descriptor.Descriptor( 29 | name='Datum', 30 | full_name='Datum', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='channels', full_name='Datum.channels', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='Datum.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='width', full_name='Datum.width', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='data', full_name='Datum.data', index=3, 58 | number=4, type=12, cpp_type=9, label=1, 59 | has_default_value=False, default_value=_b(""), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='label', full_name='Datum.label', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='float_data', full_name='Datum.float_data', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='encoded', full_name='Datum.encoded', index=6, 79 | number=7, type=8, cpp_type=7, label=1, 80 | has_default_value=True, default_value=False, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | options=None, 91 | is_extendable=False, 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=16, 96 | serialized_end=145, 97 | ) 98 | 99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM 100 | 101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict( 102 | DESCRIPTOR = _DATUM, 103 | __module__ = 'datum_pb2' 104 | # @@protoc_insertion_point(class_scope:Datum) 105 | )) 106 | _sym_db.RegisterMessage(Datum) 107 | 108 | 109 | # @@protoc_insertion_point(module_scope) 110 | -------------------------------------------------------------------------------- /src/lifting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/lifting/__init__.py -------------------------------------------------------------------------------- /src/lifting/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mar 23 11:57 2017 4 | 5 | @author: Denis Tome' 6 | """ 7 | 8 | __all__ = [ 9 | 'VISIBLE_PART', 10 | 'MIN_NUM_JOINTS', 11 | 'CENTER_TR', 12 | 'SIGMA', 13 | 'STRIDE', 14 | 'SIGMA_CENTER', 15 | 'INPUT_SIZE', 16 | 'OUTPUT_SIZE', 17 | 'NUM_JOINTS', 18 | 'NUM_OUTPUT', 19 | 'H36M_NUM_JOINTS', 20 | 'JOINT_DRAW_SIZE', 21 | 'LIMB_DRAW_SIZE' 22 | ] 23 | 24 | # threshold 25 | VISIBLE_PART = 1e-3 26 | MIN_NUM_JOINTS = 5 27 | CENTER_TR = 0.4 28 | 29 | # net attributes 30 | SIGMA = 7 31 | STRIDE = 8 32 | SIGMA_CENTER = 21 33 | INPUT_SIZE = 368 34 | OUTPUT_SIZE = 46 35 | NUM_JOINTS = 14 36 | NUM_OUTPUT = NUM_JOINTS + 1 37 | H36M_NUM_JOINTS = 17 38 | 39 | # draw options 40 | JOINT_DRAW_SIZE = 3 41 | LIMB_DRAW_SIZE = 2 42 | NORMALISATION_COEFFICIENT = 1280*720 43 | -------------------------------------------------------------------------------- /src/lifting/draw.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mar 23 15:04 2017 4 | 5 | @author: Denis Tome' 6 | """ 7 | import cv2 8 | import numpy as np 9 | from .config import JOINT_DRAW_SIZE 10 | from .config import LIMB_DRAW_SIZE 11 | from .config import NORMALISATION_COEFFICIENT 12 | import matplotlib.pyplot as plt 13 | import math 14 | 15 | __all__ = [ 16 | 'draw_limbs', 17 | 'plot_pose' 18 | ] 19 | 20 | 21 | def draw_limbs(image, pose_2d, visible): 22 | """Draw the 2D pose without the occluded/not visible joints.""" 23 | 24 | _COLORS = [ 25 | [0, 0, 255], [0, 170, 255], [0, 255, 170], [0, 255, 0], 26 | [170, 255, 0], [255, 170, 0], [255, 0, 0], [255, 0, 170], 27 | [170, 0, 255] 28 | ] 29 | _LIMBS = np.array([0, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8, 9, 30 | 9, 10, 11, 12, 12, 13]).reshape((-1, 2)) 31 | 32 | _NORMALISATION_FACTOR = int(math.floor(math.sqrt(image.shape[0] * image.shape[1] / NORMALISATION_COEFFICIENT))) 33 | 34 | for oid in range(pose_2d.shape[0]): 35 | for lid, (p0, p1) in enumerate(_LIMBS): 36 | if not (visible[oid][p0] and visible[oid][p1]): 37 | continue 38 | y0, x0 = pose_2d[oid][p0] 39 | y1, x1 = pose_2d[oid][p1] 40 | cv2.circle(image, (x0, y0), JOINT_DRAW_SIZE *_NORMALISATION_FACTOR , _COLORS[lid], -1) 41 | cv2.circle(image, (x1, y1), JOINT_DRAW_SIZE*_NORMALISATION_FACTOR , _COLORS[lid], -1) 42 | cv2.line(image, (x0, y0), (x1, y1), 43 | _COLORS[lid], LIMB_DRAW_SIZE*_NORMALISATION_FACTOR , 16) 44 | 45 | 46 | def plot_pose(pose): 47 | """Plot the 3D pose showing the joint connections.""" 48 | import mpl_toolkits.mplot3d.axes3d as p3 49 | 50 | _CONNECTION = [ 51 | [0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], 52 | [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], 53 | [15, 16]] 54 | 55 | def joint_color(j): 56 | """ 57 | TODO: 'j' shadows name 'j' from outer scope 58 | """ 59 | 60 | colors = [(0, 0, 0), (255, 0, 255), (0, 0, 255), 61 | (0, 255, 255), (255, 0, 0), (0, 255, 0)] 62 | _c = 0 63 | if j in range(1, 4): 64 | _c = 1 65 | if j in range(4, 7): 66 | _c = 2 67 | if j in range(9, 11): 68 | _c = 3 69 | if j in range(11, 14): 70 | _c = 4 71 | if j in range(14, 17): 72 | _c = 5 73 | return colors[_c] 74 | 75 | assert (pose.ndim == 2) 76 | assert (pose.shape[0] == 3) 77 | fig = plt.figure() 78 | ax = fig.gca(projection='3d') 79 | for c in _CONNECTION: 80 | col = '#%02x%02x%02x' % joint_color(c[0]) 81 | ax.plot([pose[0, c[0]], pose[0, c[1]]], 82 | [pose[1, c[0]], pose[1, c[1]]], 83 | [pose[2, c[0]], pose[2, c[1]]], c=col) 84 | for j in range(pose.shape[1]): 85 | col = '#%02x%02x%02x' % joint_color(j) 86 | ax.scatter(pose[0, j], pose[1, j], pose[2, j], 87 | c=col, marker='o', edgecolor=col) 88 | smallest = pose.min() 89 | largest = pose.max() 90 | ax.set_xlim3d(smallest, largest) 91 | ax.set_ylim3d(smallest, largest) 92 | ax.set_zlim3d(smallest, largest) 93 | 94 | return fig 95 | 96 | 97 | -------------------------------------------------------------------------------- /src/lifting/models/prob_model_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/lifting/models/prob_model_params.mat -------------------------------------------------------------------------------- /src/lifting/prob_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Apr 21 13:53 2017 4 | 5 | @author: Denis Tome' 6 | """ 7 | import numpy as np 8 | import scipy.io as sio 9 | from lifting.upright_fast import pick_e 10 | 11 | from lifting import config 12 | 13 | __all__ = ['Prob3dPose'] 14 | 15 | 16 | class Prob3dPose: 17 | 18 | def __init__(self, prob_model_path): 19 | model_param = sio.loadmat(prob_model_path) 20 | self.mu = np.reshape( 21 | model_param['mu'], (model_param['mu'].shape[0], 3, -1)) 22 | self.e = np.reshape(model_param['e'], (model_param['e'].shape[ 23 | 0], model_param['e'].shape[1], 3, -1)) 24 | self.sigma = model_param['sigma'] 25 | self.cam = np.array( 26 | [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) 27 | 28 | @staticmethod 29 | def cost3d(model, gt): 30 | """3d error in mm""" 31 | out = np.sqrt(((gt - model) ** 2).sum(1)).mean(-1) 32 | return out 33 | 34 | @staticmethod 35 | def renorm_gt(gt): 36 | """Compel gt data to have mean joint length of one""" 37 | _POSE_TREE = np.asarray([ 38 | [0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], 39 | [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], 40 | [15, 16]]).T 41 | scale = np.sqrt(((gt[:, :, _POSE_TREE[0]] - 42 | gt[:, :, _POSE_TREE[1]]) ** 2).sum(2).sum(1)) 43 | return gt / scale[:, np.newaxis, np.newaxis] 44 | 45 | @staticmethod 46 | def build_model(a, e, s0): 47 | """Build 3D model""" 48 | assert (s0.shape[1] == 3) 49 | assert (e.shape[2] == 3) 50 | assert (a.shape[1] == e.shape[1]) 51 | out = np.einsum('...i,...ijk', a, e) 52 | out += s0 53 | return out 54 | 55 | @staticmethod 56 | def build_and_rot_model(a, e, s0, r): 57 | """ 58 | Build model and rotate according to the identified rotation matrix 59 | """ 60 | from numpy.core.umath_tests import matrix_multiply 61 | 62 | r2 = Prob3dPose.upgrade_r(r.T).transpose((0, 2, 1)) 63 | mod = Prob3dPose.build_model(a, e, s0) 64 | mod = matrix_multiply(r2, mod) 65 | return mod 66 | 67 | @staticmethod 68 | def upgrade_r(r): 69 | """ 70 | Upgrades complex parameterisation of planar rotation to tensor 71 | containing per frame 3x3 rotation matrices 72 | """ 73 | assert (r.ndim == 2) 74 | # Technically optional assert, but if this fails data is probably 75 | # transposed 76 | assert (r.shape[1] == 2) 77 | assert (np.all(np.isfinite(r))) 78 | norm = np.sqrt((r[:, :2] ** 2).sum(1)) 79 | assert (np.all(norm > 0)) 80 | r /= norm[:, np.newaxis] 81 | assert (np.all(np.isfinite(r))) 82 | newr = np.zeros((r.shape[0], 3, 3)) 83 | newr[:, :2, 0] = r[:, :2] 84 | newr[:, 2, 2] = 1 85 | newr[:, 1::-1, 1] = r[:, :2] 86 | newr[:, 0, 1] *= -1 87 | return newr 88 | 89 | @staticmethod 90 | def centre(data_2d): 91 | """center data according to each of the coordiante components""" 92 | return (data_2d.T - data_2d.mean(1)).T 93 | 94 | @staticmethod 95 | def centre_all(data): 96 | """center all data""" 97 | if data.ndim == 2: 98 | return Prob3dPose.centre(data) 99 | return (data.transpose(2, 0, 1) - data.mean(2)).transpose(1, 2, 0) 100 | 101 | @staticmethod 102 | def normalise_data(d2, weights): 103 | """Normalise data according to height""" 104 | 105 | # the joints with weight set to 0 should not be considered in the 106 | # normalisation process 107 | d2 = d2.reshape(d2.shape[0], -1, 2).transpose(0, 2, 1) 108 | idx_consider = weights[0, 0].astype(np.bool) 109 | if np.sum(weights[:, 0].sum(1) >= config.MIN_NUM_JOINTS) == 0: 110 | raise Exception('Not enough 2D joints identified to generate 3D pose') 111 | d2[:, :, idx_consider] = Prob3dPose.centre_all(d2[:, :, idx_consider]) 112 | 113 | # Height normalisation (2 meters) 114 | m2 = d2[:, 1, idx_consider].min(1) / 2.0 115 | m2 -= d2[:, 1, idx_consider].max(1) / 2.0 116 | crap = m2 == 0 117 | m2[crap] = 1.0 118 | d2[:, :, idx_consider] /= m2[:, np.newaxis, np.newaxis] 119 | return d2, m2 120 | 121 | @staticmethod 122 | def transform_joints(pose_2d, visible_joints): 123 | """ 124 | Transform the set of joints according to what the probabilistic model 125 | expects as input. 126 | 127 | It returns the new set of joints of each of the people and the set of 128 | weights for the joints. 129 | """ 130 | 131 | _H36M_ORDER = [8, 9, 10, 11, 12, 13, 1, 0, 5, 6, 7, 2, 3, 4] 132 | _W_POS = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16] 133 | 134 | def swap_xy(poses): 135 | tmp = np.copy(poses[:, :, 0]) 136 | poses[:, :, 0] = poses[:, :, 1] 137 | poses[:, :, 1] = tmp 138 | return poses 139 | 140 | assert (pose_2d.ndim == 3) 141 | new_pose = pose_2d.copy() 142 | # new_pose = swap_xy(new_pose) # not used 143 | new_pose = new_pose[:, _H36M_ORDER] 144 | 145 | # defining weights according to occlusions 146 | weights = np.zeros((pose_2d.shape[0], 2, config.H36M_NUM_JOINTS)) 147 | ordered_visibility = np.repeat( 148 | visible_joints[:, _H36M_ORDER, np.newaxis], 2, 2 149 | ).transpose([0, 2, 1]) 150 | weights[:, :, _W_POS] = ordered_visibility 151 | return new_pose, weights 152 | 153 | def affine_estimate(self, w, depth_reg=0.085, weights=None, scale=10.0, 154 | scale_mean=0.0016 * 1.8 * 1.2, scale_std=1.2 * 0, 155 | cap_scale=-0.00129): 156 | """ 157 | Quick switch to allow reconstruction at unknown scale returns a,r 158 | and scale 159 | """ 160 | weights = np.zeros((0, 0, 0)) if weights is None else weights 161 | 162 | s = np.empty((self.sigma.shape[0], self.sigma.shape[1] + 4)) # e,y,x,z 163 | s[:, :4] = 10 ** -5 # Tiny but makes stuff well-posed 164 | s[:, 0] = scale_std 165 | s[:, 4:] = self.sigma 166 | s[:, 4:-1] *= scale 167 | 168 | e2 = np.zeros((self.e.shape[0], self.e.shape[ 169 | 1] + 4, 3, self.e.shape[3])) 170 | e2[:, 1, 0] = 1.0 171 | e2[:, 2, 1] = 1.0 172 | e2[:, 3, 0] = 1.0 173 | # This makes the least_squares problem ill posed, as X,Z are 174 | # interchangable 175 | # Hence regularisation above to speed convergence and stop blow-up 176 | e2[:, 0] = self.mu 177 | e2[:, 4:] = self.e 178 | t_m = np.zeros_like(self.mu) 179 | 180 | res, a, r = pick_e(w, e2, t_m, self.cam, s, weights=weights, 181 | interval=0.01, depth_reg=depth_reg, 182 | scale_prior=scale_mean) 183 | 184 | scale = a[:, :, 0] 185 | reestimate = scale > cap_scale 186 | m = self.mu * cap_scale 187 | for i in range(scale.shape[0]): 188 | if reestimate[i].sum() > 0: 189 | ehat = e2[i:i + 1, 1:] 190 | mhat = m[i:i + 1] 191 | shat = s[i:i + 1, 1:] 192 | (res2, a2, r2) = pick_e( 193 | w[reestimate[i]], ehat, mhat, self.cam, shat, 194 | weights=weights[reestimate[i]], 195 | interval=0.01, depth_reg=depth_reg, 196 | scale_prior=scale_mean 197 | ) 198 | res[i:i + 1, reestimate[i]] = res2 199 | a[i:i + 1, reestimate[i], 1:] = a2 200 | a[i:i + 1, reestimate[i], 0] = cap_scale 201 | r[i:i + 1, :, reestimate[i]] = r2 202 | scale = a[:, :, 0] 203 | a = a[:, :, 1:] / a[:, :, 0][:, :, np.newaxis] 204 | return res, e2[:, 1:], a, r, scale 205 | 206 | def better_rec(self, w, model, s=1, weights=1, damp_z=1): 207 | """Quick switch to allow reconstruction at unknown scale 208 | returns a,r and scale""" 209 | from numpy.core.umath_tests import matrix_multiply 210 | proj = matrix_multiply(self.cam[np.newaxis], model) 211 | proj[:, :2] = (proj[:, :2] * s + w * weights) / (s + weights) 212 | proj[:, 2] *= damp_z 213 | out = matrix_multiply(self.cam.T[np.newaxis], proj) 214 | return out 215 | 216 | def create_rec(self, w2, weights, res_weight=1): 217 | """Reconstruct 3D pose given a 2D pose""" 218 | _SIGMA_SCALING = 5.2 219 | 220 | res, e, a, r, scale = self.affine_estimate( 221 | w2, scale=_SIGMA_SCALING, weights=weights, 222 | depth_reg=0, cap_scale=-0.001, scale_mean=-0.003 223 | ) 224 | 225 | remaining_dims = 3 * w2.shape[2] - e.shape[1] 226 | assert (remaining_dims >= 0) 227 | llambda = -np.log(self.sigma) 228 | lgdet = np.sum(llambda[:, :-1], 1) + llambda[:, -1] * remaining_dims 229 | score = (res * res_weight + lgdet[:, np.newaxis] * (scale ** 2)) 230 | best = np.argmin(score, 0) 231 | index = np.arange(best.shape[0]) 232 | a2 = a[best, index] 233 | r2 = r[best, :, index].T 234 | rec = Prob3dPose.build_and_rot_model(a2, e[best], self.mu[best], r2) 235 | rec *= -np.abs(scale[best, index])[:, np.newaxis, np.newaxis] 236 | 237 | rec = self.better_rec(w2, rec, 1, 1.55 * weights, 1) * -1 238 | rec = Prob3dPose.renorm_gt(rec) 239 | rec *= 0.97 240 | return rec 241 | 242 | def compute_3d(self, pose_2d, weights): 243 | """Reconstruct 3D poses given 2D estimations""" 244 | 245 | _J_POS = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16] 246 | _SCALE_3D = 1174.88312988 247 | 248 | if pose_2d.shape[1] != config.H36M_NUM_JOINTS: 249 | # need to call the linear regressor 250 | reg_joints = np.zeros( 251 | (pose_2d.shape[0], config.H36M_NUM_JOINTS, 2)) 252 | for oid, singe_pose in enumerate(pose_2d): 253 | reg_joints[oid, _J_POS] = singe_pose 254 | 255 | norm_pose, _ = Prob3dPose.normalise_data(reg_joints, weights) 256 | else: 257 | norm_pose, _ = Prob3dPose.normalise_data(pose_2d, weights) 258 | 259 | pose_3d = self.create_rec(norm_pose, weights) * _SCALE_3D 260 | return pose_3d 261 | -------------------------------------------------------------------------------- /src/network_cmu.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | import tensorflow as tf 3 | 4 | 5 | class CmuNetwork(network_base.BaseNetwork): 6 | def setup(self): 7 | (self.feed('image') 8 | .normalize_vgg(name='preprocess') 9 | .conv(3, 3, 64, 1, 1, name='conv1_1') 10 | .conv(3, 3, 64, 1, 1, name='conv1_2') 11 | .max_pool(2, 2, 2, 2, name='pool1_stage1') 12 | .conv(3, 3, 128, 1, 1, name='conv2_1') 13 | .conv(3, 3, 128, 1, 1, name='conv2_2') 14 | .max_pool(2, 2, 2, 2, name='pool2_stage1') 15 | .conv(3, 3, 256, 1, 1, name='conv3_1') 16 | .conv(3, 3, 256, 1, 1, name='conv3_2') 17 | .conv(3, 3, 256, 1, 1, name='conv3_3') 18 | .conv(3, 3, 256, 1, 1, name='conv3_4') 19 | .max_pool(2, 2, 2, 2, name='pool3_stage1') 20 | .conv(3, 3, 512, 1, 1, name='conv4_1') 21 | .conv(3, 3, 512, 1, 1, name='conv4_2') 22 | .conv(3, 3, 256, 1, 1, name='conv4_3_CPM') 23 | .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # ***** 24 | 25 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1') 26 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1') 27 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1') 28 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 29 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 30 | 31 | (self.feed('conv4_4_CPM') 32 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2') 33 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2') 34 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2') 35 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 36 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 37 | 38 | (self.feed('conv5_5_CPM_L1', 39 | 'conv5_5_CPM_L2', 40 | 'conv4_4_CPM') 41 | .concat(3, name='concat_stage2') 42 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1') 43 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1') 44 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1') 45 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1') 46 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1') 47 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 48 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 49 | 50 | (self.feed('concat_stage2') 51 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2') 52 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2') 53 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2') 54 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2') 55 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2') 56 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 57 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 58 | 59 | (self.feed('Mconv7_stage2_L1', 60 | 'Mconv7_stage2_L2', 61 | 'conv4_4_CPM') 62 | .concat(3, name='concat_stage3') 63 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1') 64 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1') 65 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1') 66 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1') 67 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1') 68 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 69 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 70 | 71 | (self.feed('concat_stage3') 72 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2') 73 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2') 74 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2') 75 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2') 76 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2') 77 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 78 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 79 | 80 | (self.feed('Mconv7_stage3_L1', 81 | 'Mconv7_stage3_L2', 82 | 'conv4_4_CPM') 83 | .concat(3, name='concat_stage4') 84 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1') 85 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1') 86 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1') 87 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1') 88 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1') 89 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 90 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 91 | 92 | (self.feed('concat_stage4') 93 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2') 94 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2') 95 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2') 96 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2') 97 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2') 98 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 99 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 100 | 101 | (self.feed('Mconv7_stage4_L1', 102 | 'Mconv7_stage4_L2', 103 | 'conv4_4_CPM') 104 | .concat(3, name='concat_stage5') 105 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1') 106 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1') 107 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1') 108 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1') 109 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1') 110 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 111 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 112 | 113 | (self.feed('concat_stage5') 114 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2') 115 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2') 116 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2') 117 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2') 118 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2') 119 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 120 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 121 | 122 | (self.feed('Mconv7_stage5_L1', 123 | 'Mconv7_stage5_L2', 124 | 'conv4_4_CPM') 125 | .concat(3, name='concat_stage6') 126 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1') 127 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1') 128 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1') 129 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1') 130 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1') 131 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 132 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 133 | 134 | (self.feed('concat_stage6') 135 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2') 136 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2') 137 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2') 138 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2') 139 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2') 140 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 141 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 142 | 143 | with tf.variable_scope('Openpose'): 144 | (self.feed('Mconv7_stage6_L2', 145 | 'Mconv7_stage6_L1') 146 | .concat(3, name='concat_stage7')) 147 | 148 | def loss_l1_l2(self): 149 | l1s = [] 150 | l2s = [] 151 | for layer_name in self.layers.keys(): 152 | if 'Mconv7' in layer_name and '_L1' in layer_name: 153 | l1s.append(self.layers[layer_name]) 154 | if 'Mconv7' in layer_name and '_L2' in layer_name: 155 | l2s.append(self.layers[layer_name]) 156 | 157 | return l1s, l2s 158 | 159 | def loss_last(self): 160 | return self.get_output('Mconv7_stage6_L1'), self.get_output('Mconv7_stage6_L2') 161 | 162 | def restorable_variables(self): 163 | return None -------------------------------------------------------------------------------- /src/network_dsconv.py: -------------------------------------------------------------------------------- 1 | import network_base 2 | 3 | 4 | class DSConvNetwork(network_base.BaseNetwork): 5 | def __init__(self, inputs, trainable=True, conv_width=1.0): 6 | self.conv_width = conv_width 7 | network_base.BaseNetwork.__init__(self, inputs, trainable) 8 | 9 | def setup(self): 10 | (self.feed('image') 11 | .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False) 12 | # .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=True) # TODO 13 | .separable_conv(3, 3, round(self.conv_width * 64), 2, name='conv1_2') 14 | # .max_pool(2, 2, 2, 2, name='pool1_stage1') 15 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv2_1') 16 | .separable_conv(3, 3, round(self.conv_width * 128), 2, name='conv2_2') 17 | # .max_pool(2, 2, 2, 2, name='pool2_stage1') 18 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_1') 19 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_2') 20 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_3') 21 | .separable_conv(3, 3, round(self.conv_width * 256), 2, name='conv3_4') 22 | # .max_pool(2, 2, 2, 2, name='pool3_stage1') 23 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_1') 24 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_2') 25 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv4_3_CPM') 26 | .separable_conv(3, 3, 128, 1, name='conv4_4_CPM') 27 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L1') 28 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L1') 29 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L1') 30 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1') 31 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1')) 32 | 33 | (self.feed('conv4_4_CPM') 34 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L2') 35 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L2') 36 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L2') 37 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2') 38 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2')) 39 | 40 | (self.feed('conv5_5_CPM_L1', 41 | 'conv5_5_CPM_L2', 42 | 'conv4_4_CPM') 43 | .concat(3, name='concat_stage2') 44 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L1') 45 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L1') 46 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L1') 47 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L1') 48 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L1') 49 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1') 50 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1')) 51 | 52 | (self.feed('concat_stage2') 53 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L2') 54 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L2') 55 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L2') 56 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L2') 57 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L2') 58 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2') 59 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2')) 60 | 61 | (self.feed('Mconv7_stage2_L1', 62 | 'Mconv7_stage2_L2', 63 | 'conv4_4_CPM') 64 | .concat(3, name='concat_stage3') 65 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L1') 66 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L1') 67 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L1') 68 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L1') 69 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L1') 70 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1') 71 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1')) 72 | 73 | (self.feed('concat_stage3') 74 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L2') 75 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L2') 76 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L2') 77 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L2') 78 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L2') 79 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2') 80 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2')) 81 | 82 | (self.feed('Mconv7_stage3_L1', 83 | 'Mconv7_stage3_L2', 84 | 'conv4_4_CPM') 85 | .concat(3, name='concat_stage4') 86 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L1') 87 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L1') 88 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L1') 89 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L1') 90 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L1') 91 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1') 92 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1')) 93 | 94 | (self.feed('concat_stage4') 95 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L2') 96 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L2') 97 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L2') 98 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L2') 99 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L2') 100 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2') 101 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2')) 102 | 103 | (self.feed('Mconv7_stage4_L1', 104 | 'Mconv7_stage4_L2', 105 | 'conv4_4_CPM') 106 | .concat(3, name='concat_stage5') 107 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L1') 108 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L1') 109 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L1') 110 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L1') 111 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L1') 112 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1') 113 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1')) 114 | 115 | (self.feed('concat_stage5') 116 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L2') 117 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L2') 118 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L2') 119 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L2') 120 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L2') 121 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2') 122 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2')) 123 | 124 | (self.feed('Mconv7_stage5_L1', 125 | 'Mconv7_stage5_L2', 126 | 'conv4_4_CPM') 127 | .concat(3, name='concat_stage6') 128 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L1') 129 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L1') 130 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L1') 131 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L1') 132 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L1') 133 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1') 134 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1')) 135 | 136 | (self.feed('concat_stage6') 137 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L2') 138 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L2') 139 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L2') 140 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L2') 141 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L2') 142 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2') 143 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2')) 144 | 145 | (self.feed('Mconv7_stage6_L2', 146 | 'Mconv7_stage6_L1') 147 | .concat(3, name='concat_stage7')) 148 | -------------------------------------------------------------------------------- /src/network_mobilenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import network_base 4 | 5 | 6 | class MobilenetNetwork(network_base.BaseNetwork): 7 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 8 | self.conv_width = conv_width 9 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 10 | self.num_refine = 4 11 | network_base.BaseNetwork.__init__(self, inputs, trainable) 12 | 13 | def setup(self): 14 | min_depth = 8 15 | depth = lambda d: max(int(d * self.conv_width), min_depth) 16 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 17 | 18 | with tf.variable_scope(None, 'MobilenetV1'): 19 | (self.feed('image') 20 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 21 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 22 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 23 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 24 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 25 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 29 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 30 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 31 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 32 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 33 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 34 | ) 35 | 36 | (self.feed('Conv2d_1').max_pool(2, 2, 2, 2, name='Conv2d_1_pool')) 37 | (self.feed('Conv2d_7').upsample(2, name='Conv2d_7_upsample')) 38 | 39 | (self.feed('Conv2d_1_pool', 'Conv2d_3', 'Conv2d_7_upsample') 40 | .concat(3, name='feat_concat')) 41 | 42 | feature_lv = 'feat_concat' 43 | with tf.variable_scope(None, 'Openpose'): 44 | prefix = 'MConv_Stage1' 45 | (self.feed(feature_lv) 46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 47 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 48 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 49 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 50 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 51 | 52 | (self.feed(feature_lv) 53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 54 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 55 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 56 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 57 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 58 | 59 | for stage_id in range(self.num_refine): 60 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 61 | prefix = 'MConv_Stage%d' % (stage_id + 2) 62 | (self.feed(prefix_prev + '_L1_5', 63 | prefix_prev + '_L2_5', 64 | feature_lv) 65 | .concat(3, name=prefix + '_concat') 66 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_1') 67 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_2') 68 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_3') 69 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 70 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 71 | 72 | (self.feed(prefix + '_concat') 73 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_1') 74 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_2') 75 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_3') 76 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 77 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 78 | 79 | # final result 80 | (self.feed('MConv_Stage%d_L2_5' % self.get_refine_num(), 81 | 'MConv_Stage%d_L1_5' % self.get_refine_num()) 82 | .concat(3, name='concat_stage7')) 83 | 84 | def loss_l1_l2(self): 85 | l1s = [] 86 | l2s = [] 87 | for layer_name in sorted(self.layers.keys()): 88 | if '_L1_5' in layer_name: 89 | l1s.append(self.layers[layer_name]) 90 | if '_L2_5' in layer_name: 91 | l2s.append(self.layers[layer_name]) 92 | 93 | return l1s, l2s 94 | 95 | def loss_last(self): 96 | return self.get_output('MConv_Stage%d_L1_5' % self.get_refine_num()), \ 97 | self.get_output('MConv_Stage%d_L2_5' % self.get_refine_num()) 98 | 99 | def restorable_variables(self): 100 | vs = {v.op.name: v for v in tf.global_variables() if 101 | 'MobilenetV1/Conv2d' in v.op.name and 102 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 'Ada' not in v.op.name 103 | } 104 | return vs 105 | 106 | def get_refine_num(self): 107 | return self.num_refine + 1 108 | -------------------------------------------------------------------------------- /src/network_mobilenet_thin.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import network_base 4 | 5 | 6 | class MobilenetNetworkThin(network_base.BaseNetwork): 7 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None): 8 | self.conv_width = conv_width 9 | self.conv_width2 = conv_width2 if conv_width2 else conv_width 10 | network_base.BaseNetwork.__init__(self, inputs, trainable) 11 | 12 | def setup(self): 13 | min_depth = 8 14 | depth = lambda d: max(int(d * self.conv_width), min_depth) 15 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth) 16 | 17 | with tf.variable_scope(None, 'MobilenetV1'): 18 | (self.feed('image') 19 | .convb(3, 3, depth(32), 2, name='Conv2d_0') 20 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1') 21 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2') 22 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3') 23 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4') 24 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5') 25 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6') 26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7') 27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8') 28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_9') 29 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_10') 30 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_11') 31 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12') 32 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13') 33 | ) 34 | 35 | (self.feed('Conv2d_3').max_pool(2, 2, 2, 2, name='Conv2d_3_pool')) 36 | 37 | (self.feed('Conv2d_3_pool', 'Conv2d_7', 'Conv2d_11') 38 | .concat(3, name='feat_concat')) 39 | 40 | feature_lv = 'feat_concat' 41 | with tf.variable_scope(None, 'Openpose'): 42 | prefix = 'MConv_Stage1' 43 | (self.feed(feature_lv) 44 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 45 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 47 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4') 48 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 49 | 50 | (self.feed(feature_lv) 51 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 52 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 54 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4') 55 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 56 | 57 | for stage_id in range(5): 58 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1) 59 | prefix = 'MConv_Stage%d' % (stage_id + 2) 60 | (self.feed(prefix_prev + '_L1_5', 61 | prefix_prev + '_L2_5', 62 | feature_lv) 63 | .concat(3, name=prefix + '_concat') 64 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1') 65 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2') 66 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3') 67 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4') 68 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5')) 69 | 70 | (self.feed(prefix + '_concat') 71 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1') 72 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2') 73 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3') 74 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4') 75 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5')) 76 | 77 | # final result 78 | (self.feed('MConv_Stage6_L2_5', 79 | 'MConv_Stage6_L1_5') 80 | .concat(3, name='concat_stage7')) 81 | 82 | def loss_l1_l2(self): 83 | l1s = [] 84 | l2s = [] 85 | for layer_name in sorted(self.layers.keys()): 86 | if '_L1_5' in layer_name: 87 | l1s.append(self.layers[layer_name]) 88 | if '_L2_5' in layer_name: 89 | l2s.append(self.layers[layer_name]) 90 | 91 | return l1s, l2s 92 | 93 | def loss_last(self): 94 | return self.get_output('MConv_Stage6_L1_5'), self.get_output('MConv_Stage6_L2_5') 95 | 96 | def restorable_variables(self): 97 | vs = {v.op.name: v for v in tf.global_variables() if 98 | 'MobilenetV1/Conv2d' in v.op.name and 99 | # 'global_step' not in v.op.name and 100 | # 'beta1_power' not in v.op.name and 'beta2_power' not in v.op.name and 101 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 102 | 'Ada' not in v.op.name and 'Adam' not in v.op.name 103 | } 104 | return vs 105 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | from network_mobilenet import MobilenetNetwork 5 | from network_mobilenet_thin import MobilenetNetworkThin 6 | 7 | from network_cmu import CmuNetwork 8 | 9 | 10 | def _get_base_path(): 11 | if not os.environ.get('OPENPOSE_MODEL', ''): 12 | return './models' 13 | return os.environ.get('OPENPOSE_MODEL') 14 | 15 | 16 | def get_network(type, placeholder_input, sess_for_load=None, trainable=True): 17 | if type == 'mobilenet': 18 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable) 19 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 20 | last_layer = 'MConv_Stage6_L{aux}_5' 21 | elif type == 'mobilenet_fast': 22 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.5, conv_width2=0.5, trainable=trainable) 23 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 24 | last_layer = 'MConv_Stage6_L{aux}_5' 25 | elif type == 'mobilenet_accurate': 26 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=1.00, conv_width2=1.00, trainable=trainable) 27 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt' 28 | last_layer = 'MConv_Stage6_L{aux}_5' 29 | 30 | elif type == 'mobilenet_thin': 31 | net = MobilenetNetworkThin({'image': placeholder_input}, conv_width=0.75, conv_width2=0.50, trainable=trainable) 32 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_1.0_224.ckpt' 33 | last_layer = 'MConv_Stage6_L{aux}_5' 34 | 35 | elif type == 'cmu': 36 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable) 37 | pretrain_path = 'numpy/openpose_coco.npy' 38 | last_layer = 'Mconv7_stage6_L{aux}' 39 | elif type == 'vgg': 40 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable) 41 | pretrain_path = 'numpy/openpose_vgg16.npy' 42 | last_layer = 'Mconv7_stage6_L{aux}' 43 | else: 44 | raise Exception('Invalid Mode.') 45 | 46 | if sess_for_load is not None: 47 | if type == 'cmu' or type == 'vgg': 48 | net.load(os.path.join(_get_base_path(), pretrain_path), sess_for_load) 49 | else: 50 | s = '%dx%d' % (placeholder_input.shape[2], placeholder_input.shape[1]) 51 | ckpts = { 52 | 'mobilenet': 'trained/mobilenet_%s/model-246038' % s, 53 | 'mobilenet_thin': 'trained/mobilenet_thin_%s/model-449003' % s, 54 | 'mobilenet_fast': 'trained/mobilenet_fast_%s/model-189000' % s, 55 | 'mobilenet_accurate': 'trained/mobilenet_accurate/model-170000' 56 | } 57 | loader = tf.train.Saver() 58 | loader.restore(sess_for_load, os.path.join(_get_base_path(), ckpts[type])) 59 | 60 | return net, os.path.join(_get_base_path(), pretrain_path), last_layer 61 | 62 | 63 | def get_graph_path(model_name): 64 | return { 65 | 'cmu_640x480': './models/graph/cmu_640x480/graph_opt.pb', 66 | 'cmuq_640x480': './models/graph/cmu_640x480/graph_q.pb', 67 | 68 | 'cmu_640x360': './models/graph/cmu_640x360/graph_opt.pb', 69 | 'cmuq_640x360': './models/graph/cmu_640x360/graph_q.pb', 70 | 71 | 'mobilenet_thin_432x368': './models/graph/mobilenet_thin_432x368/graph_opt.pb', 72 | }[model_name] 73 | 74 | 75 | def model_wh(model_name): 76 | width, height = model_name.split('_')[-1].split('x') 77 | return int(width), int(height) 78 | -------------------------------------------------------------------------------- /src/pose_augment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | from tensorpack.dataflow.imgaug.geometry import RotationAndCropValid 7 | 8 | from common import CocoPart 9 | 10 | _network_w = 368 11 | _network_h = 368 12 | _scale = 2 13 | 14 | 15 | def set_network_input_wh(w, h): 16 | global _network_w, _network_h 17 | _network_w, _network_h = w, h 18 | 19 | 20 | def set_network_scale(scale): 21 | global _scale 22 | _scale = scale 23 | 24 | 25 | def pose_random_scale(meta): 26 | scalew = random.uniform(0.8, 1.2) 27 | scaleh = random.uniform(0.8, 1.2) 28 | neww = int(meta.width * scalew) 29 | newh = int(meta.height * scaleh) 30 | dst = cv2.resize(meta.img, (neww, newh), interpolation=cv2.INTER_AREA) 31 | 32 | # adjust meta data 33 | adjust_joint_list = [] 34 | for joint in meta.joint_list: 35 | adjust_joint = [] 36 | for point in joint: 37 | if point[0] < -100 or point[1] < -100: 38 | adjust_joint.append((-1000, -1000)) 39 | continue 40 | # if point[0] <= 0 or point[1] <= 0 or int(point[0] * scalew + 0.5) > neww or int( 41 | # point[1] * scaleh + 0.5) > newh: 42 | # adjust_joint.append((-1, -1)) 43 | # continue 44 | adjust_joint.append((int(point[0] * scalew + 0.5), int(point[1] * scaleh + 0.5))) 45 | adjust_joint_list.append(adjust_joint) 46 | 47 | meta.joint_list = adjust_joint_list 48 | meta.width, meta.height = neww, newh 49 | meta.img = dst 50 | return meta 51 | 52 | 53 | def pose_resize_shortestedge_fixed(meta): 54 | ratio_w = _network_w / meta.width 55 | ratio_h = _network_h / meta.height 56 | ratio = max(ratio_w, ratio_h) 57 | return pose_resize_shortestedge(meta, int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5))) 58 | 59 | 60 | def pose_resize_shortestedge_random(meta): 61 | ratio_w = _network_w / meta.width 62 | ratio_h = _network_h / meta.height 63 | ratio = min(ratio_w, ratio_h) 64 | target_size = int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5)) 65 | target_size = int(target_size * random.uniform(0.95, 1.6)) 66 | # target_size = int(min(_network_w, _network_h) * random.uniform(0.7, 1.5)) 67 | return pose_resize_shortestedge(meta, target_size) 68 | 69 | 70 | def pose_resize_shortestedge(meta, target_size): 71 | global _network_w, _network_h 72 | img = meta.img 73 | 74 | # adjust image 75 | scale = target_size / min(meta.height, meta.width) 76 | if meta.height < meta.width: 77 | newh, neww = target_size, int(scale * meta.width + 0.5) 78 | else: 79 | newh, neww = int(scale * meta.height + 0.5), target_size 80 | 81 | dst = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA) 82 | 83 | pw = ph = 0 84 | if neww < _network_w or newh < _network_h: 85 | pw = max(0, (_network_w - neww) // 2) 86 | ph = max(0, (_network_h - newh) // 2) 87 | mw = (_network_w - neww) % 2 88 | mh = (_network_h - newh) % 2 89 | color = random.randint(0, 255) 90 | dst = cv2.copyMakeBorder(dst, ph, ph+mh, pw, pw+mw, cv2.BORDER_CONSTANT, value=(color, 0, 0)) 91 | 92 | # adjust meta data 93 | adjust_joint_list = [] 94 | for joint in meta.joint_list: 95 | adjust_joint = [] 96 | for point in joint: 97 | if point[0] < -100 or point[1] < -100: 98 | adjust_joint.append((-1000, -1000)) 99 | continue 100 | # if point[0] <= 0 or point[1] <= 0 or int(point[0]*scale+0.5) > neww or int(point[1]*scale+0.5) > newh: 101 | # adjust_joint.append((-1, -1)) 102 | # continue 103 | adjust_joint.append((int(point[0]*scale+0.5) + pw, int(point[1]*scale+0.5) + ph)) 104 | adjust_joint_list.append(adjust_joint) 105 | 106 | meta.joint_list = adjust_joint_list 107 | meta.width, meta.height = neww + pw * 2, newh + ph * 2 108 | meta.img = dst 109 | return meta 110 | 111 | 112 | def pose_crop_center(meta): 113 | global _network_w, _network_h 114 | target_size = (_network_w, _network_h) 115 | x = (meta.width - target_size[0]) // 2 if meta.width > target_size[0] else 0 116 | y = (meta.height - target_size[1]) // 2 if meta.height > target_size[1] else 0 117 | 118 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 119 | 120 | 121 | def pose_crop_random(meta): 122 | global _network_w, _network_h 123 | target_size = (_network_w, _network_h) 124 | 125 | for _ in range(50): 126 | x = random.randrange(0, meta.width - target_size[0]) if meta.width > target_size[0] else 0 127 | y = random.randrange(0, meta.height - target_size[1]) if meta.height > target_size[1] else 0 128 | 129 | # check whether any face is inside the box to generate a reasonably-balanced datasets 130 | for joint in meta.joint_list: 131 | if x <= joint[CocoPart.Nose.value][0] < x + target_size[0] and y <= joint[CocoPart.Nose.value][1] < y + target_size[1]: 132 | break 133 | 134 | return pose_crop(meta, x, y, target_size[0], target_size[1]) 135 | 136 | 137 | def pose_crop(meta, x, y, w, h): 138 | # adjust image 139 | target_size = (w, h) 140 | 141 | img = meta.img 142 | resized = img[y:y+target_size[1], x:x+target_size[0], :] 143 | 144 | # adjust meta data 145 | adjust_joint_list = [] 146 | for joint in meta.joint_list: 147 | adjust_joint = [] 148 | for point in joint: 149 | if point[0] < -100 or point[1] < -100: 150 | adjust_joint.append((-1000, -1000)) 151 | continue 152 | # if point[0] <= 0 or point[1] <= 0: 153 | # adjust_joint.append((-1000, -1000)) 154 | # continue 155 | new_x, new_y = point[0] - x, point[1] - y 156 | # if new_x <= 0 or new_y <= 0 or new_x > target_size[0] or new_y > target_size[1]: 157 | # adjust_joint.append((-1, -1)) 158 | # continue 159 | adjust_joint.append((new_x, new_y)) 160 | adjust_joint_list.append(adjust_joint) 161 | 162 | meta.joint_list = adjust_joint_list 163 | meta.width, meta.height = target_size 164 | meta.img = resized 165 | return meta 166 | 167 | 168 | def pose_flip(meta): 169 | r = random.uniform(0, 1.0) 170 | if r > 0.5: 171 | return meta 172 | 173 | img = meta.img 174 | img = cv2.flip(img, 1) 175 | 176 | # flip meta 177 | flip_list = [CocoPart.Nose, CocoPart.Neck, CocoPart.LShoulder, CocoPart.LElbow, CocoPart.LWrist, CocoPart.RShoulder, CocoPart.RElbow, CocoPart.RWrist, 178 | CocoPart.LHip, CocoPart.LKnee, CocoPart.LAnkle, CocoPart.RHip, CocoPart.RKnee, CocoPart.RAnkle, 179 | CocoPart.LEye, CocoPart.REye, CocoPart.LEar, CocoPart.REar, CocoPart.Background] 180 | adjust_joint_list = [] 181 | for joint in meta.joint_list: 182 | adjust_joint = [] 183 | for cocopart in flip_list: 184 | point = joint[cocopart.value] 185 | if point[0] < -100 or point[1] < -100: 186 | adjust_joint.append((-1000, -1000)) 187 | continue 188 | # if point[0] <= 0 or point[1] <= 0: 189 | # adjust_joint.append((-1, -1)) 190 | # continue 191 | adjust_joint.append((meta.width - point[0], point[1])) 192 | adjust_joint_list.append(adjust_joint) 193 | 194 | meta.joint_list = adjust_joint_list 195 | 196 | meta.img = img 197 | return meta 198 | 199 | 200 | def pose_rotation(meta): 201 | deg = random.uniform(-15.0, 15.0) 202 | img = meta.img 203 | 204 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5) # x, y 205 | rot_m = cv2.getRotationMatrix2D((int(center[0]), int(center[1])), deg, 1) 206 | ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], flags=cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT) 207 | if img.ndim == 3 and ret.ndim == 2: 208 | ret = ret[:, :, np.newaxis] 209 | neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) 210 | neww = min(neww, ret.shape[1]) 211 | newh = min(newh, ret.shape[0]) 212 | newx = int(center[0] - neww * 0.5) 213 | newy = int(center[1] - newh * 0.5) 214 | # print(ret.shape, deg, newx, newy, neww, newh) 215 | img = ret[newy:newy + newh, newx:newx + neww] 216 | 217 | # adjust meta data 218 | adjust_joint_list = [] 219 | for joint in meta.joint_list: 220 | adjust_joint = [] 221 | for point in joint: 222 | if point[0] < -100 or point[1] < -100: 223 | adjust_joint.append((-1000, -1000)) 224 | continue 225 | # if point[0] <= 0 or point[1] <= 0: 226 | # adjust_joint.append((-1, -1)) 227 | # continue 228 | x, y = _rotate_coord((meta.width, meta.height), (newx, newy), point, deg) 229 | adjust_joint.append((x, y)) 230 | adjust_joint_list.append(adjust_joint) 231 | 232 | meta.joint_list = adjust_joint_list 233 | meta.width, meta.height = neww, newh 234 | meta.img = img 235 | 236 | return meta 237 | 238 | 239 | def _rotate_coord(shape, newxy, point, angle): 240 | angle = -1 * angle / 180.0 * math.pi 241 | 242 | ox, oy = shape 243 | px, py = point 244 | 245 | ox /= 2 246 | oy /= 2 247 | 248 | qx = math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) 249 | qy = math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) 250 | 251 | new_x, new_y = newxy 252 | 253 | qx += ox - new_x 254 | qy += oy - new_y 255 | 256 | return int(qx + 0.5), int(qy + 0.5) 257 | 258 | 259 | def pose_to_img(meta_l): 260 | global _network_w, _network_h, _scale 261 | return [meta_l[0].img.astype(np.float16), 262 | meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)), 263 | meta_l[0].get_vectormap(target_size=(_network_w // _scale, _network_h // _scale))] 264 | -------------------------------------------------------------------------------- /src/pose_datamaster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | from tensorpack.dataflow.remote import RemoteDataZMQ 6 | 7 | from pose_dataset import CocoPose 8 | 9 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') 10 | 11 | if __name__ == '__main__': 12 | """ 13 | Speed Test for Getting Input batches from other nodes 14 | """ 15 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 16 | parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027') 17 | parser.add_argument('--show', type=bool, default=False) 18 | args = parser.parse_args() 19 | 20 | df = RemoteDataZMQ(args.listen) 21 | 22 | logging.info('tcp queue start') 23 | df.reset_state() 24 | t = time.time() 25 | for i, dp in enumerate(df.get_data()): 26 | if i == 100: 27 | break 28 | logging.info('Input batch %d received.' % i) 29 | if i == 0: 30 | for d in dp: 31 | logging.info('%d dp shape={}'.format(d.shape)) 32 | 33 | if args.show: 34 | CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0]) 35 | 36 | logging.info('Speed Test Done for 100 Batches in %f seconds.' % (time.time() - t)) 37 | -------------------------------------------------------------------------------- /src/pose_dataworker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tensorpack.dataflow.remote import send_dataflow_zmq 4 | 5 | from pose_dataset import get_dataflow_batch 6 | from pose_augment import set_network_input_wh, set_network_scale 7 | 8 | if __name__ == '__main__': 9 | """ 10 | OpenPose Data Preparation might be a bottleneck for training. 11 | You can run multiple workers to generate input batches in multi-nodes to make training process faster. 12 | """ 13 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.') 14 | parser.add_argument('--datapath', type=str, default='/coco/annotations/') 15 | parser.add_argument('--imgpath', type=str, default='/coco/') 16 | parser.add_argument('--batchsize', type=int, default=64) 17 | parser.add_argument('--train', type=bool, default=True) 18 | parser.add_argument('--master', type=str, default='tcp://csi-cluster-gpu20.dakao.io:1027') 19 | parser.add_argument('--input-width', type=int, default=368) 20 | parser.add_argument('--input-height', type=int, default=368) 21 | parser.add_argument('--scale-factor', type=int, default=2) 22 | args = parser.parse_args() 23 | 24 | set_network_input_wh(args.input_width, args.input_height) 25 | set_network_scale(args.scale_factor) 26 | 27 | df = get_dataflow_batch(args.datapath, args.train, args.batchsize, args.imgpath) 28 | 29 | send_dataflow_zmq(df, args.master, hwm=10) 30 | -------------------------------------------------------------------------------- /src/pose_stats.py: -------------------------------------------------------------------------------- 1 | from pose_dataset import CocoPose 2 | from tensorpack import imgaug 3 | from tensorpack.dataflow.common import MapDataComponent, MapData 4 | from tensorpack.dataflow.image import AugmentImageComponent 5 | 6 | from pose_augment import * 7 | 8 | 9 | def get_idx_hands_up(): 10 | from src.pose_augment import set_network_input_wh 11 | set_network_input_wh(368, 368) 12 | 13 | show_sample = True 14 | db = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train=True, decode_img=show_sample) 15 | db.reset_state() 16 | total_cnt = 0 17 | handup_cnt = 0 18 | for idx, metas in enumerate(db.get_data()): 19 | meta = metas[0] 20 | if len(meta.joint_list) <= 0: 21 | continue 22 | body = meta.joint_list[0] 23 | if body[CocoPart.Neck.value][1] <= 0: 24 | continue 25 | if body[CocoPart.LWrist.value][1] <= 0: 26 | continue 27 | if body[CocoPart.RWrist.value][1] <= 0: 28 | continue 29 | 30 | if body[CocoPart.Neck.value][1] > body[CocoPart.LWrist.value][1] or body[CocoPart.Neck.value][1] > body[CocoPart.RWrist.value][1]: 31 | print(meta.idx) 32 | handup_cnt += 1 33 | 34 | if show_sample: 35 | l1, l2, l3 = pose_to_img(metas) 36 | CocoPose.display_image(l1, l2, l3) 37 | 38 | total_cnt += 1 39 | 40 | print('%d / %d' % (handup_cnt, total_cnt)) 41 | 42 | 43 | def sample_augmentations(): 44 | ds = CocoPose('/data/public/rw/coco-pose-estimation-lmdb/', is_train=False, only_idx=0) 45 | ds = MapDataComponent(ds, pose_random_scale) 46 | ds = MapDataComponent(ds, pose_rotation) 47 | ds = MapDataComponent(ds, pose_flip) 48 | ds = MapDataComponent(ds, pose_resize_shortestedge_random) 49 | ds = MapDataComponent(ds, pose_crop_random) 50 | ds = MapData(ds, pose_to_img) 51 | augs = [ 52 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([ 53 | imgaug.GaussianBlur(3), 54 | imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 55 | imgaug.RandomOrderAug([ 56 | imgaug.BrightnessScale((0.8, 1.2), clip=False), 57 | imgaug.Contrast((0.8, 1.2), clip=False), 58 | # imgaug.Saturation(0.4, rgb=True), 59 | ]), 60 | ]), 0.7), 61 | ] 62 | ds = AugmentImageComponent(ds, augs) 63 | 64 | ds.reset_state() 65 | for l1, l2, l3 in ds.get_data(): 66 | CocoPose.display_image(l1, l2, l3) 67 | 68 | 69 | if __name__ == '__main__': 70 | # codes for tests 71 | # get_idx_hands_up() 72 | 73 | # show augmentation samples 74 | sample_augmentations() 75 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | import ast 5 | 6 | import common 7 | import cv2 8 | import numpy as np 9 | from estimator import TfPoseEstimator 10 | from networks import get_graph_path, model_wh 11 | 12 | from lifting.prob_model import Prob3dPose 13 | from lifting.draw import plot_pose 14 | 15 | logger = logging.getLogger('TfPoseEstimator') 16 | logger.setLevel(logging.DEBUG) 17 | ch = logging.StreamHandler() 18 | ch.setLevel(logging.DEBUG) 19 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 20 | ch.setFormatter(formatter) 21 | logger.addHandler(ch) 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser(description='tf-pose-estimation run') 26 | # parser.add_argument('--image', type=str, default='/Users/ildoonet/Downloads/me.jpg') 27 | parser.add_argument('--image', type=str, default='./images/apink2.jpg') 28 | # parser.add_argument('--model', type=str, default='mobilenet_320x240', help='cmu / mobilenet_320x240') 29 | parser.add_argument('--model', type=str, default='mobilenet_thin_432x368', help='cmu_640x480 / cmu_640x360 / mobilenet_thin_432x368') 30 | parser.add_argument('--scales', type=str, default='[None]', help='for multiple scales, eg. [1.0, (1.1, 0.05)]') 31 | args = parser.parse_args() 32 | scales = ast.literal_eval(scales) 33 | 34 | w, h = model_wh(args.model) 35 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 36 | 37 | # estimate human poses from a single image ! 38 | image = common.read_imgfile(args.image, None, None) 39 | # image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21) 40 | t = time.time() 41 | humans = e.inference(image, scales=[None]) 42 | # humans = e.inference(image, scales=[None, (0.7, 0.5, 1.8)]) 43 | # humans = e.inference(image, scales=[(1.2, 0.05)]) 44 | # humans = e.inference(image, scales=[(0.2, 0.2, 1.4)]) 45 | elapsed = time.time() - t 46 | 47 | logger.info('inference image: %s in %.4f seconds.' % (args.image, elapsed)) 48 | 49 | image = cv2.imread(args.image, cv2.IMREAD_COLOR) 50 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 51 | cv2.imshow('tf-pose-estimation result', image) 52 | cv2.waitKey() 53 | 54 | import sys 55 | sys.exit(0) 56 | 57 | logger.info('3d lifting initialization.') 58 | poseLifting = Prob3dPose('./src/lifting/models/prob_model_params.mat') 59 | 60 | image_h, image_w = image.shape[:2] 61 | standard_w = 640 62 | standard_h = 480 63 | 64 | pose_2d_mpiis = [] 65 | visibilities = [] 66 | for human in humans: 67 | pose_2d_mpii, visibility = common.MPIIPart.from_coco(human) 68 | pose_2d_mpiis.append([(int(x * standard_w + 0.5), int(y * standard_h + 0.5)) for x, y in pose_2d_mpii]) 69 | visibilities.append(visibility) 70 | 71 | pose_2d_mpiis = np.array(pose_2d_mpiis) 72 | visibilities = np.array(visibilities) 73 | transformed_pose2d, weights = poseLifting.transform_joints(pose_2d_mpiis, visibilities) 74 | pose_3d = poseLifting.compute_3d(transformed_pose2d, weights) 75 | 76 | import matplotlib.pyplot as plt 77 | 78 | fig = plt.figure() 79 | a = fig.add_subplot(2, 2, 1) 80 | a.set_title('Result') 81 | plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 82 | 83 | # show network output 84 | a = fig.add_subplot(2, 2, 2) 85 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5) 86 | tmp = np.amax(e.heatMat, axis=2) 87 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5) 88 | plt.colorbar() 89 | 90 | tmp2 = e.pafMat.transpose((2, 0, 1)) 91 | tmp2_odd = np.amax(np.absolute(tmp2[::2, :, :]), axis=0) 92 | tmp2_even = np.amax(np.absolute(tmp2[1::2, :, :]), axis=0) 93 | 94 | a = fig.add_subplot(2, 2, 3) 95 | a.set_title('Vectormap-x') 96 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 97 | plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5) 98 | plt.colorbar() 99 | 100 | a = fig.add_subplot(2, 2, 4) 101 | a.set_title('Vectormap-y') 102 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5) 103 | plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5) 104 | plt.colorbar() 105 | 106 | for i, single_3d in enumerate(pose_3d): 107 | plot_pose(single_3d) 108 | plt.show() 109 | 110 | pass 111 | -------------------------------------------------------------------------------- /src/run_checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import tensorflow as tf 5 | 6 | from networks import get_network 7 | 8 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') 9 | 10 | config = tf.ConfigProto() 11 | config.gpu_options.allocator_type = 'BFC' 12 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 13 | config.gpu_options.allow_growth = True 14 | 15 | 16 | if __name__ == '__main__': 17 | """ 18 | Use this script to just save graph and checkpoint. 19 | While training, checkpoints are saved. You can test them with this python code. 20 | """ 21 | parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor') 22 | parser.add_argument('--model', type=str, default='cmu', help='cmu / mobilenet / mobilenet_thin') 23 | args = parser.parse_args() 24 | 25 | input_node = tf.placeholder(tf.float32, shape=(1, 368, 432, 3), name='image') 26 | 27 | with tf.Session(config=config) as sess: 28 | net, _, last_layer = get_network(args.model, input_node, sess, trainable=False) 29 | 30 | tf.train.write_graph(sess.graph_def, './tmp', 'graph.pb', as_text=True) 31 | 32 | graph = tf.get_default_graph() 33 | dir(graph) 34 | for n in tf.get_default_graph().as_graph_def().node: 35 | if 'concat_stage' not in n.name: 36 | continue 37 | print(n.name) 38 | 39 | saver = tf.train.Saver(max_to_keep=100) 40 | saver.save(sess, '/Users/ildoonet/repos/tf-openpose/tmp/chk', global_step=1) 41 | -------------------------------------------------------------------------------- /src/run_webcam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from estimator import TfPoseEstimator 9 | from networks import get_graph_path, model_wh 10 | 11 | logger = logging.getLogger('TfPoseEstimator-WebCam') 12 | logger.setLevel(logging.DEBUG) 13 | ch = logging.StreamHandler() 14 | ch.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | 19 | fps_time = 0 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='tf-pose-estimation realtime webcam') 24 | parser.add_argument('--camera', type=int, default=0) 25 | parser.add_argument('--zoom', type=float, default=1.0) 26 | parser.add_argument('--model', type=str, default='mobilenet_thin_432x368', help='cmu_640x480 / cmu_640x360 / mobilenet_thin_432x368') 27 | parser.add_argument('--show-process', type=bool, default=False, 28 | help='for debug purpose, if enabled, speed for inference is dropped.') 29 | args = parser.parse_args() 30 | 31 | logger.debug('initialization %s : %s' % (args.model, get_graph_path(args.model))) 32 | w, h = model_wh(args.model) 33 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h)) 34 | logger.debug('cam read+') 35 | cam = cv2.VideoCapture(args.camera) 36 | ret_val, image = cam.read() 37 | logger.info('cam image=%dx%d' % (image.shape[1], image.shape[0])) 38 | 39 | while True: 40 | ret_val, image = cam.read() 41 | 42 | logger.debug('image preprocess+') 43 | if args.zoom < 1.0: 44 | canvas = np.zeros_like(image) 45 | img_scaled = cv2.resize(image, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR) 46 | dx = (canvas.shape[1] - img_scaled.shape[1]) // 2 47 | dy = (canvas.shape[0] - img_scaled.shape[0]) // 2 48 | canvas[dy:dy + img_scaled.shape[0], dx:dx + img_scaled.shape[1]] = img_scaled 49 | image = canvas 50 | elif args.zoom > 1.0: 51 | img_scaled = cv2.resize(image, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR) 52 | dx = (img_scaled.shape[1] - image.shape[1]) // 2 53 | dy = (img_scaled.shape[0] - image.shape[0]) // 2 54 | image = img_scaled[dy:image.shape[0], dx:image.shape[1]] 55 | 56 | logger.debug('image process+') 57 | humans = e.inference(image) 58 | 59 | logger.debug('postprocess+') 60 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False) 61 | 62 | logger.debug('show+') 63 | cv2.putText(image, 64 | "FPS: %f" % (1.0 / (time.time() - fps_time)), 65 | (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 66 | (0, 255, 0), 2) 67 | cv2.imshow('tf-pose-estimation result', image) 68 | fps_time = time.time() 69 | if cv2.waitKey(1) == 27: 70 | break 71 | logger.debug('finished+') 72 | 73 | cv2.destroyAllWindows() 74 | -------------------------------------------------------------------------------- /src/slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/slim/__init__.py -------------------------------------------------------------------------------- /src/slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/slim/nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /src/slim/nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 300, 400 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 224, 224 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 70 | expected_names = ['alexnet_v2/conv1', 71 | 'alexnet_v2/pool1', 72 | 'alexnet_v2/conv2', 73 | 'alexnet_v2/pool2', 74 | 'alexnet_v2/conv3', 75 | 'alexnet_v2/conv4', 76 | 'alexnet_v2/conv5', 77 | 'alexnet_v2/pool5', 78 | 'alexnet_v2/fc6', 79 | 'alexnet_v2/fc7', 80 | 'alexnet_v2/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 224, 224 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes) 91 | expected_names = ['alexnet_v2/conv1', 92 | 'alexnet_v2/pool1', 93 | 'alexnet_v2/conv2', 94 | 'alexnet_v2/pool2', 95 | 'alexnet_v2/conv3', 96 | 'alexnet_v2/conv4', 97 | 'alexnet_v2/conv5', 98 | 'alexnet_v2/pool5', 99 | 'alexnet_v2/fc6', 100 | 'alexnet_v2/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7')) 104 | self.assertListEqual(net.get_shape().as_list(), 105 | [batch_size, 1, 1, 4096]) 106 | 107 | def testModelVariables(self): 108 | batch_size = 5 109 | height, width = 224, 224 110 | num_classes = 1000 111 | with self.test_session(): 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | alexnet.alexnet_v2(inputs, num_classes) 114 | expected_names = ['alexnet_v2/conv1/weights', 115 | 'alexnet_v2/conv1/biases', 116 | 'alexnet_v2/conv2/weights', 117 | 'alexnet_v2/conv2/biases', 118 | 'alexnet_v2/conv3/weights', 119 | 'alexnet_v2/conv3/biases', 120 | 'alexnet_v2/conv4/weights', 121 | 'alexnet_v2/conv4/biases', 122 | 'alexnet_v2/conv5/weights', 123 | 'alexnet_v2/conv5/biases', 124 | 'alexnet_v2/fc6/weights', 125 | 'alexnet_v2/fc6/biases', 126 | 'alexnet_v2/fc7/weights', 127 | 'alexnet_v2/fc7/biases', 128 | 'alexnet_v2/fc8/weights', 129 | 'alexnet_v2/fc8/biases', 130 | ] 131 | model_variables = [v.op.name for v in slim.get_model_variables()] 132 | self.assertSetEqual(set(model_variables), set(expected_names)) 133 | 134 | def testEvaluation(self): 135 | batch_size = 2 136 | height, width = 224, 224 137 | num_classes = 1000 138 | with self.test_session(): 139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 141 | self.assertListEqual(logits.get_shape().as_list(), 142 | [batch_size, num_classes]) 143 | predictions = tf.argmax(logits, 1) 144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 145 | 146 | def testTrainEvalWithReuse(self): 147 | train_batch_size = 2 148 | eval_batch_size = 1 149 | train_height, train_width = 224, 224 150 | eval_height, eval_width = 300, 400 151 | num_classes = 1000 152 | with self.test_session(): 153 | train_inputs = tf.random_uniform( 154 | (train_batch_size, train_height, train_width, 3)) 155 | logits, _ = alexnet.alexnet_v2(train_inputs) 156 | self.assertListEqual(logits.get_shape().as_list(), 157 | [train_batch_size, num_classes]) 158 | tf.get_variable_scope().reuse_variables() 159 | eval_inputs = tf.random_uniform( 160 | (eval_batch_size, eval_height, eval_width, 3)) 161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 162 | spatial_squeeze=False) 163 | self.assertListEqual(logits.get_shape().as_list(), 164 | [eval_batch_size, 4, 7, num_classes]) 165 | logits = tf.reduce_mean(logits, [1, 2]) 166 | predictions = tf.argmax(logits, 1) 167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 168 | 169 | def testForward(self): 170 | batch_size = 1 171 | height, width = 224, 224 172 | with self.test_session() as sess: 173 | inputs = tf.random_uniform((batch_size, height, width, 3)) 174 | logits, _ = alexnet.alexnet_v2(inputs) 175 | sess.run(tf.global_variables_initializer()) 176 | output = sess.run(logits) 177 | self.assertTrue(output.any()) 178 | 179 | if __name__ == '__main__': 180 | tf.test.main() 181 | -------------------------------------------------------------------------------- /src/slim/nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /src/slim/nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /src/slim/nets/dcgan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from math import log 21 | 22 | import tensorflow as tf 23 | slim = tf.contrib.slim 24 | 25 | 26 | def _validate_image_inputs(inputs): 27 | inputs.get_shape().assert_has_rank(4) 28 | inputs.get_shape()[1:3].assert_is_fully_defined() 29 | if inputs.get_shape()[1] != inputs.get_shape()[2]: 30 | raise ValueError('Input tensor does not have equal width and height: ', 31 | inputs.get_shape()[1:3]) 32 | width = inputs.get_shape().as_list()[1] 33 | if log(width, 2) != int(log(width, 2)): 34 | raise ValueError('Input tensor `width` is not a power of 2: ', width) 35 | 36 | 37 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 38 | # setups need the gradient of gradient FusedBatchNormGrad. 39 | def discriminator(inputs, 40 | depth=64, 41 | is_training=True, 42 | reuse=None, 43 | scope='Discriminator', 44 | fused_batch_norm=False): 45 | """Discriminator network for DCGAN. 46 | 47 | Construct discriminator network from inputs to the final endpoint. 48 | 49 | Args: 50 | inputs: A tensor of size [batch_size, height, width, channels]. Must be 51 | floating point. 52 | depth: Number of channels in first convolution layer. 53 | is_training: Whether the network is for training or not. 54 | reuse: Whether or not the network variables should be reused. `scope` 55 | must be given to be reused. 56 | scope: Optional variable_scope. 57 | fused_batch_norm: If `True`, use a faster, fused implementation of 58 | batch norm. 59 | 60 | Returns: 61 | logits: The pre-softmax activations, a tensor of size [batch_size, 1] 62 | end_points: a dictionary from components of the network to their activation. 63 | 64 | Raises: 65 | ValueError: If the input image shape is not 4-dimensional, if the spatial 66 | dimensions aren't defined at graph construction time, if the spatial 67 | dimensions aren't square, or if the spatial dimensions aren't a power of 68 | two. 69 | """ 70 | 71 | normalizer_fn = slim.batch_norm 72 | normalizer_fn_args = { 73 | 'is_training': is_training, 74 | 'zero_debias_moving_mean': True, 75 | 'fused': fused_batch_norm, 76 | } 77 | 78 | _validate_image_inputs(inputs) 79 | inp_shape = inputs.get_shape().as_list()[1] 80 | 81 | end_points = {} 82 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 83 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 84 | with slim.arg_scope([slim.conv2d], 85 | stride=2, 86 | kernel_size=4, 87 | activation_fn=tf.nn.leaky_relu): 88 | net = inputs 89 | for i in xrange(int(log(inp_shape, 2))): 90 | scope = 'conv%i' % (i + 1) 91 | current_depth = depth * 2**i 92 | normalizer_fn_ = None if i == 0 else normalizer_fn 93 | net = slim.conv2d( 94 | net, current_depth, normalizer_fn=normalizer_fn_, scope=scope) 95 | end_points[scope] = net 96 | 97 | logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID', 98 | normalizer_fn=None, activation_fn=None) 99 | logits = tf.reshape(logits, [-1, 1]) 100 | end_points['logits'] = logits 101 | 102 | return logits, end_points 103 | 104 | 105 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN 106 | # setups need the gradient of gradient FusedBatchNormGrad. 107 | def generator(inputs, 108 | depth=64, 109 | final_size=32, 110 | num_outputs=3, 111 | is_training=True, 112 | reuse=None, 113 | scope='Generator', 114 | fused_batch_norm=False): 115 | """Generator network for DCGAN. 116 | 117 | Construct generator network from inputs to the final endpoint. 118 | 119 | Args: 120 | inputs: A tensor with any size N. [batch_size, N] 121 | depth: Number of channels in last deconvolution layer. 122 | final_size: The shape of the final output. 123 | num_outputs: Number of output features. For images, this is the number of 124 | channels. 125 | is_training: whether is training or not. 126 | reuse: Whether or not the network has its variables should be reused. scope 127 | must be given to be reused. 128 | scope: Optional variable_scope. 129 | fused_batch_norm: If `True`, use a faster, fused implementation of 130 | batch norm. 131 | 132 | Returns: 133 | logits: the pre-softmax activations, a tensor of size 134 | [batch_size, 32, 32, channels] 135 | end_points: a dictionary from components of the network to their activation. 136 | 137 | Raises: 138 | ValueError: If `inputs` is not 2-dimensional. 139 | ValueError: If `final_size` isn't a power of 2 or is less than 8. 140 | """ 141 | normalizer_fn = slim.batch_norm 142 | normalizer_fn_args = { 143 | 'is_training': is_training, 144 | 'zero_debias_moving_mean': True, 145 | 'fused': fused_batch_norm, 146 | } 147 | 148 | inputs.get_shape().assert_has_rank(2) 149 | if log(final_size, 2) != int(log(final_size, 2)): 150 | raise ValueError('`final_size` (%i) must be a power of 2.' % final_size) 151 | if final_size < 8: 152 | raise ValueError('`final_size` (%i) must be greater than 8.' % final_size) 153 | 154 | end_points = {} 155 | num_layers = int(log(final_size, 2)) - 1 156 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope: 157 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args): 158 | with slim.arg_scope([slim.conv2d_transpose], 159 | normalizer_fn=normalizer_fn, 160 | stride=2, 161 | kernel_size=4): 162 | net = tf.expand_dims(tf.expand_dims(inputs, 1), 1) 163 | 164 | # First upscaling is different because it takes the input vector. 165 | current_depth = depth * 2 ** (num_layers - 1) 166 | scope = 'deconv1' 167 | net = slim.conv2d_transpose( 168 | net, current_depth, stride=1, padding='VALID', scope=scope) 169 | end_points[scope] = net 170 | 171 | for i in xrange(2, num_layers): 172 | scope = 'deconv%i' % (i) 173 | current_depth = depth * 2 ** (num_layers - i) 174 | net = slim.conv2d_transpose(net, current_depth, scope=scope) 175 | end_points[scope] = net 176 | 177 | # Last layer has different normalizer and activation. 178 | scope = 'deconv%i' % (num_layers) 179 | net = slim.conv2d_transpose( 180 | net, depth, normalizer_fn=None, activation_fn=None, scope=scope) 181 | end_points[scope] = net 182 | 183 | # Convert to proper channels. 184 | scope = 'logits' 185 | logits = slim.conv2d( 186 | net, 187 | num_outputs, 188 | normalizer_fn=None, 189 | activation_fn=None, 190 | kernel_size=1, 191 | stride=1, 192 | padding='VALID', 193 | scope=scope) 194 | end_points[scope] = logits 195 | 196 | logits.get_shape().assert_has_rank(4) 197 | logits.get_shape().assert_is_compatible_with( 198 | [None, final_size, final_size, num_outputs]) 199 | 200 | return logits, end_points 201 | -------------------------------------------------------------------------------- /src/slim/nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import dcgan 23 | 24 | 25 | class DCGANTest(tf.test.TestCase): 26 | 27 | def test_generator_run(self): 28 | tf.set_random_seed(1234) 29 | noise = tf.random_normal([100, 64]) 30 | image, _ = dcgan.generator(noise) 31 | with self.test_session() as sess: 32 | sess.run(tf.global_variables_initializer()) 33 | image.eval() 34 | 35 | def test_generator_graph(self): 36 | tf.set_random_seed(1234) 37 | # Check graph construction for a number of image size/depths and batch 38 | # sizes. 39 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 40 | tf.reset_default_graph() 41 | final_size = 2 ** i 42 | noise = tf.random_normal([batch_size, 64]) 43 | image, end_points = dcgan.generator( 44 | noise, 45 | depth=32, 46 | final_size=final_size) 47 | 48 | self.assertAllEqual([batch_size, final_size, final_size, 3], 49 | image.shape.as_list()) 50 | 51 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 52 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 53 | 54 | # Check layer depths. 55 | for j in range(1, i): 56 | layer = end_points['deconv%i' % j] 57 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 58 | 59 | def test_generator_invalid_input(self): 60 | wrong_dim_input = tf.zeros([5, 32, 32]) 61 | with self.assertRaises(ValueError): 62 | dcgan.generator(wrong_dim_input) 63 | 64 | correct_input = tf.zeros([3, 2]) 65 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 66 | dcgan.generator(correct_input, final_size=30) 67 | 68 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 69 | dcgan.generator(correct_input, final_size=4) 70 | 71 | def test_discriminator_run(self): 72 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 73 | output, _ = dcgan.discriminator(image) 74 | with self.test_session() as sess: 75 | sess.run(tf.global_variables_initializer()) 76 | output.eval() 77 | 78 | def test_discriminator_graph(self): 79 | # Check graph construction for a number of image size/depths and batch 80 | # sizes. 81 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 82 | tf.reset_default_graph() 83 | img_w = 2 ** i 84 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 85 | output, end_points = dcgan.discriminator( 86 | image, 87 | depth=32) 88 | 89 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 90 | 91 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 92 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 93 | 94 | # Check layer depths. 95 | for j in range(1, i+1): 96 | layer = end_points['conv%i' % j] 97 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 98 | 99 | def test_discriminator_invalid_input(self): 100 | wrong_dim_img = tf.zeros([5, 32, 32]) 101 | with self.assertRaises(ValueError): 102 | dcgan.discriminator(wrong_dim_img) 103 | 104 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 105 | with self.assertRaises(ValueError): 106 | dcgan.discriminator(spatially_undefined_shape) 107 | 108 | not_square = tf.zeros([5, 32, 16, 3]) 109 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 110 | dcgan.discriminator(not_square) 111 | 112 | not_power_2 = tf.zeros([5, 30, 30, 3]) 113 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 114 | dcgan.discriminator(not_power_2) 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /src/slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /src/slim/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu): 37 | """Defines the default arg scope for inception models. 38 | 39 | Args: 40 | weight_decay: The weight decay to use for regularizing the model. 41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 42 | batch_norm_decay: Decay for batch norm moving average. 43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 44 | in batch norm. 45 | activation_fn: Activation function for conv2d. 46 | 47 | Returns: 48 | An `arg_scope` to use for the inception models. 49 | """ 50 | batch_norm_params = { 51 | # Decay for the moving averages. 52 | 'decay': batch_norm_decay, 53 | # epsilon to prevent 0s in variance. 54 | 'epsilon': batch_norm_epsilon, 55 | # collection containing update_ops. 56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 57 | # use fused batch norm if possible. 58 | 'fused': None, 59 | } 60 | if use_batch_norm: 61 | normalizer_fn = slim.batch_norm 62 | normalizer_params = batch_norm_params 63 | else: 64 | normalizer_fn = None 65 | normalizer_params = {} 66 | # Set weight_decay for weights in Conv and FC layers. 67 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 68 | weights_regularizer=slim.l2_regularizer(weight_decay)): 69 | with slim.arg_scope( 70 | [slim.conv2d], 71 | weights_initializer=slim.variance_scaling_initializer(), 72 | activation_fn=activation_fn, 73 | normalizer_fn=normalizer_fn, 74 | normalizer_params=normalizer_params) as sc: 75 | return sc 76 | -------------------------------------------------------------------------------- /src/slim/nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /src/slim/nets/mobilenet_v1.md: -------------------------------------------------------------------------------- 1 | # MobileNet_v1 2 | 3 | [MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 4 | 5 | MobileNets trade off between latency, size and accuracy while comparing favorably with popular models from the literature. 6 | 7 | ![alt text](mobilenet_v1.png "MobileNet Graph") 8 | 9 | # Pre-trained Models 10 | 11 | Choose the right MobileNet model to fit your latency and size budget. The size of the network in memory and on disk is proportional to the number of parameters. The latency and power usage of the network scales with the number of Multiply-Accumulates (MACs) which measures the number of fused Multiplication and Addition operations. These MobileNet models have been trained on the 12 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 13 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 14 | 15 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 16 | :----:|:------------:|:----------:|:-------:|:-------:| 17 | [MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|569|4.24|70.7|89.5| 18 | [MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_1.0_192_2017_06_14.tar.gz)|418|4.24|69.3|88.9| 19 | [MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_1.0_160_2017_06_14.tar.gz)|291|4.24|67.2|87.5| 20 | [MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_1.0_128_2017_06_14.tar.gz)|186|4.24|64.1|85.3| 21 | [MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_0.75_224_2017_06_14.tar.gz)|317|2.59|68.4|88.2| 22 | [MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_0.75_192_2017_06_14.tar.gz)|233|2.59|67.4|87.3| 23 | [MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_0.75_160_2017_06_14.tar.gz)|162|2.59|65.2|86.1| 24 | [MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_0.75_128_2017_06_14.tar.gz)|104|2.59|61.8|83.6| 25 | [MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_0.50_224_2017_06_14.tar.gz)|150|1.34|64.0|85.4| 26 | [MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_0.50_192_2017_06_14.tar.gz)|110|1.34|62.1|84.0| 27 | [MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|77|1.34|59.9|82.5| 28 | [MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_0.50_128_2017_06_14.tar.gz)|49|1.34|56.2|79.6| 29 | [MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_0.25_224_2017_06_14.tar.gz)|41|0.47|50.6|75.0| 30 | [MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_0.25_192_2017_06_14.tar.gz)|34|0.47|49.0|73.6| 31 | [MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_0.25_160_2017_06_14.tar.gz)|21|0.47|46.0|70.7| 32 | [MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|14|0.47|41.3|66.2| 33 | 34 | 35 | Here is an example of how to download the MobileNet_v1_1.0_224 checkpoint: 36 | 37 | ```shell 38 | $ CHECKPOINT_DIR=/tmp/checkpoints 39 | $ mkdir ${CHECKPOINT_DIR} 40 | $ wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 41 | $ tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz 42 | $ mv mobilenet_v1_1.0_224.ckpt.* ${CHECKPOINT_DIR} 43 | $ rm mobilenet_v1_1.0_224_2017_06_14.tar.gz 44 | ``` 45 | More information on integrating MobileNets into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 46 | 47 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 48 | -------------------------------------------------------------------------------- /src/slim/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/slim/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /src/slim/nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 \ 48 | --moving_average_decay=0.9999 49 | ``` 50 | 51 | Run eval with the NASNet-A large ImageNet model 52 | 53 | ```shell 54 | DATASET_DIR=/tmp/imagenet 55 | EVAL_DIR=/tmp/tfmodel/eval 56 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 57 | python tensorflow_models/research/slim/eval_image_classifier \ 58 | --checkpoint_path=${CHECKPOINT_DIR} \ 59 | --eval_dir=${EVAL_DIR} \ 60 | --dataset_dir=${DATASET_DIR} \ 61 | --dataset_name=imagenet \ 62 | --dataset_split_name=validation \ 63 | --model_name=nasnet_large \ 64 | --eval_image_size=331 \ 65 | --moving_average_decay=0.9999 66 | ``` 67 | -------------------------------------------------------------------------------- /src/slim/nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/slim/nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /src/slim/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.nasnet import nasnet 34 | 35 | slim = tf.contrib.slim 36 | 37 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 38 | 'cifarnet': cifarnet.cifarnet, 39 | 'overfeat': overfeat.overfeat, 40 | 'vgg_a': vgg.vgg_a, 41 | 'vgg_16': vgg.vgg_16, 42 | 'vgg_19': vgg.vgg_19, 43 | 'inception_v1': inception.inception_v1, 44 | 'inception_v2': inception.inception_v2, 45 | 'inception_v3': inception.inception_v3, 46 | 'inception_v4': inception.inception_v4, 47 | 'inception_resnet_v2': inception.inception_resnet_v2, 48 | 'lenet': lenet.lenet, 49 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 50 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 51 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 52 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 53 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 54 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 55 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 56 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 57 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 58 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 59 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 60 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 61 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 62 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 63 | 'nasnet_large': nasnet.build_nasnet_large, 64 | } 65 | 66 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 67 | 'cifarnet': cifarnet.cifarnet_arg_scope, 68 | 'overfeat': overfeat.overfeat_arg_scope, 69 | 'vgg_a': vgg.vgg_arg_scope, 70 | 'vgg_16': vgg.vgg_arg_scope, 71 | 'vgg_19': vgg.vgg_arg_scope, 72 | 'inception_v1': inception.inception_v3_arg_scope, 73 | 'inception_v2': inception.inception_v3_arg_scope, 74 | 'inception_v3': inception.inception_v3_arg_scope, 75 | 'inception_v4': inception.inception_v4_arg_scope, 76 | 'inception_resnet_v2': 77 | inception.inception_resnet_v2_arg_scope, 78 | 'lenet': lenet.lenet_arg_scope, 79 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 80 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 81 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 82 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 83 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 84 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 85 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 86 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 87 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 88 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 89 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 90 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 91 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 92 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 93 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 94 | } 95 | 96 | 97 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 98 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 99 | 100 | Args: 101 | name: The name of the network. 102 | num_classes: The number of classes to use for classification. If 0 or None, 103 | the logits layer is omitted and its input features are returned instead. 104 | weight_decay: The l2 coefficient for the model weights. 105 | is_training: `True` if the model is being used for training and `False` 106 | otherwise. 107 | 108 | Returns: 109 | network_fn: A function that applies the model to a batch of images. It has 110 | the following signature: 111 | net, end_points = network_fn(images) 112 | The `images` input is a tensor of shape [batch_size, height, width, 3] 113 | with height = width = network_fn.default_image_size. (The permissibility 114 | and treatment of other sizes depends on the network_fn.) 115 | The returned `end_points` are a dictionary of intermediate activations. 116 | The returned `net` is the topmost layer, depending on `num_classes`: 117 | If `num_classes` was a non-zero integer, `net` is a logits tensor 118 | of shape [batch_size, num_classes]. 119 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 120 | to the logits layer of shape [batch_size, 1, 1, num_features] or 121 | [batch_size, num_features]. Dropout has not been applied to this 122 | (even if the network's original classification does); it remains for 123 | the caller to do this or not. 124 | 125 | Raises: 126 | ValueError: If network `name` is not recognized. 127 | """ 128 | if name not in networks_map: 129 | raise ValueError('Name of network unknown %s' % name) 130 | func = networks_map[name] 131 | @functools.wraps(func) 132 | def network_fn(images, **kwargs): 133 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 134 | with slim.arg_scope(arg_scope): 135 | return func(images, num_classes, is_training=is_training, **kwargs) 136 | if hasattr(func, 'default_image_size'): 137 | network_fn.default_image_size = func.default_image_size 138 | 139 | return network_fn 140 | -------------------------------------------------------------------------------- /src/slim/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map.keys()[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in nets_factory.networks_map.keys()[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /src/slim/nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /src/slim/nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /src/slim/nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def test_nonsquare_inputs_raise_exception(self): 28 | batch_size = 2 29 | height, width = 240, 320 30 | num_outputs = 4 31 | 32 | images = tf.ones((batch_size, height, width, 3)) 33 | 34 | with self.assertRaises(ValueError): 35 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 36 | pix2pix.pix2pix_generator( 37 | images, num_outputs, upsample_method='nn_upsample_conv') 38 | 39 | def _reduced_default_blocks(self): 40 | """Returns the default blocks, scaled down to make test run faster.""" 41 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 42 | for b in pix2pix._default_generator_blocks()] 43 | 44 | def test_output_size_nn_upsample_conv(self): 45 | batch_size = 2 46 | height, width = 256, 256 47 | num_outputs = 4 48 | 49 | images = tf.ones((batch_size, height, width, 3)) 50 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 51 | logits, _ = pix2pix.pix2pix_generator( 52 | images, num_outputs, blocks=self._reduced_default_blocks(), 53 | upsample_method='nn_upsample_conv') 54 | 55 | with self.test_session() as session: 56 | session.run(tf.global_variables_initializer()) 57 | np_outputs = session.run(logits) 58 | self.assertListEqual([batch_size, height, width, num_outputs], 59 | list(np_outputs.shape)) 60 | 61 | def test_output_size_conv2d_transpose(self): 62 | batch_size = 2 63 | height, width = 256, 256 64 | num_outputs = 4 65 | 66 | images = tf.ones((batch_size, height, width, 3)) 67 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 68 | logits, _ = pix2pix.pix2pix_generator( 69 | images, num_outputs, blocks=self._reduced_default_blocks(), 70 | upsample_method='conv2d_transpose') 71 | 72 | with self.test_session() as session: 73 | session.run(tf.global_variables_initializer()) 74 | np_outputs = session.run(logits) 75 | self.assertListEqual([batch_size, height, width, num_outputs], 76 | list(np_outputs.shape)) 77 | 78 | def test_block_number_dictates_number_of_layers(self): 79 | batch_size = 2 80 | height, width = 256, 256 81 | num_outputs = 4 82 | 83 | images = tf.ones((batch_size, height, width, 3)) 84 | blocks = [ 85 | pix2pix.Block(64, 0.5), 86 | pix2pix.Block(128, 0), 87 | ] 88 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 89 | _, end_points = pix2pix.pix2pix_generator( 90 | images, num_outputs, blocks) 91 | 92 | num_encoder_layers = 0 93 | num_decoder_layers = 0 94 | for end_point in end_points: 95 | if end_point.startswith('encoder'): 96 | num_encoder_layers += 1 97 | elif end_point.startswith('decoder'): 98 | num_decoder_layers += 1 99 | 100 | self.assertEqual(num_encoder_layers, len(blocks)) 101 | self.assertEqual(num_decoder_layers, len(blocks)) 102 | 103 | 104 | class DiscriminatorTest(tf.test.TestCase): 105 | 106 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 107 | return (input_size + pad * 2 - kernel_size) // stride + 1 108 | 109 | def test_four_layers(self): 110 | batch_size = 2 111 | input_size = 256 112 | 113 | output_size = self._layer_output_size(input_size) 114 | output_size = self._layer_output_size(output_size) 115 | output_size = self._layer_output_size(output_size) 116 | output_size = self._layer_output_size(output_size, stride=1) 117 | output_size = self._layer_output_size(output_size, stride=1) 118 | 119 | images = tf.ones((batch_size, input_size, input_size, 3)) 120 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 121 | logits, end_points = pix2pix.pix2pix_discriminator( 122 | images, num_filters=[64, 128, 256, 512]) 123 | self.assertListEqual([batch_size, output_size, output_size, 1], 124 | logits.shape.as_list()) 125 | self.assertListEqual([batch_size, output_size, output_size, 1], 126 | end_points['predictions'].shape.as_list()) 127 | 128 | def test_four_layers_no_padding(self): 129 | batch_size = 2 130 | input_size = 256 131 | 132 | output_size = self._layer_output_size(input_size, pad=0) 133 | output_size = self._layer_output_size(output_size, pad=0) 134 | output_size = self._layer_output_size(output_size, pad=0) 135 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 136 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 137 | 138 | images = tf.ones((batch_size, input_size, input_size, 3)) 139 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 140 | logits, end_points = pix2pix.pix2pix_discriminator( 141 | images, num_filters=[64, 128, 256, 512], padding=0) 142 | self.assertListEqual([batch_size, output_size, output_size, 1], 143 | logits.shape.as_list()) 144 | self.assertListEqual([batch_size, output_size, output_size, 1], 145 | end_points['predictions'].shape.as_list()) 146 | 147 | def test_four_layers_wrog_paddig(self): 148 | batch_size = 2 149 | input_size = 256 150 | 151 | images = tf.ones((batch_size, input_size, input_size, 3)) 152 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 153 | with self.assertRaises(TypeError): 154 | pix2pix.pix2pix_discriminator( 155 | images, num_filters=[64, 128, 256, 512], padding=1.5) 156 | 157 | def test_four_layers_negative_padding(self): 158 | batch_size = 2 159 | input_size = 256 160 | 161 | images = tf.ones((batch_size, input_size, input_size, 3)) 162 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 163 | with self.assertRaises(ValueError): 164 | pix2pix.pix2pix_discriminator( 165 | images, num_filters=[64, 128, 256, 512], padding=-1) 166 | 167 | if __name__ == '__main__': 168 | tf.test.main() 169 | -------------------------------------------------------------------------------- /src/slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/slim/preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /src/slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /src/slim/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'nasnet_mobile': inception_preprocessing, 58 | 'nasnet_large': inception_preprocessing, 59 | 'resnet_v1_50': vgg_preprocessing, 60 | 'resnet_v1_101': vgg_preprocessing, 61 | 'resnet_v1_152': vgg_preprocessing, 62 | 'resnet_v1_200': vgg_preprocessing, 63 | 'resnet_v2_50': vgg_preprocessing, 64 | 'resnet_v2_101': vgg_preprocessing, 65 | 'resnet_v2_152': vgg_preprocessing, 66 | 'resnet_v2_200': vgg_preprocessing, 67 | 'vgg': vgg_preprocessing, 68 | 'vgg_a': vgg_preprocessing, 69 | 'vgg_16': vgg_preprocessing, 70 | 'vgg_19': vgg_preprocessing, 71 | } 72 | 73 | if name not in preprocessing_fn_map: 74 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 75 | 76 | def preprocessing_fn(image, output_height, output_width, **kwargs): 77 | return preprocessing_fn_map[name].preprocess_image( 78 | image, output_height, output_width, is_training=is_training, **kwargs) 79 | 80 | return preprocessing_fn 81 | --------------------------------------------------------------------------------