├── .dockerignore
├── .gitattributes
├── .gitignore
├── CMakeLists.txt
├── Dockerfile
├── LICENSE
├── README.md
├── docker
├── Dockerfile
└── README.md
├── etcs
├── dance.mp4
├── feature.md
├── inference_result2.png
├── loss_ll_heat.png
├── loss_ll_paf.png
├── openpose_macbook13_mobilenet2.gif
├── openpose_macbook_cmu.gif
├── openpose_macbook_mobilenet3.gif
├── openpose_p40_cmu.gif
├── openpose_p40_mobilenet.gif
├── openpose_tx2_mobilenet3.gif
├── ros.md
└── training.md
├── images
├── apink1.jpg
├── apink1_crop.jpg
├── apink1_crop_s1.jpg
├── apink2.jpg
├── apink3.jpg
├── cat.jpg
├── golf.jpg
├── hand1.jpg
├── hand1_small.jpg
├── hand2.jpg
├── handsup1.jpg
├── p1.jpg
├── p2.jpg
├── p3.jpg
├── p3_dance.png
├── ski.jpg
└── valid_person1.jpg
├── launch
└── demo_video.launch
├── models
├── graph
│ ├── cmu_640x360
│ │ └── download.sh
│ ├── cmu_640x480
│ │ └── download.sh
│ └── mobilenet_thin_432x368
│ │ ├── graph.pb
│ │ ├── graph_freeze.pb
│ │ └── graph_opt.pb
├── numpy
│ └── download.sh
└── pretrained
│ ├── mobilenet_v1_0.50_224_2017_06_14
│ └── download.sh
│ ├── mobilenet_v1_0.75_224_2017_06_14
│ └── download.sh
│ └── mobilenet_v1_1.0_224_2017_06_14
│ └── download.sh
├── msg
├── BodyPartElm.msg
├── Person.msg
└── Persons.msg
├── package.xml
├── requirements.txt
├── scripts
├── broadcaster_ros.py
└── visualization.py
└── src
├── __init__.py
├── common.py
├── datum_pb2.py
├── estimator.py
├── lifting
├── __init__.py
├── config.py
├── draw.py
├── models
│ └── prob_model_params.mat
├── prob_model.py
└── upright_fast.py
├── network_base.py
├── network_cmu.py
├── network_dsconv.py
├── network_mobilenet.py
├── network_mobilenet_thin.py
├── networks.py
├── pose_augment.py
├── pose_datamaster.py
├── pose_dataset.py
├── pose_dataworker.py
├── pose_stats.py
├── run.py
├── run_checkpoint.py
├── run_webcam.py
├── slim
├── __init__.py
├── nets
│ ├── __init__.py
│ ├── alexnet.py
│ ├── alexnet_test.py
│ ├── cifarnet.py
│ ├── cyclegan.py
│ ├── cyclegan_test.py
│ ├── dcgan.py
│ ├── dcgan_test.py
│ ├── inception.py
│ ├── inception_resnet_v2.py
│ ├── inception_resnet_v2_test.py
│ ├── inception_utils.py
│ ├── inception_v1.py
│ ├── inception_v1_test.py
│ ├── inception_v2.py
│ ├── inception_v2_test.py
│ ├── inception_v3.py
│ ├── inception_v3_test.py
│ ├── inception_v4.py
│ ├── inception_v4_test.py
│ ├── lenet.py
│ ├── mobilenet_v1.md
│ ├── mobilenet_v1.png
│ ├── mobilenet_v1.py
│ ├── mobilenet_v1_test.py
│ ├── nasnet
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── nasnet.py
│ │ ├── nasnet_test.py
│ │ ├── nasnet_utils.py
│ │ └── nasnet_utils_test.py
│ ├── nets_factory.py
│ ├── nets_factory_test.py
│ ├── overfeat.py
│ ├── overfeat_test.py
│ ├── pix2pix.py
│ ├── pix2pix_test.py
│ ├── resnet_utils.py
│ ├── resnet_v1.py
│ ├── resnet_v1_test.py
│ ├── resnet_v2.py
│ ├── resnet_v2_test.py
│ ├── vgg.py
│ └── vgg_test.py
└── preprocessing
│ ├── __init__.py
│ ├── cifarnet_preprocessing.py
│ ├── inception_preprocessing.py
│ ├── lenet_preprocessing.py
│ ├── preprocessing_factory.py
│ └── vgg_preprocessing.py
└── train.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | ./models
2 | ./models/*
3 | ./models/*/*
4 | ./models/*/*/*
5 | ./models/*/*/*/*
6 | models
7 | models/*
8 | models/*/*
9 | models/*/*/*
10 | models/*/*/*/*
11 | *.meta
12 | *.index
13 | *.data-*
14 | *.ckpt.*
15 |
16 | ./tests
17 | ./tests/*
18 | tests
19 | graph*.pb
20 | chk*.meta
21 | lifting
22 | slim
23 |
24 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | models/numpy/*.npy filter=lfs diff=lfs merge=lfs -text
2 | *.ckpt* filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # models
104 | *ckpt*
105 | *.npy
106 | timeline*.json
107 |
108 | tmp
109 | models/graph/cmu_*/*.pb
110 | models/trained/*/checkpoint
111 | models/trained/*/*.pb
112 | models/trained/*/model-*.data-*
113 | models/trained/*/model-*.index
114 | models/trained/*/model-*.meta
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(tfpose_ros)
3 |
4 | ## Add support for C++11, supported in ROS Kinetic and newer
5 | add_definitions(-std=c++11)
6 |
7 | find_package(catkin REQUIRED COMPONENTS
8 | roscpp
9 | rospy
10 | std_msgs
11 | message_generation
12 | )
13 |
14 | # Generate messages in the 'msg' folder
15 | add_message_files(
16 | FILES
17 | BodyPartElm.msg
18 | Person.msg
19 | Persons.msg
20 | )
21 |
22 | generate_messages(
23 | DEPENDENCIES std_msgs
24 | )
25 |
26 | catkin_package(
27 | CATKIN_DEPENDS rospy message_generation message_runtime
28 | )
29 |
30 | install()
31 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess
2 |
3 | COPY ./*.py /root/tf-openpose/
4 | WORKDIR /root/tf-openpose/
5 |
6 | RUN cd /root/tf-openpose/ && pip3 install -r requirements.txt
7 |
8 | ENTRYPOINT ["python3", "pose_dataworker.py"]
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tf-pose-estimation
2 |
3 | 'Openpose' for human pose estimation have been implemented using Tensorflow. It also provides several variants that have made some changes to the network structure for **real-time processing on the CPU or low-power embedded devices.**
4 |
5 |
6 | **You can even run this on your macbook with descent FPS!**
7 |
8 | Original Repo(Caffe) : https://github.com/CMU-Perceptual-Computing-Lab/openpose
9 |
10 | | CMU's Original Model on Macbook Pro 15" | Mobilenet Variant on Macbook Pro 15" | Mobilenet Varianton Jetson TX2 |
11 | |:---------|:--------------------|:----------------|
12 | |  |  |  |
13 | | **~0.6 FPS** | **~4.2 FPS** @ 368x368 | **~10 FPS** @ 368x368 |
14 | | 2.8GHz Quad-core i7 | 2.8GHz Quad-core i7 | Jetson TX2 Embedded Board |
15 |
16 | Implemented features are listed here : [features](./etcs/feature.md)
17 |
18 | ## Install
19 |
20 | ### Dependencies
21 |
22 | You need dependencies below.
23 |
24 | - python3
25 | - tensorflow 1.4.1+
26 | - opencv3, protobuf, python3-tk
27 |
28 | ### Install
29 |
30 | ```bash
31 | $ git clone https://www.github.com/ildoonet/tf-openpose
32 | $ cd tf-openpose
33 | $ pip3 install -r requirements.txt
34 | ```
35 |
36 | ## Models
37 |
38 | I have tried multiple variations of models to find optmized network architecture. Some of them are below and checkpoint files are provided for research purpose.
39 |
40 | - cmu
41 | - the model based VGG pretrained network which described in the original paper.
42 | - I converted Weights in Caffe format to use in tensorflow.
43 | - [pretrained weight download](https://www.dropbox.com/s/xh5s7sb7remu8tx/openpose_coco.npy?dl=0)
44 |
45 | - dsconv
46 | - Same architecture as the cmu version except for the **depthwise separable convolution** of mobilenet.
47 | - I trained it using 'transfer learning', but it provides not-enough speed and accuracy.
48 |
49 | - mobilenet
50 | - Based on the mobilenet paper, 12 convolutional layers are used as feature-extraction layers.
51 | - To improve on small person, **minor modification** on the architecture have been made.
52 | - Three models were learned according to network size parameters.
53 | - mobilenet
54 | - 368x368 : [checkpoint weight download](https://www.dropbox.com/s/09xivpuboecge56/mobilenet_0.75_0.50_model-388003.zip?dl=0)
55 | - mobilenet_fast
56 | - mobilenet_accurate
57 | - I published models which is not the best ones, but you can test them before you trained a model from the scratch.
58 |
59 | ### Download Tensorflow Graph File(pb file)
60 |
61 | Before running demo, you should download graph files. You can deploy this graph on your mobile or other platforms.
62 |
63 | - cmu_640x360
64 | - cmu_640x480
65 | - mobilenet_thin_432x368
66 |
67 | CMU's model graphs are too large for git, so I uploaded them on dropbox. You should download them if you want to use cmu's original model.
68 |
69 | ```
70 | $ cd models/graph/cmu_640x360
71 | $ bash download.sh
72 | $ cd models/graph/cmu_640x480
73 | $ bash download.sh
74 | ```
75 |
76 | ### Inference Time
77 |
78 | | Dataset | Model | Inference Time
Macbook Pro i5 3.1G | Inference Time
Jetson TX2 |
79 | |---------|--------------------|----------------:|----------------:|
80 | | Coco | cmu | 10.0s @ 368x368 | OOM @ 368x368
5.5s @ 320x240|
81 | | Coco | dsconv | 1.10s @ 368x368 |
82 | | Coco | mobilenet_accurate | 0.40s @ 368x368 | 0.18s @ 368x368 |
83 | | Coco | mobilenet | 0.24s @ 368x368 | 0.10s @ 368x368 |
84 | | Coco | mobilenet_fast | 0.16s @ 368x368 | 0.07s @ 368x368 |
85 |
86 | ## Demo
87 |
88 | ### Test Inference
89 |
90 | You can test the inference feature with a single image.
91 |
92 | ```
93 | $ python3 run.py --model=mobilenet_thin_432x368 --image=...
94 | ```
95 |
96 | The image flag MUST be relative to the src folder with no "~", i.e:
97 | ```
98 | --image ../../Desktop
99 | ```
100 |
101 | Then you will see the screen as below with pafmap, heatmap, result and etc.
102 |
103 | 
104 |
105 | ### Realtime Webcam
106 |
107 | ```
108 | $ python3 run_webcam.py --model=mobilenet_thin_432x368 --camera=0
109 | ```
110 |
111 | Then you will see the realtime webcam screen with estimated poses as below. This [Realtime Result](./etcs/openpose_macbook13_mobilenet2.gif) was recored on macbook pro 13" with 3.1Ghz Dual-Core CPU.
112 |
113 | ## Python Usage
114 |
115 | This pose estimator provides simple python classes that you can use in your applications.
116 |
117 | See [run.py](run.py) or [run_webcam.py](run_webcam.py) as references.
118 |
119 | ```python
120 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h))
121 | humans = e.inference(image)
122 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False)
123 | ```
124 |
125 | ## ROS Support
126 |
127 | See : [etcs/ros.md](./etcs/ros.md)
128 |
129 | ## Training
130 |
131 | See : [etcs/training.md](./etcs/training.md)
132 |
133 | ## References
134 |
135 | ### OpenPose
136 |
137 | [1] https://github.com/CMU-Perceptual-Computing-Lab/openpose
138 |
139 | [2] Training Codes : https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation
140 |
141 | [3] Custom Caffe by Openpose : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train
142 |
143 | [4] Keras Openpose : https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation
144 |
145 | ### Lifting from the deep
146 |
147 | [1] Arxiv Paper : https://arxiv.org/abs/1701.00295
148 |
149 | [2] https://github.com/DenisTome/Lifting-from-the-Deep-release
150 |
151 | ### Mobilenet
152 |
153 | [1] Original Paper : https://arxiv.org/abs/1704.04861
154 |
155 | [2] Pretrained model : https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md
156 |
157 | ### Libraries
158 |
159 | [1] Tensorpack : https://github.com/ppwwyyxx/tensorpack
160 |
161 | ### Tensorflow Tips
162 |
163 | [1] Freeze graph : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py
164 |
165 | [2] Optimize graph : https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2
166 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 |
3 | #ENV http_proxy=http://10.41.249.28:8080 https_proxy=http://10.41.249.28:8080
4 |
5 | RUN apt-get update -yq && apt-get install -yq build-essential cmake git pkg-config wget zip && \
6 | apt-get install -yq libjpeg8-dev libtiff5-dev libjasper-dev libpng12-dev && \
7 | apt-get install -yq libavcodec-dev libavformat-dev libswscale-dev libv4l-dev && \
8 | apt-get install -yq libgtk2.0-dev && \
9 | apt-get install -yq libatlas-base-dev gfortran && \
10 | apt-get install -yq python3 python3-dev python3-pip python3-setuptools python3-tk git && \
11 | apt-get remove -yq python-pip python3-pip && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && \
12 | pip3 install numpy && \
13 | cd ~ && git clone https://github.com/Itseez/opencv.git && \
14 | cd opencv && mkdir build && cd build && \
15 | cmake -D CMAKE_BUILD_TYPE=RELEASE \
16 | -D CMAKE_INSTALL_PREFIX=/usr/local \
17 | -D INSTALL_PYTHON_EXAMPLES=ON \
18 | -D BUILD_opencv_python3=yes -D PYTHON_EXECUTABLE=/usr/bin/python3 .. && \
19 | make -j8 && make install && rm -rf /root/opencv/ && \
20 | mkdir -p /root/tf-openpose && \
21 | rm -rf /tmp/*.tar.gz && \
22 | apt-get clean && rm -rf /tmp/* /var/tmp* /var/lib/apt/lists/* && \
23 | rm -f /etc/ssh/ssh_host_* && rm -rf /usr/share/man?? /usr/share/man/??_*
24 |
25 | COPY . /root/tf-openpose/
26 | WORKDIR /root/tf-openpose/
27 |
28 | RUN cd /root/tf-openpose/ && pip3 install -U setuptools && \
29 | pip3 install tensorflow && pip3 install -r requirements.txt
30 |
31 | RUN cd /root && git clone https://github.com/cocodataset/cocoapi && \
32 | pip3 install cython && \
33 | cd cocoapi/PythonAPI && python3 setup.py build_ext --inplace && python3 setup.py build_ext install && \
34 | mkdir /coco && cd /coco && wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip && \
35 | unzip annotations_trainval2017.zip && rm -rf annotations_trainval2017.zip
36 |
37 | ENTRYPOINT ["python3", "pose_dataworker.py"]
38 |
39 | #ENV http_proxy= https_proxy=
40 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | ```
2 | $ docker build --tag idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess -f ./docker/full/ .
3 | ```
4 |
5 | ```
6 | $ docker build --tag idock.daumkakao.io/kakaobrain/deepcloud-sshd:openpose-preprocess -f ./docker/update/Dockerfile .
7 | ```
--------------------------------------------------------------------------------
/etcs/dance.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/dance.mp4
--------------------------------------------------------------------------------
/etcs/feature.md:
--------------------------------------------------------------------------------
1 | ## Features
2 |
3 | - [x] CMU's original network architecture and weights.
4 |
5 | - [x] Transfer Original Weights to Tensorflow
6 |
7 | - [x] Training Code with multi-gpus
8 |
9 | - [x] Evaluate with test dataset
10 |
11 | - [ ] Inference
12 |
13 | - [x] Post processing from network output.
14 |
15 | - [x] Faster post-processing
16 |
17 | - [ ] Multi-Scale Inference
18 |
19 | - [x] Faster network variants using custom mobilenet architecture.
20 |
21 | - [x] Depthwise Separable Convolution Version
22 |
23 | - [x] Mobilenet Version
24 |
25 | - [ ] Demos
26 |
27 | - [x] Realtime Webcam Demo
28 |
29 | - [x] Image File Demo
30 |
31 | - [ ] Video File Demo
32 |
33 | - [x] ROS Support. See [./etcs/ros.md](ros readme).
--------------------------------------------------------------------------------
/etcs/inference_result2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/inference_result2.png
--------------------------------------------------------------------------------
/etcs/loss_ll_heat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/loss_ll_heat.png
--------------------------------------------------------------------------------
/etcs/loss_ll_paf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/loss_ll_paf.png
--------------------------------------------------------------------------------
/etcs/openpose_macbook13_mobilenet2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_macbook13_mobilenet2.gif
--------------------------------------------------------------------------------
/etcs/openpose_macbook_cmu.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_macbook_cmu.gif
--------------------------------------------------------------------------------
/etcs/openpose_macbook_mobilenet3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_macbook_mobilenet3.gif
--------------------------------------------------------------------------------
/etcs/openpose_p40_cmu.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_p40_cmu.gif
--------------------------------------------------------------------------------
/etcs/openpose_p40_mobilenet.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_p40_mobilenet.gif
--------------------------------------------------------------------------------
/etcs/openpose_tx2_mobilenet3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/etcs/openpose_tx2_mobilenet3.gif
--------------------------------------------------------------------------------
/etcs/ros.md:
--------------------------------------------------------------------------------
1 | # tf-pose-estimation for ROS
2 |
3 | Human pose estimation is expected to use on mobile robots which need human interactions.
4 |
5 | ## Installation
6 |
7 | Cloning this repository under src folder in your ros workstation. And the same should be carried out as [README.md](README.md).
8 |
9 | ```
10 | $ cd $(ros-workspace)
11 | $ cd src
12 | $ git clone https://github.com/ildoonet/tf-pose-estimation
13 | $ pip install -r tf-pose-estimation/requirements.txt
14 | ```
15 |
16 | There are dependencies to launch demo,
17 |
18 | - video_stream_opencv
19 | - image_view
20 | - ros_video_recorder : https://github.com/ildoonet/ros-video-recorder
21 |
22 | ## Video/camera demo
23 |
24 | | CMU
640x360 | Mobilenet_Thin
432x368 |
25 | |:----------------|:---------------------------|
26 | |  |  |
27 |
28 | Above tests were run on a P40 gpu. Latency between current video frames and processed frames is much lower on mobilenet version.
29 |
30 | Source : https://www.youtube.com/watch?v=rSZnyBuc6tc
31 |
32 | ```
33 | $ roslaunch tfpose_ros demo_video.launch
34 | ```
35 |
36 | You can specify 'video' arguments to launch realtime video demo using your camera. See [./launch/demo_video.launch](ros launch file).
--------------------------------------------------------------------------------
/etcs/training.md:
--------------------------------------------------------------------------------
1 | ## Training
2 |
3 | ### Coco Dataset
4 |
5 | You should download COCO Dataset from http://cocodataset.org/#download
6 |
7 | Also, you need to install cocoapi for easy parsing : https://github.com/cocodataset/cocoapi
8 |
9 | '''
10 | $ git clone https://github.com/cocodataset/cocoapi
11 | $ cd cocoapi/PythonAPI
12 | $ python3 setup.py build_ext --inplace
13 | $ python3 setup.py build_ext install
14 | '''
15 |
16 | ### Augmentation
17 |
18 | CMU Perceptual Computing Lab has modified Caffe to provide data augmentation. See : https://github.com/CMU-Perceptual-Computing-Lab/caffe_train
19 |
20 | I implemented the augmentation codes as the way of the original version, See [pose_dataset.py](pose_dataset.py) and [pose_augment.py](pose_augment.py). This includes scaling, rotation, flip, cropping.
21 |
22 | This process can be a bottleneck for training, so if you have enough computing resources, please see [Run for Faster Training]() Section
23 |
24 | ### Run
25 |
26 | ```
27 | $ python3 train.py --model=cmu --datapath={datapath} --batchsize=64 --lr=0.001 --modelpath={path-to-save}
28 |
29 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights...
30 | ```
31 |
32 | If you want to reproduce the original paper's result, the following setting is recommended.
33 |
34 | - model : vgg
35 | - lr : 0.0001 or 0.00004
36 | - input-width = input-height = 368x368 or 432x368
37 | - batchsize : 10 (I trained with batchsizes up to 128, they are trained well)
38 |
39 | | Heatmap Loss | PAFmap(Part Affinity Field) Loss |
40 | |-------------------------------------------|------------------------------------------|
41 | |  |  |
42 |
43 | As you can see from the table above, training loss was converged at the almost same trends with the original paper.
44 |
45 | The mobilenet versions has slightly poor loss value compared to the original one. Training losses are 3~8% larger, though validation losses are 5~14% larger.
46 |
47 |
48 | ### Run for Faster Training
49 |
50 | If you have enough computing resources in multiple nodes, you can launch multiple workers on nodes to help data preparation.
51 |
52 | ```
53 | worker-node1$ python3 pose_dataworker.py --master=tcp://host:port
54 | worker-node2$ python3 pose_dataworker.py --master=tcp://host:port
55 | worker-node3$ python3 pose_dataworker.py --master=tcp://host:port
56 | ...
57 | ```
58 |
59 | After above preparation, you can launch training script with 'remote-data' arguments.
60 |
61 | ```
62 | $ python3 train.py --remote-data=tcp://0.0.0.0:port
63 |
64 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights...
65 | ```
66 |
67 | Also, You can quickly train with multiple gpus. This automatically splits batch into multiple gpus for forward/backward computations.
68 |
69 | ```
70 | $ python3 train.py --remote-data=tcp://0.0.0.0:port --gpus=8
71 |
72 | 2017-09-27 15:58:50,307 INFO Restore pretrained weights...
73 | ```
74 |
75 | I trained models within a day with 8 gpus and multiple pre-processing nodes with 48 core cpus.
76 |
77 | ### Model Optimization for Inference
78 |
79 | After trained a model, I optimized models by folding batch normalization to convolutional layers and removing redundant operations.
80 |
81 | Firstly, the model should be frozen.
82 |
83 | ```bash
84 | $ python3 -m tensorflow.python.tools.freeze_graph \
85 | --input_graph=... \
86 | --output_graph=... \
87 | --input_checkpoint=... \
88 | --output_node_names="Openpose/concat_stage7"
89 | ```
90 |
91 | And the optimization can be performed on the frozen model via graph transform provided by tensorflow.
92 |
93 | ```bash
94 | $ bazel build tensorflow/tools/graph_transforms:transform_graph
95 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
96 | --in_graph=... \
97 | --out_graph=... \
98 | --inputs='image:0' \
99 | --outputs='Openpose/concat_stage7:0' \
100 | --transforms='
101 | strip_unused_nodes(type=float, shape="1,368,368,3")
102 | remove_nodes(op=Identity, op=CheckNumerics)
103 | fold_constants(ignoreError=False)
104 | fold_old_batch_norms
105 | fold_batch_norms'
106 | ```
107 |
108 | Also, It is promising to quantize neural network in 8 bit to get futher improvement for speed. In my case, this will make inference less accurate and take more time on Intel's CPUs.
109 |
110 | ```
111 | $ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
112 | --in_graph=/Users/ildoonet/repos/tf-openpose/tmp/cmu_640x480/graph_opt.pb \
113 | --out_graph=/Users/ildoonet/repos/tf-openpose/tmp/cmu_640x480/graph_q.pb \
114 | --inputs='image' \
115 | --outputs='Openpose/concat_stage7:0' \
116 | --transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,360,640,3")
117 | remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true)
118 | fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes
119 | strip_unused_nodes sort_by_execution_order'
120 | ```
--------------------------------------------------------------------------------
/images/apink1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink1.jpg
--------------------------------------------------------------------------------
/images/apink1_crop.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink1_crop.jpg
--------------------------------------------------------------------------------
/images/apink1_crop_s1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink1_crop_s1.jpg
--------------------------------------------------------------------------------
/images/apink2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink2.jpg
--------------------------------------------------------------------------------
/images/apink3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/apink3.jpg
--------------------------------------------------------------------------------
/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/cat.jpg
--------------------------------------------------------------------------------
/images/golf.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/golf.jpg
--------------------------------------------------------------------------------
/images/hand1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/hand1.jpg
--------------------------------------------------------------------------------
/images/hand1_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/hand1_small.jpg
--------------------------------------------------------------------------------
/images/hand2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/hand2.jpg
--------------------------------------------------------------------------------
/images/handsup1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/handsup1.jpg
--------------------------------------------------------------------------------
/images/p1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p1.jpg
--------------------------------------------------------------------------------
/images/p2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p2.jpg
--------------------------------------------------------------------------------
/images/p3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p3.jpg
--------------------------------------------------------------------------------
/images/p3_dance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/p3_dance.png
--------------------------------------------------------------------------------
/images/ski.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/ski.jpg
--------------------------------------------------------------------------------
/images/valid_person1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/images/valid_person1.jpg
--------------------------------------------------------------------------------
/launch/demo_video.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/models/graph/cmu_640x360/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "download model graph : cmu_640x360"
4 |
5 | wget http://download1662.mediafire.com/9w41h1qg57og/fdhqpmdrdzafoc4/graph_freeze.pb -O graph_freeze.pb
6 | wget http://download1650.mediafire.com/35hcd7ukp3fg/n6qnqz00g1pjf7d/graph_opt.pb -O graph_opt.pb
7 | wget http://download1193.mediafire.com/eaoeszlwevfg/38hyjrwfdsyqsbq/graph_q.pb -O graph_q.pb
8 | wget http://download1477.mediafire.com/5mujvsj810xg/a2a0nc8i1oj5iam/graph.pb -O graph.pb
9 |
--------------------------------------------------------------------------------
/models/graph/cmu_640x480/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "download model graph : cmu_640x480"
4 |
5 | wget http://download1515.mediafire.com/66p3hu7p6lfg/q7fbue8dqr3ytrb/graph_freeze.pb -O graph_freeze.pb
6 | wget http://download1640.mediafire.com/vqciqfcbz7qg/eolfk6t1t3yb191/graph_opt.pb -O graph_opt.pb
7 | wget http://download843.mediafire.com/zczmlmayrrng/s6d01qvmlkfxgzr/graph_q.pb -O graph_q.pb
8 | wget http://download938.mediafire.com/3mootio0u5ag/ae7hud583cx259z/graph.pb -O graph.pb
9 |
--------------------------------------------------------------------------------
/models/graph/mobilenet_thin_432x368/graph_freeze.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/models/graph/mobilenet_thin_432x368/graph_freeze.pb
--------------------------------------------------------------------------------
/models/graph/mobilenet_thin_432x368/graph_opt.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/models/graph/mobilenet_thin_432x368/graph_opt.pb
--------------------------------------------------------------------------------
/models/numpy/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget http://www.mediafire.com/file/ropayv77vklvf56/openpose_coco.npy
4 | wget http://www.mediafire.com/file/7e73ddj31rzw6qq/openpose_vgg16.npy
5 |
--------------------------------------------------------------------------------
/models/pretrained/mobilenet_v1_0.50_224_2017_06_14/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget http://www.mediafire.com/file/meu73iq8rxlsd3g/mobilenet_v1_0.50_224.ckpt.data-00000-of-00001
4 | wget http://www.mediafire.com/file/7u6iupfkcaxk5hx/mobilenet_v1_0.50_224.ckpt.index
5 | wget http://www.mediafire.com/file/zp8y4d0ytzharzz/mobilenet_v1_0.50_224.ckpt.meta
6 |
--------------------------------------------------------------------------------
/models/pretrained/mobilenet_v1_0.75_224_2017_06_14/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget http://www.mediafire.com/file/kibz0x9e7h11ueb/mobilenet_v1_0.75_224.ckpt.data-00000-of-00001
4 | wget http://www.mediafire.com/file/t8909eaikvc6ea2/mobilenet_v1_0.75_224.ckpt.index
5 | wget http://www.mediafire.com/file/6jjnbn1aged614x/mobilenet_v1_0.75_224.ckpt.meta
6 |
--------------------------------------------------------------------------------
/models/pretrained/mobilenet_v1_1.0_224_2017_06_14/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget http://www.mediafire.com/file/oh6njnz9lgoqwdj/mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
4 | wget http://www.mediafire.com/file/61qln0tbac4ny9o/mobilenet_v1_1.0_224.ckpt.meta
5 | wget http://www.mediafire.com/file/2111rh6tb5fl1lr/mobilenet_v1_1.0_224.ckpt.index
6 |
--------------------------------------------------------------------------------
/msg/BodyPartElm.msg:
--------------------------------------------------------------------------------
1 | uint32 part_id
2 | float32 x
3 | float32 y
4 | float32 confidence
--------------------------------------------------------------------------------
/msg/Person.msg:
--------------------------------------------------------------------------------
1 | BodyPartElm[] body_part
--------------------------------------------------------------------------------
/msg/Persons.msg:
--------------------------------------------------------------------------------
1 | Person[] persons
2 | uint32 image_w
3 | uint32 image_h
4 | Header header
--------------------------------------------------------------------------------
/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | tfpose_ros
4 | 0.1.0
5 | ROS Package for Human Pose Estimation Implemented in Tensorflow
6 |
7 | Curtis Kim
8 |
9 | GNU2.0
10 |
11 | catkin
12 |
13 | rospy
14 | message_generation
15 | std_msgs
16 | cv_bridge
17 |
18 | message_generation
19 | message_runtime
20 | std_msgs
21 | rospy
22 |
23 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | matplotlib
3 | scipy
4 | tqdm
5 | requests
6 | fire
7 | git+https://github.com/ppwwyyxx/tensorpack.git
--------------------------------------------------------------------------------
/scripts/broadcaster_ros.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import time
3 | import os
4 | import sys
5 | import ast
6 |
7 | from threading import Lock
8 | import rospy
9 | import rospkg
10 | from cv_bridge import CvBridge, CvBridgeError
11 | from sensor_msgs.msg import Image
12 | from tfpose_ros.msg import Persons, Person, BodyPartElm
13 |
14 | from estimator import TfPoseEstimator
15 | from networks import model_wh, get_graph_path
16 |
17 |
18 | def humans_to_msg(humans):
19 | persons = Persons()
20 |
21 | for human in humans:
22 | person = Person()
23 |
24 | for k in human.body_parts:
25 | body_part = human.body_parts[k]
26 |
27 | body_part_msg = BodyPartElm()
28 | body_part_msg.part_id = body_part.part_idx
29 | body_part_msg.x = body_part.x
30 | body_part_msg.y = body_part.y
31 | body_part_msg.confidence = body_part.score
32 | person.body_part.append(body_part_msg)
33 | persons.persons.append(person)
34 |
35 | return persons
36 |
37 |
38 | def callback_image(data):
39 | # et = time.time()
40 | try:
41 | cv_image = cv_bridge.imgmsg_to_cv2(data, "bgr8")
42 | except CvBridgeError as e:
43 | rospy.logerr('[ros-video-recorder][VideoFrames] Converting Image Error. ' + str(e))
44 | return
45 |
46 | acquired = tf_lock.acquire(False)
47 | if not acquired:
48 | return
49 |
50 | try:
51 | humans = pose_estimator.inference(cv_image, scales)
52 | finally:
53 | tf_lock.release()
54 |
55 | msg = humans_to_msg(humans)
56 | msg.image_w = data.width
57 | msg.image_h = data.height
58 | msg.header = data.header
59 |
60 | pub_pose.publish(msg)
61 | # rospy.loginfo(time.time() - et)
62 |
63 |
64 | if __name__ == '__main__':
65 | rospy.loginfo('initialization+')
66 | rospy.init_node('TfPoseEstimatorROS', anonymous=True)
67 |
68 | # parameters
69 | image_topic = rospy.get_param('~camera', '')
70 | model = rospy.get_param('~model', 'cmu_640x480')
71 | scales = rospy.get_param('~scales', '[None]')
72 | scales = ast.literal_eval(scales)
73 | tf_lock = Lock()
74 |
75 | rospy.loginfo('[TfPoseEstimatorROS] scales(%d)=%s' % (len(scales), str(scales)))
76 |
77 | if not image_topic:
78 | rospy.logerr('Parameter \'camera\' is not provided.')
79 | sys.exit(-1)
80 |
81 | try:
82 | w, h = model_wh(model)
83 | graph_path = get_graph_path(model)
84 |
85 | rospack = rospkg.RosPack()
86 | graph_path = os.path.join(rospack.get_path('tfpose_ros'), graph_path)
87 | except Exception as e:
88 | rospy.logerr('invalid model: %s, e=%s' % (model, e))
89 | sys.exit(-1)
90 |
91 | pose_estimator = TfPoseEstimator(graph_path, target_size=(w, h))
92 | cv_bridge = CvBridge()
93 |
94 | rospy.Subscriber(image_topic, Image, callback_image, queue_size=1, buff_size=2**24)
95 | pub_pose = rospy.Publisher('~pose', Persons, queue_size=1)
96 |
97 | rospy.loginfo('start+')
98 | rospy.spin()
99 | rospy.loginfo('finished')
100 |
--------------------------------------------------------------------------------
/scripts/visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import time
3 | import cv2
4 | import rospy
5 | from sensor_msgs.msg import Image
6 | from cv_bridge import CvBridge, CvBridgeError
7 |
8 | from tfpose_ros.msg import Persons, Person, BodyPartElm
9 | from estimator import Human, BodyPart, TfPoseEstimator
10 |
11 |
12 | class VideoFrames:
13 | """
14 | Reference : ros-video-recorder
15 | https://github.com/ildoonet/ros-video-recorder/blob/master/scripts/recorder.py
16 | """
17 | def __init__(self, image_topic):
18 | self.image_sub = rospy.Subscriber(image_topic, Image, self.callback_image, queue_size=1)
19 | self.bridge = CvBridge()
20 | self.frames = []
21 |
22 | def callback_image(self, data):
23 | try:
24 | cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
25 | except CvBridgeError as e:
26 | rospy.logerr('Converting Image Error. ' + str(e))
27 | return
28 |
29 | self.frames.append((data.header.stamp, cv_image))
30 |
31 | def get_latest(self, at_time, remove_older=True):
32 | fs = [x for x in self.frames if x[0] <= at_time]
33 | if len(fs) == 0:
34 | return None
35 |
36 | f = fs[-1]
37 | if remove_older:
38 | self.frames = self.frames[len(fs) - 1:]
39 |
40 | return f[1]
41 |
42 |
43 | def cb_pose(data):
44 | # get image with pose time
45 | t = data.header.stamp
46 | image = vf.get_latest(t, remove_older=True)
47 | if image is None:
48 | rospy.logwarn('No received images.')
49 | return
50 |
51 | h, w = image.shape[:2]
52 | if resize_ratio > 0:
53 | image = cv2.resize(image, (int(resize_ratio*w), int(resize_ratio*h)), interpolation=cv2.INTER_LINEAR)
54 |
55 | # ros topic to Person instance
56 | humans = []
57 | for p_idx, person in enumerate(data.persons):
58 | human = Human([])
59 | for body_part in person.body_part:
60 | part = BodyPart('', body_part.part_id, body_part.x, body_part.y, body_part.confidence)
61 | human.body_parts[body_part.part_id] = part
62 |
63 | humans.append(human)
64 |
65 | # draw
66 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False)
67 | pub_img.publish(cv_bridge.cv2_to_imgmsg(image, 'bgr8'))
68 |
69 |
70 | if __name__ == '__main__':
71 | rospy.loginfo('initialization+')
72 | rospy.init_node('TfPoseEstimatorROS-Visualization', anonymous=True)
73 |
74 | # topics params
75 | image_topic = rospy.get_param('~camera', '')
76 | pose_topic = rospy.get_param('~pose', '/pose_estimator/pose')
77 |
78 | resize_ratio = float(rospy.get_param('~resize_ratio', '-1'))
79 |
80 | # publishers
81 | pub_img = rospy.Publisher('~output', Image, queue_size=1)
82 |
83 | # initialization
84 | cv_bridge = CvBridge()
85 | vf = VideoFrames(image_topic)
86 | rospy.wait_for_message(image_topic, Image, timeout=30)
87 |
88 | # subscribers
89 | rospy.Subscriber(pose_topic, Persons, cb_pose, queue_size=1)
90 |
91 | # run
92 | rospy.spin()
93 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/__init__.py
--------------------------------------------------------------------------------
/src/common.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import tensorflow as tf
4 | import cv2
5 |
6 |
7 | regularizer_conv = 0.004
8 | regularizer_dsconv = 0.0004
9 | batchnorm_fused = True
10 | activation_fn = tf.nn.relu
11 |
12 |
13 | class CocoPart(Enum):
14 | Nose = 0
15 | Neck = 1
16 | RShoulder = 2
17 | RElbow = 3
18 | RWrist = 4
19 | LShoulder = 5
20 | LElbow = 6
21 | LWrist = 7
22 | RHip = 8
23 | RKnee = 9
24 | RAnkle = 10
25 | LHip = 11
26 | LKnee = 12
27 | LAnkle = 13
28 | REye = 14
29 | LEye = 15
30 | REar = 16
31 | LEar = 17
32 | Background = 18
33 |
34 |
35 | class MPIIPart(Enum):
36 | RAnkle = 0
37 | RKnee = 1
38 | RHip = 2
39 | LHip = 3
40 | LKnee = 4
41 | LAnkle = 5
42 | RWrist = 6
43 | RElbow = 7
44 | RShoulder = 8
45 | LShoulder = 9
46 | LElbow = 10
47 | LWrist = 11
48 | Neck = 12
49 | Head = 13
50 |
51 | @staticmethod
52 | def from_coco(human):
53 | # t = {
54 | # MPIIPart.RAnkle: CocoPart.RAnkle,
55 | # MPIIPart.RKnee: CocoPart.RKnee,
56 | # MPIIPart.RHip: CocoPart.RHip,
57 | # MPIIPart.LHip: CocoPart.LHip,
58 | # MPIIPart.LKnee: CocoPart.LKnee,
59 | # MPIIPart.LAnkle: CocoPart.LAnkle,
60 | # MPIIPart.RWrist: CocoPart.RWrist,
61 | # MPIIPart.RElbow: CocoPart.RElbow,
62 | # MPIIPart.RShoulder: CocoPart.RShoulder,
63 | # MPIIPart.LShoulder: CocoPart.LShoulder,
64 | # MPIIPart.LElbow: CocoPart.LElbow,
65 | # MPIIPart.LWrist: CocoPart.LWrist,
66 | # MPIIPart.Neck: CocoPart.Neck,
67 | # MPIIPart.Nose: CocoPart.Nose,
68 | # }
69 |
70 | t = [
71 | (MPIIPart.Head, CocoPart.Nose),
72 | (MPIIPart.Neck, CocoPart.Neck),
73 | (MPIIPart.RShoulder, CocoPart.RShoulder),
74 | (MPIIPart.RElbow, CocoPart.RElbow),
75 | (MPIIPart.RWrist, CocoPart.RWrist),
76 | (MPIIPart.LShoulder, CocoPart.LShoulder),
77 | (MPIIPart.LElbow, CocoPart.LElbow),
78 | (MPIIPart.LWrist, CocoPart.LWrist),
79 | (MPIIPart.RHip, CocoPart.RHip),
80 | (MPIIPart.RKnee, CocoPart.RKnee),
81 | (MPIIPart.RAnkle, CocoPart.RAnkle),
82 | (MPIIPart.LHip, CocoPart.LHip),
83 | (MPIIPart.LKnee, CocoPart.LKnee),
84 | (MPIIPart.LAnkle, CocoPart.LAnkle),
85 | ]
86 |
87 | pose_2d_mpii = []
88 | visibilty = []
89 | for mpi, coco in t:
90 | if coco.value not in human.body_parts.keys():
91 | pose_2d_mpii.append((0, 0))
92 | visibilty.append(False)
93 | continue
94 | pose_2d_mpii.append((human.body_parts[coco.value].x, human.body_parts[coco.value].y))
95 | visibilty.append(True)
96 | return pose_2d_mpii, visibilty
97 |
98 | CocoPairs = [
99 | (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7), (1, 8), (8, 9), (9, 10), (1, 11),
100 | (11, 12), (12, 13), (1, 0), (0, 14), (14, 16), (0, 15), (15, 17), (2, 16), (5, 17)
101 | ] # = 19
102 | CocoPairsRender = CocoPairs[:-2]
103 | CocoPairsNetwork = [
104 | (12, 13), (20, 21), (14, 15), (16, 17), (22, 23), (24, 25), (0, 1), (2, 3), (4, 5),
105 | (6, 7), (8, 9), (10, 11), (28, 29), (30, 31), (34, 35), (32, 33), (36, 37), (18, 19), (26, 27)
106 | ] # = 19
107 |
108 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
109 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
110 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
111 |
112 |
113 | def read_imgfile(path, width, height):
114 | val_image = cv2.imread(path, cv2.IMREAD_COLOR)
115 | if width is not None and height is not None:
116 | val_image = cv2.resize(val_image, (width, height))
117 | return val_image
118 |
119 |
120 | def get_sample_images(w, h):
121 | val_image = [
122 | read_imgfile('./images/p1.jpg', w, h),
123 | read_imgfile('./images/p2.jpg', w, h),
124 | read_imgfile('./images/p3.jpg', w, h),
125 | read_imgfile('./images/golf.jpg', w, h),
126 | read_imgfile('./images/hand1.jpg', w, h),
127 | read_imgfile('./images/hand2.jpg', w, h),
128 | read_imgfile('./images/apink1_crop.jpg', w, h),
129 | read_imgfile('./images/ski.jpg', w, h),
130 | read_imgfile('./images/apink2.jpg', w, h),
131 | read_imgfile('./images/apink3.jpg', w, h),
132 | read_imgfile('./images/handsup1.jpg', w, h),
133 | read_imgfile('./images/p3_dance.png', w, h),
134 | ]
135 | return val_image
136 |
--------------------------------------------------------------------------------
/src/datum_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: datum.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='datum.proto',
20 | package='',
21 | serialized_pb=_b('\n\x0b\x64\x61tum.proto\"\x81\x01\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02\x12\x16\n\x07\x65ncoded\x18\x07 \x01(\x08:\x05\x66\x61lse')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _DATUM = _descriptor.Descriptor(
29 | name='Datum',
30 | full_name='Datum',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='channels', full_name='Datum.channels', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='height', full_name='Datum.height', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='width', full_name='Datum.width', index=2,
51 | number=3, type=5, cpp_type=1, label=1,
52 | has_default_value=False, default_value=0,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='data', full_name='Datum.data', index=3,
58 | number=4, type=12, cpp_type=9, label=1,
59 | has_default_value=False, default_value=_b(""),
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='label', full_name='Datum.label', index=4,
65 | number=5, type=5, cpp_type=1, label=1,
66 | has_default_value=False, default_value=0,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | _descriptor.FieldDescriptor(
71 | name='float_data', full_name='Datum.float_data', index=5,
72 | number=6, type=2, cpp_type=6, label=3,
73 | has_default_value=False, default_value=[],
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | options=None),
77 | _descriptor.FieldDescriptor(
78 | name='encoded', full_name='Datum.encoded', index=6,
79 | number=7, type=8, cpp_type=7, label=1,
80 | has_default_value=True, default_value=False,
81 | message_type=None, enum_type=None, containing_type=None,
82 | is_extension=False, extension_scope=None,
83 | options=None),
84 | ],
85 | extensions=[
86 | ],
87 | nested_types=[],
88 | enum_types=[
89 | ],
90 | options=None,
91 | is_extendable=False,
92 | extension_ranges=[],
93 | oneofs=[
94 | ],
95 | serialized_start=16,
96 | serialized_end=145,
97 | )
98 |
99 | DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
100 |
101 | Datum = _reflection.GeneratedProtocolMessageType('Datum', (_message.Message,), dict(
102 | DESCRIPTOR = _DATUM,
103 | __module__ = 'datum_pb2'
104 | # @@protoc_insertion_point(class_scope:Datum)
105 | ))
106 | _sym_db.RegisterMessage(Datum)
107 |
108 |
109 | # @@protoc_insertion_point(module_scope)
110 |
--------------------------------------------------------------------------------
/src/lifting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/lifting/__init__.py
--------------------------------------------------------------------------------
/src/lifting/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mar 23 11:57 2017
4 |
5 | @author: Denis Tome'
6 | """
7 |
8 | __all__ = [
9 | 'VISIBLE_PART',
10 | 'MIN_NUM_JOINTS',
11 | 'CENTER_TR',
12 | 'SIGMA',
13 | 'STRIDE',
14 | 'SIGMA_CENTER',
15 | 'INPUT_SIZE',
16 | 'OUTPUT_SIZE',
17 | 'NUM_JOINTS',
18 | 'NUM_OUTPUT',
19 | 'H36M_NUM_JOINTS',
20 | 'JOINT_DRAW_SIZE',
21 | 'LIMB_DRAW_SIZE'
22 | ]
23 |
24 | # threshold
25 | VISIBLE_PART = 1e-3
26 | MIN_NUM_JOINTS = 5
27 | CENTER_TR = 0.4
28 |
29 | # net attributes
30 | SIGMA = 7
31 | STRIDE = 8
32 | SIGMA_CENTER = 21
33 | INPUT_SIZE = 368
34 | OUTPUT_SIZE = 46
35 | NUM_JOINTS = 14
36 | NUM_OUTPUT = NUM_JOINTS + 1
37 | H36M_NUM_JOINTS = 17
38 |
39 | # draw options
40 | JOINT_DRAW_SIZE = 3
41 | LIMB_DRAW_SIZE = 2
42 | NORMALISATION_COEFFICIENT = 1280*720
43 |
--------------------------------------------------------------------------------
/src/lifting/draw.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mar 23 15:04 2017
4 |
5 | @author: Denis Tome'
6 | """
7 | import cv2
8 | import numpy as np
9 | from .config import JOINT_DRAW_SIZE
10 | from .config import LIMB_DRAW_SIZE
11 | from .config import NORMALISATION_COEFFICIENT
12 | import matplotlib.pyplot as plt
13 | import math
14 |
15 | __all__ = [
16 | 'draw_limbs',
17 | 'plot_pose'
18 | ]
19 |
20 |
21 | def draw_limbs(image, pose_2d, visible):
22 | """Draw the 2D pose without the occluded/not visible joints."""
23 |
24 | _COLORS = [
25 | [0, 0, 255], [0, 170, 255], [0, 255, 170], [0, 255, 0],
26 | [170, 255, 0], [255, 170, 0], [255, 0, 0], [255, 0, 170],
27 | [170, 0, 255]
28 | ]
29 | _LIMBS = np.array([0, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8, 9,
30 | 9, 10, 11, 12, 12, 13]).reshape((-1, 2))
31 |
32 | _NORMALISATION_FACTOR = int(math.floor(math.sqrt(image.shape[0] * image.shape[1] / NORMALISATION_COEFFICIENT)))
33 |
34 | for oid in range(pose_2d.shape[0]):
35 | for lid, (p0, p1) in enumerate(_LIMBS):
36 | if not (visible[oid][p0] and visible[oid][p1]):
37 | continue
38 | y0, x0 = pose_2d[oid][p0]
39 | y1, x1 = pose_2d[oid][p1]
40 | cv2.circle(image, (x0, y0), JOINT_DRAW_SIZE *_NORMALISATION_FACTOR , _COLORS[lid], -1)
41 | cv2.circle(image, (x1, y1), JOINT_DRAW_SIZE*_NORMALISATION_FACTOR , _COLORS[lid], -1)
42 | cv2.line(image, (x0, y0), (x1, y1),
43 | _COLORS[lid], LIMB_DRAW_SIZE*_NORMALISATION_FACTOR , 16)
44 |
45 |
46 | def plot_pose(pose):
47 | """Plot the 3D pose showing the joint connections."""
48 | import mpl_toolkits.mplot3d.axes3d as p3
49 |
50 | _CONNECTION = [
51 | [0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8],
52 | [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], [8, 14], [14, 15],
53 | [15, 16]]
54 |
55 | def joint_color(j):
56 | """
57 | TODO: 'j' shadows name 'j' from outer scope
58 | """
59 |
60 | colors = [(0, 0, 0), (255, 0, 255), (0, 0, 255),
61 | (0, 255, 255), (255, 0, 0), (0, 255, 0)]
62 | _c = 0
63 | if j in range(1, 4):
64 | _c = 1
65 | if j in range(4, 7):
66 | _c = 2
67 | if j in range(9, 11):
68 | _c = 3
69 | if j in range(11, 14):
70 | _c = 4
71 | if j in range(14, 17):
72 | _c = 5
73 | return colors[_c]
74 |
75 | assert (pose.ndim == 2)
76 | assert (pose.shape[0] == 3)
77 | fig = plt.figure()
78 | ax = fig.gca(projection='3d')
79 | for c in _CONNECTION:
80 | col = '#%02x%02x%02x' % joint_color(c[0])
81 | ax.plot([pose[0, c[0]], pose[0, c[1]]],
82 | [pose[1, c[0]], pose[1, c[1]]],
83 | [pose[2, c[0]], pose[2, c[1]]], c=col)
84 | for j in range(pose.shape[1]):
85 | col = '#%02x%02x%02x' % joint_color(j)
86 | ax.scatter(pose[0, j], pose[1, j], pose[2, j],
87 | c=col, marker='o', edgecolor=col)
88 | smallest = pose.min()
89 | largest = pose.max()
90 | ax.set_xlim3d(smallest, largest)
91 | ax.set_ylim3d(smallest, largest)
92 | ax.set_zlim3d(smallest, largest)
93 |
94 | return fig
95 |
96 |
97 |
--------------------------------------------------------------------------------
/src/lifting/models/prob_model_params.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/lifting/models/prob_model_params.mat
--------------------------------------------------------------------------------
/src/lifting/prob_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Apr 21 13:53 2017
4 |
5 | @author: Denis Tome'
6 | """
7 | import numpy as np
8 | import scipy.io as sio
9 | from lifting.upright_fast import pick_e
10 |
11 | from lifting import config
12 |
13 | __all__ = ['Prob3dPose']
14 |
15 |
16 | class Prob3dPose:
17 |
18 | def __init__(self, prob_model_path):
19 | model_param = sio.loadmat(prob_model_path)
20 | self.mu = np.reshape(
21 | model_param['mu'], (model_param['mu'].shape[0], 3, -1))
22 | self.e = np.reshape(model_param['e'], (model_param['e'].shape[
23 | 0], model_param['e'].shape[1], 3, -1))
24 | self.sigma = model_param['sigma']
25 | self.cam = np.array(
26 | [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
27 |
28 | @staticmethod
29 | def cost3d(model, gt):
30 | """3d error in mm"""
31 | out = np.sqrt(((gt - model) ** 2).sum(1)).mean(-1)
32 | return out
33 |
34 | @staticmethod
35 | def renorm_gt(gt):
36 | """Compel gt data to have mean joint length of one"""
37 | _POSE_TREE = np.asarray([
38 | [0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8],
39 | [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], [8, 14], [14, 15],
40 | [15, 16]]).T
41 | scale = np.sqrt(((gt[:, :, _POSE_TREE[0]] -
42 | gt[:, :, _POSE_TREE[1]]) ** 2).sum(2).sum(1))
43 | return gt / scale[:, np.newaxis, np.newaxis]
44 |
45 | @staticmethod
46 | def build_model(a, e, s0):
47 | """Build 3D model"""
48 | assert (s0.shape[1] == 3)
49 | assert (e.shape[2] == 3)
50 | assert (a.shape[1] == e.shape[1])
51 | out = np.einsum('...i,...ijk', a, e)
52 | out += s0
53 | return out
54 |
55 | @staticmethod
56 | def build_and_rot_model(a, e, s0, r):
57 | """
58 | Build model and rotate according to the identified rotation matrix
59 | """
60 | from numpy.core.umath_tests import matrix_multiply
61 |
62 | r2 = Prob3dPose.upgrade_r(r.T).transpose((0, 2, 1))
63 | mod = Prob3dPose.build_model(a, e, s0)
64 | mod = matrix_multiply(r2, mod)
65 | return mod
66 |
67 | @staticmethod
68 | def upgrade_r(r):
69 | """
70 | Upgrades complex parameterisation of planar rotation to tensor
71 | containing per frame 3x3 rotation matrices
72 | """
73 | assert (r.ndim == 2)
74 | # Technically optional assert, but if this fails data is probably
75 | # transposed
76 | assert (r.shape[1] == 2)
77 | assert (np.all(np.isfinite(r)))
78 | norm = np.sqrt((r[:, :2] ** 2).sum(1))
79 | assert (np.all(norm > 0))
80 | r /= norm[:, np.newaxis]
81 | assert (np.all(np.isfinite(r)))
82 | newr = np.zeros((r.shape[0], 3, 3))
83 | newr[:, :2, 0] = r[:, :2]
84 | newr[:, 2, 2] = 1
85 | newr[:, 1::-1, 1] = r[:, :2]
86 | newr[:, 0, 1] *= -1
87 | return newr
88 |
89 | @staticmethod
90 | def centre(data_2d):
91 | """center data according to each of the coordiante components"""
92 | return (data_2d.T - data_2d.mean(1)).T
93 |
94 | @staticmethod
95 | def centre_all(data):
96 | """center all data"""
97 | if data.ndim == 2:
98 | return Prob3dPose.centre(data)
99 | return (data.transpose(2, 0, 1) - data.mean(2)).transpose(1, 2, 0)
100 |
101 | @staticmethod
102 | def normalise_data(d2, weights):
103 | """Normalise data according to height"""
104 |
105 | # the joints with weight set to 0 should not be considered in the
106 | # normalisation process
107 | d2 = d2.reshape(d2.shape[0], -1, 2).transpose(0, 2, 1)
108 | idx_consider = weights[0, 0].astype(np.bool)
109 | if np.sum(weights[:, 0].sum(1) >= config.MIN_NUM_JOINTS) == 0:
110 | raise Exception('Not enough 2D joints identified to generate 3D pose')
111 | d2[:, :, idx_consider] = Prob3dPose.centre_all(d2[:, :, idx_consider])
112 |
113 | # Height normalisation (2 meters)
114 | m2 = d2[:, 1, idx_consider].min(1) / 2.0
115 | m2 -= d2[:, 1, idx_consider].max(1) / 2.0
116 | crap = m2 == 0
117 | m2[crap] = 1.0
118 | d2[:, :, idx_consider] /= m2[:, np.newaxis, np.newaxis]
119 | return d2, m2
120 |
121 | @staticmethod
122 | def transform_joints(pose_2d, visible_joints):
123 | """
124 | Transform the set of joints according to what the probabilistic model
125 | expects as input.
126 |
127 | It returns the new set of joints of each of the people and the set of
128 | weights for the joints.
129 | """
130 |
131 | _H36M_ORDER = [8, 9, 10, 11, 12, 13, 1, 0, 5, 6, 7, 2, 3, 4]
132 | _W_POS = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16]
133 |
134 | def swap_xy(poses):
135 | tmp = np.copy(poses[:, :, 0])
136 | poses[:, :, 0] = poses[:, :, 1]
137 | poses[:, :, 1] = tmp
138 | return poses
139 |
140 | assert (pose_2d.ndim == 3)
141 | new_pose = pose_2d.copy()
142 | # new_pose = swap_xy(new_pose) # not used
143 | new_pose = new_pose[:, _H36M_ORDER]
144 |
145 | # defining weights according to occlusions
146 | weights = np.zeros((pose_2d.shape[0], 2, config.H36M_NUM_JOINTS))
147 | ordered_visibility = np.repeat(
148 | visible_joints[:, _H36M_ORDER, np.newaxis], 2, 2
149 | ).transpose([0, 2, 1])
150 | weights[:, :, _W_POS] = ordered_visibility
151 | return new_pose, weights
152 |
153 | def affine_estimate(self, w, depth_reg=0.085, weights=None, scale=10.0,
154 | scale_mean=0.0016 * 1.8 * 1.2, scale_std=1.2 * 0,
155 | cap_scale=-0.00129):
156 | """
157 | Quick switch to allow reconstruction at unknown scale returns a,r
158 | and scale
159 | """
160 | weights = np.zeros((0, 0, 0)) if weights is None else weights
161 |
162 | s = np.empty((self.sigma.shape[0], self.sigma.shape[1] + 4)) # e,y,x,z
163 | s[:, :4] = 10 ** -5 # Tiny but makes stuff well-posed
164 | s[:, 0] = scale_std
165 | s[:, 4:] = self.sigma
166 | s[:, 4:-1] *= scale
167 |
168 | e2 = np.zeros((self.e.shape[0], self.e.shape[
169 | 1] + 4, 3, self.e.shape[3]))
170 | e2[:, 1, 0] = 1.0
171 | e2[:, 2, 1] = 1.0
172 | e2[:, 3, 0] = 1.0
173 | # This makes the least_squares problem ill posed, as X,Z are
174 | # interchangable
175 | # Hence regularisation above to speed convergence and stop blow-up
176 | e2[:, 0] = self.mu
177 | e2[:, 4:] = self.e
178 | t_m = np.zeros_like(self.mu)
179 |
180 | res, a, r = pick_e(w, e2, t_m, self.cam, s, weights=weights,
181 | interval=0.01, depth_reg=depth_reg,
182 | scale_prior=scale_mean)
183 |
184 | scale = a[:, :, 0]
185 | reestimate = scale > cap_scale
186 | m = self.mu * cap_scale
187 | for i in range(scale.shape[0]):
188 | if reestimate[i].sum() > 0:
189 | ehat = e2[i:i + 1, 1:]
190 | mhat = m[i:i + 1]
191 | shat = s[i:i + 1, 1:]
192 | (res2, a2, r2) = pick_e(
193 | w[reestimate[i]], ehat, mhat, self.cam, shat,
194 | weights=weights[reestimate[i]],
195 | interval=0.01, depth_reg=depth_reg,
196 | scale_prior=scale_mean
197 | )
198 | res[i:i + 1, reestimate[i]] = res2
199 | a[i:i + 1, reestimate[i], 1:] = a2
200 | a[i:i + 1, reestimate[i], 0] = cap_scale
201 | r[i:i + 1, :, reestimate[i]] = r2
202 | scale = a[:, :, 0]
203 | a = a[:, :, 1:] / a[:, :, 0][:, :, np.newaxis]
204 | return res, e2[:, 1:], a, r, scale
205 |
206 | def better_rec(self, w, model, s=1, weights=1, damp_z=1):
207 | """Quick switch to allow reconstruction at unknown scale
208 | returns a,r and scale"""
209 | from numpy.core.umath_tests import matrix_multiply
210 | proj = matrix_multiply(self.cam[np.newaxis], model)
211 | proj[:, :2] = (proj[:, :2] * s + w * weights) / (s + weights)
212 | proj[:, 2] *= damp_z
213 | out = matrix_multiply(self.cam.T[np.newaxis], proj)
214 | return out
215 |
216 | def create_rec(self, w2, weights, res_weight=1):
217 | """Reconstruct 3D pose given a 2D pose"""
218 | _SIGMA_SCALING = 5.2
219 |
220 | res, e, a, r, scale = self.affine_estimate(
221 | w2, scale=_SIGMA_SCALING, weights=weights,
222 | depth_reg=0, cap_scale=-0.001, scale_mean=-0.003
223 | )
224 |
225 | remaining_dims = 3 * w2.shape[2] - e.shape[1]
226 | assert (remaining_dims >= 0)
227 | llambda = -np.log(self.sigma)
228 | lgdet = np.sum(llambda[:, :-1], 1) + llambda[:, -1] * remaining_dims
229 | score = (res * res_weight + lgdet[:, np.newaxis] * (scale ** 2))
230 | best = np.argmin(score, 0)
231 | index = np.arange(best.shape[0])
232 | a2 = a[best, index]
233 | r2 = r[best, :, index].T
234 | rec = Prob3dPose.build_and_rot_model(a2, e[best], self.mu[best], r2)
235 | rec *= -np.abs(scale[best, index])[:, np.newaxis, np.newaxis]
236 |
237 | rec = self.better_rec(w2, rec, 1, 1.55 * weights, 1) * -1
238 | rec = Prob3dPose.renorm_gt(rec)
239 | rec *= 0.97
240 | return rec
241 |
242 | def compute_3d(self, pose_2d, weights):
243 | """Reconstruct 3D poses given 2D estimations"""
244 |
245 | _J_POS = [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16]
246 | _SCALE_3D = 1174.88312988
247 |
248 | if pose_2d.shape[1] != config.H36M_NUM_JOINTS:
249 | # need to call the linear regressor
250 | reg_joints = np.zeros(
251 | (pose_2d.shape[0], config.H36M_NUM_JOINTS, 2))
252 | for oid, singe_pose in enumerate(pose_2d):
253 | reg_joints[oid, _J_POS] = singe_pose
254 |
255 | norm_pose, _ = Prob3dPose.normalise_data(reg_joints, weights)
256 | else:
257 | norm_pose, _ = Prob3dPose.normalise_data(pose_2d, weights)
258 |
259 | pose_3d = self.create_rec(norm_pose, weights) * _SCALE_3D
260 | return pose_3d
261 |
--------------------------------------------------------------------------------
/src/network_cmu.py:
--------------------------------------------------------------------------------
1 | import network_base
2 | import tensorflow as tf
3 |
4 |
5 | class CmuNetwork(network_base.BaseNetwork):
6 | def setup(self):
7 | (self.feed('image')
8 | .normalize_vgg(name='preprocess')
9 | .conv(3, 3, 64, 1, 1, name='conv1_1')
10 | .conv(3, 3, 64, 1, 1, name='conv1_2')
11 | .max_pool(2, 2, 2, 2, name='pool1_stage1')
12 | .conv(3, 3, 128, 1, 1, name='conv2_1')
13 | .conv(3, 3, 128, 1, 1, name='conv2_2')
14 | .max_pool(2, 2, 2, 2, name='pool2_stage1')
15 | .conv(3, 3, 256, 1, 1, name='conv3_1')
16 | .conv(3, 3, 256, 1, 1, name='conv3_2')
17 | .conv(3, 3, 256, 1, 1, name='conv3_3')
18 | .conv(3, 3, 256, 1, 1, name='conv3_4')
19 | .max_pool(2, 2, 2, 2, name='pool3_stage1')
20 | .conv(3, 3, 512, 1, 1, name='conv4_1')
21 | .conv(3, 3, 512, 1, 1, name='conv4_2')
22 | .conv(3, 3, 256, 1, 1, name='conv4_3_CPM')
23 | .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # *****
24 |
25 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1')
26 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1')
27 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1')
28 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1')
29 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1'))
30 |
31 | (self.feed('conv4_4_CPM')
32 | .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2')
33 | .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2')
34 | .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2')
35 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2')
36 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2'))
37 |
38 | (self.feed('conv5_5_CPM_L1',
39 | 'conv5_5_CPM_L2',
40 | 'conv4_4_CPM')
41 | .concat(3, name='concat_stage2')
42 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1')
43 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1')
44 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1')
45 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1')
46 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1')
47 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1')
48 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1'))
49 |
50 | (self.feed('concat_stage2')
51 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2')
52 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2')
53 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2')
54 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2')
55 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2')
56 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2')
57 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2'))
58 |
59 | (self.feed('Mconv7_stage2_L1',
60 | 'Mconv7_stage2_L2',
61 | 'conv4_4_CPM')
62 | .concat(3, name='concat_stage3')
63 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1')
64 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1')
65 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1')
66 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1')
67 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1')
68 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1')
69 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1'))
70 |
71 | (self.feed('concat_stage3')
72 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2')
73 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2')
74 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2')
75 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2')
76 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2')
77 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2')
78 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2'))
79 |
80 | (self.feed('Mconv7_stage3_L1',
81 | 'Mconv7_stage3_L2',
82 | 'conv4_4_CPM')
83 | .concat(3, name='concat_stage4')
84 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1')
85 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1')
86 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1')
87 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1')
88 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1')
89 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1')
90 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1'))
91 |
92 | (self.feed('concat_stage4')
93 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2')
94 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2')
95 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2')
96 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2')
97 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2')
98 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2')
99 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2'))
100 |
101 | (self.feed('Mconv7_stage4_L1',
102 | 'Mconv7_stage4_L2',
103 | 'conv4_4_CPM')
104 | .concat(3, name='concat_stage5')
105 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1')
106 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1')
107 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1')
108 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1')
109 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1')
110 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1')
111 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1'))
112 |
113 | (self.feed('concat_stage5')
114 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2')
115 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2')
116 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2')
117 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2')
118 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2')
119 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2')
120 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2'))
121 |
122 | (self.feed('Mconv7_stage5_L1',
123 | 'Mconv7_stage5_L2',
124 | 'conv4_4_CPM')
125 | .concat(3, name='concat_stage6')
126 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1')
127 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1')
128 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1')
129 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1')
130 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1')
131 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1')
132 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1'))
133 |
134 | (self.feed('concat_stage6')
135 | .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2')
136 | .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2')
137 | .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2')
138 | .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2')
139 | .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2')
140 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2')
141 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2'))
142 |
143 | with tf.variable_scope('Openpose'):
144 | (self.feed('Mconv7_stage6_L2',
145 | 'Mconv7_stage6_L1')
146 | .concat(3, name='concat_stage7'))
147 |
148 | def loss_l1_l2(self):
149 | l1s = []
150 | l2s = []
151 | for layer_name in self.layers.keys():
152 | if 'Mconv7' in layer_name and '_L1' in layer_name:
153 | l1s.append(self.layers[layer_name])
154 | if 'Mconv7' in layer_name and '_L2' in layer_name:
155 | l2s.append(self.layers[layer_name])
156 |
157 | return l1s, l2s
158 |
159 | def loss_last(self):
160 | return self.get_output('Mconv7_stage6_L1'), self.get_output('Mconv7_stage6_L2')
161 |
162 | def restorable_variables(self):
163 | return None
--------------------------------------------------------------------------------
/src/network_dsconv.py:
--------------------------------------------------------------------------------
1 | import network_base
2 |
3 |
4 | class DSConvNetwork(network_base.BaseNetwork):
5 | def __init__(self, inputs, trainable=True, conv_width=1.0):
6 | self.conv_width = conv_width
7 | network_base.BaseNetwork.__init__(self, inputs, trainable)
8 |
9 | def setup(self):
10 | (self.feed('image')
11 | .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False)
12 | # .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=True) # TODO
13 | .separable_conv(3, 3, round(self.conv_width * 64), 2, name='conv1_2')
14 | # .max_pool(2, 2, 2, 2, name='pool1_stage1')
15 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv2_1')
16 | .separable_conv(3, 3, round(self.conv_width * 128), 2, name='conv2_2')
17 | # .max_pool(2, 2, 2, 2, name='pool2_stage1')
18 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_1')
19 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_2')
20 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv3_3')
21 | .separable_conv(3, 3, round(self.conv_width * 256), 2, name='conv3_4')
22 | # .max_pool(2, 2, 2, 2, name='pool3_stage1')
23 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_1')
24 | .separable_conv(3, 3, round(self.conv_width * 512), 1, name='conv4_2')
25 | .separable_conv(3, 3, round(self.conv_width * 256), 1, name='conv4_3_CPM')
26 | .separable_conv(3, 3, 128, 1, name='conv4_4_CPM')
27 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L1')
28 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L1')
29 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L1')
30 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1')
31 | .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1'))
32 |
33 | (self.feed('conv4_4_CPM')
34 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_1_CPM_L2')
35 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_2_CPM_L2')
36 | .separable_conv(3, 3, round(self.conv_width * 128), 1, name='conv5_3_CPM_L2')
37 | .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2')
38 | .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2'))
39 |
40 | (self.feed('conv5_5_CPM_L1',
41 | 'conv5_5_CPM_L2',
42 | 'conv4_4_CPM')
43 | .concat(3, name='concat_stage2')
44 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L1')
45 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L1')
46 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L1')
47 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L1')
48 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L1')
49 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1')
50 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1'))
51 |
52 | (self.feed('concat_stage2')
53 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage2_L2')
54 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage2_L2')
55 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage2_L2')
56 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage2_L2')
57 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage2_L2')
58 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2')
59 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2'))
60 |
61 | (self.feed('Mconv7_stage2_L1',
62 | 'Mconv7_stage2_L2',
63 | 'conv4_4_CPM')
64 | .concat(3, name='concat_stage3')
65 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L1')
66 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L1')
67 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L1')
68 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L1')
69 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L1')
70 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1')
71 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1'))
72 |
73 | (self.feed('concat_stage3')
74 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage3_L2')
75 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage3_L2')
76 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage3_L2')
77 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage3_L2')
78 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage3_L2')
79 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2')
80 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2'))
81 |
82 | (self.feed('Mconv7_stage3_L1',
83 | 'Mconv7_stage3_L2',
84 | 'conv4_4_CPM')
85 | .concat(3, name='concat_stage4')
86 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L1')
87 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L1')
88 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L1')
89 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L1')
90 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L1')
91 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1')
92 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1'))
93 |
94 | (self.feed('concat_stage4')
95 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage4_L2')
96 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage4_L2')
97 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage4_L2')
98 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage4_L2')
99 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage4_L2')
100 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2')
101 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2'))
102 |
103 | (self.feed('Mconv7_stage4_L1',
104 | 'Mconv7_stage4_L2',
105 | 'conv4_4_CPM')
106 | .concat(3, name='concat_stage5')
107 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L1')
108 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L1')
109 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L1')
110 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L1')
111 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L1')
112 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1')
113 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1'))
114 |
115 | (self.feed('concat_stage5')
116 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage5_L2')
117 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage5_L2')
118 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage5_L2')
119 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage5_L2')
120 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage5_L2')
121 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2')
122 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2'))
123 |
124 | (self.feed('Mconv7_stage5_L1',
125 | 'Mconv7_stage5_L2',
126 | 'conv4_4_CPM')
127 | .concat(3, name='concat_stage6')
128 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L1')
129 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L1')
130 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L1')
131 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L1')
132 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L1')
133 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1')
134 | .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1'))
135 |
136 | (self.feed('concat_stage6')
137 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv1_stage6_L2')
138 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv2_stage6_L2')
139 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv3_stage6_L2')
140 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv4_stage6_L2')
141 | .separable_conv(7, 7, round(self.conv_width * 128), 1, name='Mconv5_stage6_L2')
142 | .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2')
143 | .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2'))
144 |
145 | (self.feed('Mconv7_stage6_L2',
146 | 'Mconv7_stage6_L1')
147 | .concat(3, name='concat_stage7'))
148 |
--------------------------------------------------------------------------------
/src/network_mobilenet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import network_base
4 |
5 |
6 | class MobilenetNetwork(network_base.BaseNetwork):
7 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None):
8 | self.conv_width = conv_width
9 | self.conv_width2 = conv_width2 if conv_width2 else conv_width
10 | self.num_refine = 4
11 | network_base.BaseNetwork.__init__(self, inputs, trainable)
12 |
13 | def setup(self):
14 | min_depth = 8
15 | depth = lambda d: max(int(d * self.conv_width), min_depth)
16 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth)
17 |
18 | with tf.variable_scope(None, 'MobilenetV1'):
19 | (self.feed('image')
20 | .convb(3, 3, depth(32), 2, name='Conv2d_0')
21 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1')
22 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2')
23 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3')
24 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4')
25 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5')
26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6')
27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7')
28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8')
29 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_9')
30 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_10')
31 | # .separable_conv(3, 3, depth(512), 1, name='Conv2d_11')
32 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12')
33 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13')
34 | )
35 |
36 | (self.feed('Conv2d_1').max_pool(2, 2, 2, 2, name='Conv2d_1_pool'))
37 | (self.feed('Conv2d_7').upsample(2, name='Conv2d_7_upsample'))
38 |
39 | (self.feed('Conv2d_1_pool', 'Conv2d_3', 'Conv2d_7_upsample')
40 | .concat(3, name='feat_concat'))
41 |
42 | feature_lv = 'feat_concat'
43 | with tf.variable_scope(None, 'Openpose'):
44 | prefix = 'MConv_Stage1'
45 | (self.feed(feature_lv)
46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1')
47 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2')
48 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3')
49 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4')
50 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5'))
51 |
52 | (self.feed(feature_lv)
53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1')
54 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2')
55 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3')
56 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4')
57 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5'))
58 |
59 | for stage_id in range(self.num_refine):
60 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1)
61 | prefix = 'MConv_Stage%d' % (stage_id + 2)
62 | (self.feed(prefix_prev + '_L1_5',
63 | prefix_prev + '_L2_5',
64 | feature_lv)
65 | .concat(3, name=prefix + '_concat')
66 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_1')
67 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_2')
68 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L1_3')
69 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4')
70 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5'))
71 |
72 | (self.feed(prefix + '_concat')
73 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_1')
74 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_2')
75 | .separable_conv(7, 7, depth2(128), 1, name=prefix + '_L2_3')
76 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4')
77 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5'))
78 |
79 | # final result
80 | (self.feed('MConv_Stage%d_L2_5' % self.get_refine_num(),
81 | 'MConv_Stage%d_L1_5' % self.get_refine_num())
82 | .concat(3, name='concat_stage7'))
83 |
84 | def loss_l1_l2(self):
85 | l1s = []
86 | l2s = []
87 | for layer_name in sorted(self.layers.keys()):
88 | if '_L1_5' in layer_name:
89 | l1s.append(self.layers[layer_name])
90 | if '_L2_5' in layer_name:
91 | l2s.append(self.layers[layer_name])
92 |
93 | return l1s, l2s
94 |
95 | def loss_last(self):
96 | return self.get_output('MConv_Stage%d_L1_5' % self.get_refine_num()), \
97 | self.get_output('MConv_Stage%d_L2_5' % self.get_refine_num())
98 |
99 | def restorable_variables(self):
100 | vs = {v.op.name: v for v in tf.global_variables() if
101 | 'MobilenetV1/Conv2d' in v.op.name and
102 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and 'Ada' not in v.op.name
103 | }
104 | return vs
105 |
106 | def get_refine_num(self):
107 | return self.num_refine + 1
108 |
--------------------------------------------------------------------------------
/src/network_mobilenet_thin.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import network_base
4 |
5 |
6 | class MobilenetNetworkThin(network_base.BaseNetwork):
7 | def __init__(self, inputs, trainable=True, conv_width=1.0, conv_width2=None):
8 | self.conv_width = conv_width
9 | self.conv_width2 = conv_width2 if conv_width2 else conv_width
10 | network_base.BaseNetwork.__init__(self, inputs, trainable)
11 |
12 | def setup(self):
13 | min_depth = 8
14 | depth = lambda d: max(int(d * self.conv_width), min_depth)
15 | depth2 = lambda d: max(int(d * self.conv_width2), min_depth)
16 |
17 | with tf.variable_scope(None, 'MobilenetV1'):
18 | (self.feed('image')
19 | .convb(3, 3, depth(32), 2, name='Conv2d_0')
20 | .separable_conv(3, 3, depth(64), 1, name='Conv2d_1')
21 | .separable_conv(3, 3, depth(128), 2, name='Conv2d_2')
22 | .separable_conv(3, 3, depth(128), 1, name='Conv2d_3')
23 | .separable_conv(3, 3, depth(256), 2, name='Conv2d_4')
24 | .separable_conv(3, 3, depth(256), 1, name='Conv2d_5')
25 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_6')
26 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_7')
27 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_8')
28 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_9')
29 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_10')
30 | .separable_conv(3, 3, depth(512), 1, name='Conv2d_11')
31 | # .separable_conv(3, 3, depth(1024), 2, name='Conv2d_12')
32 | # .separable_conv(3, 3, depth(1024), 1, name='Conv2d_13')
33 | )
34 |
35 | (self.feed('Conv2d_3').max_pool(2, 2, 2, 2, name='Conv2d_3_pool'))
36 |
37 | (self.feed('Conv2d_3_pool', 'Conv2d_7', 'Conv2d_11')
38 | .concat(3, name='feat_concat'))
39 |
40 | feature_lv = 'feat_concat'
41 | with tf.variable_scope(None, 'Openpose'):
42 | prefix = 'MConv_Stage1'
43 | (self.feed(feature_lv)
44 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1')
45 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2')
46 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3')
47 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L1_4')
48 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5'))
49 |
50 | (self.feed(feature_lv)
51 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1')
52 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2')
53 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3')
54 | .separable_conv(1, 1, depth2(512), 1, name=prefix + '_L2_4')
55 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5'))
56 |
57 | for stage_id in range(5):
58 | prefix_prev = 'MConv_Stage%d' % (stage_id + 1)
59 | prefix = 'MConv_Stage%d' % (stage_id + 2)
60 | (self.feed(prefix_prev + '_L1_5',
61 | prefix_prev + '_L2_5',
62 | feature_lv)
63 | .concat(3, name=prefix + '_concat')
64 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_1')
65 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_2')
66 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L1_3')
67 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L1_4')
68 | .separable_conv(1, 1, 38, 1, relu=False, name=prefix + '_L1_5'))
69 |
70 | (self.feed(prefix + '_concat')
71 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_1')
72 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_2')
73 | .separable_conv(3, 3, depth2(128), 1, name=prefix + '_L2_3')
74 | .separable_conv(1, 1, depth2(128), 1, name=prefix + '_L2_4')
75 | .separable_conv(1, 1, 19, 1, relu=False, name=prefix + '_L2_5'))
76 |
77 | # final result
78 | (self.feed('MConv_Stage6_L2_5',
79 | 'MConv_Stage6_L1_5')
80 | .concat(3, name='concat_stage7'))
81 |
82 | def loss_l1_l2(self):
83 | l1s = []
84 | l2s = []
85 | for layer_name in sorted(self.layers.keys()):
86 | if '_L1_5' in layer_name:
87 | l1s.append(self.layers[layer_name])
88 | if '_L2_5' in layer_name:
89 | l2s.append(self.layers[layer_name])
90 |
91 | return l1s, l2s
92 |
93 | def loss_last(self):
94 | return self.get_output('MConv_Stage6_L1_5'), self.get_output('MConv_Stage6_L2_5')
95 |
96 | def restorable_variables(self):
97 | vs = {v.op.name: v for v in tf.global_variables() if
98 | 'MobilenetV1/Conv2d' in v.op.name and
99 | # 'global_step' not in v.op.name and
100 | # 'beta1_power' not in v.op.name and 'beta2_power' not in v.op.name and
101 | 'RMSProp' not in v.op.name and 'Momentum' not in v.op.name and
102 | 'Ada' not in v.op.name and 'Adam' not in v.op.name
103 | }
104 | return vs
105 |
--------------------------------------------------------------------------------
/src/networks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | from network_mobilenet import MobilenetNetwork
5 | from network_mobilenet_thin import MobilenetNetworkThin
6 |
7 | from network_cmu import CmuNetwork
8 |
9 |
10 | def _get_base_path():
11 | if not os.environ.get('OPENPOSE_MODEL', ''):
12 | return './models'
13 | return os.environ.get('OPENPOSE_MODEL')
14 |
15 |
16 | def get_network(type, placeholder_input, sess_for_load=None, trainable=True):
17 | if type == 'mobilenet':
18 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable)
19 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
20 | last_layer = 'MConv_Stage6_L{aux}_5'
21 | elif type == 'mobilenet_fast':
22 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.5, conv_width2=0.5, trainable=trainable)
23 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
24 | last_layer = 'MConv_Stage6_L{aux}_5'
25 | elif type == 'mobilenet_accurate':
26 | net = MobilenetNetwork({'image': placeholder_input}, conv_width=1.00, conv_width2=1.00, trainable=trainable)
27 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
28 | last_layer = 'MConv_Stage6_L{aux}_5'
29 |
30 | elif type == 'mobilenet_thin':
31 | net = MobilenetNetworkThin({'image': placeholder_input}, conv_width=0.75, conv_width2=0.50, trainable=trainable)
32 | pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_1.0_224.ckpt'
33 | last_layer = 'MConv_Stage6_L{aux}_5'
34 |
35 | elif type == 'cmu':
36 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable)
37 | pretrain_path = 'numpy/openpose_coco.npy'
38 | last_layer = 'Mconv7_stage6_L{aux}'
39 | elif type == 'vgg':
40 | net = CmuNetwork({'image': placeholder_input}, trainable=trainable)
41 | pretrain_path = 'numpy/openpose_vgg16.npy'
42 | last_layer = 'Mconv7_stage6_L{aux}'
43 | else:
44 | raise Exception('Invalid Mode.')
45 |
46 | if sess_for_load is not None:
47 | if type == 'cmu' or type == 'vgg':
48 | net.load(os.path.join(_get_base_path(), pretrain_path), sess_for_load)
49 | else:
50 | s = '%dx%d' % (placeholder_input.shape[2], placeholder_input.shape[1])
51 | ckpts = {
52 | 'mobilenet': 'trained/mobilenet_%s/model-246038' % s,
53 | 'mobilenet_thin': 'trained/mobilenet_thin_%s/model-449003' % s,
54 | 'mobilenet_fast': 'trained/mobilenet_fast_%s/model-189000' % s,
55 | 'mobilenet_accurate': 'trained/mobilenet_accurate/model-170000'
56 | }
57 | loader = tf.train.Saver()
58 | loader.restore(sess_for_load, os.path.join(_get_base_path(), ckpts[type]))
59 |
60 | return net, os.path.join(_get_base_path(), pretrain_path), last_layer
61 |
62 |
63 | def get_graph_path(model_name):
64 | return {
65 | 'cmu_640x480': './models/graph/cmu_640x480/graph_opt.pb',
66 | 'cmuq_640x480': './models/graph/cmu_640x480/graph_q.pb',
67 |
68 | 'cmu_640x360': './models/graph/cmu_640x360/graph_opt.pb',
69 | 'cmuq_640x360': './models/graph/cmu_640x360/graph_q.pb',
70 |
71 | 'mobilenet_thin_432x368': './models/graph/mobilenet_thin_432x368/graph_opt.pb',
72 | }[model_name]
73 |
74 |
75 | def model_wh(model_name):
76 | width, height = model_name.split('_')[-1].split('x')
77 | return int(width), int(height)
78 |
--------------------------------------------------------------------------------
/src/pose_augment.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import cv2
5 | import numpy as np
6 | from tensorpack.dataflow.imgaug.geometry import RotationAndCropValid
7 |
8 | from common import CocoPart
9 |
10 | _network_w = 368
11 | _network_h = 368
12 | _scale = 2
13 |
14 |
15 | def set_network_input_wh(w, h):
16 | global _network_w, _network_h
17 | _network_w, _network_h = w, h
18 |
19 |
20 | def set_network_scale(scale):
21 | global _scale
22 | _scale = scale
23 |
24 |
25 | def pose_random_scale(meta):
26 | scalew = random.uniform(0.8, 1.2)
27 | scaleh = random.uniform(0.8, 1.2)
28 | neww = int(meta.width * scalew)
29 | newh = int(meta.height * scaleh)
30 | dst = cv2.resize(meta.img, (neww, newh), interpolation=cv2.INTER_AREA)
31 |
32 | # adjust meta data
33 | adjust_joint_list = []
34 | for joint in meta.joint_list:
35 | adjust_joint = []
36 | for point in joint:
37 | if point[0] < -100 or point[1] < -100:
38 | adjust_joint.append((-1000, -1000))
39 | continue
40 | # if point[0] <= 0 or point[1] <= 0 or int(point[0] * scalew + 0.5) > neww or int(
41 | # point[1] * scaleh + 0.5) > newh:
42 | # adjust_joint.append((-1, -1))
43 | # continue
44 | adjust_joint.append((int(point[0] * scalew + 0.5), int(point[1] * scaleh + 0.5)))
45 | adjust_joint_list.append(adjust_joint)
46 |
47 | meta.joint_list = adjust_joint_list
48 | meta.width, meta.height = neww, newh
49 | meta.img = dst
50 | return meta
51 |
52 |
53 | def pose_resize_shortestedge_fixed(meta):
54 | ratio_w = _network_w / meta.width
55 | ratio_h = _network_h / meta.height
56 | ratio = max(ratio_w, ratio_h)
57 | return pose_resize_shortestedge(meta, int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5)))
58 |
59 |
60 | def pose_resize_shortestedge_random(meta):
61 | ratio_w = _network_w / meta.width
62 | ratio_h = _network_h / meta.height
63 | ratio = min(ratio_w, ratio_h)
64 | target_size = int(min(meta.width * ratio + 0.5, meta.height * ratio + 0.5))
65 | target_size = int(target_size * random.uniform(0.95, 1.6))
66 | # target_size = int(min(_network_w, _network_h) * random.uniform(0.7, 1.5))
67 | return pose_resize_shortestedge(meta, target_size)
68 |
69 |
70 | def pose_resize_shortestedge(meta, target_size):
71 | global _network_w, _network_h
72 | img = meta.img
73 |
74 | # adjust image
75 | scale = target_size / min(meta.height, meta.width)
76 | if meta.height < meta.width:
77 | newh, neww = target_size, int(scale * meta.width + 0.5)
78 | else:
79 | newh, neww = int(scale * meta.height + 0.5), target_size
80 |
81 | dst = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
82 |
83 | pw = ph = 0
84 | if neww < _network_w or newh < _network_h:
85 | pw = max(0, (_network_w - neww) // 2)
86 | ph = max(0, (_network_h - newh) // 2)
87 | mw = (_network_w - neww) % 2
88 | mh = (_network_h - newh) % 2
89 | color = random.randint(0, 255)
90 | dst = cv2.copyMakeBorder(dst, ph, ph+mh, pw, pw+mw, cv2.BORDER_CONSTANT, value=(color, 0, 0))
91 |
92 | # adjust meta data
93 | adjust_joint_list = []
94 | for joint in meta.joint_list:
95 | adjust_joint = []
96 | for point in joint:
97 | if point[0] < -100 or point[1] < -100:
98 | adjust_joint.append((-1000, -1000))
99 | continue
100 | # if point[0] <= 0 or point[1] <= 0 or int(point[0]*scale+0.5) > neww or int(point[1]*scale+0.5) > newh:
101 | # adjust_joint.append((-1, -1))
102 | # continue
103 | adjust_joint.append((int(point[0]*scale+0.5) + pw, int(point[1]*scale+0.5) + ph))
104 | adjust_joint_list.append(adjust_joint)
105 |
106 | meta.joint_list = adjust_joint_list
107 | meta.width, meta.height = neww + pw * 2, newh + ph * 2
108 | meta.img = dst
109 | return meta
110 |
111 |
112 | def pose_crop_center(meta):
113 | global _network_w, _network_h
114 | target_size = (_network_w, _network_h)
115 | x = (meta.width - target_size[0]) // 2 if meta.width > target_size[0] else 0
116 | y = (meta.height - target_size[1]) // 2 if meta.height > target_size[1] else 0
117 |
118 | return pose_crop(meta, x, y, target_size[0], target_size[1])
119 |
120 |
121 | def pose_crop_random(meta):
122 | global _network_w, _network_h
123 | target_size = (_network_w, _network_h)
124 |
125 | for _ in range(50):
126 | x = random.randrange(0, meta.width - target_size[0]) if meta.width > target_size[0] else 0
127 | y = random.randrange(0, meta.height - target_size[1]) if meta.height > target_size[1] else 0
128 |
129 | # check whether any face is inside the box to generate a reasonably-balanced datasets
130 | for joint in meta.joint_list:
131 | if x <= joint[CocoPart.Nose.value][0] < x + target_size[0] and y <= joint[CocoPart.Nose.value][1] < y + target_size[1]:
132 | break
133 |
134 | return pose_crop(meta, x, y, target_size[0], target_size[1])
135 |
136 |
137 | def pose_crop(meta, x, y, w, h):
138 | # adjust image
139 | target_size = (w, h)
140 |
141 | img = meta.img
142 | resized = img[y:y+target_size[1], x:x+target_size[0], :]
143 |
144 | # adjust meta data
145 | adjust_joint_list = []
146 | for joint in meta.joint_list:
147 | adjust_joint = []
148 | for point in joint:
149 | if point[0] < -100 or point[1] < -100:
150 | adjust_joint.append((-1000, -1000))
151 | continue
152 | # if point[0] <= 0 or point[1] <= 0:
153 | # adjust_joint.append((-1000, -1000))
154 | # continue
155 | new_x, new_y = point[0] - x, point[1] - y
156 | # if new_x <= 0 or new_y <= 0 or new_x > target_size[0] or new_y > target_size[1]:
157 | # adjust_joint.append((-1, -1))
158 | # continue
159 | adjust_joint.append((new_x, new_y))
160 | adjust_joint_list.append(adjust_joint)
161 |
162 | meta.joint_list = adjust_joint_list
163 | meta.width, meta.height = target_size
164 | meta.img = resized
165 | return meta
166 |
167 |
168 | def pose_flip(meta):
169 | r = random.uniform(0, 1.0)
170 | if r > 0.5:
171 | return meta
172 |
173 | img = meta.img
174 | img = cv2.flip(img, 1)
175 |
176 | # flip meta
177 | flip_list = [CocoPart.Nose, CocoPart.Neck, CocoPart.LShoulder, CocoPart.LElbow, CocoPart.LWrist, CocoPart.RShoulder, CocoPart.RElbow, CocoPart.RWrist,
178 | CocoPart.LHip, CocoPart.LKnee, CocoPart.LAnkle, CocoPart.RHip, CocoPart.RKnee, CocoPart.RAnkle,
179 | CocoPart.LEye, CocoPart.REye, CocoPart.LEar, CocoPart.REar, CocoPart.Background]
180 | adjust_joint_list = []
181 | for joint in meta.joint_list:
182 | adjust_joint = []
183 | for cocopart in flip_list:
184 | point = joint[cocopart.value]
185 | if point[0] < -100 or point[1] < -100:
186 | adjust_joint.append((-1000, -1000))
187 | continue
188 | # if point[0] <= 0 or point[1] <= 0:
189 | # adjust_joint.append((-1, -1))
190 | # continue
191 | adjust_joint.append((meta.width - point[0], point[1]))
192 | adjust_joint_list.append(adjust_joint)
193 |
194 | meta.joint_list = adjust_joint_list
195 |
196 | meta.img = img
197 | return meta
198 |
199 |
200 | def pose_rotation(meta):
201 | deg = random.uniform(-15.0, 15.0)
202 | img = meta.img
203 |
204 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5) # x, y
205 | rot_m = cv2.getRotationMatrix2D((int(center[0]), int(center[1])), deg, 1)
206 | ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], flags=cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT)
207 | if img.ndim == 3 and ret.ndim == 2:
208 | ret = ret[:, :, np.newaxis]
209 | neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
210 | neww = min(neww, ret.shape[1])
211 | newh = min(newh, ret.shape[0])
212 | newx = int(center[0] - neww * 0.5)
213 | newy = int(center[1] - newh * 0.5)
214 | # print(ret.shape, deg, newx, newy, neww, newh)
215 | img = ret[newy:newy + newh, newx:newx + neww]
216 |
217 | # adjust meta data
218 | adjust_joint_list = []
219 | for joint in meta.joint_list:
220 | adjust_joint = []
221 | for point in joint:
222 | if point[0] < -100 or point[1] < -100:
223 | adjust_joint.append((-1000, -1000))
224 | continue
225 | # if point[0] <= 0 or point[1] <= 0:
226 | # adjust_joint.append((-1, -1))
227 | # continue
228 | x, y = _rotate_coord((meta.width, meta.height), (newx, newy), point, deg)
229 | adjust_joint.append((x, y))
230 | adjust_joint_list.append(adjust_joint)
231 |
232 | meta.joint_list = adjust_joint_list
233 | meta.width, meta.height = neww, newh
234 | meta.img = img
235 |
236 | return meta
237 |
238 |
239 | def _rotate_coord(shape, newxy, point, angle):
240 | angle = -1 * angle / 180.0 * math.pi
241 |
242 | ox, oy = shape
243 | px, py = point
244 |
245 | ox /= 2
246 | oy /= 2
247 |
248 | qx = math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
249 | qy = math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
250 |
251 | new_x, new_y = newxy
252 |
253 | qx += ox - new_x
254 | qy += oy - new_y
255 |
256 | return int(qx + 0.5), int(qy + 0.5)
257 |
258 |
259 | def pose_to_img(meta_l):
260 | global _network_w, _network_h, _scale
261 | return [meta_l[0].img.astype(np.float16),
262 | meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)),
263 | meta_l[0].get_vectormap(target_size=(_network_w // _scale, _network_h // _scale))]
264 |
--------------------------------------------------------------------------------
/src/pose_datamaster.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import time
4 |
5 | from tensorpack.dataflow.remote import RemoteDataZMQ
6 |
7 | from pose_dataset import CocoPose
8 |
9 | logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s')
10 |
11 | if __name__ == '__main__':
12 | """
13 | Speed Test for Getting Input batches from other nodes
14 | """
15 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.')
16 | parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027')
17 | parser.add_argument('--show', type=bool, default=False)
18 | args = parser.parse_args()
19 |
20 | df = RemoteDataZMQ(args.listen)
21 |
22 | logging.info('tcp queue start')
23 | df.reset_state()
24 | t = time.time()
25 | for i, dp in enumerate(df.get_data()):
26 | if i == 100:
27 | break
28 | logging.info('Input batch %d received.' % i)
29 | if i == 0:
30 | for d in dp:
31 | logging.info('%d dp shape={}'.format(d.shape))
32 |
33 | if args.show:
34 | CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0])
35 |
36 | logging.info('Speed Test Done for 100 Batches in %f seconds.' % (time.time() - t))
37 |
--------------------------------------------------------------------------------
/src/pose_dataworker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from tensorpack.dataflow.remote import send_dataflow_zmq
4 |
5 | from pose_dataset import get_dataflow_batch
6 | from pose_augment import set_network_input_wh, set_network_scale
7 |
8 | if __name__ == '__main__':
9 | """
10 | OpenPose Data Preparation might be a bottleneck for training.
11 | You can run multiple workers to generate input batches in multi-nodes to make training process faster.
12 | """
13 | parser = argparse.ArgumentParser(description='Worker for preparing input batches.')
14 | parser.add_argument('--datapath', type=str, default='/coco/annotations/')
15 | parser.add_argument('--imgpath', type=str, default='/coco/')
16 | parser.add_argument('--batchsize', type=int, default=64)
17 | parser.add_argument('--train', type=bool, default=True)
18 | parser.add_argument('--master', type=str, default='tcp://csi-cluster-gpu20.dakao.io:1027')
19 | parser.add_argument('--input-width', type=int, default=368)
20 | parser.add_argument('--input-height', type=int, default=368)
21 | parser.add_argument('--scale-factor', type=int, default=2)
22 | args = parser.parse_args()
23 |
24 | set_network_input_wh(args.input_width, args.input_height)
25 | set_network_scale(args.scale_factor)
26 |
27 | df = get_dataflow_batch(args.datapath, args.train, args.batchsize, args.imgpath)
28 |
29 | send_dataflow_zmq(df, args.master, hwm=10)
30 |
--------------------------------------------------------------------------------
/src/pose_stats.py:
--------------------------------------------------------------------------------
1 | from pose_dataset import CocoPose
2 | from tensorpack import imgaug
3 | from tensorpack.dataflow.common import MapDataComponent, MapData
4 | from tensorpack.dataflow.image import AugmentImageComponent
5 |
6 | from pose_augment import *
7 |
8 |
9 | def get_idx_hands_up():
10 | from src.pose_augment import set_network_input_wh
11 | set_network_input_wh(368, 368)
12 |
13 | show_sample = True
14 | db = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train=True, decode_img=show_sample)
15 | db.reset_state()
16 | total_cnt = 0
17 | handup_cnt = 0
18 | for idx, metas in enumerate(db.get_data()):
19 | meta = metas[0]
20 | if len(meta.joint_list) <= 0:
21 | continue
22 | body = meta.joint_list[0]
23 | if body[CocoPart.Neck.value][1] <= 0:
24 | continue
25 | if body[CocoPart.LWrist.value][1] <= 0:
26 | continue
27 | if body[CocoPart.RWrist.value][1] <= 0:
28 | continue
29 |
30 | if body[CocoPart.Neck.value][1] > body[CocoPart.LWrist.value][1] or body[CocoPart.Neck.value][1] > body[CocoPart.RWrist.value][1]:
31 | print(meta.idx)
32 | handup_cnt += 1
33 |
34 | if show_sample:
35 | l1, l2, l3 = pose_to_img(metas)
36 | CocoPose.display_image(l1, l2, l3)
37 |
38 | total_cnt += 1
39 |
40 | print('%d / %d' % (handup_cnt, total_cnt))
41 |
42 |
43 | def sample_augmentations():
44 | ds = CocoPose('/data/public/rw/coco-pose-estimation-lmdb/', is_train=False, only_idx=0)
45 | ds = MapDataComponent(ds, pose_random_scale)
46 | ds = MapDataComponent(ds, pose_rotation)
47 | ds = MapDataComponent(ds, pose_flip)
48 | ds = MapDataComponent(ds, pose_resize_shortestedge_random)
49 | ds = MapDataComponent(ds, pose_crop_random)
50 | ds = MapData(ds, pose_to_img)
51 | augs = [
52 | imgaug.RandomApplyAug(imgaug.RandomChooseAug([
53 | imgaug.GaussianBlur(3),
54 | imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
55 | imgaug.RandomOrderAug([
56 | imgaug.BrightnessScale((0.8, 1.2), clip=False),
57 | imgaug.Contrast((0.8, 1.2), clip=False),
58 | # imgaug.Saturation(0.4, rgb=True),
59 | ]),
60 | ]), 0.7),
61 | ]
62 | ds = AugmentImageComponent(ds, augs)
63 |
64 | ds.reset_state()
65 | for l1, l2, l3 in ds.get_data():
66 | CocoPose.display_image(l1, l2, l3)
67 |
68 |
69 | if __name__ == '__main__':
70 | # codes for tests
71 | # get_idx_hands_up()
72 |
73 | # show augmentation samples
74 | sample_augmentations()
75 |
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import time
4 | import ast
5 |
6 | import common
7 | import cv2
8 | import numpy as np
9 | from estimator import TfPoseEstimator
10 | from networks import get_graph_path, model_wh
11 |
12 | from lifting.prob_model import Prob3dPose
13 | from lifting.draw import plot_pose
14 |
15 | logger = logging.getLogger('TfPoseEstimator')
16 | logger.setLevel(logging.DEBUG)
17 | ch = logging.StreamHandler()
18 | ch.setLevel(logging.DEBUG)
19 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
20 | ch.setFormatter(formatter)
21 | logger.addHandler(ch)
22 |
23 |
24 | if __name__ == '__main__':
25 | parser = argparse.ArgumentParser(description='tf-pose-estimation run')
26 | # parser.add_argument('--image', type=str, default='/Users/ildoonet/Downloads/me.jpg')
27 | parser.add_argument('--image', type=str, default='./images/apink2.jpg')
28 | # parser.add_argument('--model', type=str, default='mobilenet_320x240', help='cmu / mobilenet_320x240')
29 | parser.add_argument('--model', type=str, default='mobilenet_thin_432x368', help='cmu_640x480 / cmu_640x360 / mobilenet_thin_432x368')
30 | parser.add_argument('--scales', type=str, default='[None]', help='for multiple scales, eg. [1.0, (1.1, 0.05)]')
31 | args = parser.parse_args()
32 | scales = ast.literal_eval(scales)
33 |
34 | w, h = model_wh(args.model)
35 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h))
36 |
37 | # estimate human poses from a single image !
38 | image = common.read_imgfile(args.image, None, None)
39 | # image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
40 | t = time.time()
41 | humans = e.inference(image, scales=[None])
42 | # humans = e.inference(image, scales=[None, (0.7, 0.5, 1.8)])
43 | # humans = e.inference(image, scales=[(1.2, 0.05)])
44 | # humans = e.inference(image, scales=[(0.2, 0.2, 1.4)])
45 | elapsed = time.time() - t
46 |
47 | logger.info('inference image: %s in %.4f seconds.' % (args.image, elapsed))
48 |
49 | image = cv2.imread(args.image, cv2.IMREAD_COLOR)
50 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False)
51 | cv2.imshow('tf-pose-estimation result', image)
52 | cv2.waitKey()
53 |
54 | import sys
55 | sys.exit(0)
56 |
57 | logger.info('3d lifting initialization.')
58 | poseLifting = Prob3dPose('./src/lifting/models/prob_model_params.mat')
59 |
60 | image_h, image_w = image.shape[:2]
61 | standard_w = 640
62 | standard_h = 480
63 |
64 | pose_2d_mpiis = []
65 | visibilities = []
66 | for human in humans:
67 | pose_2d_mpii, visibility = common.MPIIPart.from_coco(human)
68 | pose_2d_mpiis.append([(int(x * standard_w + 0.5), int(y * standard_h + 0.5)) for x, y in pose_2d_mpii])
69 | visibilities.append(visibility)
70 |
71 | pose_2d_mpiis = np.array(pose_2d_mpiis)
72 | visibilities = np.array(visibilities)
73 | transformed_pose2d, weights = poseLifting.transform_joints(pose_2d_mpiis, visibilities)
74 | pose_3d = poseLifting.compute_3d(transformed_pose2d, weights)
75 |
76 | import matplotlib.pyplot as plt
77 |
78 | fig = plt.figure()
79 | a = fig.add_subplot(2, 2, 1)
80 | a.set_title('Result')
81 | plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
82 |
83 | # show network output
84 | a = fig.add_subplot(2, 2, 2)
85 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5)
86 | tmp = np.amax(e.heatMat, axis=2)
87 | plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5)
88 | plt.colorbar()
89 |
90 | tmp2 = e.pafMat.transpose((2, 0, 1))
91 | tmp2_odd = np.amax(np.absolute(tmp2[::2, :, :]), axis=0)
92 | tmp2_even = np.amax(np.absolute(tmp2[1::2, :, :]), axis=0)
93 |
94 | a = fig.add_subplot(2, 2, 3)
95 | a.set_title('Vectormap-x')
96 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
97 | plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5)
98 | plt.colorbar()
99 |
100 | a = fig.add_subplot(2, 2, 4)
101 | a.set_title('Vectormap-y')
102 | # plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
103 | plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5)
104 | plt.colorbar()
105 |
106 | for i, single_3d in enumerate(pose_3d):
107 | plot_pose(single_3d)
108 | plt.show()
109 |
110 | pass
111 |
--------------------------------------------------------------------------------
/src/run_checkpoint.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import tensorflow as tf
5 |
6 | from networks import get_network
7 |
8 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
9 |
10 | config = tf.ConfigProto()
11 | config.gpu_options.allocator_type = 'BFC'
12 | config.gpu_options.per_process_gpu_memory_fraction = 0.95
13 | config.gpu_options.allow_growth = True
14 |
15 |
16 | if __name__ == '__main__':
17 | """
18 | Use this script to just save graph and checkpoint.
19 | While training, checkpoints are saved. You can test them with this python code.
20 | """
21 | parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor')
22 | parser.add_argument('--model', type=str, default='cmu', help='cmu / mobilenet / mobilenet_thin')
23 | args = parser.parse_args()
24 |
25 | input_node = tf.placeholder(tf.float32, shape=(1, 368, 432, 3), name='image')
26 |
27 | with tf.Session(config=config) as sess:
28 | net, _, last_layer = get_network(args.model, input_node, sess, trainable=False)
29 |
30 | tf.train.write_graph(sess.graph_def, './tmp', 'graph.pb', as_text=True)
31 |
32 | graph = tf.get_default_graph()
33 | dir(graph)
34 | for n in tf.get_default_graph().as_graph_def().node:
35 | if 'concat_stage' not in n.name:
36 | continue
37 | print(n.name)
38 |
39 | saver = tf.train.Saver(max_to_keep=100)
40 | saver.save(sess, '/Users/ildoonet/repos/tf-openpose/tmp/chk', global_step=1)
41 |
--------------------------------------------------------------------------------
/src/run_webcam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import time
4 |
5 | import cv2
6 | import numpy as np
7 |
8 | from estimator import TfPoseEstimator
9 | from networks import get_graph_path, model_wh
10 |
11 | logger = logging.getLogger('TfPoseEstimator-WebCam')
12 | logger.setLevel(logging.DEBUG)
13 | ch = logging.StreamHandler()
14 | ch.setLevel(logging.DEBUG)
15 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
16 | ch.setFormatter(formatter)
17 | logger.addHandler(ch)
18 |
19 | fps_time = 0
20 |
21 |
22 | if __name__ == '__main__':
23 | parser = argparse.ArgumentParser(description='tf-pose-estimation realtime webcam')
24 | parser.add_argument('--camera', type=int, default=0)
25 | parser.add_argument('--zoom', type=float, default=1.0)
26 | parser.add_argument('--model', type=str, default='mobilenet_thin_432x368', help='cmu_640x480 / cmu_640x360 / mobilenet_thin_432x368')
27 | parser.add_argument('--show-process', type=bool, default=False,
28 | help='for debug purpose, if enabled, speed for inference is dropped.')
29 | args = parser.parse_args()
30 |
31 | logger.debug('initialization %s : %s' % (args.model, get_graph_path(args.model)))
32 | w, h = model_wh(args.model)
33 | e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h))
34 | logger.debug('cam read+')
35 | cam = cv2.VideoCapture(args.camera)
36 | ret_val, image = cam.read()
37 | logger.info('cam image=%dx%d' % (image.shape[1], image.shape[0]))
38 |
39 | while True:
40 | ret_val, image = cam.read()
41 |
42 | logger.debug('image preprocess+')
43 | if args.zoom < 1.0:
44 | canvas = np.zeros_like(image)
45 | img_scaled = cv2.resize(image, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR)
46 | dx = (canvas.shape[1] - img_scaled.shape[1]) // 2
47 | dy = (canvas.shape[0] - img_scaled.shape[0]) // 2
48 | canvas[dy:dy + img_scaled.shape[0], dx:dx + img_scaled.shape[1]] = img_scaled
49 | image = canvas
50 | elif args.zoom > 1.0:
51 | img_scaled = cv2.resize(image, None, fx=args.zoom, fy=args.zoom, interpolation=cv2.INTER_LINEAR)
52 | dx = (img_scaled.shape[1] - image.shape[1]) // 2
53 | dy = (img_scaled.shape[0] - image.shape[0]) // 2
54 | image = img_scaled[dy:image.shape[0], dx:image.shape[1]]
55 |
56 | logger.debug('image process+')
57 | humans = e.inference(image)
58 |
59 | logger.debug('postprocess+')
60 | image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False)
61 |
62 | logger.debug('show+')
63 | cv2.putText(image,
64 | "FPS: %f" % (1.0 / (time.time() - fps_time)),
65 | (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
66 | (0, 255, 0), 2)
67 | cv2.imshow('tf-pose-estimation result', image)
68 | fps_time = time.time()
69 | if cv2.waitKey(1) == 27:
70 | break
71 | logger.debug('finished+')
72 |
73 | cv2.destroyAllWindows()
74 |
--------------------------------------------------------------------------------
/src/slim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/slim/__init__.py
--------------------------------------------------------------------------------
/src/slim/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/slim/nets/alexnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a model definition for AlexNet.
16 |
17 | This work was first described in:
18 | ImageNet Classification with Deep Convolutional Neural Networks
19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
20 |
21 | and later refined in:
22 | One weird trick for parallelizing convolutional neural networks
23 | Alex Krizhevsky, 2014
24 |
25 | Here we provide the implementation proposed in "One weird trick" and not
26 | "ImageNet Classification", as per the paper, the LRN layers have been removed.
27 |
28 | Usage:
29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
30 | outputs, end_points = alexnet.alexnet_v2(inputs)
31 |
32 | @@alexnet_v2
33 | """
34 |
35 | from __future__ import absolute_import
36 | from __future__ import division
37 | from __future__ import print_function
38 |
39 | import tensorflow as tf
40 |
41 | slim = tf.contrib.slim
42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
43 |
44 |
45 | def alexnet_v2_arg_scope(weight_decay=0.0005):
46 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
47 | activation_fn=tf.nn.relu,
48 | biases_initializer=tf.constant_initializer(0.1),
49 | weights_regularizer=slim.l2_regularizer(weight_decay)):
50 | with slim.arg_scope([slim.conv2d], padding='SAME'):
51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
52 | return arg_sc
53 |
54 |
55 | def alexnet_v2(inputs,
56 | num_classes=1000,
57 | is_training=True,
58 | dropout_keep_prob=0.5,
59 | spatial_squeeze=True,
60 | scope='alexnet_v2',
61 | global_pool=False):
62 | """AlexNet version 2.
63 |
64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf
65 | Parameters from:
66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
67 | layers-imagenet-1gpu.cfg
68 |
69 | Note: All the fully_connected layers have been transformed to conv2d layers.
70 | To use in classification mode, resize input to 224x224 or set
71 | global_pool=True. To use in fully convolutional mode, set
72 | spatial_squeeze to false.
73 | The LRN layers have been removed and change the initializers from
74 | random_normal_initializer to xavier_initializer.
75 |
76 | Args:
77 | inputs: a tensor of size [batch_size, height, width, channels].
78 | num_classes: the number of predicted classes. If 0 or None, the logits layer
79 | is omitted and the input features to the logits layer are returned instead.
80 | is_training: whether or not the model is being trained.
81 | dropout_keep_prob: the probability that activations are kept in the dropout
82 | layers during training.
83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
84 | logits. Useful to remove unnecessary dimensions for classification.
85 | scope: Optional scope for the variables.
86 | global_pool: Optional boolean flag. If True, the input to the classification
87 | layer is avgpooled to size 1x1, for any input size. (This is not part
88 | of the original AlexNet.)
89 |
90 | Returns:
91 | net: the output of the logits layer (if num_classes is a non-zero integer),
92 | or the non-dropped-out input to the logits layer (if num_classes is 0
93 | or None).
94 | end_points: a dict of tensors with intermediate activations.
95 | """
96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
97 | end_points_collection = sc.original_name_scope + '_end_points'
98 | # Collect outputs for conv2d, fully_connected and max_pool2d.
99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
100 | outputs_collections=[end_points_collection]):
101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
102 | scope='conv1')
103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2')
105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3')
107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4')
108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5')
109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
110 |
111 | # Use conv2d instead of fully_connected layers.
112 | with slim.arg_scope([slim.conv2d],
113 | weights_initializer=trunc_normal(0.005),
114 | biases_initializer=tf.constant_initializer(0.1)):
115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
116 | scope='fc6')
117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
118 | scope='dropout6')
119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
120 | # Convert end_points_collection into a end_point dict.
121 | end_points = slim.utils.convert_collection_to_dict(
122 | end_points_collection)
123 | if global_pool:
124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
125 | end_points['global_pool'] = net
126 | if num_classes:
127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
128 | scope='dropout7')
129 | net = slim.conv2d(net, num_classes, [1, 1],
130 | activation_fn=None,
131 | normalizer_fn=None,
132 | biases_initializer=tf.zeros_initializer(),
133 | scope='fc8')
134 | if spatial_squeeze:
135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
136 | end_points[sc.name + '/fc8'] = net
137 | return net, end_points
138 | alexnet_v2.default_image_size = 224
139 |
--------------------------------------------------------------------------------
/src/slim/nets/alexnet_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.nets.alexnet."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from nets import alexnet
23 |
24 | slim = tf.contrib.slim
25 |
26 |
27 | class AlexnetV2Test(tf.test.TestCase):
28 |
29 | def testBuild(self):
30 | batch_size = 5
31 | height, width = 224, 224
32 | num_classes = 1000
33 | with self.test_session():
34 | inputs = tf.random_uniform((batch_size, height, width, 3))
35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes)
36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
37 | self.assertListEqual(logits.get_shape().as_list(),
38 | [batch_size, num_classes])
39 |
40 | def testFullyConvolutional(self):
41 | batch_size = 1
42 | height, width = 300, 400
43 | num_classes = 1000
44 | with self.test_session():
45 | inputs = tf.random_uniform((batch_size, height, width, 3))
46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
48 | self.assertListEqual(logits.get_shape().as_list(),
49 | [batch_size, 4, 7, num_classes])
50 |
51 | def testGlobalPool(self):
52 | batch_size = 1
53 | height, width = 300, 400
54 | num_classes = 1000
55 | with self.test_session():
56 | inputs = tf.random_uniform((batch_size, height, width, 3))
57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False,
58 | global_pool=True)
59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
60 | self.assertListEqual(logits.get_shape().as_list(),
61 | [batch_size, 1, 1, num_classes])
62 |
63 | def testEndPoints(self):
64 | batch_size = 5
65 | height, width = 224, 224
66 | num_classes = 1000
67 | with self.test_session():
68 | inputs = tf.random_uniform((batch_size, height, width, 3))
69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes)
70 | expected_names = ['alexnet_v2/conv1',
71 | 'alexnet_v2/pool1',
72 | 'alexnet_v2/conv2',
73 | 'alexnet_v2/pool2',
74 | 'alexnet_v2/conv3',
75 | 'alexnet_v2/conv4',
76 | 'alexnet_v2/conv5',
77 | 'alexnet_v2/pool5',
78 | 'alexnet_v2/fc6',
79 | 'alexnet_v2/fc7',
80 | 'alexnet_v2/fc8'
81 | ]
82 | self.assertSetEqual(set(end_points.keys()), set(expected_names))
83 |
84 | def testNoClasses(self):
85 | batch_size = 5
86 | height, width = 224, 224
87 | num_classes = None
88 | with self.test_session():
89 | inputs = tf.random_uniform((batch_size, height, width, 3))
90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes)
91 | expected_names = ['alexnet_v2/conv1',
92 | 'alexnet_v2/pool1',
93 | 'alexnet_v2/conv2',
94 | 'alexnet_v2/pool2',
95 | 'alexnet_v2/conv3',
96 | 'alexnet_v2/conv4',
97 | 'alexnet_v2/conv5',
98 | 'alexnet_v2/pool5',
99 | 'alexnet_v2/fc6',
100 | 'alexnet_v2/fc7'
101 | ]
102 | self.assertSetEqual(set(end_points.keys()), set(expected_names))
103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7'))
104 | self.assertListEqual(net.get_shape().as_list(),
105 | [batch_size, 1, 1, 4096])
106 |
107 | def testModelVariables(self):
108 | batch_size = 5
109 | height, width = 224, 224
110 | num_classes = 1000
111 | with self.test_session():
112 | inputs = tf.random_uniform((batch_size, height, width, 3))
113 | alexnet.alexnet_v2(inputs, num_classes)
114 | expected_names = ['alexnet_v2/conv1/weights',
115 | 'alexnet_v2/conv1/biases',
116 | 'alexnet_v2/conv2/weights',
117 | 'alexnet_v2/conv2/biases',
118 | 'alexnet_v2/conv3/weights',
119 | 'alexnet_v2/conv3/biases',
120 | 'alexnet_v2/conv4/weights',
121 | 'alexnet_v2/conv4/biases',
122 | 'alexnet_v2/conv5/weights',
123 | 'alexnet_v2/conv5/biases',
124 | 'alexnet_v2/fc6/weights',
125 | 'alexnet_v2/fc6/biases',
126 | 'alexnet_v2/fc7/weights',
127 | 'alexnet_v2/fc7/biases',
128 | 'alexnet_v2/fc8/weights',
129 | 'alexnet_v2/fc8/biases',
130 | ]
131 | model_variables = [v.op.name for v in slim.get_model_variables()]
132 | self.assertSetEqual(set(model_variables), set(expected_names))
133 |
134 | def testEvaluation(self):
135 | batch_size = 2
136 | height, width = 224, 224
137 | num_classes = 1000
138 | with self.test_session():
139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
141 | self.assertListEqual(logits.get_shape().as_list(),
142 | [batch_size, num_classes])
143 | predictions = tf.argmax(logits, 1)
144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
145 |
146 | def testTrainEvalWithReuse(self):
147 | train_batch_size = 2
148 | eval_batch_size = 1
149 | train_height, train_width = 224, 224
150 | eval_height, eval_width = 300, 400
151 | num_classes = 1000
152 | with self.test_session():
153 | train_inputs = tf.random_uniform(
154 | (train_batch_size, train_height, train_width, 3))
155 | logits, _ = alexnet.alexnet_v2(train_inputs)
156 | self.assertListEqual(logits.get_shape().as_list(),
157 | [train_batch_size, num_classes])
158 | tf.get_variable_scope().reuse_variables()
159 | eval_inputs = tf.random_uniform(
160 | (eval_batch_size, eval_height, eval_width, 3))
161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False,
162 | spatial_squeeze=False)
163 | self.assertListEqual(logits.get_shape().as_list(),
164 | [eval_batch_size, 4, 7, num_classes])
165 | logits = tf.reduce_mean(logits, [1, 2])
166 | predictions = tf.argmax(logits, 1)
167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
168 |
169 | def testForward(self):
170 | batch_size = 1
171 | height, width = 224, 224
172 | with self.test_session() as sess:
173 | inputs = tf.random_uniform((batch_size, height, width, 3))
174 | logits, _ = alexnet.alexnet_v2(inputs)
175 | sess.run(tf.global_variables_initializer())
176 | output = sess.run(logits)
177 | self.assertTrue(output.any())
178 |
179 | if __name__ == '__main__':
180 | tf.test.main()
181 |
--------------------------------------------------------------------------------
/src/slim/nets/cifarnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a variant of the CIFAR-10 model definition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
26 |
27 |
28 | def cifarnet(images, num_classes=10, is_training=False,
29 | dropout_keep_prob=0.5,
30 | prediction_fn=slim.softmax,
31 | scope='CifarNet'):
32 | """Creates a variant of the CifarNet model.
33 |
34 | Note that since the output is a set of 'logits', the values fall in the
35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a
36 | probability distribution over the characters, one will need to convert them
37 | using the softmax function:
38 |
39 | logits = cifarnet.cifarnet(images, is_training=False)
40 | probabilities = tf.nn.softmax(logits)
41 | predictions = tf.argmax(logits, 1)
42 |
43 | Args:
44 | images: A batch of `Tensors` of size [batch_size, height, width, channels].
45 | num_classes: the number of classes in the dataset. If 0 or None, the logits
46 | layer is omitted and the input features to the logits layer are returned
47 | instead.
48 | is_training: specifies whether or not we're currently training the model.
49 | This variable will determine the behaviour of the dropout layer.
50 | dropout_keep_prob: the percentage of activation values that are retained.
51 | prediction_fn: a function to get predictions out of logits.
52 | scope: Optional variable_scope.
53 |
54 | Returns:
55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
56 | is a non-zero integer, or the input to the logits layer if num_classes
57 | is 0 or None.
58 | end_points: a dictionary from components of the network to the corresponding
59 | activation.
60 | """
61 | end_points = {}
62 |
63 | with tf.variable_scope(scope, 'CifarNet', [images]):
64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1')
65 | end_points['conv1'] = net
66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
67 | end_points['pool1'] = net
68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2')
70 | end_points['conv2'] = net
71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
73 | end_points['pool2'] = net
74 | net = slim.flatten(net)
75 | end_points['Flatten'] = net
76 | net = slim.fully_connected(net, 384, scope='fc3')
77 | end_points['fc3'] = net
78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
79 | scope='dropout3')
80 | net = slim.fully_connected(net, 192, scope='fc4')
81 | end_points['fc4'] = net
82 | if not num_classes:
83 | return net, end_points
84 | logits = slim.fully_connected(net, num_classes,
85 | biases_initializer=tf.zeros_initializer(),
86 | weights_initializer=trunc_normal(1/192.0),
87 | weights_regularizer=None,
88 | activation_fn=None,
89 | scope='logits')
90 |
91 | end_points['Logits'] = logits
92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
93 |
94 | return logits, end_points
95 | cifarnet.default_image_size = 32
96 |
97 |
98 | def cifarnet_arg_scope(weight_decay=0.004):
99 | """Defines the default cifarnet argument scope.
100 |
101 | Args:
102 | weight_decay: The weight decay to use for regularizing the model.
103 |
104 | Returns:
105 | An `arg_scope` to use for the inception v3 model.
106 | """
107 | with slim.arg_scope(
108 | [slim.conv2d],
109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
110 | activation_fn=tf.nn.relu):
111 | with slim.arg_scope(
112 | [slim.fully_connected],
113 | biases_initializer=tf.constant_initializer(0.1),
114 | weights_initializer=trunc_normal(0.04),
115 | weights_regularizer=slim.l2_regularizer(weight_decay),
116 | activation_fn=tf.nn.relu) as sc:
117 | return sc
118 |
--------------------------------------------------------------------------------
/src/slim/nets/cyclegan_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for tensorflow.contrib.slim.nets.cyclegan."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from nets import cyclegan
24 |
25 |
26 | # TODO(joelshor): Add a test to check generator endpoints.
27 | class CycleganTest(tf.test.TestCase):
28 |
29 | def test_generator_inference(self):
30 | """Check one inference step."""
31 | img_batch = tf.zeros([2, 32, 32, 3])
32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch)
33 | with self.test_session() as sess:
34 | sess.run(tf.global_variables_initializer())
35 | sess.run(model_output)
36 |
37 | def _test_generator_graph_helper(self, shape):
38 | """Check that generator can take small and non-square inputs."""
39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape))
40 | self.assertAllEqual(shape, output_imgs.shape.as_list())
41 |
42 | def test_generator_graph_small(self):
43 | self._test_generator_graph_helper([4, 32, 32, 3])
44 |
45 | def test_generator_graph_medium(self):
46 | self._test_generator_graph_helper([3, 128, 128, 3])
47 |
48 | def test_generator_graph_nonsquare(self):
49 | self._test_generator_graph_helper([2, 80, 400, 3])
50 |
51 | def test_generator_unknown_batch_dim(self):
52 | """Check that generator can take unknown batch dimension inputs."""
53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3])
54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img)
55 |
56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list())
57 |
58 | def _input_and_output_same_shape_helper(self, kernel_size):
59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet(
61 | img_batch, kernel_size=kernel_size)
62 |
63 | self.assertAllEqual(img_batch.shape.as_list(),
64 | output_img_batch.shape.as_list())
65 |
66 | def input_and_output_same_shape_kernel3(self):
67 | self._input_and_output_same_shape_helper(3)
68 |
69 | def input_and_output_same_shape_kernel4(self):
70 | self._input_and_output_same_shape_helper(4)
71 |
72 | def input_and_output_same_shape_kernel5(self):
73 | self._input_and_output_same_shape_helper(5)
74 |
75 | def input_and_output_same_shape_kernel6(self):
76 | self._input_and_output_same_shape_helper(6)
77 |
78 | def _error_if_height_not_multiple_of_four_helper(self, height):
79 | self.assertRaisesRegexp(
80 | ValueError,
81 | 'The input height must be a multiple of 4.',
82 | cyclegan.cyclegan_generator_resnet,
83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3]))
84 |
85 | def test_error_if_height_not_multiple_of_four_height29(self):
86 | self._error_if_height_not_multiple_of_four_helper(29)
87 |
88 | def test_error_if_height_not_multiple_of_four_height30(self):
89 | self._error_if_height_not_multiple_of_four_helper(30)
90 |
91 | def test_error_if_height_not_multiple_of_four_height31(self):
92 | self._error_if_height_not_multiple_of_four_helper(31)
93 |
94 | def _error_if_width_not_multiple_of_four_helper(self, width):
95 | self.assertRaisesRegexp(
96 | ValueError,
97 | 'The input width must be a multiple of 4.',
98 | cyclegan.cyclegan_generator_resnet,
99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3]))
100 |
101 | def test_error_if_width_not_multiple_of_four_width29(self):
102 | self._error_if_width_not_multiple_of_four_helper(29)
103 |
104 | def test_error_if_width_not_multiple_of_four_width30(self):
105 | self._error_if_width_not_multiple_of_four_helper(30)
106 |
107 | def test_error_if_width_not_multiple_of_four_width31(self):
108 | self._error_if_width_not_multiple_of_four_helper(31)
109 |
110 |
111 | if __name__ == '__main__':
112 | tf.test.main()
113 |
--------------------------------------------------------------------------------
/src/slim/nets/dcgan.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | from math import log
21 |
22 | import tensorflow as tf
23 | slim = tf.contrib.slim
24 |
25 |
26 | def _validate_image_inputs(inputs):
27 | inputs.get_shape().assert_has_rank(4)
28 | inputs.get_shape()[1:3].assert_is_fully_defined()
29 | if inputs.get_shape()[1] != inputs.get_shape()[2]:
30 | raise ValueError('Input tensor does not have equal width and height: ',
31 | inputs.get_shape()[1:3])
32 | width = inputs.get_shape().as_list()[1]
33 | if log(width, 2) != int(log(width, 2)):
34 | raise ValueError('Input tensor `width` is not a power of 2: ', width)
35 |
36 |
37 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN
38 | # setups need the gradient of gradient FusedBatchNormGrad.
39 | def discriminator(inputs,
40 | depth=64,
41 | is_training=True,
42 | reuse=None,
43 | scope='Discriminator',
44 | fused_batch_norm=False):
45 | """Discriminator network for DCGAN.
46 |
47 | Construct discriminator network from inputs to the final endpoint.
48 |
49 | Args:
50 | inputs: A tensor of size [batch_size, height, width, channels]. Must be
51 | floating point.
52 | depth: Number of channels in first convolution layer.
53 | is_training: Whether the network is for training or not.
54 | reuse: Whether or not the network variables should be reused. `scope`
55 | must be given to be reused.
56 | scope: Optional variable_scope.
57 | fused_batch_norm: If `True`, use a faster, fused implementation of
58 | batch norm.
59 |
60 | Returns:
61 | logits: The pre-softmax activations, a tensor of size [batch_size, 1]
62 | end_points: a dictionary from components of the network to their activation.
63 |
64 | Raises:
65 | ValueError: If the input image shape is not 4-dimensional, if the spatial
66 | dimensions aren't defined at graph construction time, if the spatial
67 | dimensions aren't square, or if the spatial dimensions aren't a power of
68 | two.
69 | """
70 |
71 | normalizer_fn = slim.batch_norm
72 | normalizer_fn_args = {
73 | 'is_training': is_training,
74 | 'zero_debias_moving_mean': True,
75 | 'fused': fused_batch_norm,
76 | }
77 |
78 | _validate_image_inputs(inputs)
79 | inp_shape = inputs.get_shape().as_list()[1]
80 |
81 | end_points = {}
82 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope:
83 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args):
84 | with slim.arg_scope([slim.conv2d],
85 | stride=2,
86 | kernel_size=4,
87 | activation_fn=tf.nn.leaky_relu):
88 | net = inputs
89 | for i in xrange(int(log(inp_shape, 2))):
90 | scope = 'conv%i' % (i + 1)
91 | current_depth = depth * 2**i
92 | normalizer_fn_ = None if i == 0 else normalizer_fn
93 | net = slim.conv2d(
94 | net, current_depth, normalizer_fn=normalizer_fn_, scope=scope)
95 | end_points[scope] = net
96 |
97 | logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID',
98 | normalizer_fn=None, activation_fn=None)
99 | logits = tf.reshape(logits, [-1, 1])
100 | end_points['logits'] = logits
101 |
102 | return logits, end_points
103 |
104 |
105 | # TODO(joelshor): Use fused batch norm by default. Investigate why some GAN
106 | # setups need the gradient of gradient FusedBatchNormGrad.
107 | def generator(inputs,
108 | depth=64,
109 | final_size=32,
110 | num_outputs=3,
111 | is_training=True,
112 | reuse=None,
113 | scope='Generator',
114 | fused_batch_norm=False):
115 | """Generator network for DCGAN.
116 |
117 | Construct generator network from inputs to the final endpoint.
118 |
119 | Args:
120 | inputs: A tensor with any size N. [batch_size, N]
121 | depth: Number of channels in last deconvolution layer.
122 | final_size: The shape of the final output.
123 | num_outputs: Number of output features. For images, this is the number of
124 | channels.
125 | is_training: whether is training or not.
126 | reuse: Whether or not the network has its variables should be reused. scope
127 | must be given to be reused.
128 | scope: Optional variable_scope.
129 | fused_batch_norm: If `True`, use a faster, fused implementation of
130 | batch norm.
131 |
132 | Returns:
133 | logits: the pre-softmax activations, a tensor of size
134 | [batch_size, 32, 32, channels]
135 | end_points: a dictionary from components of the network to their activation.
136 |
137 | Raises:
138 | ValueError: If `inputs` is not 2-dimensional.
139 | ValueError: If `final_size` isn't a power of 2 or is less than 8.
140 | """
141 | normalizer_fn = slim.batch_norm
142 | normalizer_fn_args = {
143 | 'is_training': is_training,
144 | 'zero_debias_moving_mean': True,
145 | 'fused': fused_batch_norm,
146 | }
147 |
148 | inputs.get_shape().assert_has_rank(2)
149 | if log(final_size, 2) != int(log(final_size, 2)):
150 | raise ValueError('`final_size` (%i) must be a power of 2.' % final_size)
151 | if final_size < 8:
152 | raise ValueError('`final_size` (%i) must be greater than 8.' % final_size)
153 |
154 | end_points = {}
155 | num_layers = int(log(final_size, 2)) - 1
156 | with tf.variable_scope(scope, values=[inputs], reuse=reuse) as scope:
157 | with slim.arg_scope([normalizer_fn], **normalizer_fn_args):
158 | with slim.arg_scope([slim.conv2d_transpose],
159 | normalizer_fn=normalizer_fn,
160 | stride=2,
161 | kernel_size=4):
162 | net = tf.expand_dims(tf.expand_dims(inputs, 1), 1)
163 |
164 | # First upscaling is different because it takes the input vector.
165 | current_depth = depth * 2 ** (num_layers - 1)
166 | scope = 'deconv1'
167 | net = slim.conv2d_transpose(
168 | net, current_depth, stride=1, padding='VALID', scope=scope)
169 | end_points[scope] = net
170 |
171 | for i in xrange(2, num_layers):
172 | scope = 'deconv%i' % (i)
173 | current_depth = depth * 2 ** (num_layers - i)
174 | net = slim.conv2d_transpose(net, current_depth, scope=scope)
175 | end_points[scope] = net
176 |
177 | # Last layer has different normalizer and activation.
178 | scope = 'deconv%i' % (num_layers)
179 | net = slim.conv2d_transpose(
180 | net, depth, normalizer_fn=None, activation_fn=None, scope=scope)
181 | end_points[scope] = net
182 |
183 | # Convert to proper channels.
184 | scope = 'logits'
185 | logits = slim.conv2d(
186 | net,
187 | num_outputs,
188 | normalizer_fn=None,
189 | activation_fn=None,
190 | kernel_size=1,
191 | stride=1,
192 | padding='VALID',
193 | scope=scope)
194 | end_points[scope] = logits
195 |
196 | logits.get_shape().assert_has_rank(4)
197 | logits.get_shape().assert_is_compatible_with(
198 | [None, final_size, final_size, num_outputs])
199 |
200 | return logits, end_points
201 |
--------------------------------------------------------------------------------
/src/slim/nets/dcgan_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for dcgan."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 | from nets import dcgan
23 |
24 |
25 | class DCGANTest(tf.test.TestCase):
26 |
27 | def test_generator_run(self):
28 | tf.set_random_seed(1234)
29 | noise = tf.random_normal([100, 64])
30 | image, _ = dcgan.generator(noise)
31 | with self.test_session() as sess:
32 | sess.run(tf.global_variables_initializer())
33 | image.eval()
34 |
35 | def test_generator_graph(self):
36 | tf.set_random_seed(1234)
37 | # Check graph construction for a number of image size/depths and batch
38 | # sizes.
39 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)):
40 | tf.reset_default_graph()
41 | final_size = 2 ** i
42 | noise = tf.random_normal([batch_size, 64])
43 | image, end_points = dcgan.generator(
44 | noise,
45 | depth=32,
46 | final_size=final_size)
47 |
48 | self.assertAllEqual([batch_size, final_size, final_size, 3],
49 | image.shape.as_list())
50 |
51 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits']
52 | self.assertSetEqual(set(expected_names), set(end_points.keys()))
53 |
54 | # Check layer depths.
55 | for j in range(1, i):
56 | layer = end_points['deconv%i' % j]
57 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1])
58 |
59 | def test_generator_invalid_input(self):
60 | wrong_dim_input = tf.zeros([5, 32, 32])
61 | with self.assertRaises(ValueError):
62 | dcgan.generator(wrong_dim_input)
63 |
64 | correct_input = tf.zeros([3, 2])
65 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'):
66 | dcgan.generator(correct_input, final_size=30)
67 |
68 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'):
69 | dcgan.generator(correct_input, final_size=4)
70 |
71 | def test_discriminator_run(self):
72 | image = tf.random_uniform([5, 32, 32, 3], -1, 1)
73 | output, _ = dcgan.discriminator(image)
74 | with self.test_session() as sess:
75 | sess.run(tf.global_variables_initializer())
76 | output.eval()
77 |
78 | def test_discriminator_graph(self):
79 | # Check graph construction for a number of image size/depths and batch
80 | # sizes.
81 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)):
82 | tf.reset_default_graph()
83 | img_w = 2 ** i
84 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1)
85 | output, end_points = dcgan.discriminator(
86 | image,
87 | depth=32)
88 |
89 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list())
90 |
91 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits']
92 | self.assertSetEqual(set(expected_names), set(end_points.keys()))
93 |
94 | # Check layer depths.
95 | for j in range(1, i+1):
96 | layer = end_points['conv%i' % j]
97 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1])
98 |
99 | def test_discriminator_invalid_input(self):
100 | wrong_dim_img = tf.zeros([5, 32, 32])
101 | with self.assertRaises(ValueError):
102 | dcgan.discriminator(wrong_dim_img)
103 |
104 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3])
105 | with self.assertRaises(ValueError):
106 | dcgan.discriminator(spatially_undefined_shape)
107 |
108 | not_square = tf.zeros([5, 32, 16, 3])
109 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'):
110 | dcgan.discriminator(not_square)
111 |
112 | not_power_2 = tf.zeros([5, 30, 30, 3])
113 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'):
114 | dcgan.discriminator(not_power_2)
115 |
116 |
117 | if __name__ == '__main__':
118 | tf.test.main()
119 |
--------------------------------------------------------------------------------
/src/slim/nets/inception.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Brings all inception models under one namespace."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # pylint: disable=unused-import
22 | from nets.inception_resnet_v2 import inception_resnet_v2
23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope
24 | from nets.inception_resnet_v2 import inception_resnet_v2_base
25 | from nets.inception_v1 import inception_v1
26 | from nets.inception_v1 import inception_v1_arg_scope
27 | from nets.inception_v1 import inception_v1_base
28 | from nets.inception_v2 import inception_v2
29 | from nets.inception_v2 import inception_v2_arg_scope
30 | from nets.inception_v2 import inception_v2_base
31 | from nets.inception_v3 import inception_v3
32 | from nets.inception_v3 import inception_v3_arg_scope
33 | from nets.inception_v3 import inception_v3_base
34 | from nets.inception_v4 import inception_v4
35 | from nets.inception_v4 import inception_v4_arg_scope
36 | from nets.inception_v4 import inception_v4_base
37 | # pylint: enable=unused-import
38 |
--------------------------------------------------------------------------------
/src/slim/nets/inception_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains common code shared by all inception models.
16 |
17 | Usage of arg scope:
18 | with slim.arg_scope(inception_arg_scope()):
19 | logits, end_points = inception.inception_v3(images, num_classes,
20 | is_training=is_training)
21 |
22 | """
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import tensorflow as tf
28 |
29 | slim = tf.contrib.slim
30 |
31 |
32 | def inception_arg_scope(weight_decay=0.00004,
33 | use_batch_norm=True,
34 | batch_norm_decay=0.9997,
35 | batch_norm_epsilon=0.001,
36 | activation_fn=tf.nn.relu):
37 | """Defines the default arg scope for inception models.
38 |
39 | Args:
40 | weight_decay: The weight decay to use for regularizing the model.
41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution.
42 | batch_norm_decay: Decay for batch norm moving average.
43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero
44 | in batch norm.
45 | activation_fn: Activation function for conv2d.
46 |
47 | Returns:
48 | An `arg_scope` to use for the inception models.
49 | """
50 | batch_norm_params = {
51 | # Decay for the moving averages.
52 | 'decay': batch_norm_decay,
53 | # epsilon to prevent 0s in variance.
54 | 'epsilon': batch_norm_epsilon,
55 | # collection containing update_ops.
56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
57 | # use fused batch norm if possible.
58 | 'fused': None,
59 | }
60 | if use_batch_norm:
61 | normalizer_fn = slim.batch_norm
62 | normalizer_params = batch_norm_params
63 | else:
64 | normalizer_fn = None
65 | normalizer_params = {}
66 | # Set weight_decay for weights in Conv and FC layers.
67 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
68 | weights_regularizer=slim.l2_regularizer(weight_decay)):
69 | with slim.arg_scope(
70 | [slim.conv2d],
71 | weights_initializer=slim.variance_scaling_initializer(),
72 | activation_fn=activation_fn,
73 | normalizer_fn=normalizer_fn,
74 | normalizer_params=normalizer_params) as sc:
75 | return sc
76 |
--------------------------------------------------------------------------------
/src/slim/nets/lenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a variant of the LeNet model definition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | def lenet(images, num_classes=10, is_training=False,
27 | dropout_keep_prob=0.5,
28 | prediction_fn=slim.softmax,
29 | scope='LeNet'):
30 | """Creates a variant of the LeNet model.
31 |
32 | Note that since the output is a set of 'logits', the values fall in the
33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a
34 | probability distribution over the characters, one will need to convert them
35 | using the softmax function:
36 |
37 | logits = lenet.lenet(images, is_training=False)
38 | probabilities = tf.nn.softmax(logits)
39 | predictions = tf.argmax(logits, 1)
40 |
41 | Args:
42 | images: A batch of `Tensors` of size [batch_size, height, width, channels].
43 | num_classes: the number of classes in the dataset. If 0 or None, the logits
44 | layer is omitted and the input features to the logits layer are returned
45 | instead.
46 | is_training: specifies whether or not we're currently training the model.
47 | This variable will determine the behaviour of the dropout layer.
48 | dropout_keep_prob: the percentage of activation values that are retained.
49 | prediction_fn: a function to get predictions out of logits.
50 | scope: Optional variable_scope.
51 |
52 | Returns:
53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer
55 | if num_classes is 0 or None.
56 | end_points: a dictionary from components of the network to the corresponding
57 | activation.
58 | """
59 | end_points = {}
60 |
61 | with tf.variable_scope(scope, 'LeNet', [images]):
62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2')
65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
66 | net = slim.flatten(net)
67 | end_points['Flatten'] = net
68 |
69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3')
70 | if not num_classes:
71 | return net, end_points
72 | net = end_points['dropout3'] = slim.dropout(
73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3')
74 | logits = end_points['Logits'] = slim.fully_connected(
75 | net, num_classes, activation_fn=None, scope='fc4')
76 |
77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
78 |
79 | return logits, end_points
80 | lenet.default_image_size = 28
81 |
82 |
83 | def lenet_arg_scope(weight_decay=0.0):
84 | """Defines the default lenet argument scope.
85 |
86 | Args:
87 | weight_decay: The weight decay to use for regularizing the model.
88 |
89 | Returns:
90 | An `arg_scope` to use for the inception v3 model.
91 | """
92 | with slim.arg_scope(
93 | [slim.conv2d, slim.fully_connected],
94 | weights_regularizer=slim.l2_regularizer(weight_decay),
95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
96 | activation_fn=tf.nn.relu) as sc:
97 | return sc
98 |
--------------------------------------------------------------------------------
/src/slim/nets/mobilenet_v1.md:
--------------------------------------------------------------------------------
1 | # MobileNet_v1
2 |
3 | [MobileNets](https://arxiv.org/abs/1704.04861) are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as Inception, are used. MobileNets can be run efficiently on mobile devices with [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
4 |
5 | MobileNets trade off between latency, size and accuracy while comparing favorably with popular models from the literature.
6 |
7 | 
8 |
9 | # Pre-trained Models
10 |
11 | Choose the right MobileNet model to fit your latency and size budget. The size of the network in memory and on disk is proportional to the number of parameters. The latency and power usage of the network scales with the number of Multiply-Accumulates (MACs) which measures the number of fused Multiplication and Addition operations. These MobileNet models have been trained on the
12 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/)
13 | image classification dataset. Accuracies were computed by evaluating using a single image crop.
14 |
15 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy |
16 | :----:|:------------:|:----------:|:-------:|:-------:|
17 | [MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|569|4.24|70.7|89.5|
18 | [MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_1.0_192_2017_06_14.tar.gz)|418|4.24|69.3|88.9|
19 | [MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_1.0_160_2017_06_14.tar.gz)|291|4.24|67.2|87.5|
20 | [MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_1.0_128_2017_06_14.tar.gz)|186|4.24|64.1|85.3|
21 | [MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_0.75_224_2017_06_14.tar.gz)|317|2.59|68.4|88.2|
22 | [MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_0.75_192_2017_06_14.tar.gz)|233|2.59|67.4|87.3|
23 | [MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_0.75_160_2017_06_14.tar.gz)|162|2.59|65.2|86.1|
24 | [MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_0.75_128_2017_06_14.tar.gz)|104|2.59|61.8|83.6|
25 | [MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_0.50_224_2017_06_14.tar.gz)|150|1.34|64.0|85.4|
26 | [MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_0.50_192_2017_06_14.tar.gz)|110|1.34|62.1|84.0|
27 | [MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|77|1.34|59.9|82.5|
28 | [MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_0.50_128_2017_06_14.tar.gz)|49|1.34|56.2|79.6|
29 | [MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_0.25_224_2017_06_14.tar.gz)|41|0.47|50.6|75.0|
30 | [MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_0.25_192_2017_06_14.tar.gz)|34|0.47|49.0|73.6|
31 | [MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_0.25_160_2017_06_14.tar.gz)|21|0.47|46.0|70.7|
32 | [MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|14|0.47|41.3|66.2|
33 |
34 |
35 | Here is an example of how to download the MobileNet_v1_1.0_224 checkpoint:
36 |
37 | ```shell
38 | $ CHECKPOINT_DIR=/tmp/checkpoints
39 | $ mkdir ${CHECKPOINT_DIR}
40 | $ wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz
41 | $ tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz
42 | $ mv mobilenet_v1_1.0_224.ckpt.* ${CHECKPOINT_DIR}
43 | $ rm mobilenet_v1_1.0_224_2017_06_14.tar.gz
44 | ```
45 | More information on integrating MobileNets into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md).
46 |
47 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
48 |
--------------------------------------------------------------------------------
/src/slim/nets/mobilenet_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZheC/tf-pose-estimation/b5a216a6ca51767a208c226a33b1a7f38cb04295/src/slim/nets/mobilenet_v1.png
--------------------------------------------------------------------------------
/src/slim/nets/nasnet/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints
2 | This directory contains the code for the NASNet-A model from the paper
3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al.
4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the
5 | other two are variants of NASNet-A trained on ImageNet, which are listed below.
6 |
7 | # Pre-Trained Models
8 | Two NASNet-A checkpoints are available that have been trained on the
9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/)
10 | image classification dataset. Accuracies were computed by evaluating using a single image crop.
11 |
12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy |
13 | :----:|:------------:|:----------:|:-------:|:-------:|
14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6|
15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2|
16 |
17 |
18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same.
19 |
20 | ```shell
21 | CHECKPOINT_DIR=/tmp/checkpoints
22 | mkdir ${CHECKPOINT_DIR}
23 | cd ${CHECKPOINT_DIR}
24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz
25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz
26 | rm nasnet-a_mobile_04_10_2017.tar.gz
27 | ```
28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md).
29 |
30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
31 |
32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference
33 | -------
34 | Run eval with the NASNet-A mobile ImageNet model
35 |
36 | ```shell
37 | DATASET_DIR=/tmp/imagenet
38 | EVAL_DIR=/tmp/tfmodel/eval
39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt
40 | python tensorflow_models/research/slim/eval_image_classifier \
41 | --checkpoint_path=${CHECKPOINT_DIR} \
42 | --eval_dir=${EVAL_DIR} \
43 | --dataset_dir=${DATASET_DIR} \
44 | --dataset_name=imagenet \
45 | --dataset_split_name=validation \
46 | --model_name=nasnet_mobile \
47 | --eval_image_size=224 \
48 | --moving_average_decay=0.9999
49 | ```
50 |
51 | Run eval with the NASNet-A large ImageNet model
52 |
53 | ```shell
54 | DATASET_DIR=/tmp/imagenet
55 | EVAL_DIR=/tmp/tfmodel/eval
56 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt
57 | python tensorflow_models/research/slim/eval_image_classifier \
58 | --checkpoint_path=${CHECKPOINT_DIR} \
59 | --eval_dir=${EVAL_DIR} \
60 | --dataset_dir=${DATASET_DIR} \
61 | --dataset_name=imagenet \
62 | --dataset_split_name=validation \
63 | --model_name=nasnet_large \
64 | --eval_image_size=331 \
65 | --moving_average_decay=0.9999
66 | ```
67 |
--------------------------------------------------------------------------------
/src/slim/nets/nasnet/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/slim/nets/nasnet/nasnet_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.nets.nasnet.nasnet_utils."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from nets.nasnet import nasnet_utils
24 |
25 |
26 | class NasnetUtilsTest(tf.test.TestCase):
27 |
28 | def testCalcReductionLayers(self):
29 | num_cells = 18
30 | num_reduction_layers = 2
31 | reduction_layers = nasnet_utils.calc_reduction_layers(
32 | num_cells, num_reduction_layers)
33 | self.assertEqual(len(reduction_layers), 2)
34 | self.assertEqual(reduction_layers[0], 6)
35 | self.assertEqual(reduction_layers[1], 12)
36 |
37 | def testGetChannelIndex(self):
38 | data_formats = ['NHWC', 'NCHW']
39 | for data_format in data_formats:
40 | index = nasnet_utils.get_channel_index(data_format)
41 | correct_index = 3 if data_format == 'NHWC' else 1
42 | self.assertEqual(index, correct_index)
43 |
44 | def testGetChannelDim(self):
45 | data_formats = ['NHWC', 'NCHW']
46 | shape = [10, 20, 30, 40]
47 | for data_format in data_formats:
48 | dim = nasnet_utils.get_channel_dim(shape, data_format)
49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1]
50 | self.assertEqual(dim, correct_dim)
51 |
52 | def testGlobalAvgPool(self):
53 | data_formats = ['NHWC', 'NCHW']
54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10))
55 | for data_format in data_formats:
56 | output = nasnet_utils.global_avg_pool(
57 | inputs, data_format)
58 | self.assertEqual(output.shape, [5, 10])
59 |
60 |
61 | if __name__ == '__main__':
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/src/slim/nets/nets_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import functools
21 |
22 | import tensorflow as tf
23 |
24 | from nets import alexnet
25 | from nets import cifarnet
26 | from nets import inception
27 | from nets import lenet
28 | from nets import mobilenet_v1
29 | from nets import overfeat
30 | from nets import resnet_v1
31 | from nets import resnet_v2
32 | from nets import vgg
33 | from nets.nasnet import nasnet
34 |
35 | slim = tf.contrib.slim
36 |
37 | networks_map = {'alexnet_v2': alexnet.alexnet_v2,
38 | 'cifarnet': cifarnet.cifarnet,
39 | 'overfeat': overfeat.overfeat,
40 | 'vgg_a': vgg.vgg_a,
41 | 'vgg_16': vgg.vgg_16,
42 | 'vgg_19': vgg.vgg_19,
43 | 'inception_v1': inception.inception_v1,
44 | 'inception_v2': inception.inception_v2,
45 | 'inception_v3': inception.inception_v3,
46 | 'inception_v4': inception.inception_v4,
47 | 'inception_resnet_v2': inception.inception_resnet_v2,
48 | 'lenet': lenet.lenet,
49 | 'resnet_v1_50': resnet_v1.resnet_v1_50,
50 | 'resnet_v1_101': resnet_v1.resnet_v1_101,
51 | 'resnet_v1_152': resnet_v1.resnet_v1_152,
52 | 'resnet_v1_200': resnet_v1.resnet_v1_200,
53 | 'resnet_v2_50': resnet_v2.resnet_v2_50,
54 | 'resnet_v2_101': resnet_v2.resnet_v2_101,
55 | 'resnet_v2_152': resnet_v2.resnet_v2_152,
56 | 'resnet_v2_200': resnet_v2.resnet_v2_200,
57 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1,
58 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075,
59 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050,
60 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025,
61 | 'nasnet_cifar': nasnet.build_nasnet_cifar,
62 | 'nasnet_mobile': nasnet.build_nasnet_mobile,
63 | 'nasnet_large': nasnet.build_nasnet_large,
64 | }
65 |
66 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
67 | 'cifarnet': cifarnet.cifarnet_arg_scope,
68 | 'overfeat': overfeat.overfeat_arg_scope,
69 | 'vgg_a': vgg.vgg_arg_scope,
70 | 'vgg_16': vgg.vgg_arg_scope,
71 | 'vgg_19': vgg.vgg_arg_scope,
72 | 'inception_v1': inception.inception_v3_arg_scope,
73 | 'inception_v2': inception.inception_v3_arg_scope,
74 | 'inception_v3': inception.inception_v3_arg_scope,
75 | 'inception_v4': inception.inception_v4_arg_scope,
76 | 'inception_resnet_v2':
77 | inception.inception_resnet_v2_arg_scope,
78 | 'lenet': lenet.lenet_arg_scope,
79 | 'resnet_v1_50': resnet_v1.resnet_arg_scope,
80 | 'resnet_v1_101': resnet_v1.resnet_arg_scope,
81 | 'resnet_v1_152': resnet_v1.resnet_arg_scope,
82 | 'resnet_v1_200': resnet_v1.resnet_arg_scope,
83 | 'resnet_v2_50': resnet_v2.resnet_arg_scope,
84 | 'resnet_v2_101': resnet_v2.resnet_arg_scope,
85 | 'resnet_v2_152': resnet_v2.resnet_arg_scope,
86 | 'resnet_v2_200': resnet_v2.resnet_arg_scope,
87 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope,
88 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope,
89 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope,
90 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope,
91 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope,
92 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope,
93 | 'nasnet_large': nasnet.nasnet_large_arg_scope,
94 | }
95 |
96 |
97 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
98 | """Returns a network_fn such as `logits, end_points = network_fn(images)`.
99 |
100 | Args:
101 | name: The name of the network.
102 | num_classes: The number of classes to use for classification. If 0 or None,
103 | the logits layer is omitted and its input features are returned instead.
104 | weight_decay: The l2 coefficient for the model weights.
105 | is_training: `True` if the model is being used for training and `False`
106 | otherwise.
107 |
108 | Returns:
109 | network_fn: A function that applies the model to a batch of images. It has
110 | the following signature:
111 | net, end_points = network_fn(images)
112 | The `images` input is a tensor of shape [batch_size, height, width, 3]
113 | with height = width = network_fn.default_image_size. (The permissibility
114 | and treatment of other sizes depends on the network_fn.)
115 | The returned `end_points` are a dictionary of intermediate activations.
116 | The returned `net` is the topmost layer, depending on `num_classes`:
117 | If `num_classes` was a non-zero integer, `net` is a logits tensor
118 | of shape [batch_size, num_classes].
119 | If `num_classes` was 0 or `None`, `net` is a tensor with the input
120 | to the logits layer of shape [batch_size, 1, 1, num_features] or
121 | [batch_size, num_features]. Dropout has not been applied to this
122 | (even if the network's original classification does); it remains for
123 | the caller to do this or not.
124 |
125 | Raises:
126 | ValueError: If network `name` is not recognized.
127 | """
128 | if name not in networks_map:
129 | raise ValueError('Name of network unknown %s' % name)
130 | func = networks_map[name]
131 | @functools.wraps(func)
132 | def network_fn(images, **kwargs):
133 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
134 | with slim.arg_scope(arg_scope):
135 | return func(images, num_classes, is_training=is_training, **kwargs)
136 | if hasattr(func, 'default_image_size'):
137 | network_fn.default_image_size = func.default_image_size
138 |
139 | return network_fn
140 |
--------------------------------------------------------------------------------
/src/slim/nets/nets_factory_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for slim.inception."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | import tensorflow as tf
24 |
25 | from nets import nets_factory
26 |
27 |
28 | class NetworksTest(tf.test.TestCase):
29 |
30 | def testGetNetworkFnFirstHalf(self):
31 | batch_size = 5
32 | num_classes = 1000
33 | for net in nets_factory.networks_map.keys()[:10]:
34 | with tf.Graph().as_default() as g, self.test_session(g):
35 | net_fn = nets_factory.get_network_fn(net, num_classes)
36 | # Most networks use 224 as their default_image_size
37 | image_size = getattr(net_fn, 'default_image_size', 224)
38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
39 | logits, end_points = net_fn(inputs)
40 | self.assertTrue(isinstance(logits, tf.Tensor))
41 | self.assertTrue(isinstance(end_points, dict))
42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size)
43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
44 |
45 | def testGetNetworkFnSecondHalf(self):
46 | batch_size = 5
47 | num_classes = 1000
48 | for net in nets_factory.networks_map.keys()[10:]:
49 | with tf.Graph().as_default() as g, self.test_session(g):
50 | net_fn = nets_factory.get_network_fn(net, num_classes)
51 | # Most networks use 224 as their default_image_size
52 | image_size = getattr(net_fn, 'default_image_size', 224)
53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
54 | logits, end_points = net_fn(inputs)
55 | self.assertTrue(isinstance(logits, tf.Tensor))
56 | self.assertTrue(isinstance(end_points, dict))
57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size)
58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
59 |
60 | if __name__ == '__main__':
61 | tf.test.main()
62 |
--------------------------------------------------------------------------------
/src/slim/nets/overfeat.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains the model definition for the OverFeat network.
16 |
17 | The definition for the network was obtained from:
18 | OverFeat: Integrated Recognition, Localization and Detection using
19 | Convolutional Networks
20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
21 | Yann LeCun, 2014
22 | http://arxiv.org/abs/1312.6229
23 |
24 | Usage:
25 | with slim.arg_scope(overfeat.overfeat_arg_scope()):
26 | outputs, end_points = overfeat.overfeat(inputs)
27 |
28 | @@overfeat
29 | """
30 | from __future__ import absolute_import
31 | from __future__ import division
32 | from __future__ import print_function
33 |
34 | import tensorflow as tf
35 |
36 | slim = tf.contrib.slim
37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
38 |
39 |
40 | def overfeat_arg_scope(weight_decay=0.0005):
41 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
42 | activation_fn=tf.nn.relu,
43 | weights_regularizer=slim.l2_regularizer(weight_decay),
44 | biases_initializer=tf.zeros_initializer()):
45 | with slim.arg_scope([slim.conv2d], padding='SAME'):
46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
47 | return arg_sc
48 |
49 |
50 | def overfeat(inputs,
51 | num_classes=1000,
52 | is_training=True,
53 | dropout_keep_prob=0.5,
54 | spatial_squeeze=True,
55 | scope='overfeat',
56 | global_pool=False):
57 | """Contains the model definition for the OverFeat network.
58 |
59 | The definition for the network was obtained from:
60 | OverFeat: Integrated Recognition, Localization and Detection using
61 | Convolutional Networks
62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
63 | Yann LeCun, 2014
64 | http://arxiv.org/abs/1312.6229
65 |
66 | Note: All the fully_connected layers have been transformed to conv2d layers.
67 | To use in classification mode, resize input to 231x231. To use in fully
68 | convolutional mode, set spatial_squeeze to false.
69 |
70 | Args:
71 | inputs: a tensor of size [batch_size, height, width, channels].
72 | num_classes: number of predicted classes. If 0 or None, the logits layer is
73 | omitted and the input features to the logits layer are returned instead.
74 | is_training: whether or not the model is being trained.
75 | dropout_keep_prob: the probability that activations are kept in the dropout
76 | layers during training.
77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
78 | outputs. Useful to remove unnecessary dimensions for classification.
79 | scope: Optional scope for the variables.
80 | global_pool: Optional boolean flag. If True, the input to the classification
81 | layer is avgpooled to size 1x1, for any input size. (This is not part
82 | of the original OverFeat.)
83 |
84 | Returns:
85 | net: the output of the logits layer (if num_classes is a non-zero integer),
86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or
87 | None).
88 | end_points: a dict of tensors with intermediate activations.
89 | """
90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc:
91 | end_points_collection = sc.original_name_scope + '_end_points'
92 | # Collect outputs for conv2d, fully_connected and max_pool2d
93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
94 | outputs_collections=end_points_collection):
95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
96 | scope='conv1')
97 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2')
99 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3')
101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4')
102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5')
103 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
104 |
105 | # Use conv2d instead of fully_connected layers.
106 | with slim.arg_scope([slim.conv2d],
107 | weights_initializer=trunc_normal(0.005),
108 | biases_initializer=tf.constant_initializer(0.1)):
109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6')
110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
111 | scope='dropout6')
112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
113 | # Convert end_points_collection into a end_point dict.
114 | end_points = slim.utils.convert_collection_to_dict(
115 | end_points_collection)
116 | if global_pool:
117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
118 | end_points['global_pool'] = net
119 | if num_classes:
120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
121 | scope='dropout7')
122 | net = slim.conv2d(net, num_classes, [1, 1],
123 | activation_fn=None,
124 | normalizer_fn=None,
125 | biases_initializer=tf.zeros_initializer(),
126 | scope='fc8')
127 | if spatial_squeeze:
128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
129 | end_points[sc.name + '/fc8'] = net
130 | return net, end_points
131 | overfeat.default_image_size = 231
132 |
--------------------------------------------------------------------------------
/src/slim/nets/overfeat_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.nets.overfeat."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from nets import overfeat
23 |
24 | slim = tf.contrib.slim
25 |
26 |
27 | class OverFeatTest(tf.test.TestCase):
28 |
29 | def testBuild(self):
30 | batch_size = 5
31 | height, width = 231, 231
32 | num_classes = 1000
33 | with self.test_session():
34 | inputs = tf.random_uniform((batch_size, height, width, 3))
35 | logits, _ = overfeat.overfeat(inputs, num_classes)
36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
37 | self.assertListEqual(logits.get_shape().as_list(),
38 | [batch_size, num_classes])
39 |
40 | def testFullyConvolutional(self):
41 | batch_size = 1
42 | height, width = 281, 281
43 | num_classes = 1000
44 | with self.test_session():
45 | inputs = tf.random_uniform((batch_size, height, width, 3))
46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
48 | self.assertListEqual(logits.get_shape().as_list(),
49 | [batch_size, 2, 2, num_classes])
50 |
51 | def testGlobalPool(self):
52 | batch_size = 1
53 | height, width = 281, 281
54 | num_classes = 1000
55 | with self.test_session():
56 | inputs = tf.random_uniform((batch_size, height, width, 3))
57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False,
58 | global_pool=True)
59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
60 | self.assertListEqual(logits.get_shape().as_list(),
61 | [batch_size, 1, 1, num_classes])
62 |
63 | def testEndPoints(self):
64 | batch_size = 5
65 | height, width = 231, 231
66 | num_classes = 1000
67 | with self.test_session():
68 | inputs = tf.random_uniform((batch_size, height, width, 3))
69 | _, end_points = overfeat.overfeat(inputs, num_classes)
70 | expected_names = ['overfeat/conv1',
71 | 'overfeat/pool1',
72 | 'overfeat/conv2',
73 | 'overfeat/pool2',
74 | 'overfeat/conv3',
75 | 'overfeat/conv4',
76 | 'overfeat/conv5',
77 | 'overfeat/pool5',
78 | 'overfeat/fc6',
79 | 'overfeat/fc7',
80 | 'overfeat/fc8'
81 | ]
82 | self.assertSetEqual(set(end_points.keys()), set(expected_names))
83 |
84 | def testNoClasses(self):
85 | batch_size = 5
86 | height, width = 231, 231
87 | num_classes = None
88 | with self.test_session():
89 | inputs = tf.random_uniform((batch_size, height, width, 3))
90 | net, end_points = overfeat.overfeat(inputs, num_classes)
91 | expected_names = ['overfeat/conv1',
92 | 'overfeat/pool1',
93 | 'overfeat/conv2',
94 | 'overfeat/pool2',
95 | 'overfeat/conv3',
96 | 'overfeat/conv4',
97 | 'overfeat/conv5',
98 | 'overfeat/pool5',
99 | 'overfeat/fc6',
100 | 'overfeat/fc7'
101 | ]
102 | self.assertSetEqual(set(end_points.keys()), set(expected_names))
103 | self.assertTrue(net.op.name.startswith('overfeat/fc7'))
104 |
105 | def testModelVariables(self):
106 | batch_size = 5
107 | height, width = 231, 231
108 | num_classes = 1000
109 | with self.test_session():
110 | inputs = tf.random_uniform((batch_size, height, width, 3))
111 | overfeat.overfeat(inputs, num_classes)
112 | expected_names = ['overfeat/conv1/weights',
113 | 'overfeat/conv1/biases',
114 | 'overfeat/conv2/weights',
115 | 'overfeat/conv2/biases',
116 | 'overfeat/conv3/weights',
117 | 'overfeat/conv3/biases',
118 | 'overfeat/conv4/weights',
119 | 'overfeat/conv4/biases',
120 | 'overfeat/conv5/weights',
121 | 'overfeat/conv5/biases',
122 | 'overfeat/fc6/weights',
123 | 'overfeat/fc6/biases',
124 | 'overfeat/fc7/weights',
125 | 'overfeat/fc7/biases',
126 | 'overfeat/fc8/weights',
127 | 'overfeat/fc8/biases',
128 | ]
129 | model_variables = [v.op.name for v in slim.get_model_variables()]
130 | self.assertSetEqual(set(model_variables), set(expected_names))
131 |
132 | def testEvaluation(self):
133 | batch_size = 2
134 | height, width = 231, 231
135 | num_classes = 1000
136 | with self.test_session():
137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3))
138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
139 | self.assertListEqual(logits.get_shape().as_list(),
140 | [batch_size, num_classes])
141 | predictions = tf.argmax(logits, 1)
142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
143 |
144 | def testTrainEvalWithReuse(self):
145 | train_batch_size = 2
146 | eval_batch_size = 1
147 | train_height, train_width = 231, 231
148 | eval_height, eval_width = 281, 281
149 | num_classes = 1000
150 | with self.test_session():
151 | train_inputs = tf.random_uniform(
152 | (train_batch_size, train_height, train_width, 3))
153 | logits, _ = overfeat.overfeat(train_inputs)
154 | self.assertListEqual(logits.get_shape().as_list(),
155 | [train_batch_size, num_classes])
156 | tf.get_variable_scope().reuse_variables()
157 | eval_inputs = tf.random_uniform(
158 | (eval_batch_size, eval_height, eval_width, 3))
159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False,
160 | spatial_squeeze=False)
161 | self.assertListEqual(logits.get_shape().as_list(),
162 | [eval_batch_size, 2, 2, num_classes])
163 | logits = tf.reduce_mean(logits, [1, 2])
164 | predictions = tf.argmax(logits, 1)
165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
166 |
167 | def testForward(self):
168 | batch_size = 1
169 | height, width = 231, 231
170 | with self.test_session() as sess:
171 | inputs = tf.random_uniform((batch_size, height, width, 3))
172 | logits, _ = overfeat.overfeat(inputs)
173 | sess.run(tf.global_variables_initializer())
174 | output = sess.run(logits)
175 | self.assertTrue(output.any())
176 |
177 | if __name__ == '__main__':
178 | tf.test.main()
179 |
--------------------------------------------------------------------------------
/src/slim/nets/pix2pix_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 | """Tests for pix2pix."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 | from nets import pix2pix
23 |
24 |
25 | class GeneratorTest(tf.test.TestCase):
26 |
27 | def test_nonsquare_inputs_raise_exception(self):
28 | batch_size = 2
29 | height, width = 240, 320
30 | num_outputs = 4
31 |
32 | images = tf.ones((batch_size, height, width, 3))
33 |
34 | with self.assertRaises(ValueError):
35 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
36 | pix2pix.pix2pix_generator(
37 | images, num_outputs, upsample_method='nn_upsample_conv')
38 |
39 | def _reduced_default_blocks(self):
40 | """Returns the default blocks, scaled down to make test run faster."""
41 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob)
42 | for b in pix2pix._default_generator_blocks()]
43 |
44 | def test_output_size_nn_upsample_conv(self):
45 | batch_size = 2
46 | height, width = 256, 256
47 | num_outputs = 4
48 |
49 | images = tf.ones((batch_size, height, width, 3))
50 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
51 | logits, _ = pix2pix.pix2pix_generator(
52 | images, num_outputs, blocks=self._reduced_default_blocks(),
53 | upsample_method='nn_upsample_conv')
54 |
55 | with self.test_session() as session:
56 | session.run(tf.global_variables_initializer())
57 | np_outputs = session.run(logits)
58 | self.assertListEqual([batch_size, height, width, num_outputs],
59 | list(np_outputs.shape))
60 |
61 | def test_output_size_conv2d_transpose(self):
62 | batch_size = 2
63 | height, width = 256, 256
64 | num_outputs = 4
65 |
66 | images = tf.ones((batch_size, height, width, 3))
67 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
68 | logits, _ = pix2pix.pix2pix_generator(
69 | images, num_outputs, blocks=self._reduced_default_blocks(),
70 | upsample_method='conv2d_transpose')
71 |
72 | with self.test_session() as session:
73 | session.run(tf.global_variables_initializer())
74 | np_outputs = session.run(logits)
75 | self.assertListEqual([batch_size, height, width, num_outputs],
76 | list(np_outputs.shape))
77 |
78 | def test_block_number_dictates_number_of_layers(self):
79 | batch_size = 2
80 | height, width = 256, 256
81 | num_outputs = 4
82 |
83 | images = tf.ones((batch_size, height, width, 3))
84 | blocks = [
85 | pix2pix.Block(64, 0.5),
86 | pix2pix.Block(128, 0),
87 | ]
88 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
89 | _, end_points = pix2pix.pix2pix_generator(
90 | images, num_outputs, blocks)
91 |
92 | num_encoder_layers = 0
93 | num_decoder_layers = 0
94 | for end_point in end_points:
95 | if end_point.startswith('encoder'):
96 | num_encoder_layers += 1
97 | elif end_point.startswith('decoder'):
98 | num_decoder_layers += 1
99 |
100 | self.assertEqual(num_encoder_layers, len(blocks))
101 | self.assertEqual(num_decoder_layers, len(blocks))
102 |
103 |
104 | class DiscriminatorTest(tf.test.TestCase):
105 |
106 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2):
107 | return (input_size + pad * 2 - kernel_size) // stride + 1
108 |
109 | def test_four_layers(self):
110 | batch_size = 2
111 | input_size = 256
112 |
113 | output_size = self._layer_output_size(input_size)
114 | output_size = self._layer_output_size(output_size)
115 | output_size = self._layer_output_size(output_size)
116 | output_size = self._layer_output_size(output_size, stride=1)
117 | output_size = self._layer_output_size(output_size, stride=1)
118 |
119 | images = tf.ones((batch_size, input_size, input_size, 3))
120 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
121 | logits, end_points = pix2pix.pix2pix_discriminator(
122 | images, num_filters=[64, 128, 256, 512])
123 | self.assertListEqual([batch_size, output_size, output_size, 1],
124 | logits.shape.as_list())
125 | self.assertListEqual([batch_size, output_size, output_size, 1],
126 | end_points['predictions'].shape.as_list())
127 |
128 | def test_four_layers_no_padding(self):
129 | batch_size = 2
130 | input_size = 256
131 |
132 | output_size = self._layer_output_size(input_size, pad=0)
133 | output_size = self._layer_output_size(output_size, pad=0)
134 | output_size = self._layer_output_size(output_size, pad=0)
135 | output_size = self._layer_output_size(output_size, stride=1, pad=0)
136 | output_size = self._layer_output_size(output_size, stride=1, pad=0)
137 |
138 | images = tf.ones((batch_size, input_size, input_size, 3))
139 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
140 | logits, end_points = pix2pix.pix2pix_discriminator(
141 | images, num_filters=[64, 128, 256, 512], padding=0)
142 | self.assertListEqual([batch_size, output_size, output_size, 1],
143 | logits.shape.as_list())
144 | self.assertListEqual([batch_size, output_size, output_size, 1],
145 | end_points['predictions'].shape.as_list())
146 |
147 | def test_four_layers_wrog_paddig(self):
148 | batch_size = 2
149 | input_size = 256
150 |
151 | images = tf.ones((batch_size, input_size, input_size, 3))
152 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
153 | with self.assertRaises(TypeError):
154 | pix2pix.pix2pix_discriminator(
155 | images, num_filters=[64, 128, 256, 512], padding=1.5)
156 |
157 | def test_four_layers_negative_padding(self):
158 | batch_size = 2
159 | input_size = 256
160 |
161 | images = tf.ones((batch_size, input_size, input_size, 3))
162 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
163 | with self.assertRaises(ValueError):
164 | pix2pix.pix2pix_discriminator(
165 | images, num_filters=[64, 128, 256, 512], padding=-1)
166 |
167 | if __name__ == '__main__':
168 | tf.test.main()
169 |
--------------------------------------------------------------------------------
/src/slim/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/slim/preprocessing/cifarnet_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities to preprocess images in CIFAR-10.
16 |
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow as tf
24 |
25 | _PADDING = 4
26 |
27 | slim = tf.contrib.slim
28 |
29 |
30 | def preprocess_for_train(image,
31 | output_height,
32 | output_width,
33 | padding=_PADDING,
34 | add_image_summaries=True):
35 | """Preprocesses the given image for training.
36 |
37 | Note that the actual resizing scale is sampled from
38 | [`resize_size_min`, `resize_size_max`].
39 |
40 | Args:
41 | image: A `Tensor` representing an image of arbitrary size.
42 | output_height: The height of the image after preprocessing.
43 | output_width: The width of the image after preprocessing.
44 | padding: The amound of padding before and after each dimension of the image.
45 | add_image_summaries: Enable image summaries.
46 |
47 | Returns:
48 | A preprocessed image.
49 | """
50 | if add_image_summaries:
51 | tf.summary.image('image', tf.expand_dims(image, 0))
52 |
53 | # Transform the image to floats.
54 | image = tf.to_float(image)
55 | if padding > 0:
56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]])
57 | # Randomly crop a [height, width] section of the image.
58 | distorted_image = tf.random_crop(image,
59 | [output_height, output_width, 3])
60 |
61 | # Randomly flip the image horizontally.
62 | distorted_image = tf.image.random_flip_left_right(distorted_image)
63 |
64 | if add_image_summaries:
65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0))
66 |
67 | # Because these operations are not commutative, consider randomizing
68 | # the order their operation.
69 | distorted_image = tf.image.random_brightness(distorted_image,
70 | max_delta=63)
71 | distorted_image = tf.image.random_contrast(distorted_image,
72 | lower=0.2, upper=1.8)
73 | # Subtract off the mean and divide by the variance of the pixels.
74 | return tf.image.per_image_standardization(distorted_image)
75 |
76 |
77 | def preprocess_for_eval(image, output_height, output_width,
78 | add_image_summaries=True):
79 | """Preprocesses the given image for evaluation.
80 |
81 | Args:
82 | image: A `Tensor` representing an image of arbitrary size.
83 | output_height: The height of the image after preprocessing.
84 | output_width: The width of the image after preprocessing.
85 | add_image_summaries: Enable image summaries.
86 |
87 | Returns:
88 | A preprocessed image.
89 | """
90 | if add_image_summaries:
91 | tf.summary.image('image', tf.expand_dims(image, 0))
92 | # Transform the image to floats.
93 | image = tf.to_float(image)
94 |
95 | # Resize and crop if needed.
96 | resized_image = tf.image.resize_image_with_crop_or_pad(image,
97 | output_width,
98 | output_height)
99 | if add_image_summaries:
100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0))
101 |
102 | # Subtract off the mean and divide by the variance of the pixels.
103 | return tf.image.per_image_standardization(resized_image)
104 |
105 |
106 | def preprocess_image(image, output_height, output_width, is_training=False,
107 | add_image_summaries=True):
108 | """Preprocesses the given image.
109 |
110 | Args:
111 | image: A `Tensor` representing an image of arbitrary size.
112 | output_height: The height of the image after preprocessing.
113 | output_width: The width of the image after preprocessing.
114 | is_training: `True` if we're preprocessing the image for training and
115 | `False` otherwise.
116 | add_image_summaries: Enable image summaries.
117 |
118 | Returns:
119 | A preprocessed image.
120 | """
121 | if is_training:
122 | return preprocess_for_train(
123 | image, output_height, output_width,
124 | add_image_summaries=add_image_summaries)
125 | else:
126 | return preprocess_for_eval(
127 | image, output_height, output_width,
128 | add_image_summaries=add_image_summaries)
129 |
--------------------------------------------------------------------------------
/src/slim/preprocessing/lenet_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities for preprocessing."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | def preprocess_image(image, output_height, output_width, is_training):
27 | """Preprocesses the given image.
28 |
29 | Args:
30 | image: A `Tensor` representing an image of arbitrary size.
31 | output_height: The height of the image after preprocessing.
32 | output_width: The width of the image after preprocessing.
33 | is_training: `True` if we're preprocessing the image for training and
34 | `False` otherwise.
35 |
36 | Returns:
37 | A preprocessed image.
38 | """
39 | image = tf.to_float(image)
40 | image = tf.image.resize_image_with_crop_or_pad(
41 | image, output_width, output_height)
42 | image = tf.subtract(image, 128.0)
43 | image = tf.div(image, 128.0)
44 | return image
45 |
--------------------------------------------------------------------------------
/src/slim/preprocessing/preprocessing_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from preprocessing import cifarnet_preprocessing
24 | from preprocessing import inception_preprocessing
25 | from preprocessing import lenet_preprocessing
26 | from preprocessing import vgg_preprocessing
27 |
28 | slim = tf.contrib.slim
29 |
30 |
31 | def get_preprocessing(name, is_training=False):
32 | """Returns preprocessing_fn(image, height, width, **kwargs).
33 |
34 | Args:
35 | name: The name of the preprocessing function.
36 | is_training: `True` if the model is being used for training and `False`
37 | otherwise.
38 |
39 | Returns:
40 | preprocessing_fn: A function that preprocessing a single image (pre-batch).
41 | It has the following signature:
42 | image = preprocessing_fn(image, output_height, output_width, ...).
43 |
44 | Raises:
45 | ValueError: If Preprocessing `name` is not recognized.
46 | """
47 | preprocessing_fn_map = {
48 | 'cifarnet': cifarnet_preprocessing,
49 | 'inception': inception_preprocessing,
50 | 'inception_v1': inception_preprocessing,
51 | 'inception_v2': inception_preprocessing,
52 | 'inception_v3': inception_preprocessing,
53 | 'inception_v4': inception_preprocessing,
54 | 'inception_resnet_v2': inception_preprocessing,
55 | 'lenet': lenet_preprocessing,
56 | 'mobilenet_v1': inception_preprocessing,
57 | 'nasnet_mobile': inception_preprocessing,
58 | 'nasnet_large': inception_preprocessing,
59 | 'resnet_v1_50': vgg_preprocessing,
60 | 'resnet_v1_101': vgg_preprocessing,
61 | 'resnet_v1_152': vgg_preprocessing,
62 | 'resnet_v1_200': vgg_preprocessing,
63 | 'resnet_v2_50': vgg_preprocessing,
64 | 'resnet_v2_101': vgg_preprocessing,
65 | 'resnet_v2_152': vgg_preprocessing,
66 | 'resnet_v2_200': vgg_preprocessing,
67 | 'vgg': vgg_preprocessing,
68 | 'vgg_a': vgg_preprocessing,
69 | 'vgg_16': vgg_preprocessing,
70 | 'vgg_19': vgg_preprocessing,
71 | }
72 |
73 | if name not in preprocessing_fn_map:
74 | raise ValueError('Preprocessing name [%s] was not recognized' % name)
75 |
76 | def preprocessing_fn(image, output_height, output_width, **kwargs):
77 | return preprocessing_fn_map[name].preprocess_image(
78 | image, output_height, output_width, is_training=is_training, **kwargs)
79 |
80 | return preprocessing_fn
81 |
--------------------------------------------------------------------------------