├── .gitattributes
├── .gitignore
├── README.md
├── Sim3DR
├── .gitignore
├── Sim3DR.py
├── __init__.py
├── _init_paths.py
├── build_sim3dr.sh
├── lib
│ ├── rasterize.h
│ ├── rasterize.pyx
│ └── rasterize_kernel.cpp
├── lighting.py
├── readme.md
├── setup.py
└── tests
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── io.cpp
│ ├── io.h
│ └── test.cpp
├── config.py
├── convert_json_list_to_lmdb.py
├── data_loader_lmdb.py
├── data_loader_lmdb_augmenter.py
├── early_stop.py
├── evaluation
├── evaluate_wider.py
└── jupyter_notebooks
│ ├── AFLW2000_annotations.txt
│ ├── BIWI_annotations.txt
│ ├── aflw_2000_3d_evaluation.ipynb
│ ├── biwi_evaluation.ipynb
│ ├── test_own_images.ipynb
│ └── visualize_trained_model_predictions.ipynb
├── generalized_rcnn.py
├── img2pose.py
├── license.md
├── losses.py
├── model_loader.py
├── models.py
├── pose_references
├── reference_3d_5_points_trans.npy
├── reference_3d_68_points_trans.npy
├── triangles.npy
└── vertices_trans.npy
├── requirements.txt
├── rpn.py
├── run_face_alignment.py
├── teaser.jpeg
├── train.py
├── train_logger.py
└── utils
├── annotate_dataset.py
├── augmentation.py
├── dist.py
├── face_align.py
├── image_operations.py
├── json_loader.py
├── json_loader_300wlp.py
├── pose_operations.py
└── renderer.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-documentation
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | parts/
18 | sdist/
19 | var/
20 | wheels/
21 | pip-wheel-metadata/
22 | share/python-wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | *.py,cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 | db.sqlite3-journal
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # pipenv
86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
89 | # install all needed dependencies.
90 | #Pipfile.lock
91 |
92 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
93 | __pypackages__/
94 |
95 | # Celery stuff
96 | celerybeat-schedule
97 | celerybeat.pid
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 | # mypy
122 | .mypy_cache/
123 | .dmypy.json
124 | dmypy.json
125 |
126 | # Pyre type checker
127 | .pyre/
128 |
129 | # User
130 | **/annotations/*
131 | **/datasets/*
132 | **/workspace/*
133 | **/workspace_old/*
134 | **/results/*
135 | **/models/*
136 | *.json
137 | *.sge
138 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation
2 |
3 | ## Paper accepted to the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2021
4 |
5 | [](https://creativecommons.org/licenses/by-nc/4.0/)
6 | [](https://paperswithcode.com/sota/head-pose-estimation-on-aflw2000?p=img2pose-face-alignment-and-detection-via)
7 | [](https://paperswithcode.com/sota/head-pose-estimation-on-biwi?p=img2pose-face-alignment-and-detection-via)
8 |
9 |
10 |
11 | Figure 1: We estimate the 6DoF rigid transformation of a 3D face (rendered in silver), aligning it with even the tiniest faces, without face detection or facial landmark localization. Our estimated 3D face locations are rendered by descending distances from the camera, for coherent visualization.
12 |
13 |
14 | ## TL;DR
15 | This repository provides a novel method for six degrees of fredoom (6DoF) detection on multiple faces without the need of prior face detection. After prediction, one can visualize the detections (as show in the figure above), customize projected bounding boxes, or crop and align each face for further processing. See details below.
16 |
17 | ## Table of contents
18 |
19 |
20 | - [Paper details](#paper-details)
21 | * [Abstract](#abstract)
22 | * [Video Spotlight](#video-spotlight)
23 | * [Citation](#citation)
24 | - [Installation](#installation)
25 | - [Training](#training)
26 | * [Prepare WIDER FACE dataset](#prepare-wider-face-dataset)
27 | * [Train](#train)
28 | * [Training on your own dataset](#training-on-your-own-dataset)
29 | - [Testing](#testing)
30 | * [Visualizing trained model](#visualizing-trained-model)
31 | * [WIDER FACE dataset evaluation](#wider-face-dataset-evaluation)
32 | * [AFLW2000-3D dataset evaluation](#aflw2000-3d-dataset-evaluation)
33 | * [BIWI dataset evaluation](#biwi-dataset-evaluation)
34 | * [Testing on your own images](#testing-on-your-own-images)
35 | - [Output customization](#output-customization)
36 | - [Align faces](#align-faces)
37 | - [Resources](#resources)
38 | - [License](#license)
39 |
40 |
41 | ## Paper details
42 |
43 | [Vítor Albiero](https://vitoralbiero.netlify.app), Xingyu Chen, [Xi Yin](https://xiyinmsu.github.io/), Guan Pang, [Tal Hassner](https://talhassner.github.io/home/), "*img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation,*" CVPR, 2021, [arXiv:2012.07791](https://arxiv.org/abs/2012.07791)
44 |
45 |
46 | ### Abstract
47 | > We propose real-time, six degrees of freedom (6DoF), 3D face pose estimation without face detection or landmark localization. We observe that estimating the 6DoF rigid transformation of a face is a simpler problem than facial landmark detection, often used for 3D face alignment. In addition, 6DoF offers more information than face bounding box labels. We leverage these observations to make multiple contributions: (a) We describe an easily trained, efficient, Faster R-CNN--based model which regresses 6DoF pose for all faces in the photo, without preliminary face detection. (b) We explain how pose is converted and kept consistent between the input photo and arbitrary crops created while training and evaluating our model. (c) Finally, we show how face poses can replace detection bounding box training labels. Tests on AFLW2000-3D and BIWI show that our method runs at real-time and outperforms state of the art (SotA) face pose estimators. Remarkably, our method also surpasses SotA models of comparable complexity on the WIDER FACE detection benchmark, despite not been optimized on bounding box labels.
48 |
49 | ### Video Spotlight
50 | [CVPR 2021 Spotlight](https://youtu.be/vDGlvpnzXGo)
51 |
52 | ### Citation
53 | If you use any part of our code or data, please cite our paper.
54 | ```
55 | @inproceedings{albiero2021img2pose,
56 | title={img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation},
57 | author={Albiero, Vítor and Chen, Xingyu and Yin, Xi and Pang, Guan and Hassner, Tal},
58 | booktitle={CVPR},
59 | year={2021},
60 | url={https://arxiv.org/abs/2012.07791},
61 | }
62 | ```
63 |
64 | ## Installation
65 |
66 |
67 | Install dependecies with Python 3.
68 | ```
69 | pip install -r requirements.txt
70 | ```
71 | Install the renderer, which is used to visualize predictions. The renderer implementation is forked from [here](https://github.com/cleardusk/3DDFA_V2/tree/master/Sim3DR).
72 | ```
73 | cd Sim3DR
74 | sh build_sim3dr.sh
75 | ```
76 |
77 | ## Training
78 | ### Prepare WIDER FACE dataset
79 | First, download our annotations as instructed in [Annotations](https://github.com/vitoralbiero/img2pose/wiki/Annotations).
80 |
81 | Download [WIDER FACE](http://shuoyang1213.me/WIDERFACE/) dataset and extract to datasets/WIDER_Face.
82 |
83 | Then, to create the train and validation files (LMDB), run the following scripts.
84 |
85 | ```
86 | python3 convert_json_list_to_lmdb.py \
87 | --json_list ./annotations/WIDER_train_annotations.txt \
88 | --dataset_path ./datasets/WIDER_Face/WIDER_train/images/ \
89 | --dest ./datasets/lmdb/ \
90 | --train
91 | ```
92 | This first script will generate a LMDB dataset, which contains the training images along with annotations. It will also output a pose mean and std deviation files, which will be used for training and testing.
93 | ```
94 | python3 convert_json_list_to_lmdb.py \
95 | --json_list ./annotations/WIDER_val_annotations.txt \
96 | --dataset_path ./datasets/WIDER_Face/WIDER_val/images/ \
97 | --dest ./datasets/lmdb
98 | ```
99 | This second script will create a LMDB containing the validation images along with annotations.
100 |
101 | ### Train
102 | Once the LMDB train/val files are created, to start training simple run the script below.
103 | ```
104 | CUDA_VISIBLE_DEVICES=0 python3 train.py \
105 | --pose_mean ./datasets/lmdb/WIDER_train_annotations_pose_mean.npy \
106 | --pose_stddev ./datasets/lmdb/WIDER_train_annotations_pose_stddev.npy \
107 | --workspace ./workspace/ \
108 | --train_source ./datasets/lmdb/WIDER_train_annotations.lmdb \
109 | --val_source ./datasets/lmdb/WIDER_val_annotations.lmdb \
110 | --prefix trial_1 \
111 | --batch_size 2 \
112 | --lr_plateau \
113 | --early_stop \
114 | --random_flip \
115 | --random_crop \
116 | --max_size 1400
117 | ```
118 | To train with multiple GPUs (in the example below 4 GPUs), use the script below.
119 | ```
120 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env train.py \
121 | --pose_mean ./datasets/lmdb/WIDER_train_annotations_pose_mean.npy \
122 | --pose_stddev ./datasets/lmdb/WIDER_train_annotations_pose_stddev.npy \
123 | --workspace ./workspace/ \
124 | --train_source ./datasets/lmdb/WIDER_train_annotations.lmdb \
125 | --val_source ./datasets/lmdb/WIDER_val_annotations.lmdb \
126 | --prefix trial_1 \
127 | --batch_size 2 \
128 | --lr_plateau \
129 | --early_stop \
130 | --random_flip \
131 | --random_crop \
132 | --max_size 1400 \
133 | --distributed
134 | ```
135 |
136 | ### Training on your own dataset
137 | If your dataset has facial landmarks and bounding boxes already annotated, store them into JSON files following the same format as in the [WIDER FACE annotations](https://github.com/vitoralbiero/img2pose/wiki/Annotations).
138 |
139 | If not, run the script below to annotate your dataset. You will need a detector and import it inside the script.
140 | ```
141 | python3 utils/annotate_dataset.py
142 | --image_list list_of_images.txt
143 | --output_path ./annotations/dataset_name
144 | ```
145 | After the dataset is annotated, create a list pointing to the JSON files there were saved. Then, follow the steps in [Prepare WIDER FACE dataset](https://github.com/vitoralbiero/img2pose#prepare-wider-face-dataset) replacing the WIDER annotations with your own dataset annotations. Once the LMDB and pose files are created, follow the steps in [Train](https://github.com/vitoralbiero/img2pose#train) replacing the WIDER LMDB and pose files with your dataset own files.
146 |
147 | ## Testing
148 | To evaluate with the pretrained model, download the model from [Model Zoo](https://github.com/vitoralbiero/img2pose/wiki/Model-Zoo), and extract it to the main folder. It will create a folder called models, which contains the model weights and the pose mean and std dev that was used for training.
149 |
150 | If evaluating with own trained model, change the pose mean and standard deviation to the ones trained with.
151 |
152 | ### Visualizing trained model
153 | To visualize a trained model on the WIDER FACE validation set run the notebook [visualize_trained_model_predictions](evaluation/jupyter_notebooks/visualize_trained_model_predictions.ipynb).
154 |
155 | ### WIDER FACE dataset evaluation
156 | If you haven't done already, download the [WIDER FACE](http://shuoyang1213.me/WIDERFACE/) dataset and extract to datasets/WIDER_Face.
157 |
158 | Download the [pre-trained model](https://drive.google.com/file/d/1OvnZ7OUQFg2bAgFADhT7UnCkSaXst10O/view?usp=sharing).
159 |
160 | ```
161 | python3 evaluation/evaluate_wider.py \
162 | --dataset_path datasets/WIDER_Face/WIDER_val/images/ \
163 | --dataset_list datasets/WIDER_Face/wider_face_split/wider_face_val_bbx_gt.txt \
164 | --pose_mean models/WIDER_train_pose_mean_v1.npy \
165 | --pose_stddev models/WIDER_train_pose_stddev_v1.npy \
166 | --pretrained_path models/img2pose_v1.pth \
167 | --output_path results/WIDER_FACE/Val/
168 | ```
169 |
170 | To check mAP and plot curves, download the [eval tools](http://shuoyang1213.me/WIDERFACE/) and point to results/WIDER_FACE/Val.
171 |
172 | ### AFLW2000-3D dataset evaluation
173 | Download the [AFLW2000-3D](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/Database/AFLW2000-3D.zip) dataset and unzip to datasets/AFLW2000.
174 |
175 | Download the [fine-tuned model](https://drive.google.com/file/d/1wSqPr9h1x_TOaxuN-Nu3OlTmhqnuf6rZ/view?usp=sharing).
176 |
177 | Run the notebook [aflw_2000_3d_evaluation](./evaluation/jupyter_notebooks/aflw_2000_3d_evaluation.ipynb).
178 |
179 | ### BIWI dataset evaluation
180 | Download the [BIWI](http://data.vision.ee.ethz.ch/cvl/gfanelli/kinect_head_pose_db.tgz) dataset and unzip to datasets/BIWI.
181 |
182 | Download the [fine-tuned model](https://drive.google.com/file/d/1wSqPr9h1x_TOaxuN-Nu3OlTmhqnuf6rZ/view?usp=sharing).
183 |
184 | Run the notebook [biwi_evaluation](./evaluation/jupyter_notebooks/biwi_evaluation.ipynb).
185 |
186 | ### Testing on your own images
187 |
188 | Run the notebook [test_own_images](./evaluation/jupyter_notebooks/test_own_images.ipynb).
189 |
190 | ## Output customization
191 |
192 | For every face detected, the model outputs by default:
193 | - Pose: rx, ry, rz, tx, ty, tz
194 | - Projected bounding boxes: left, top, right, bottom
195 | - Face scores: 0 to 1
196 |
197 | Since the projected bounding box without expansion ends at the start of the forehead, we provide a way of expanding the forehead invidually, along with default x and y expansion.
198 |
199 | To customize the size of the projected bounding boxes, when creating the model change any of the bounding box expansion variables as shown below (a complete example can be seen at [visualize_trained_model_predictions](evaluation/jupyter_notebooks/visualize_trained_model_predictions.ipynb)).
200 | ```python
201 | # how much to expand in width
202 | bbox_x_factor = 1.1
203 | # how much to expand in height
204 | bbox_y_factor = 1.1
205 | # how much to expand in the forehead
206 | expand_forehead = 0.3
207 |
208 | img2pose_model = img2poseModel(
209 | ...,
210 | bbox_x_factor=bbox_x_factor,
211 | bbox_y_factor=bbox_y_factor,
212 | expand_forehead=expand_forehead,
213 | )
214 | ```
215 |
216 | ## Align faces
217 | To detect and align faces, simply run the command below, passing the path to the images you want to detect and align and the path to save them.
218 | ```
219 | python3 run_face_alignment.py \
220 | --pose_mean models/WIDER_train_pose_mean_v1.npy \
221 | --pose_stddev models/WIDER_train_pose_stddev_v1.npy \
222 | --pretrained_path models/img2pose_v1.pth \
223 | --images_path image_path_or_list \
224 | --output_path path_to_save_aligned_faces
225 | ```
226 |
227 | ## Resources
228 | [Model Zoo](https://github.com/vitoralbiero/img2pose/wiki/Model-Zoo)
229 |
230 | [Annotations](https://github.com/vitoralbiero/img2pose/wiki/Annotations)
231 |
232 | [Data Zoo](https://github.com/vitoralbiero/img2pose/wiki/Data-Zoo)
233 |
234 | ## License
235 | Check [license](./license.md) for license details.
236 |
--------------------------------------------------------------------------------
/Sim3DR/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | cmake-build-debug/
3 | .idea/
4 | build/
5 | *.so
6 | data/
7 |
8 | lib/rasterize.cpp
--------------------------------------------------------------------------------
/Sim3DR/Sim3DR.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from . import _init_paths
4 | import numpy as np
5 | import Sim3DR_Cython
6 |
7 |
8 | def get_normal(vertices, triangles):
9 | normal = np.zeros_like(vertices, dtype=np.float32)
10 | Sim3DR_Cython.get_normal(normal, vertices, triangles, vertices.shape[0], triangles.shape[0])
11 | return normal
12 |
13 |
14 | def rasterize(vertices, triangles, colors, bg=None,
15 | height=None, width=None, channel=None,
16 | reverse=False):
17 | if bg is not None:
18 | height, width, channel = bg.shape
19 | else:
20 | assert height is not None and width is not None and channel is not None
21 | bg = np.zeros((height, width, channel), dtype=np.uint8)
22 |
23 | buffer = np.zeros((height, width), dtype=np.float32) - 1e8
24 |
25 | if colors.dtype != np.float32:
26 | colors = colors.astype(np.float32)
27 | Sim3DR_Cython.rasterize(bg, vertices, triangles, colors, buffer, triangles.shape[0], height, width, channel,
28 | reverse=reverse)
29 | return bg
30 |
--------------------------------------------------------------------------------
/Sim3DR/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from .Sim3DR import get_normal, rasterize
4 | from .lighting import RenderPipeline
5 |
--------------------------------------------------------------------------------
/Sim3DR/_init_paths.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import os.path as osp
4 | import sys
5 |
6 |
7 | def add_path(path):
8 | if path not in sys.path:
9 | sys.path.insert(0, path)
10 |
11 |
12 | this_dir = osp.dirname(__file__)
13 | lib_path = osp.join(this_dir, '.')
14 | add_path(lib_path)
15 |
--------------------------------------------------------------------------------
/Sim3DR/build_sim3dr.sh:
--------------------------------------------------------------------------------
1 | python3 setup.py build_ext --inplace
--------------------------------------------------------------------------------
/Sim3DR/lib/rasterize.h:
--------------------------------------------------------------------------------
1 | #ifndef MESH_CORE_HPP_
2 | #define MESH_CORE_HPP_
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | using namespace std;
12 |
13 | class Point3D {
14 | public:
15 | float x;
16 | float y;
17 | float z;
18 |
19 | public:
20 | Point3D() : x(0.f), y(0.f), z(0.f) {}
21 | Point3D(float x_, float y_, float z_) : x(x_), y(y_), z(z_) {}
22 |
23 | void initialize(float x_, float y_, float z_){
24 | this->x = x_; this->y = y_; this->z = z_;
25 | }
26 |
27 | Point3D cross(Point3D &p){
28 | Point3D c;
29 | c.x = this->y * p.z - this->z * p.y;
30 | c.y = this->z * p.x - this->x * p.z;
31 | c.z = this->x * p.y - this->y * p.x;
32 | return c;
33 | }
34 |
35 | float dot(Point3D &p) {
36 | return this->x * p.x + this->y * p.y + this->z * p.z;
37 | }
38 |
39 | Point3D operator-(const Point3D &p) {
40 | Point3D np;
41 | np.x = this->x - p.x;
42 | np.y = this->y - p.y;
43 | np.z = this->z - p.z;
44 | return np;
45 | }
46 |
47 | };
48 |
49 | class Point {
50 | public:
51 | float x;
52 | float y;
53 |
54 | public:
55 | Point() : x(0.f), y(0.f) {}
56 | Point(float x_, float y_) : x(x_), y(y_) {}
57 | float dot(Point p) {
58 | return this->x * p.x + this->y * p.y;
59 | }
60 |
61 | Point operator-(const Point &p) {
62 | Point np;
63 | np.x = this->x - p.x;
64 | np.y = this->y - p.y;
65 | return np;
66 | }
67 |
68 | Point operator+(const Point &p) {
69 | Point np;
70 | np.x = this->x + p.x;
71 | np.y = this->y + p.y;
72 | return np;
73 | }
74 |
75 | Point operator*(float s) {
76 | Point np;
77 | np.x = s * this->x;
78 | np.y = s * this->y;
79 | return np;
80 | }
81 | };
82 |
83 |
84 | bool is_point_in_tri(Point p, Point p0, Point p1, Point p2);
85 |
86 | void get_point_weight(float *weight, Point p, Point p0, Point p1, Point p2);
87 |
88 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int ntri, bool norm_flg);
89 |
90 | void _get_ver_normal(float *ver_normal, float *tri_normal, int *triangles, int nver, int ntri);
91 |
92 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri);
93 |
94 | void _rasterize_triangles(
95 | float *vertices, int *triangles, float *depth_buffer, int *triangle_buffer, float *barycentric_weight,
96 | int ntri, int h, int w);
97 |
98 | void _rasterize(
99 | unsigned char *image, float *vertices, int *triangles, float *colors,
100 | float *depth_buffer, int ntri, int h, int w, int c, float alpha, bool reverse);
101 |
102 | void _render_texture_core(
103 | float *image, float *vertices, int *triangles,
104 | float *texture, float *tex_coords, int *tex_triangles,
105 | float *depth_buffer,
106 | int nver, int tex_nver, int ntri,
107 | int h, int w, int c,
108 | int tex_h, int tex_w, int tex_c,
109 | int mapping_type);
110 |
111 | void _write_obj_with_colors_texture(string filename, string mtl_name,
112 | float *vertices, int *triangles, float *colors, float *uv_coords,
113 | int nver, int ntri, int ntexver);
114 |
115 | #endif
--------------------------------------------------------------------------------
/Sim3DR/lib/rasterize.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | cimport numpy as np
3 | # from libcpp.string cimport string
4 | cimport cython
5 | from libcpp cimport bool
6 |
7 | # from cpython import bool
8 |
9 | # use the Numpy-C-API from Cython
10 | np.import_array()
11 |
12 | # cdefine the signature of our c function
13 | cdef extern from "rasterize.h":
14 | void _rasterize_triangles(
15 | float*vertices, int*triangles, float*depth_buffer, int*triangle_buffer, float*barycentric_weight,
16 | int ntri, int h, int w
17 | )
18 |
19 | void _rasterize(
20 | unsigned char*image, float*vertices, int*triangles, float*colors, float*depth_buffer,
21 | int ntri, int h, int w, int c, float alpha, bool reverse
22 | )
23 |
24 | # void _render_texture_core(
25 | # float* image, float* vertices, int* triangles,
26 | # float* texture, float* tex_coords, int* tex_triangles,
27 | # float* depth_buffer,
28 | # int nver, int tex_nver, int ntri,
29 | # int h, int w, int c,
30 | # int tex_h, int tex_w, int tex_c,
31 | # int mapping_type)
32 |
33 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int nver, bool norm_flg)
34 | void _get_ver_normal(float *ver_normal, float*tri_normal, int*triangles, int nver, int ntri)
35 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri)
36 |
37 |
38 | # void _write_obj_with_colors_texture(string filename, string mtl_name,
39 | # float* vertices, int* triangles, float* colors, float* uv_coords,
40 | # int nver, int ntri, int ntexver)
41 |
42 | @cython.boundscheck(False)
43 | @cython.wraparound(False)
44 | def get_tri_normal(np.ndarray[float, ndim=2, mode="c"] tri_normal not None,
45 | np.ndarray[float, ndim=2, mode = "c"] vertices not None,
46 | np.ndarray[int, ndim=2, mode="c"] triangles not None,
47 | int ntri, bool norm_flg = False):
48 | _get_tri_normal( np.PyArray_DATA(tri_normal), np.PyArray_DATA(vertices),
49 | np.PyArray_DATA(triangles), ntri, norm_flg)
50 |
51 | @cython.boundscheck(False) # turn off bounds-checking for entire function
52 | @cython.wraparound(False) # turn off negative index wrapping for entire function
53 | def get_ver_normal(np.ndarray[float, ndim=2, mode = "c"] ver_normal not None,
54 | np.ndarray[float, ndim=2, mode = "c"] tri_normal not None,
55 | np.ndarray[int, ndim=2, mode="c"] triangles not None,
56 | int nver, int ntri):
57 | _get_ver_normal(
58 | np.PyArray_DATA(ver_normal), np.PyArray_DATA(tri_normal), np.PyArray_DATA(triangles),
59 | nver, ntri)
60 |
61 | @cython.boundscheck(False) # turn off bounds-checking for entire function
62 | @cython.wraparound(False) # turn off negative index wrapping for entire function
63 | def get_normal(np.ndarray[float, ndim=2, mode = "c"] ver_normal not None,
64 | np.ndarray[float, ndim=2, mode = "c"] vertices not None,
65 | np.ndarray[int, ndim=2, mode="c"] triangles not None,
66 | int nver, int ntri):
67 | _get_normal(
68 | np.PyArray_DATA(ver_normal), np.PyArray_DATA(vertices), np.PyArray_DATA(triangles),
69 | nver, ntri)
70 |
71 |
72 | @cython.boundscheck(False) # turn off bounds-checking for entire function
73 | @cython.wraparound(False) # turn off negative index wrapping for entire function
74 | def rasterize_triangles(
75 | np.ndarray[float, ndim=2, mode = "c"] vertices not None,
76 | np.ndarray[int, ndim=2, mode="c"] triangles not None,
77 | np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None,
78 | np.ndarray[int, ndim=2, mode = "c"] triangle_buffer not None,
79 | np.ndarray[float, ndim=2, mode = "c"] barycentric_weight not None,
80 | int ntri, int h, int w
81 | ):
82 | _rasterize_triangles(
83 | np.PyArray_DATA(vertices), np.PyArray_DATA(triangles),
84 | np.PyArray_DATA(depth_buffer), np.PyArray_DATA(triangle_buffer),
85 | np.PyArray_DATA(barycentric_weight),
86 | ntri, h, w)
87 |
88 | @cython.boundscheck(False) # turn off bounds-checking for entire function
89 | @cython.wraparound(False) # turn off negative index wrapping for entire function
90 | def rasterize(np.ndarray[unsigned char, ndim=3, mode = "c"] image not None,
91 | np.ndarray[float, ndim=2, mode = "c"] vertices not None,
92 | np.ndarray[int, ndim=2, mode="c"] triangles not None,
93 | np.ndarray[float, ndim=2, mode = "c"] colors not None,
94 | np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None,
95 | int ntri, int h, int w, int c, float alpha = 1, bool reverse = False
96 | ):
97 | _rasterize(
98 | np.PyArray_DATA(image), np.PyArray_DATA(vertices),
99 | np.PyArray_DATA(triangles),
100 | np.PyArray_DATA(colors),
101 | np.PyArray_DATA(depth_buffer),
102 | ntri, h, w, c, alpha, reverse)
103 |
104 | # def render_texture_core(np.ndarray[float, ndim=3, mode = "c"] image not None,
105 | # np.ndarray[float, ndim=2, mode = "c"] vertices not None,
106 | # np.ndarray[int, ndim=2, mode="c"] triangles not None,
107 | # np.ndarray[float, ndim=3, mode = "c"] texture not None,
108 | # np.ndarray[float, ndim=2, mode = "c"] tex_coords not None,
109 | # np.ndarray[int, ndim=2, mode="c"] tex_triangles not None,
110 | # np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None,
111 | # int nver, int tex_nver, int ntri,
112 | # int h, int w, int c,
113 | # int tex_h, int tex_w, int tex_c,
114 | # int mapping_type
115 | # ):
116 | # _render_texture_core(
117 | # np.PyArray_DATA(image), np.PyArray_DATA(vertices), np.PyArray_DATA(triangles),
118 | # np.PyArray_DATA(texture), np.PyArray_DATA(tex_coords), np.PyArray_DATA(tex_triangles),
119 | # np.PyArray_DATA(depth_buffer),
120 | # nver, tex_nver, ntri,
121 | # h, w, c,
122 | # tex_h, tex_w, tex_c,
123 | # mapping_type)
124 | #
125 | # def write_obj_with_colors_texture_core(string filename, string mtl_name,
126 | # np.ndarray[float, ndim=2, mode = "c"] vertices not None,
127 | # np.ndarray[int, ndim=2, mode="c"] triangles not None,
128 | # np.ndarray[float, ndim=2, mode = "c"] colors not None,
129 | # np.ndarray[float, ndim=2, mode = "c"] uv_coords not None,
130 | # int nver, int ntri, int ntexver
131 | # ):
132 | # _write_obj_with_colors_texture(filename, mtl_name,
133 | # np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), np.PyArray_DATA(colors), np.PyArray_DATA(uv_coords),
134 | # nver, ntri, ntexver)
135 |
--------------------------------------------------------------------------------
/Sim3DR/lighting.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import numpy as np
4 | from .Sim3DR import get_normal, rasterize
5 |
6 | _norm = lambda arr: arr / np.sqrt(np.sum(arr ** 2, axis=1))[:, None]
7 |
8 |
9 | def norm_vertices(vertices):
10 | vertices -= vertices.min(0)[None, :]
11 | vertices /= vertices.max()
12 | vertices *= 2
13 | vertices -= vertices.max(0)[None, :] / 2
14 | return vertices
15 |
16 |
17 | def convert_type(obj):
18 | if isinstance(obj, tuple) or isinstance(obj, list):
19 | return np.array(obj, dtype=np.float32)[None, :]
20 | return obj
21 |
22 |
23 | class RenderPipeline(object):
24 | def __init__(self, **kwargs):
25 | self.intensity_ambient = convert_type(kwargs.get('intensity_ambient', 0.3))
26 | self.intensity_directional = convert_type(kwargs.get('intensity_directional', 0.6))
27 | self.intensity_specular = convert_type(kwargs.get('intensity_specular', 0.1))
28 | self.specular_exp = kwargs.get('specular_exp', 5)
29 | self.color_ambient = convert_type(kwargs.get('color_ambient', (1, 1, 1)))
30 | self.color_directional = convert_type(kwargs.get('color_directional', (1, 1, 1)))
31 | self.light_pos = convert_type(kwargs.get('light_pos', (0, 0, 5)))
32 | self.view_pos = convert_type(kwargs.get('view_pos', (0, 0, 5)))
33 |
34 | def update_light_pos(self, light_pos):
35 | self.light_pos = convert_type(light_pos)
36 |
37 | def __call__(self, vertices, triangles, bg, texture=None):
38 | normal = get_normal(vertices, triangles)
39 |
40 | # 2. lighting
41 | light = np.zeros_like(vertices, dtype=np.float32)
42 | # ambient component
43 | if self.intensity_ambient > 0:
44 | light += self.intensity_ambient * self.color_ambient
45 |
46 | vertices_n = norm_vertices(vertices.copy())
47 | if self.intensity_directional > 0:
48 | # diffuse component
49 | direction = _norm(self.light_pos - vertices_n)
50 | cos = np.sum(normal * direction, axis=1)[:, None]
51 | # cos = np.clip(cos, 0, 1)
52 | # todo: check below
53 | light += self.intensity_directional * (self.color_directional * np.clip(cos, 0, 1))
54 |
55 | # specular component
56 | if self.intensity_specular > 0:
57 | v2v = _norm(self.view_pos - vertices_n)
58 | reflection = 2 * cos * normal - direction
59 | spe = np.sum((v2v * reflection) ** self.specular_exp, axis=1)[:, None]
60 | spe = np.where(cos != 0, np.clip(spe, 0, 1), np.zeros_like(spe))
61 | light += self.intensity_specular * self.color_directional * np.clip(spe, 0, 1)
62 | light = np.clip(light, 0, 1)
63 |
64 | # 2. rasterization, [0, 1]
65 | if texture is None:
66 | render_img = rasterize(vertices, triangles, light, bg=bg)
67 | return render_img
68 | else:
69 | texture *= light
70 | render_img = rasterize(vertices, triangles, texture, bg=bg)
71 | return render_img
72 |
73 |
74 | def main():
75 | pass
76 |
77 |
78 | if __name__ == '__main__':
79 | main()
80 |
--------------------------------------------------------------------------------
/Sim3DR/readme.md:
--------------------------------------------------------------------------------
1 | ## Forked from https://github.com/cleardusk/3DDFA_V2/tree/master/Sim3DR
2 |
3 | ## Sim3DR
4 | This is a simple 3D render, written by c++ and cython.
5 |
6 | ### Build Sim3DR
7 |
8 | ```shell script
9 | python3 setup.py build_ext --inplace
10 | ```
11 |
12 |
--------------------------------------------------------------------------------
/Sim3DR/setup.py:
--------------------------------------------------------------------------------
1 | '''
2 | python setup.py build_ext -i
3 | to compile
4 | '''
5 |
6 | from distutils.core import setup, Extension
7 | from Cython.Build import cythonize
8 | from Cython.Distutils import build_ext
9 | import numpy
10 |
11 | setup(
12 | name='Sim3DR_Cython', # not the package name
13 | cmdclass={'build_ext': build_ext},
14 | ext_modules=[Extension("Sim3DR_Cython",
15 | sources=["lib/rasterize.pyx", "lib/rasterize_kernel.cpp"],
16 | language='c++',
17 | include_dirs=[numpy.get_include()],
18 | extra_compile_args=["-std=c++11"])],
19 | )
20 |
--------------------------------------------------------------------------------
/Sim3DR/tests/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 |
--------------------------------------------------------------------------------
/Sim3DR/tests/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8)
2 |
3 | set(TARGET test)
4 | project(${TARGET})
5 |
6 | #find_package( OpenCV REQUIRED )
7 | #include_directories( ${OpenCV_INCLUDE_DIRS} )
8 |
9 | #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -O3")
10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -std=c++11")
11 | add_executable(${TARGET} test.cpp rasterize_kernel.cpp io.cpp)
12 | target_include_directories(${TARGET} PRIVATE ${PROJECT_SOURCE_DIR})
13 |
--------------------------------------------------------------------------------
/Sim3DR/tests/io.cpp:
--------------------------------------------------------------------------------
1 | #include "io.h"
2 |
3 | //void load_obj(const string obj_fp, float* vertices, float* colors, float* triangles){
4 | // string line;
5 | // ifstream in(obj_fp);
6 | //
7 | // if(in.is_open()){
8 | // while (getline(in, line)){
9 | // stringstream ss(line);
10 | //
11 | // char t; // type: v, f
12 | // ss >> t;
13 | // if (t == 'v'){
14 | //
15 | // }
16 | // }
17 | // }
18 | //}
19 |
20 | void load_obj(const char *obj_fp, float *vertices, float *colors, int *triangles, int nver, int ntri) {
21 | FILE *fp;
22 | fp = fopen(obj_fp, "r");
23 |
24 | char t; // type: v or f
25 | if (fp != nullptr) {
26 | for (int i = 0; i < nver; ++i) {
27 | fscanf(fp, "%c", &t);
28 | for (int j = 0; j < 3; ++j)
29 | fscanf(fp, " %f", &vertices[3 * i + j]);
30 | for (int j = 0; j < 3; ++j)
31 | fscanf(fp, " %f", &colors[3 * i + j]);
32 | fscanf(fp, "\n");
33 | }
34 | // fscanf(fp, "%c", &t);
35 | for (int i = 0; i < ntri; ++i) {
36 | fscanf(fp, "%c", &t);
37 | for (int j = 0; j < 3; ++j) {
38 | fscanf(fp, " %d", &triangles[3 * i + j]);
39 | triangles[3 * i + j] -= 1;
40 | }
41 | fscanf(fp, "\n");
42 | }
43 |
44 | fclose(fp);
45 | }
46 | }
47 |
48 | void load_ply(const char *ply_fp, float *vertices, int *triangles, int nver, int ntri) {
49 | FILE *fp;
50 | fp = fopen(ply_fp, "r");
51 |
52 | // char s[256];
53 | char t;
54 | if (fp != nullptr) {
55 | // for (int i = 0; i < 9; ++i)
56 | // fscanf(fp, "%s", s);
57 | for (int i = 0; i < nver; ++i)
58 | fscanf(fp, "%f %f %f\n", &vertices[3 * i], &vertices[3 * i + 1], &vertices[3 * i + 2]);
59 |
60 | for (int i = 0; i < ntri; ++i)
61 | fscanf(fp, "%c %d %d %d\n", &t, &triangles[3 * i], &triangles[3 * i + 1], &triangles[3 * i + 2]);
62 |
63 | fclose(fp);
64 | }
65 | }
66 |
67 | void write_ppm(const char *filename, unsigned char *img, int h, int w, int c) {
68 | FILE *fp;
69 | //open file for output
70 | fp = fopen(filename, "wb");
71 | if (!fp) {
72 | fprintf(stderr, "Unable to open file '%s'\n", filename);
73 | exit(1);
74 | }
75 |
76 | //write the header file
77 | //image format
78 | fprintf(fp, "P6\n");
79 |
80 | //image size
81 | fprintf(fp, "%d %d\n", w, h);
82 |
83 | // rgb component depth
84 | fprintf(fp, "%d\n", MAX_PXL_VALUE);
85 |
86 | // pixel data
87 | fwrite(img, sizeof(unsigned char), size_t(h * w * c), fp);
88 | fclose(fp);
89 | }
--------------------------------------------------------------------------------
/Sim3DR/tests/io.h:
--------------------------------------------------------------------------------
1 | #ifndef IO_H_
2 | #define IO_H_
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | using namespace std;
11 |
12 | #define MAX_PXL_VALUE 255
13 |
14 | void load_obj(const char* obj_fp, float* vertices, float* colors, int* triangles, int nver, int ntri);
15 | void load_ply(const char* ply_fp, float* vertices, int* triangles, int nver, int ntri);
16 |
17 |
18 | void write_ppm(const char *filename, unsigned char *img, int h, int w, int c);
19 |
20 | #endif
--------------------------------------------------------------------------------
/Sim3DR/tests/test.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Tesing cases
3 | */
4 |
5 | #include
6 | #include
7 | #include "rasterize.h"
8 | #include "io.h"
9 |
10 | void test_isPointInTri() {
11 | Point p0(0, 0);
12 | Point p1(1, 0);
13 | Point p2(1, 1);
14 |
15 | Point p(0.2, 0.2);
16 |
17 | if (is_point_in_tri(p, p0, p1, p2))
18 | std::cout << "In";
19 | else
20 | std::cout << "Out";
21 | std::cout << std::endl;
22 | }
23 |
24 | void test_getPointWeight() {
25 | Point p0(0, 0);
26 | Point p1(1, 0);
27 | Point p2(1, 1);
28 |
29 | Point p(0.2, 0.2);
30 |
31 | float weight[3];
32 | get_point_weight(weight, p, p0, p1, p2);
33 | std::cout << weight[0] << " " << weight[1] << " " << weight[2] << std::endl;
34 | }
35 |
36 | void test_get_tri_normal() {
37 | float tri_normal[3];
38 | // float vertices[9] = {1, 0, 0, 0, 0, 0, 0, 1, 0};
39 | float vertices[9] = {1, 1.1, 0, 0, 0, 0, 0, 0.6, 0.7};
40 | int triangles[3] = {0, 1, 2};
41 | int ntri = 1;
42 |
43 | _get_tri_normal(tri_normal, vertices, triangles, ntri);
44 |
45 | for (int i = 0; i < 3; ++i)
46 | std::cout << tri_normal[i] << ", ";
47 | std::cout << std::endl;
48 | }
49 |
50 | void test_load_obj() {
51 | const char *fp = "../data/vd005_mesh.obj";
52 | int nver = 35709;
53 | int ntri = 70789;
54 |
55 | auto *vertices = new float[nver];
56 | auto *colors = new float[nver];
57 | auto *triangles = new int[ntri];
58 | load_obj(fp, vertices, colors, triangles, nver, ntri);
59 |
60 | delete[] vertices;
61 | delete[] colors;
62 | delete[] triangles;
63 | }
64 |
65 | void test_render() {
66 | // 1. loading obj
67 | // const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/vd005_mesh.obj";
68 | const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/face1.obj";
69 | int nver = 35709; //53215; //35709;
70 | int ntri = 70789; //105840;//70789;
71 |
72 | auto *vertices = new float[3 * nver];
73 | auto *colors = new float[3 * nver];
74 | auto *triangles = new int[3 * ntri];
75 | load_obj(fp, vertices, colors, triangles, nver, ntri);
76 |
77 | // 2. rendering
78 | int h = 224, w = 224, c = 3;
79 |
80 | // enlarging
81 | int scale = 4;
82 | h *= scale;
83 | w *= scale;
84 | for (int i = 0; i < nver * 3; ++i) vertices[i] *= scale;
85 |
86 | auto *image = new unsigned char[h * w * c]();
87 | auto *depth_buffer = new float[h * w]();
88 |
89 | for (int i = 0; i < h * w; ++i) depth_buffer[i] = -999999;
90 |
91 | clock_t t;
92 | t = clock();
93 |
94 | _rasterize(image, vertices, triangles, colors, depth_buffer, ntri, h, w, c, true);
95 | t = clock() - t;
96 | double time_taken = ((double) t) / CLOCKS_PER_SEC; // in seconds
97 | printf("Render took %f seconds to execute \n", time_taken);
98 |
99 |
100 | // auto *image_char = new u_char[h * w * c]();
101 | // for (int i = 0; i < h * w * c; ++i)
102 | // image_char[i] = u_char(255 * image[i]);
103 | write_ppm("res.ppm", image, h, w, c);
104 |
105 | // delete[] image_char;
106 | delete[] vertices;
107 | delete[] colors;
108 | delete[] triangles;
109 | delete[] image;
110 | delete[] depth_buffer;
111 | }
112 |
113 | void test_light() {
114 | // 1. loading obj
115 | const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/emma_input_0_noheader.ply";
116 | int nver = 53215; //35709;
117 | int ntri = 105840; //70789;
118 |
119 | auto *vertices = new float[3 * nver];
120 | auto *colors = new float[3 * nver];
121 | auto *triangles = new int[3 * ntri];
122 | load_ply(fp, vertices, triangles, nver, ntri);
123 |
124 | // 2. rendering
125 | // int h = 1901, w = 3913, c = 3;
126 | int h = 2000, w = 4000, c = 3;
127 |
128 | // enlarging
129 | // int scale = 1;
130 | // h *= scale;
131 | // w *= scale;
132 | // for (int i = 0; i < nver * 3; ++i) vertices[i] *= scale;
133 |
134 | auto *image = new unsigned char[h * w * c]();
135 | auto *depth_buffer = new float[h * w]();
136 |
137 | for (int i = 0; i < h * w; ++i) depth_buffer[i] = -999999;
138 | for (int i = 0; i < 3 * nver; ++i) colors[i] = 0.8;
139 |
140 | clock_t t;
141 | t = clock();
142 |
143 | _rasterize(image, vertices, triangles, colors, depth_buffer, ntri, h, w, c, true);
144 | t = clock() - t;
145 | double time_taken = ((double) t) / CLOCKS_PER_SEC; // in seconds
146 | printf("Render took %f seconds to execute \n", time_taken);
147 |
148 |
149 | // auto *image_char = new u_char[h * w * c]();
150 | // for (int i = 0; i < h * w * c; ++i)
151 | // image_char[i] = u_char(255 * image[i]);
152 | write_ppm("emma.ppm", image, h, w, c);
153 |
154 | // delete[] image_char;
155 | delete[] vertices;
156 | delete[] colors;
157 | delete[] triangles;
158 | delete[] image;
159 | delete[] depth_buffer;
160 | }
161 |
162 | int main(int argc, char *argv[]) {
163 | // std::cout << "Hello CMake!" << std::endl;
164 |
165 | // test_isPointInTri();
166 | // test_getPointWeight();
167 | // test_get_tri_normal();
168 | // test_load_obj();
169 | // test_render();
170 | test_light();
171 | return 0;
172 | }
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from easydict import EasyDict
6 | from torch.nn import MSELoss
7 |
8 |
9 | class Config(EasyDict):
10 | def __init__(self, args):
11 | # workspace configuration
12 | self.prefix = args.prefix
13 | self.work_path = os.path.join(args.workspace, self.prefix)
14 | self.model_path = os.path.join(self.work_path, "models")
15 | try:
16 | self.create_path(self.model_path)
17 | except Exception as e:
18 | print(e)
19 |
20 | self.log_path = os.path.join(self.work_path, "log")
21 | try:
22 | self.create_path(self.log_path)
23 | except Exception as e:
24 | print(e)
25 |
26 | self.frequency_log = 20
27 |
28 | # training/validation configuration
29 | self.train_source = args.train_source
30 | self.val_source = args.val_source
31 |
32 | # network and training parameters
33 | self.pose_loss = MSELoss(reduction="sum")
34 | self.pose_mean = np.load(args.pose_mean)
35 | self.pose_stddev = np.load(args.pose_stddev)
36 | self.depth = args.depth
37 | self.lr = args.lr
38 | self.lr_plateau = args.lr_plateau
39 | self.early_stop = args.early_stop
40 | self.batch_size = args.batch_size
41 | self.workers = args.workers
42 | self.epochs = args.epochs
43 | self.min_size = args.min_size
44 | self.max_size = args.max_size
45 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 | self.weight_decay = 5e-4
47 | self.momentum = 0.9
48 | self.pin_memory = True
49 |
50 | # resume from or load pretrained weights
51 | self.pretrained_path = args.pretrained_path
52 | self.resume_path = args.resume_path
53 |
54 | # online augmentation
55 | self.noise_augmentation = args.noise_augmentation
56 | self.contrast_augmentation = args.contrast_augmentation
57 | self.random_flip = args.random_flip
58 | self.random_crop = args.random_crop
59 |
60 | # 3d reference points to compute pose
61 | self.threed_5_points = args.threed_5_points
62 | self.threed_68_points = args.threed_68_points
63 |
64 | # distributed
65 | self.distributed = args.distributed
66 | if not args.distributed:
67 | self.gpu = 0
68 | else:
69 | self.gpu = args.gpu
70 |
71 | self.num_gpus = args.world_size
72 |
73 | def create_path(self, file_path):
74 | if not os.path.exists(file_path):
75 | os.makedirs(file_path)
76 |
--------------------------------------------------------------------------------
/convert_json_list_to_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import lmdb
5 | import msgpack
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from utils.json_loader import JsonLoader
10 |
11 |
12 | def json_list_to_lmdb(args):
13 | cpu_available = os.cpu_count()
14 | if args.num_workers > cpu_available:
15 | args.num_workers = cpu_available
16 |
17 | threed_5_points = np.load(args.threed_5_points)
18 | threed_68_points = np.load(args.threed_68_points)
19 |
20 | print("Loading dataset from %s" % args.json_list)
21 | data_loader = JsonLoader(
22 | args.num_workers,
23 | args.json_list,
24 | threed_5_points,
25 | threed_68_points,
26 | args.dataset_path,
27 | )
28 |
29 | name = f"{os.path.split(args.json_list)[1][:-4]}.lmdb"
30 | lmdb_path = os.path.join(args.dest, name)
31 | isdir = os.path.isdir(lmdb_path)
32 |
33 | if os.path.isfile(lmdb_path):
34 | os.remove(lmdb_path)
35 |
36 | print(f"Generate LMDB to {lmdb_path}")
37 |
38 | size = len(data_loader) * 1200 * 1200 * 3
39 | print(f"LMDB max size: {size}")
40 |
41 | db = lmdb.open(
42 | lmdb_path,
43 | subdir=isdir,
44 | map_size=size * 2,
45 | readonly=False,
46 | meminit=False,
47 | map_async=True,
48 | )
49 |
50 | print(f"Total number of samples: {len(data_loader)}")
51 |
52 | all_pose_labels = []
53 |
54 | txn = db.begin(write=True)
55 |
56 | total_samples = 0
57 |
58 | for idx, data in tqdm(enumerate(data_loader)):
59 | image, global_pose_labels, bboxes, pose_labels, landmarks = data[0]
60 |
61 | if len(bboxes) == 0:
62 | continue
63 |
64 | has_pose = False
65 | for pose_label in pose_labels:
66 | if pose_label[0] != -9:
67 | all_pose_labels.append(pose_label)
68 | has_pose = True
69 |
70 | if not has_pose:
71 | continue
72 |
73 | txn.put(
74 | "{}".format(total_samples).encode("ascii"),
75 | msgpack.dumps((image, global_pose_labels, bboxes, pose_labels, landmarks)),
76 | )
77 | if idx % args.write_frequency == 0:
78 | print(f"[{idx}/{len(data_loader)}]")
79 | txn.commit()
80 | txn = db.begin(write=True)
81 |
82 | total_samples += 1
83 |
84 | print(total_samples)
85 |
86 | txn.commit()
87 | keys = ["{}".format(k).encode("ascii") for k in range(total_samples)]
88 | with db.begin(write=True) as txn:
89 | txn.put(b"__keys__", msgpack.dumps(keys))
90 | txn.put(b"__len__", msgpack.dumps(len(keys)))
91 |
92 | print("Flushing database ...")
93 | db.sync()
94 | db.close()
95 |
96 | if args.train:
97 | print("Saving pose mean and std dev.")
98 | all_pose_labels = np.asarray(all_pose_labels)
99 | pose_mean = np.mean(all_pose_labels, axis=0)
100 | pose_stddev = np.std(all_pose_labels, axis=0)
101 |
102 | save_file_path = os.path.join(args.dest, os.path.split(args.json_list)[1][:-4])
103 | np.save(f"{save_file_path}_pose_mean.npy", pose_mean)
104 | np.save(f"{save_file_path}_pose_stddev.npy", pose_stddev)
105 |
106 |
107 | def parse_args():
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument(
110 | "--json_list",
111 | type=str,
112 | required=True,
113 | help="List of json files that contain frames annotations",
114 | )
115 | parser.add_argument(
116 | "--dataset_path",
117 | type=str,
118 | help="Path to the dataset images",
119 | )
120 | parser.add_argument("--num_workers", default=16, type=int)
121 | parser.add_argument(
122 | "--write_frequency", help="Frequency to save to file.", type=int, default=5000
123 | )
124 | parser.add_argument(
125 | "--dest", type=str, required=True, help="Path to save the lmdb file."
126 | )
127 | parser.add_argument(
128 | "--train", action="store_true", help="Dataset will be used for training."
129 | )
130 | parser.add_argument(
131 | "--threed_5_points",
132 | type=str,
133 | help="Reference 3D points to compute pose.",
134 | default="./pose_references/reference_3d_5_points_trans.npy",
135 | )
136 |
137 | parser.add_argument(
138 | "--threed_68_points",
139 | type=str,
140 | help="Reference 3D points to compute pose.",
141 | default="./pose_references/reference_3d_68_points_trans.npy",
142 | )
143 |
144 | args = parser.parse_args()
145 |
146 | if not os.path.exists(args.dest):
147 | os.makedirs(args.dest)
148 |
149 | return args
150 |
151 |
152 | if __name__ == "__main__":
153 | args = parse_args()
154 |
155 | json_list_to_lmdb(args)
156 |
--------------------------------------------------------------------------------
/data_loader_lmdb.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import lmdb
4 | import msgpack
5 | import numpy as np
6 | import six
7 | import torch
8 | from PIL import Image
9 | from torch.utils.data import BatchSampler, DataLoader, Dataset
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torchvision import transforms
12 |
13 | import utils.augmentation as augmentation
14 | from utils.image_operations import expand_bbox_rectangle
15 | from utils.pose_operations import plot_3d_landmark, pose_bbox_to_full_image
16 |
17 |
18 | class LMDB(Dataset):
19 | def __init__(
20 | self,
21 | config,
22 | db_path,
23 | transform=None,
24 | pose_label_transform=None,
25 | augmentation_methods=None,
26 | ):
27 | self.config = config
28 | self.env = lmdb.open(
29 | db_path,
30 | subdir=path.isdir(db_path),
31 | readonly=True,
32 | lock=False,
33 | readahead=False,
34 | meminit=False,
35 | )
36 |
37 | with self.env.begin(write=False) as txn:
38 | self.length = msgpack.loads(txn.get(b"__len__"))
39 | self.keys = msgpack.loads(txn.get(b"__keys__"))
40 |
41 | self.transform = transform
42 | self.pose_label_transform = pose_label_transform
43 | self.augmentation_methods = augmentation_methods
44 | self.threed_68_points = np.load(self.config.threed_68_points)
45 |
46 | def __getitem__(self, index):
47 | img, target = None, None
48 | env = self.env
49 | with env.begin(write=False) as txn:
50 | byteflow = txn.get(self.keys[index])
51 | data = msgpack.loads(byteflow)
52 |
53 | # load image
54 | imgbuf = data[0]
55 | buf = six.BytesIO()
56 | buf.write(imgbuf)
57 | buf.seek(0)
58 | img = Image.open(buf).convert("RGB")
59 |
60 | # load local pose label
61 | pose_labels = np.asarray(data[3])
62 |
63 | # load bbox
64 | bbox_labels = np.asarray(data[2])
65 |
66 | # load landmarks label
67 | landmark_labels = data[4]
68 |
69 | # apply augmentations that are provided from the parent class
70 | for augmentation_method in self.augmentation_methods:
71 | img, _, _ = augmentation_method(img, None, None)
72 |
73 | # create global intrinsics
74 | (w, h) = img.size
75 | global_intrinsics = np.array(
76 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
77 | )
78 |
79 | img = np.array(img)
80 | projected_bbox_labels = []
81 | new_pose_labels = []
82 | # get projected bboxes
83 | for i in range(len(pose_labels)):
84 | pose_label = pose_labels[i]
85 | bbox = bbox_labels[i]
86 |
87 | lms = np.asarray(landmark_labels[i])
88 |
89 | # black out faces that do not have pose annotation
90 | if -1 in lms:
91 | img[int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2]), :] = 0
92 | continue
93 |
94 | # convert to global image
95 | pose_label = pose_bbox_to_full_image(pose_label, global_intrinsics, bbox)
96 |
97 | # project points and get bbox
98 | projected_lms, _ = plot_3d_landmark(
99 | self.threed_68_points, pose_label, global_intrinsics
100 | )
101 | projected_bbox = expand_bbox_rectangle(
102 | w, h, 1.1, 1.1, projected_lms, roll=pose_label[2]
103 | )
104 |
105 | projected_bbox_labels.append(projected_bbox)
106 |
107 | new_pose_labels.append(pose_label)
108 |
109 | img = Image.fromarray(img)
110 |
111 | if self.transform is not None:
112 | img = self.transform(img)
113 |
114 | target = {
115 | "dofs": torch.from_numpy(np.asarray(new_pose_labels)).float(),
116 | "boxes": torch.from_numpy(np.asarray(projected_bbox_labels)).float(),
117 | "labels": torch.ones((len(projected_bbox_labels),), dtype=torch.int64),
118 | }
119 |
120 | return img, target
121 |
122 | def __len__(self):
123 | return self.length
124 |
125 |
126 | class LMDBDataLoader(DataLoader):
127 | def __init__(self, config, lmdb_path, train=True):
128 | self.config = config
129 |
130 | transform = transforms.Compose([transforms.ToTensor()])
131 |
132 | augmentation_methods = []
133 |
134 | if train:
135 | if self.config.noise_augmentation:
136 | augmentation_methods.append(augmentation.add_noise)
137 |
138 | if self.config.contrast_augmentation:
139 | augmentation_methods.append(augmentation.change_contrast)
140 |
141 | if self.config.pose_mean is not None:
142 | pose_label_transform = self.normalize_pose_labels
143 | else:
144 | pose_label_transform = None
145 |
146 | self._dataset = LMDB(
147 | config, lmdb_path, transform, pose_label_transform, augmentation_methods
148 | )
149 |
150 | if config.distributed:
151 | self._sampler = DistributedSampler(self._dataset, shuffle=False)
152 |
153 | if train:
154 | self._sampler = BatchSampler(
155 | self._sampler, config.batch_size, drop_last=True
156 | )
157 |
158 | super(LMDBDataLoader, self).__init__(
159 | self._dataset,
160 | batch_sampler=self._sampler,
161 | pin_memory=config.pin_memory,
162 | num_workers=config.workers,
163 | collate_fn=collate_fn,
164 | )
165 | else:
166 | super(LMDBDataLoader, self).__init__(
167 | self._dataset,
168 | config.batch_size,
169 | drop_last=False,
170 | sampler=self._sampler,
171 | pin_memory=config.pin_memory,
172 | num_workers=config.workers,
173 | collate_fn=collate_fn,
174 | )
175 |
176 | else:
177 | super(LMDBDataLoader, self).__init__(
178 | self._dataset,
179 | batch_size=config.batch_size,
180 | shuffle=train,
181 | pin_memory=config.pin_memory,
182 | num_workers=config.workers,
183 | drop_last=True,
184 | collate_fn=collate_fn,
185 | )
186 |
187 | def normalize_pose_labels(self, pose_labels):
188 | for i in range(len(pose_labels)):
189 | pose_labels[i] = (
190 | pose_labels[i] - self.config.pose_mean
191 | ) / self.config.pose_stddev
192 |
193 | return pose_labels
194 |
195 |
196 | def collate_fn(batch):
197 | return tuple(zip(*batch))
198 |
--------------------------------------------------------------------------------
/data_loader_lmdb_augmenter.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import lmdb
4 | import msgpack
5 | import numpy as np
6 | import six
7 | import torch
8 | from PIL import Image
9 | from torch.utils.data import BatchSampler, DataLoader, Dataset
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torchvision import transforms
12 |
13 | import utils.augmentation as augmentation
14 | from utils.image_operations import expand_bbox_rectangle
15 | from utils.pose_operations import (get_pose, plot_3d_landmark,
16 | pose_bbox_to_full_image)
17 |
18 |
19 | class LMDB(Dataset):
20 | def __init__(
21 | self,
22 | config,
23 | db_path,
24 | transform=None,
25 | pose_label_transform=None,
26 | augmentation_methods=None,
27 | ):
28 | self.config = config
29 |
30 | self.env = lmdb.open(
31 | db_path,
32 | subdir=path.isdir(db_path),
33 | readonly=True,
34 | lock=False,
35 | readahead=False,
36 | meminit=False,
37 | )
38 |
39 | with self.env.begin(write=False) as txn:
40 | self.length = msgpack.loads(txn.get(b"__len__"))
41 | self.keys = msgpack.loads(txn.get(b"__keys__"))
42 |
43 | self.transform = transform
44 | self.pose_label_transform = pose_label_transform
45 | self.augmentation_methods = augmentation_methods
46 |
47 | self.threed_5_points = np.load(self.config.threed_5_points)
48 | self.threed_68_points = np.load(self.config.threed_68_points)
49 |
50 | def __getitem__(self, index):
51 | img, target = None, None
52 | env = self.env
53 | with env.begin(write=False) as txn:
54 | byteflow = txn.get(self.keys[index])
55 | data = msgpack.loads(byteflow)
56 |
57 | # load image
58 | imgbuf = data[0]
59 | buf = six.BytesIO()
60 | buf.write(imgbuf)
61 | buf.seek(0)
62 | img = Image.open(buf).convert("RGB")
63 |
64 | # load landmarks label
65 | landmark_labels = data[4]
66 |
67 | # load bbox
68 | bbox_labels = np.asarray(data[2])
69 |
70 | # apply augmentations that are provided from the parent class in creation order
71 | for augmentation_method in self.augmentation_methods:
72 | img, bbox_labels, landmark_labels = augmentation_method(
73 | img, bbox_labels, landmark_labels
74 | )
75 |
76 | # create global intrinsics
77 | (img_w, img_h) = img.size
78 | global_intrinsics = np.array(
79 | [[img_w + img_h, 0, img_w // 2], [0, img_w + img_h, img_h // 2], [0, 0, 1]]
80 | )
81 |
82 | projected_bbox_labels = []
83 | pose_labels = []
84 |
85 | img = np.array(img)
86 |
87 | # get pose labels
88 | for i in range(len(bbox_labels)):
89 | bbox = bbox_labels[i]
90 | lms = np.asarray(landmark_labels[i])
91 |
92 | # black out faces that do not have pose annotation
93 | if -1 in lms:
94 | img[int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2]), :] = 0
95 | continue
96 |
97 | # convert landmarks to bbox
98 | bbox_lms = lms.copy()
99 | bbox_lms[:, 0] -= bbox[0]
100 | bbox_lms[:, 1] -= bbox[1]
101 |
102 | # create bbox intrinsincs
103 | w = int(bbox[2] - bbox[0])
104 | h = int(bbox[3] - bbox[1])
105 |
106 | bbox_intrinsics = np.array(
107 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
108 | )
109 |
110 | # get pose between gt points and 3D reference
111 | if len(bbox_lms) == 5:
112 | P, pose = get_pose(self.threed_5_points, bbox_lms, bbox_intrinsics)
113 | else:
114 | P, pose = get_pose(self.threed_68_points, bbox_lms, bbox_intrinsics)
115 |
116 | # convert to global image
117 | pose_label = pose_bbox_to_full_image(pose, global_intrinsics, bbox)
118 |
119 | # project points and get bbox
120 | projected_lms, _ = plot_3d_landmark(
121 | self.threed_68_points, pose_label, global_intrinsics
122 | )
123 | projected_bbox = expand_bbox_rectangle(
124 | img_w, img_h, 1.1, 1.1, projected_lms, roll=pose_label[2]
125 | )
126 |
127 | pose_labels.append(pose_label)
128 | projected_bbox_labels.append(projected_bbox)
129 |
130 | pose_labels = np.asarray(pose_labels)
131 |
132 | img = Image.fromarray(img)
133 |
134 | if self.transform is not None:
135 | img = self.transform(img)
136 |
137 | target = {
138 | "dofs": torch.from_numpy(pose_labels).float(),
139 | "boxes": torch.from_numpy(np.asarray(projected_bbox_labels)).float(),
140 | "labels": torch.ones((len(bbox_labels),), dtype=torch.int64),
141 | }
142 |
143 | return img, target
144 |
145 | def __len__(self):
146 | return self.length
147 |
148 |
149 | class LMDBDataLoaderAugmenter(DataLoader):
150 | def __init__(self, config, lmdb_path, train=True):
151 | self.config = config
152 |
153 | transform = transforms.Compose([transforms.ToTensor()])
154 |
155 | augmentation_methods = []
156 |
157 | if train:
158 | if self.config.random_flip:
159 | augmentation_methods.append(augmentation.random_flip)
160 |
161 | if self.config.random_crop:
162 | augmentation_methods.append(augmentation.random_crop)
163 |
164 | if self.config.noise_augmentation:
165 | augmentation_methods.append(augmentation.add_noise)
166 |
167 | if self.config.contrast_augmentation:
168 | augmentation_methods.append(augmentation.change_contrast)
169 |
170 | if self.config.pose_mean is not None:
171 | pose_label_transform = self.normalize_pose_labels
172 | else:
173 | pose_label_transform = None
174 |
175 | self._dataset = LMDB(
176 | config,
177 | lmdb_path,
178 | transform,
179 | pose_label_transform,
180 | augmentation_methods,
181 | )
182 |
183 | if config.distributed:
184 | self._sampler = DistributedSampler(self._dataset, shuffle=False)
185 |
186 | if train:
187 | self._sampler = BatchSampler(
188 | self._sampler, config.batch_size, drop_last=True
189 | )
190 |
191 | super(LMDBDataLoaderAugmenter, self).__init__(
192 | self._dataset,
193 | batch_sampler=self._sampler,
194 | pin_memory=config.pin_memory,
195 | num_workers=config.workers,
196 | collate_fn=collate_fn,
197 | )
198 | else:
199 | super(LMDBDataLoaderAugmenter, self).__init__(
200 | self._dataset,
201 | config.batch_size,
202 | drop_last=False,
203 | sampler=self._sampler,
204 | pin_memory=config.pin_memory,
205 | num_workers=config.workers,
206 | collate_fn=collate_fn,
207 | )
208 |
209 | else:
210 | super(LMDBDataLoaderAugmenter, self).__init__(
211 | self._dataset,
212 | batch_size=config.batch_size,
213 | shuffle=train,
214 | pin_memory=config.pin_memory,
215 | num_workers=config.workers,
216 | drop_last=True,
217 | collate_fn=collate_fn,
218 | )
219 |
220 | def normalize_pose_labels(self, pose_labels):
221 | for i in range(len(pose_labels)):
222 | pose_labels[i] = (
223 | pose_labels[i] - self.config.pose_mean
224 | ) / self.config.pose_stddev
225 |
226 | return pose_labels
227 |
228 |
229 | def collate_fn(batch):
230 | return tuple(zip(*batch))
231 |
--------------------------------------------------------------------------------
/early_stop.py:
--------------------------------------------------------------------------------
1 | class EarlyStop:
2 | def __init__(self, patience=5, mode="max", threshold=0):
3 | self.patience = patience
4 | self.counter = 0
5 | self.best_score = None
6 | self.stop = False
7 | self.mode = mode
8 | self.threshold = threshold
9 | self.val_score = float("Inf")
10 | if mode == "max":
11 | self.val_score *= -1
12 |
13 | def __call__(self, val_score):
14 | if self.best_score is None:
15 | self.best_score = val_score
16 |
17 | # if val score did not improve, add to early stop counter
18 | elif (val_score < self.best_score + self.threshold and self.mode == "max") or (
19 | val_score > self.best_score + self.threshold and self.mode == "min"
20 | ):
21 | self.counter += 1
22 | print(f"Early stop counter: {self.counter} out of {self.patience}")
23 |
24 | # if not improve for patience times, stop training earlier
25 | if self.counter >= self.patience:
26 | self.stop = True
27 | else:
28 | self.best_score = val_score
29 | self.counter = 0
30 |
--------------------------------------------------------------------------------
/evaluation/evaluate_wider.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import time
5 |
6 | import numpy as np
7 | from PIL import Image, ImageOps
8 | from torchvision import transforms
9 | from tqdm import tqdm
10 |
11 | sys.path.append("./")
12 | from img2pose import img2poseModel
13 | from model_loader import load_model
14 |
15 |
16 | class WIDER_Eval:
17 | def __init__(self, args):
18 | self.threed_68_points = np.load(args.threed_68_points)
19 | self.nms_threshold = args.nms_threshold
20 | self.pose_mean = np.load(args.pose_mean)
21 | self.pose_stddev = np.load(args.pose_stddev)
22 |
23 | self.test_dataset = self.get_dataset(args)
24 | self.dataset_path = args.dataset_path
25 | self.img2pose_model = self.create_model(args)
26 |
27 | self.transform = transforms.Compose([transforms.ToTensor()])
28 | self.min_size = args.min_size
29 | self.max_size = args.max_size
30 | self.flip = len(args.min_size) > 1
31 | self.output_path = args.output_path
32 |
33 | def create_model(self, args):
34 | img2pose_model = img2poseModel(
35 | args.depth,
36 | args.min_size[-1],
37 | args.max_size,
38 | pose_mean=self.pose_mean,
39 | pose_stddev=self.pose_stddev,
40 | threed_68_points=self.threed_68_points,
41 | )
42 | load_model(
43 | img2pose_model.fpn_model,
44 | args.pretrained_path,
45 | cpu_mode=str(img2pose_model.device) == "cpu",
46 | model_only=True,
47 | )
48 | img2pose_model.evaluate()
49 |
50 | return img2pose_model
51 |
52 | def get_dataset(self, args):
53 | annotations = open(args.dataset_list)
54 | lines = annotations.readlines()
55 |
56 | test_dataset = []
57 |
58 | for i in range(len(lines)):
59 | lines[i] = str(lines[i].rstrip("\n"))
60 | if "--" in lines[i]:
61 | test_dataset.append(lines[i])
62 |
63 | return test_dataset
64 |
65 | def bbox_voting(self, bboxes, iou_thresh=0.6):
66 | # bboxes: a numpy array of N*5 size representing N boxes;
67 | # for each box, it is represented as [x1, y1, x2, y2, s]
68 | # iou_thresh: group bounding boxes if their overlap is > threshold.
69 |
70 | order = bboxes[:, 4].ravel().argsort()[::-1]
71 | bboxes = bboxes[order, :]
72 | areas = (bboxes[:, 2] - bboxes[:, 0] + 1) * (bboxes[:, 3] - bboxes[:, 1] + 1)
73 | voted_bboxes = np.zeros([0, 5])
74 | while bboxes.shape[0] > 0:
75 | xx1 = np.maximum(bboxes[0, 0], bboxes[:, 0])
76 | yy1 = np.maximum(bboxes[0, 1], bboxes[:, 1])
77 | xx2 = np.minimum(bboxes[0, 2], bboxes[:, 2])
78 | yy2 = np.minimum(bboxes[0, 3], bboxes[:, 3])
79 | w = np.maximum(0.0, xx2 - xx1 + 1)
80 | h = np.maximum(0.0, yy2 - yy1 + 1)
81 | inter = w * h
82 | overlaps = inter / (areas[0] + areas[:] - inter)
83 | merge_indexs = np.where(overlaps >= iou_thresh)[0]
84 | if merge_indexs.shape[0] == 0:
85 | bboxes = np.delete(bboxes, np.array([0]), 0)
86 | areas = np.delete(areas, np.array([0]), 0)
87 | continue
88 | bboxes_accu = bboxes[merge_indexs, :]
89 | bboxes = np.delete(bboxes, merge_indexs, 0)
90 | areas = np.delete(areas, merge_indexs, 0)
91 | # generate a new box by score voting and box voting
92 | bbox = np.zeros((1, 5))
93 | box_weights = (bboxes_accu[:, -1] / max(bboxes_accu[:, -1])) * overlaps[
94 | merge_indexs
95 | ]
96 | bboxes_accu[:, 0:4] = bboxes_accu[:, 0:4] * np.tile(
97 | box_weights.reshape((-1, 1)), (1, 4)
98 | )
99 | bbox[:, 0:4] = np.sum(bboxes_accu[:, 0:4], axis=0) / (np.sum(box_weights))
100 | bbox[0, 4] = np.sum(bboxes_accu[:, 4] * box_weights)
101 | voted_bboxes = np.row_stack((voted_bboxes, bbox))
102 |
103 | return voted_bboxes
104 |
105 | def get_scales(self, im):
106 | im_shape = im.size
107 | im_size_min = np.min(im_shape[0:2])
108 |
109 | scales = [float(scale) / im_size_min for scale in self.min_size]
110 |
111 | return scales
112 |
113 | def test(self):
114 | times = []
115 |
116 | for img_path in tqdm(self.test_dataset):
117 | img_full_path = os.path.join(self.dataset_path, img_path)
118 | img = Image.open(img_full_path).convert("RGB")
119 |
120 | bboxes = []
121 |
122 | (w, h) = img.size
123 | scales = self.get_scales(img)
124 |
125 | if self.flip:
126 | flip_list = (False, True)
127 | else:
128 | flip_list = (False,)
129 |
130 | for flip in flip_list:
131 | for scale in scales:
132 | run_img = img.copy()
133 |
134 | if flip:
135 | run_img = ImageOps.mirror(run_img)
136 |
137 | new_w = int(run_img.size[0] * scale)
138 | new_h = int(run_img.size[1] * scale)
139 |
140 | min_size = min(new_w, new_h)
141 | max_size = max(new_w, new_h)
142 |
143 | if len(scales) > 1:
144 | self.img2pose_model.fpn_model.module.set_max_min_size(
145 | max_size, min_size
146 | )
147 |
148 | time1 = time.time()
149 |
150 | res = self.img2pose_model.predict([self.transform(run_img)])
151 |
152 | time2 = time.time()
153 | times.append(time2 - time1)
154 |
155 | res = res[0]
156 |
157 | for i in range(len(res["scores"])):
158 | bbox = res["boxes"].cpu().numpy()[i].astype("int")
159 | score = res["scores"].cpu().numpy()[i]
160 | pose = res["dofs"].cpu().numpy()[i]
161 |
162 | if flip:
163 | bbox_copy = bbox.copy()
164 | bbox[0] = w - bbox_copy[2]
165 | bbox[2] = w - bbox_copy[0]
166 | pose[1:4] *= -1
167 |
168 | bboxes.append(np.append(bbox, score))
169 |
170 | bboxes = np.asarray(bboxes)
171 |
172 | if len(self.min_size) > 1:
173 | bboxes = self.bbox_voting(bboxes, self.nms_threshold)
174 |
175 | if np.ndim(bboxes) == 1 and len(bboxes) > 0:
176 | bboxes = bboxes[np.newaxis, :]
177 |
178 | output_path = os.path.join(self.output_path, os.path.split(img_path)[0])
179 |
180 | if not os.path.exists(output_path):
181 | os.makedirs(output_path)
182 |
183 | file_name = os.path.split(img_path)[-1]
184 | f = open(os.path.join(output_path, file_name[:-4] + ".txt"), "w")
185 | f.write(file_name + "\n")
186 | f.write(str(len(bboxes)) + "\n")
187 | for i in range(len(bboxes)):
188 | bbox = bboxes[i]
189 |
190 | f.write(
191 | f"{bbox[0]} {bbox[1]} {bbox[2]-bbox[0]} {bbox[3]-bbox[1]} {bbox[4]}\n"
192 | )
193 | f.close()
194 |
195 | print(f"Average time forward pass: {np.mean(np.asarray(times))}")
196 |
197 |
198 | def parse_args():
199 | parser = argparse.ArgumentParser(
200 | description="Train a deep network to predict 3D expression and 6DOF pose."
201 | )
202 | parser.add_argument(
203 | "--min_size",
204 | help="Min size",
205 | default="200, 300, 500, 800, 1100, 1400, 1700",
206 | type=str,
207 | )
208 | parser.add_argument("--max_size", help="Max size", default=1400, type=int)
209 | parser.add_argument(
210 | "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int
211 | )
212 | parser.add_argument(
213 | "--pose_mean",
214 | help="Pose mean file path.",
215 | type=str,
216 | default="./models/WIDER_train_pose_mean_v1.npy",
217 | )
218 | parser.add_argument(
219 | "--pose_stddev",
220 | help="Pose stddev file path.",
221 | type=str,
222 | default="./models/WIDER_train_pose_stddev_v1.npy",
223 | )
224 |
225 | # training/validation configuration
226 | parser.add_argument("--output_path", help="Path to save predictions", required=True)
227 | parser.add_argument("--dataset_path", help="Path to the dataset", required=True)
228 | parser.add_argument("--dataset_list", help="Dataset list.")
229 |
230 | # resume from or load pretrained weights
231 | parser.add_argument(
232 | "--pretrained_path", help="Path to pretrained weights.", type=str
233 | )
234 | parser.add_argument("--nms_threshold", default=0.6, type=float)
235 |
236 | parser.add_argument(
237 | "--threed_68_points",
238 | type=str,
239 | help="Reference 3D points to compute pose.",
240 | default="./pose_references/reference_3d_68_points_trans.npy",
241 | )
242 |
243 | args = parser.parse_args()
244 |
245 | args.min_size = [int(item) for item in args.min_size.split(", ")]
246 |
247 | return args
248 |
249 |
250 | if __name__ == "__main__":
251 | args = parse_args()
252 |
253 | wider_eval = WIDER_Eval(args)
254 | wider_eval.test()
255 |
--------------------------------------------------------------------------------
/evaluation/jupyter_notebooks/aflw_2000_3d_evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Imports"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {
14 | "ExecuteTime": {
15 | "end_time": "2021-03-18T19:35:06.774927Z",
16 | "start_time": "2021-03-18T19:35:05.706994Z"
17 | },
18 | "scrolled": true
19 | },
20 | "outputs": [],
21 | "source": [
22 | "import sys\n",
23 | "sys.path.append('../../')\n",
24 | "import numpy as np\n",
25 | "import torch\n",
26 | "from torchvision import transforms\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "from tqdm.notebook import tqdm\n",
29 | "from PIL import Image, ImageOps\n",
30 | "from scipy.spatial.transform import Rotation\n",
31 | "import pandas as pd\n",
32 | "from scipy.spatial import distance\n",
33 | "import time\n",
34 | "import os\n",
35 | "import math\n",
36 | "import scipy.io as sio\n",
37 | "from utils.renderer import Renderer\n",
38 | "from utils.image_operations import expand_bbox_rectangle\n",
39 | "from utils.pose_operations import get_pose\n",
40 | "from img2pose import img2poseModel\n",
41 | "from model_loader import load_model\n",
42 | "\n",
43 | "np.set_printoptions(suppress=True)"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "## Load dataset annotations "
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {
57 | "ExecuteTime": {
58 | "end_time": "2021-03-18T19:35:06.783194Z",
59 | "start_time": "2021-03-18T19:35:06.776810Z"
60 | },
61 | "scrolled": true
62 | },
63 | "outputs": [],
64 | "source": [
65 | "dataset_path = \"AFLW2000_annotations.txt\"\n",
66 | "test_dataset = pd.read_csv(dataset_path, delimiter=\" \", header=None)\n",
67 | "test_dataset = np.asarray(test_dataset).squeeze() "
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Declare useful functions"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 3,
80 | "metadata": {
81 | "ExecuteTime": {
82 | "end_time": "2021-03-18T19:35:06.797090Z",
83 | "start_time": "2021-03-18T19:35:06.785309Z"
84 | }
85 | },
86 | "outputs": [],
87 | "source": [
88 | "def bb_intersection_over_union(boxA, boxB):\n",
89 | " xA = max(boxA[0], boxB[0])\n",
90 | " yA = max(boxA[1], boxB[1])\n",
91 | " xB = min(boxA[2], boxB[2])\n",
92 | " yB = min(boxA[3], boxB[3])\n",
93 | "\n",
94 | " interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)\n",
95 | " boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)\n",
96 | " boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)\n",
97 | " iou = interArea / float(boxAArea + boxBArea - interArea)\n",
98 | " return iou\n",
99 | "\n",
100 | "def render_plot(img, pose_pred):\n",
101 | " (w, h) = img.size\n",
102 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
103 | "\n",
104 | " trans_vertices = renderer.transform_vertices(img, [pose_pred])\n",
105 | " img = renderer.render(img, trans_vertices, alpha=1) \n",
106 | "\n",
107 | " plt.figure(figsize=(8, 8)) \n",
108 | "\n",
109 | " plt.imshow(img) \n",
110 | " plt.show()\n",
111 | " \n",
112 | "def convert_to_aflw(rotvec, is_rotvec=True):\n",
113 | " if is_rotvec:\n",
114 | " rotvec = Rotation.from_rotvec(rotvec).as_matrix()\n",
115 | " rot_mat_2 = np.transpose(rotvec)\n",
116 | " angle = Rotation.from_matrix(rot_mat_2).as_euler('xyz', degrees=True)\n",
117 | " \n",
118 | " return np.array([angle[0], -angle[1], -angle[2]])"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": [
125 | "## Create the renderer for visualization"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 4,
131 | "metadata": {
132 | "ExecuteTime": {
133 | "end_time": "2021-03-18T19:35:06.816811Z",
134 | "start_time": "2021-03-18T19:35:06.798493Z"
135 | }
136 | },
137 | "outputs": [],
138 | "source": [
139 | "renderer = Renderer(\n",
140 | " vertices_path=\"../../pose_references/vertices_trans.npy\", \n",
141 | " triangles_path=\"../../pose_references/triangles.npy\"\n",
142 | ")\n",
143 | "\n",
144 | "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "## Load model weights and pose mean and std deviation\n",
152 | "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 5,
158 | "metadata": {
159 | "ExecuteTime": {
160 | "end_time": "2021-03-18T19:35:09.656797Z",
161 | "start_time": "2021-03-18T19:35:06.818766Z"
162 | }
163 | },
164 | "outputs": [
165 | {
166 | "name": "stdout",
167 | "output_type": "stream",
168 | "text": [
169 | "Model will use 1 GPUs!\n"
170 | ]
171 | }
172 | ],
173 | "source": [
174 | "transform = transforms.Compose([transforms.ToTensor()])\n",
175 | "\n",
176 | "DEPTH = 18\n",
177 | "MAX_SIZE = 1400\n",
178 | "MIN_SIZE = 400\n",
179 | "\n",
180 | "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n",
181 | "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n",
182 | "MODEL_PATH = \"../../models/img2pose_v1_ft_300w_lp.pth\"\n",
183 | "\n",
184 | "\n",
185 | "pose_mean = np.load(POSE_MEAN)\n",
186 | "pose_stddev = np.load(POSE_STDDEV)\n",
187 | "\n",
188 | "\n",
189 | "img2pose_model = img2poseModel(\n",
190 | " DEPTH, MIN_SIZE, MAX_SIZE, \n",
191 | " pose_mean=pose_mean, pose_stddev=pose_stddev,\n",
192 | " threed_68_points=threed_points,\n",
193 | " rpn_pre_nms_top_n_test=500,\n",
194 | " rpn_post_nms_top_n_test=10,\n",
195 | ")\n",
196 | "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n",
197 | "img2pose_model.evaluate()"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "## Run AFLW2000-3D evaluation\n",
205 | "To visualize the predictions, change visualize to True and change total_imgs to the amount of images desired."
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 7,
211 | "metadata": {
212 | "ExecuteTime": {
213 | "end_time": "2021-03-18T19:36:41.278410Z",
214 | "start_time": "2021-03-18T19:35:34.696305Z"
215 | },
216 | "scrolled": false
217 | },
218 | "outputs": [
219 | {
220 | "data": {
221 | "application/vnd.jupyter.widget-view+json": {
222 | "model_id": "098a3413f3414f4db135d05411b36713",
223 | "version_major": 2,
224 | "version_minor": 0
225 | },
226 | "text/plain": [
227 | "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
228 | ]
229 | },
230 | "metadata": {},
231 | "output_type": "display_data"
232 | },
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "\n",
238 | "Model failed on 0 images\n",
239 | "Yaw: 3.426 Pitch: 5.034 Roll: 3.278 MAE: 3.913\n",
240 | "H. Trans.: 0.028 V. Trans.: 0.038 Scale: 0.230 MAE: 0.099\n",
241 | "Average time 0.02748371853690392\n"
242 | ]
243 | }
244 | ],
245 | "source": [
246 | "visualize = False\n",
247 | "total_imgs = len(test_dataset)\n",
248 | "\n",
249 | "threshold = 0.0\n",
250 | "targets = []\n",
251 | "predictions = []\n",
252 | "\n",
253 | "total_failures = 0\n",
254 | "times = []\n",
255 | "\n",
256 | "for img_path in tqdm(test_dataset[:total_imgs]):\n",
257 | " img = Image.open(img_path).convert(\"RGB\")\n",
258 | "\n",
259 | " image_name = os.path.split(img_path)[1]\n",
260 | "\n",
261 | " ori_img = img.copy()\n",
262 | "\n",
263 | " (w, h) = ori_img.size\n",
264 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
265 | "\n",
266 | " mat_contents = sio.loadmat(img_path[:-4] + \".mat\")\n",
267 | " target_points = np.asarray(mat_contents['pt3d_68']).T[:, :2]\n",
268 | "\n",
269 | " _, pose_target = get_pose(threed_points, target_points, image_intrinsics)\n",
270 | "\n",
271 | " target_bbox = expand_bbox_rectangle(w, h, 1.1, 1.1, target_points, roll=pose_target[2])\n",
272 | "\n",
273 | " pose_para = np.asarray(mat_contents['Pose_Para'])[0][:3]\n",
274 | " pose_para_degrees = pose_para[:3] * (180 / math.pi)\n",
275 | "\n",
276 | " if np.any(np.abs(pose_para_degrees) > 99):\n",
277 | " continue \n",
278 | "\n",
279 | " run_time = 0\n",
280 | " time1 = time.time()\n",
281 | " res = img2pose_model.predict([transform(img)])\n",
282 | " time2 = time.time()\n",
283 | " run_time += (time2 - time1)\n",
284 | "\n",
285 | " res = res[0]\n",
286 | "\n",
287 | " bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n",
288 | " max_iou = 0\n",
289 | " best_index = -1\n",
290 | "\n",
291 | " for i in range(len(bboxes)):\n",
292 | " if res[\"scores\"][i] > threshold:\n",
293 | " bbox = bboxes[i]\n",
294 | " pose_pred = res[\"dofs\"].cpu().numpy()[i].astype('float')\n",
295 | " pose_pred = np.asarray(pose_pred.squeeze()) \n",
296 | " iou = bb_intersection_over_union(bbox, target_bbox)\n",
297 | "\n",
298 | " if iou > max_iou:\n",
299 | " max_iou = iou\n",
300 | " best_index = i \n",
301 | "\n",
302 | " if best_index >= 0:\n",
303 | " bbox = bboxes[best_index]\n",
304 | " pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n",
305 | " pose_pred = np.asarray(pose_pred.squeeze()) \n",
306 | "\n",
307 | " \n",
308 | " if visualize and best_index >= 0: \n",
309 | " render_plot(ori_img.copy(), pose_pred)\n",
310 | "\n",
311 | " if len(bboxes) == 0:\n",
312 | " total_failures += 1\n",
313 | "\n",
314 | " continue\n",
315 | "\n",
316 | " times.append(run_time)\n",
317 | "\n",
318 | " pose_target[:3] = pose_para_degrees \n",
319 | " pose_pred[:3] = convert_to_aflw(pose_pred[:3])\n",
320 | "\n",
321 | " targets.append(pose_target)\n",
322 | " predictions.append(pose_pred)\n",
323 | "\n",
324 | "pose_mae = np.mean(abs(np.asarray(predictions) - np.asarray(targets)), axis=0)\n",
325 | "threed_pose = pose_mae[:3]\n",
326 | "trans_pose = pose_mae[3:]\n",
327 | "\n",
328 | "print(f\"Model failed on {total_failures} images\")\n",
329 | "print(f\"Yaw: {threed_pose[1]:.3f} Pitch: {threed_pose[0]:.3f} Roll: {threed_pose[2]:.3f} MAE: {np.mean(threed_pose):.3f}\")\n",
330 | "print(f\"H. Trans.: {trans_pose[0]:.3f} V. Trans.: {trans_pose[1]:.3f} Scale: {trans_pose[2]:.3f} MAE: {np.mean(trans_pose):.3f}\")\n",
331 | "print(f\"Average time {np.mean(np.asarray(times))}\")"
332 | ]
333 | }
334 | ],
335 | "metadata": {
336 | "bento_stylesheets": {
337 | "bento/extensions/flow/main.css": true,
338 | "bento/extensions/kernel_selector/main.css": true,
339 | "bento/extensions/kernel_ui/main.css": true,
340 | "bento/extensions/new_kernel/main.css": true,
341 | "bento/extensions/system_usage/main.css": true,
342 | "bento/extensions/theme/main.css": true
343 | },
344 | "disseminate_notebook_id": {
345 | "notebook_id": "336305087387084"
346 | },
347 | "disseminate_notebook_info": {
348 | "bento_version": "20200629-000305",
349 | "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n",
350 | "hide_code": false,
351 | "hipster_group": "",
352 | "kernel_build_info": {
353 | "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found."
354 | },
355 | "no_uii": true,
356 | "notebook_number": "296232",
357 | "others_can_edit": false,
358 | "reviewers": "",
359 | "revision_id": "1126967097689153",
360 | "tags": "",
361 | "tasks": "",
362 | "title": "Updated Model Pose and Expression Qualitative Results"
363 | },
364 | "kernelspec": {
365 | "display_name": "pytorch",
366 | "language": "python",
367 | "name": "pytorch"
368 | },
369 | "language_info": {
370 | "codemirror_mode": {
371 | "name": "ipython",
372 | "version": 3
373 | },
374 | "file_extension": ".py",
375 | "mimetype": "text/x-python",
376 | "name": "python",
377 | "nbconvert_exporter": "python",
378 | "pygments_lexer": "ipython3",
379 | "version": "3.8.1"
380 | },
381 | "notify_time": "30"
382 | },
383 | "nbformat": 4,
384 | "nbformat_minor": 2
385 | }
386 |
--------------------------------------------------------------------------------
/evaluation/jupyter_notebooks/biwi_evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Imports"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {
14 | "ExecuteTime": {
15 | "end_time": "2021-03-18T19:09:02.700914Z",
16 | "start_time": "2021-03-18T19:09:01.585915Z"
17 | },
18 | "scrolled": true
19 | },
20 | "outputs": [],
21 | "source": [
22 | "import sys\n",
23 | "sys.path.append('../../')\n",
24 | "import numpy as np\n",
25 | "import torch\n",
26 | "from torchvision import transforms\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "from tqdm.notebook import tqdm\n",
29 | "from PIL import Image, ImageOps\n",
30 | "from scipy.spatial.transform import Rotation\n",
31 | "import pandas as pd\n",
32 | "from scipy.spatial import distance\n",
33 | "import time\n",
34 | "import os\n",
35 | "import math\n",
36 | "import scipy.io as sio\n",
37 | "from utils.renderer import Renderer\n",
38 | "from img2pose import img2poseModel\n",
39 | "from model_loader import load_model\n",
40 | "\n",
41 | "np.set_printoptions(suppress=True)"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "## Declare useful functions"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {
55 | "ExecuteTime": {
56 | "end_time": "2021-03-18T19:09:02.712136Z",
57 | "start_time": "2021-03-18T19:09:02.702535Z"
58 | }
59 | },
60 | "outputs": [],
61 | "source": [
62 | "def bb_intersection_over_union(boxA, boxB):\n",
63 | " xA = max(boxA[0], boxB[0])\n",
64 | " yA = max(boxA[1], boxB[1])\n",
65 | " xB = min(boxA[2], boxB[2])\n",
66 | " yB = min(boxA[3], boxB[3])\n",
67 | "\n",
68 | " interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)\n",
69 | " boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)\n",
70 | " boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)\n",
71 | " iou = interArea / float(boxAArea + boxBArea - interArea)\n",
72 | " return iou\n",
73 | "\n",
74 | "def render_plot(img, pose_pred):\n",
75 | " (w, h) = img.size\n",
76 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
77 | "\n",
78 | " trans_vertices = renderer.transform_vertices(img, [pose_pred])\n",
79 | " img = renderer.render(img, trans_vertices, alpha=1) \n",
80 | "\n",
81 | " plt.figure(figsize=(16, 16)) \n",
82 | "\n",
83 | " plt.imshow(img) \n",
84 | " plt.show()\n",
85 | " \n",
86 | "def convert_to_aflw(rotvec, is_rotvec=True):\n",
87 | " if is_rotvec:\n",
88 | " rotvec = Rotation.from_rotvec(rotvec).as_matrix()\n",
89 | " rot_mat_2 = np.transpose(rotvec)\n",
90 | " angle = Rotation.from_matrix(rot_mat_2).as_euler('xyz', degrees=True)\n",
91 | " \n",
92 | " return np.array([angle[0], -angle[1], -angle[2]])"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "## Load BIWI dataset annotations "
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 3,
105 | "metadata": {
106 | "ExecuteTime": {
107 | "end_time": "2021-03-18T19:09:09.822126Z",
108 | "start_time": "2021-03-18T19:09:02.715404Z"
109 | }
110 | },
111 | "outputs": [],
112 | "source": [
113 | "dataset_path = \"./BIWI_annotations.txt\"\n",
114 | "dataset = pd.read_csv(dataset_path, delimiter=\" \", header=None)\n",
115 | "dataset = np.asarray(dataset).squeeze()\n",
116 | "\n",
117 | "pose_targets = []\n",
118 | "test_dataset = []\n",
119 | "\n",
120 | "for sample in dataset:\n",
121 | " img_path = sample[0]\n",
122 | " \n",
123 | " annotations = open(img_path.replace(\"_rgb.png\", \"_pose.txt\"))\n",
124 | " lines = annotations.readlines()\n",
125 | " \n",
126 | " pose_target = []\n",
127 | " for i in range(3):\n",
128 | " lines[i] = str(lines[i].rstrip(\"\\n\")) \n",
129 | " pose_target.append(lines[i].split(\" \")[:3])\n",
130 | " \n",
131 | " pose_target = np.asarray(pose_target).astype(float) \n",
132 | " pose_target = convert_to_aflw(pose_target, False)\n",
133 | " pose_targets.append(pose_target)\n",
134 | " \n",
135 | " test_dataset.append(img_path)"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {},
141 | "source": [
142 | "## Create the renderer for visualization"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 4,
148 | "metadata": {
149 | "ExecuteTime": {
150 | "end_time": "2021-03-18T19:09:09.833558Z",
151 | "start_time": "2021-03-18T19:09:09.825145Z"
152 | }
153 | },
154 | "outputs": [],
155 | "source": [
156 | "renderer = Renderer(\n",
157 | " vertices_path=\"../../pose_references/vertices_trans.npy\", \n",
158 | " triangles_path=\"../../pose_references/triangles.npy\"\n",
159 | ")\n",
160 | "\n",
161 | "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | "## Load model weights and pose mean and std deviation\n",
169 | "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 5,
175 | "metadata": {
176 | "ExecuteTime": {
177 | "end_time": "2021-03-18T19:09:12.607511Z",
178 | "start_time": "2021-03-18T19:09:09.835246Z"
179 | },
180 | "scrolled": true
181 | },
182 | "outputs": [
183 | {
184 | "name": "stdout",
185 | "output_type": "stream",
186 | "text": [
187 | "Model will use 1 GPUs!\n"
188 | ]
189 | }
190 | ],
191 | "source": [
192 | "transform = transforms.Compose([transforms.ToTensor()])\n",
193 | "\n",
194 | "DEPTH = 18\n",
195 | "MAX_SIZE = 1400\n",
196 | "MIN_SIZE = 700\n",
197 | "\n",
198 | "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n",
199 | "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n",
200 | "MODEL_PATH = \"../../models/img2pose_v1_ft_300w_lp.pth\"\n",
201 | "pose_mean = np.load(POSE_MEAN)\n",
202 | "pose_stddev = np.load(POSE_STDDEV)\n",
203 | "\n",
204 | "img2pose_model = img2poseModel(\n",
205 | " DEPTH, MIN_SIZE, MAX_SIZE, \n",
206 | " pose_mean=pose_mean, pose_stddev=pose_stddev,\n",
207 | " threed_68_points=threed_points,\n",
208 | " rpn_pre_nms_top_n_test=500,\n",
209 | " rpn_post_nms_top_n_test=10,\n",
210 | ")\n",
211 | "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n",
212 | "img2pose_model.evaluate()"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {},
218 | "source": [
219 | "## Run BIWI evaluation\n",
220 | "To visualize the predictions, change visualize to True and change total_imgs to the amount of images desired."
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 7,
226 | "metadata": {
227 | "ExecuteTime": {
228 | "end_time": "2021-03-18T19:20:56.102388Z",
229 | "start_time": "2021-03-18T19:09:16.285597Z"
230 | },
231 | "scrolled": false
232 | },
233 | "outputs": [
234 | {
235 | "data": {
236 | "application/vnd.jupyter.widget-view+json": {
237 | "model_id": "097fef45c4534f7b818555e66dfad113",
238 | "version_major": 2,
239 | "version_minor": 0
240 | },
241 | "text/plain": [
242 | "HBox(children=(FloatProgress(value=0.0, max=13219.0), HTML(value='')))"
243 | ]
244 | },
245 | "metadata": {},
246 | "output_type": "display_data"
247 | },
248 | {
249 | "name": "stdout",
250 | "output_type": "stream",
251 | "text": [
252 | "\n",
253 | "Model failed on 0 images\n",
254 | "Yaw: 4.567 Pitch: 3.546 Roll: 3.244 MAE: 3.786\n",
255 | "Average time 0.034422722154255576\n"
256 | ]
257 | }
258 | ],
259 | "source": [
260 | "visualize = False\n",
261 | "total_imgs = len(test_dataset)\n",
262 | "threshold = 0.9\n",
263 | "\n",
264 | "predictions = []\n",
265 | "targets = []\n",
266 | "\n",
267 | "total_failures = 0\n",
268 | "times = []\n",
269 | "\n",
270 | "for j in tqdm(range(total_imgs)):\n",
271 | " img = Image.open(test_dataset[j]).convert(\"RGB\")\n",
272 | " (w, h) = img.size\n",
273 | " pose_target = pose_targets[j]\n",
274 | " ori_img = img.copy()\n",
275 | " \n",
276 | " time1 = time.time()\n",
277 | " res = img2pose_model.predict([transform(img)])\n",
278 | " time2 = time.time()\n",
279 | " times.append(time2 - time1)\n",
280 | "\n",
281 | " res = res[0]\n",
282 | "\n",
283 | " bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n",
284 | "\n",
285 | " min_dist_center = float(\"Inf\")\n",
286 | " best_index = 0\n",
287 | "\n",
288 | " if len(bboxes) == 0:\n",
289 | " total_failures += 1 \n",
290 | " continue\n",
291 | "\n",
292 | " for i in range(len(bboxes)):\n",
293 | " if res[\"scores\"][i] > threshold:\n",
294 | " bbox = bboxes[i]\n",
295 | " bbox_center_x = bbox[0] + ((bbox[2] - bbox[0]) // 2)\n",
296 | " bbox_center_y = bbox[1] + ((bbox[3] - bbox[1]) // 2)\n",
297 | "\n",
298 | " dist_center = distance.euclidean([bbox_center_x, bbox_center_y], [w // 2, h // 2])\n",
299 | "\n",
300 | " if dist_center < min_dist_center:\n",
301 | " min_dist_center = dist_center\n",
302 | " best_index = i \n",
303 | "\n",
304 | " bbox = bboxes[best_index]\n",
305 | " pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n",
306 | " pose_pred = np.asarray(pose_pred.squeeze())\n",
307 | "\n",
308 | " if best_index >= 0:\n",
309 | " bbox = bboxes[best_index]\n",
310 | " pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n",
311 | " pose_pred = np.asarray(pose_pred.squeeze())\n",
312 | " \n",
313 | "\n",
314 | " if visualize and best_index >= 0: \n",
315 | " render_plot(ori_img.copy(), pose_pred)\n",
316 | "\n",
317 | " if len(bboxes) == 0:\n",
318 | " total_failures += 1\n",
319 | "\n",
320 | " continue\n",
321 | " \n",
322 | " pose_pred = convert_to_aflw(pose_pred[:3])\n",
323 | " \n",
324 | " predictions.append(pose_pred[:3])\n",
325 | " targets.append(pose_target[:3])\n",
326 | "\n",
327 | "pose_mae = np.mean(abs(np.asarray(predictions) - np.asarray(targets)), axis=0)\n",
328 | "threed_pose = pose_mae[:3]\n",
329 | "\n",
330 | "print(f\"Model failed on {total_failures} images\")\n",
331 | "print(f\"Yaw: {threed_pose[1]:.3f} Pitch: {threed_pose[0]:.3f} Roll: {threed_pose[2]:.3f} MAE: {np.mean(threed_pose):.3f}\")\n",
332 | "print(f\"Average time {np.mean(np.asarray(times))}\")"
333 | ]
334 | }
335 | ],
336 | "metadata": {
337 | "bento_stylesheets": {
338 | "bento/extensions/flow/main.css": true,
339 | "bento/extensions/kernel_selector/main.css": true,
340 | "bento/extensions/kernel_ui/main.css": true,
341 | "bento/extensions/new_kernel/main.css": true,
342 | "bento/extensions/system_usage/main.css": true,
343 | "bento/extensions/theme/main.css": true
344 | },
345 | "disseminate_notebook_id": {
346 | "notebook_id": "336305087387084"
347 | },
348 | "disseminate_notebook_info": {
349 | "bento_version": "20200629-000305",
350 | "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n",
351 | "hide_code": false,
352 | "hipster_group": "",
353 | "kernel_build_info": {
354 | "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found."
355 | },
356 | "no_uii": true,
357 | "notebook_number": "296232",
358 | "others_can_edit": false,
359 | "reviewers": "",
360 | "revision_id": "1126967097689153",
361 | "tags": "",
362 | "tasks": "",
363 | "title": "Updated Model Pose and Expression Qualitative Results"
364 | },
365 | "kernelspec": {
366 | "display_name": "pytorch",
367 | "language": "python",
368 | "name": "pytorch"
369 | },
370 | "language_info": {
371 | "codemirror_mode": {
372 | "name": "ipython",
373 | "version": 3
374 | },
375 | "file_extension": ".py",
376 | "mimetype": "text/x-python",
377 | "name": "python",
378 | "nbconvert_exporter": "python",
379 | "pygments_lexer": "ipython3",
380 | "version": "3.8.1"
381 | },
382 | "notify_time": "30"
383 | },
384 | "nbformat": 4,
385 | "nbformat_minor": 2
386 | }
387 |
--------------------------------------------------------------------------------
/evaluation/jupyter_notebooks/test_own_images.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Imports"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "ExecuteTime": {
15 | "end_time": "2020-12-17T00:06:41.333846Z",
16 | "start_time": "2020-12-17T00:06:40.238883Z"
17 | },
18 | "scrolled": true
19 | },
20 | "outputs": [],
21 | "source": [
22 | "import sys\n",
23 | "sys.path.append('../../')\n",
24 | "import numpy as np\n",
25 | "import torch\n",
26 | "from torchvision import transforms\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "from tqdm.notebook import tqdm\n",
29 | "from PIL import Image, ImageOps\n",
30 | "import matplotlib.patches as patches\n",
31 | "from scipy.spatial.transform import Rotation\n",
32 | "import pandas as pd\n",
33 | "from scipy.spatial import distance\n",
34 | "import time\n",
35 | "import os\n",
36 | "import math\n",
37 | "import scipy.io as sio\n",
38 | "from utils.renderer import Renderer\n",
39 | "from utils.image_operations import expand_bbox_rectangle\n",
40 | "from utils.pose_operations import get_pose\n",
41 | "from img2pose import img2poseModel\n",
42 | "from model_loader import load_model\n",
43 | "\n",
44 | "np.set_printoptions(suppress=True)"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "## Declare useful functions"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {
58 | "ExecuteTime": {
59 | "end_time": "2020-12-17T00:06:41.340777Z",
60 | "start_time": "2020-12-17T00:06:41.335946Z"
61 | }
62 | },
63 | "outputs": [],
64 | "source": [
65 | "def render_plot(img, poses, bboxes):\n",
66 | " (w, h) = img.size\n",
67 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
68 | " \n",
69 | " trans_vertices = renderer.transform_vertices(img, poses)\n",
70 | " img = renderer.render(img, trans_vertices, alpha=1) \n",
71 | "\n",
72 | " plt.figure(figsize=(8, 8)) \n",
73 | " \n",
74 | " for bbox in bboxes:\n",
75 | " plt.gca().add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1],linewidth=3,edgecolor='b',facecolor='none')) \n",
76 | " \n",
77 | " plt.imshow(img) \n",
78 | " plt.show()"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "## Create the renderer for visualization"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {
92 | "ExecuteTime": {
93 | "end_time": "2020-12-17T00:06:41.357319Z",
94 | "start_time": "2020-12-17T00:06:41.342763Z"
95 | }
96 | },
97 | "outputs": [],
98 | "source": [
99 | "renderer = Renderer(\n",
100 | " vertices_path=\"../../pose_references/vertices_trans.npy\", \n",
101 | " triangles_path=\"../../pose_references/triangles.npy\"\n",
102 | ")\n",
103 | "\n",
104 | "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "## Load model weights and pose mean and std deviation\n",
112 | "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {
119 | "ExecuteTime": {
120 | "end_time": "2020-12-17T00:06:44.297174Z",
121 | "start_time": "2020-12-17T00:06:41.359153Z"
122 | },
123 | "scrolled": true
124 | },
125 | "outputs": [],
126 | "source": [
127 | "transform = transforms.Compose([transforms.ToTensor()])\n",
128 | "\n",
129 | "DEPTH = 18\n",
130 | "MAX_SIZE = 1400\n",
131 | "MIN_SIZE = 600\n",
132 | "\n",
133 | "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n",
134 | "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n",
135 | "MODEL_PATH = \"../../models/img2pose_v1.pth\"\n",
136 | "\n",
137 | "pose_mean = np.load(POSE_MEAN)\n",
138 | "pose_stddev = np.load(POSE_STDDEV)\n",
139 | "\n",
140 | "img2pose_model = img2poseModel(\n",
141 | " DEPTH, MIN_SIZE, MAX_SIZE, \n",
142 | " pose_mean=pose_mean, pose_stddev=pose_stddev,\n",
143 | " threed_68_points=threed_points,\n",
144 | ")\n",
145 | "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n",
146 | "img2pose_model.evaluate()"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "## Run on a folder or an image list\n",
154 | "Give it a list with images paths, or a folder containing images"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "ExecuteTime": {
162 | "end_time": "2020-12-17T00:06:54.309068Z",
163 | "start_time": "2020-12-17T00:06:51.244841Z"
164 | },
165 | "scrolled": false
166 | },
167 | "outputs": [],
168 | "source": [
169 | "# change to a folder with images, or another list containing image paths\n",
170 | "images_path = \"your_own_image_folder/list\"\n",
171 | "\n",
172 | "threshold = 0.9\n",
173 | "\n",
174 | "if os.path.isfile(images_path):\n",
175 | " img_paths = pd.read_csv(images_path, delimiter=\" \", header=None)\n",
176 | " img_paths = np.asarray(img_paths).squeeze()\n",
177 | "else:\n",
178 | " img_paths = [os.path.join(images_path, img_path) for img_path in os.listdir(images_path)]\n",
179 | "\n",
180 | "for img_path in tqdm(img_paths):\n",
181 | " img = Image.open(img_path).convert(\"RGB\")\n",
182 | " \n",
183 | " image_name = os.path.split(img_path)[1]\n",
184 | " \n",
185 | " (w, h) = img.size\n",
186 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n",
187 | " \n",
188 | " res = img2pose_model.predict([transform(img)])[0]\n",
189 | "\n",
190 | " all_bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n",
191 | "\n",
192 | " poses = []\n",
193 | " bboxes = []\n",
194 | " for i in range(len(all_bboxes)):\n",
195 | " if res[\"scores\"][i] > threshold:\n",
196 | " bbox = all_bboxes[i]\n",
197 | " pose_pred = res[\"dofs\"].cpu().numpy()[i].astype('float')\n",
198 | " pose_pred = pose_pred.squeeze()\n",
199 | "\n",
200 | " poses.append(pose_pred) \n",
201 | " bboxes.append(bbox)\n",
202 | "\n",
203 | " render_plot(img.copy(), poses, bboxes)"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": []
212 | }
213 | ],
214 | "metadata": {
215 | "bento_stylesheets": {
216 | "bento/extensions/flow/main.css": true,
217 | "bento/extensions/kernel_selector/main.css": true,
218 | "bento/extensions/kernel_ui/main.css": true,
219 | "bento/extensions/new_kernel/main.css": true,
220 | "bento/extensions/system_usage/main.css": true,
221 | "bento/extensions/theme/main.css": true
222 | },
223 | "disseminate_notebook_id": {
224 | "notebook_id": "336305087387084"
225 | },
226 | "disseminate_notebook_info": {
227 | "bento_version": "20200629-000305",
228 | "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n",
229 | "hide_code": false,
230 | "hipster_group": "",
231 | "kernel_build_info": {
232 | "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found."
233 | },
234 | "no_uii": true,
235 | "notebook_number": "296232",
236 | "others_can_edit": false,
237 | "reviewers": "",
238 | "revision_id": "1126967097689153",
239 | "tags": "",
240 | "tasks": "",
241 | "title": "Updated Model Pose and Expression Qualitative Results"
242 | },
243 | "kernelspec": {
244 | "display_name": "test",
245 | "language": "python",
246 | "name": "test"
247 | },
248 | "language_info": {
249 | "codemirror_mode": {
250 | "name": "ipython",
251 | "version": 3
252 | },
253 | "file_extension": ".py",
254 | "mimetype": "text/x-python",
255 | "name": "python",
256 | "nbconvert_exporter": "python",
257 | "pygments_lexer": "ipython3",
258 | "version": "3.9.1"
259 | },
260 | "notify_time": "30"
261 | },
262 | "nbformat": 4,
263 | "nbformat_minor": 2
264 | }
265 |
--------------------------------------------------------------------------------
/generalized_rcnn.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import OrderedDict
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.jit.annotations import Dict, List, Optional, Tuple
7 |
8 |
9 | class GeneralizedRCNN(nn.Module):
10 | """
11 | Main class for Generalized R-CNN.
12 |
13 | Arguments:
14 | backbone (nn.Module):
15 | rpn (nn.Module):
16 | roi_heads (nn.Module): takes the features + the proposals from the RPN
17 | and computes detections / masks from it.
18 | transform (nn.Module): performs the data transformation from the inputs
19 | to feed into the model
20 | """
21 |
22 | def __init__(self, backbone, rpn, roi_heads, transform):
23 | super(GeneralizedRCNN, self).__init__()
24 | self.transform = transform
25 | self.backbone = backbone
26 | self.rpn = rpn
27 | self.roi_heads = roi_heads
28 | # used only on torchscript mode
29 | self._has_warned = False
30 |
31 | @torch.jit.unused
32 | def eager_outputs(self, losses, detections, evaluating):
33 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]])
34 | # -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
35 | if evaluating:
36 | return losses
37 |
38 | return detections
39 |
40 | def forward(self, images, targets=None):
41 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
42 | # -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
43 | """
44 | Arguments:
45 | images (list[Tensor]): images to be processed
46 | targets (list[Dict[Tensor]]): ground-truth (optional)
47 |
48 | Returns:
49 | result (list[BoxList] or dict[Tensor]): the output from the model.
50 | During training, it returns a dict[Tensor] which contains the losses.
51 | During testing, it returns list[BoxList] contains additional fields
52 | like `scores`, `labels` and `mask` (for Mask R-CNN models).
53 |
54 | """
55 | if self.training and targets is None:
56 | raise ValueError("In training mode, targets should be passed")
57 | if self.training or targets is not None:
58 | assert targets is not None
59 | for target in targets:
60 | boxes = target["boxes"]
61 | if isinstance(boxes, torch.Tensor):
62 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
63 | raise ValueError(
64 | "Expected target boxes to be a tensor"
65 | "of shape [N, 4], got {:}.".format(boxes.shape)
66 | )
67 | else:
68 | raise ValueError(
69 | "Expected target boxes to be of type "
70 | "Tensor, got {:}.".format(type(boxes))
71 | )
72 |
73 | original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
74 | for img in images:
75 | val = img.shape[-2:]
76 | assert len(val) == 2
77 | original_image_sizes.append((val[0], val[1]))
78 |
79 | images, targets = self.transform(images, targets)
80 |
81 | # Check for degenerate boxes
82 | # TODO: Move this to a function
83 | if targets is not None:
84 | for target_idx, target in enumerate(targets):
85 | boxes = target["boxes"]
86 | degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
87 | if degenerate_boxes.any():
88 | # print the first degenrate box
89 | bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
90 | degen_bb: List[float] = boxes[bb_idx].tolist()
91 | raise ValueError(
92 | "All bounding boxes should have positive height and width."
93 | " Found invaid box {} for target at index {}.".format(
94 | degen_bb, target_idx
95 | )
96 | )
97 |
98 | features = self.backbone(images.tensors)
99 | if isinstance(features, torch.Tensor):
100 | features = OrderedDict([("0", features)])
101 | proposals, proposal_losses = self.rpn(images, features, targets)
102 | detections, detector_losses = self.roi_heads(
103 | features, proposals, images.image_sizes, targets
104 | )
105 | detections = self.transform.postprocess(
106 | detections, images.image_sizes, original_image_sizes
107 | )
108 |
109 | losses = {}
110 | losses.update(detector_losses)
111 | losses.update(proposal_losses)
112 |
113 | if torch.jit.is_scripting():
114 | if not self._has_warned:
115 | warnings.warn(
116 | "RCNN always returns a (Losses, Detections) tuple in scripting"
117 | )
118 | self._has_warned = True
119 | return (losses, detections)
120 | else:
121 | return self.eager_outputs(losses, detections, targets is not None)
122 |
--------------------------------------------------------------------------------
/img2pose.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import DataParallel, Module
3 | from torch.nn.parallel import DistributedDataParallel
4 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
5 |
6 | from model_loader import load_model
7 | from models import FasterDoFRCNN
8 |
9 |
10 | class WrappedModel(Module):
11 | def __init__(self, module):
12 | super(WrappedModel, self).__init__()
13 | self.module = module
14 |
15 | def forward(self, images, targets=None):
16 | return self.module(images, targets)
17 |
18 |
19 | class img2poseModel:
20 | def __init__(
21 | self,
22 | depth,
23 | min_size,
24 | max_size,
25 | model_path=None,
26 | device=None,
27 | pose_mean=None,
28 | pose_stddev=None,
29 | distributed=False,
30 | gpu=0,
31 | threed_68_points=None,
32 | threed_5_points=None,
33 | rpn_pre_nms_top_n_test=6000,
34 | rpn_post_nms_top_n_test=1000,
35 | bbox_x_factor=1.1,
36 | bbox_y_factor=1.1,
37 | expand_forehead=0.3,
38 | ):
39 | self.depth = depth
40 | self.min_size = min_size
41 | self.max_size = max_size
42 | self.model_path = model_path
43 | self.distributed = distributed
44 | self.gpu = gpu
45 |
46 | if device is None:
47 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48 | else:
49 | self.device = device
50 |
51 | # create network backbone
52 | backbone = resnet_fpn_backbone(f"resnet{self.depth}", pretrained=True)
53 |
54 | if pose_mean is not None:
55 | pose_mean = torch.tensor(pose_mean)
56 | pose_stddev = torch.tensor(pose_stddev)
57 |
58 | if threed_68_points is not None:
59 | threed_68_points = torch.tensor(threed_68_points)
60 |
61 | if threed_5_points is not None:
62 | threed_5_points = torch.tensor(threed_5_points)
63 |
64 | # create the feature pyramid network
65 | self.fpn_model = FasterDoFRCNN(
66 | backbone,
67 | 2,
68 | min_size=self.min_size,
69 | max_size=self.max_size,
70 | pose_mean=pose_mean,
71 | pose_stddev=pose_stddev,
72 | threed_68_points=threed_68_points,
73 | threed_5_points=threed_5_points,
74 | rpn_pre_nms_top_n_test=rpn_pre_nms_top_n_test,
75 | rpn_post_nms_top_n_test=rpn_post_nms_top_n_test,
76 | bbox_x_factor=bbox_x_factor,
77 | bbox_y_factor=bbox_y_factor,
78 | expand_forehead=expand_forehead,
79 | )
80 |
81 | # if using cpu, remove the parallel modules from the saved model
82 | self.fpn_model_without_ddp = self.fpn_model
83 |
84 | if self.distributed:
85 | self.fpn_model = self.fpn_model.to(self.device)
86 | self.fpn_model = DistributedDataParallel(
87 | self.fpn_model, device_ids=[self.gpu]
88 | )
89 | self.fpn_model_without_ddp = self.fpn_model.module
90 |
91 | print("Model will use distributed mode!")
92 |
93 | elif str(self.device) == "cpu":
94 | self.fpn_model = WrappedModel(self.fpn_model)
95 | self.fpn_model_without_ddp = self.fpn_model
96 |
97 | print("Model will run on CPU!")
98 |
99 | else:
100 | self.fpn_model = DataParallel(self.fpn_model)
101 | self.fpn_model = self.fpn_model.to(self.device)
102 | self.fpn_model_without_ddp = self.fpn_model
103 |
104 | print(f"Model will use {torch.cuda.device_count()} GPUs!")
105 |
106 | if self.model_path is not None:
107 | self.load_saved_model(self.model_path)
108 | self.evaluate()
109 |
110 | def load_saved_model(self, model_path):
111 | load_model(
112 | self.fpn_model_without_ddp, model_path, cpu_mode=str(self.device) == "cpu"
113 | )
114 |
115 | def evaluate(self):
116 | self.fpn_model.eval()
117 |
118 | def train(self):
119 | self.fpn_model.train()
120 |
121 | def run_model(self, imgs, targets=None):
122 | outputs = self.fpn_model(imgs, targets)
123 |
124 | return outputs
125 |
126 | def forward(self, imgs, targets):
127 | losses = self.run_model(imgs, targets)
128 |
129 | return losses
130 |
131 | def predict(self, imgs):
132 | assert self.fpn_model.training is False
133 |
134 | with torch.no_grad():
135 | predictions = self.run_model(imgs)
136 |
137 | return predictions
138 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from itertools import chain, repeat
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from utils.pose_operations import plot_3d_landmark_torch, pose_full_image_to_bbox
7 |
8 |
9 | def fastrcnn_loss(
10 | class_logits,
11 | class_labels,
12 | dof_regression,
13 | labels,
14 | dof_regression_targets,
15 | proposals,
16 | image_shapes,
17 | pose_mean=None,
18 | pose_stddev=None,
19 | threed_points=None,
20 | ):
21 | # # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
22 | """
23 | Computes the loss for Faster R-CNN.
24 |
25 | Arguments:
26 | class_logits (Tensor)
27 | dof_regression (Tensor)
28 | labels (list[BoxList])
29 | regression_targets (Tensor)
30 |
31 | Returns:
32 | classification_loss (Tensor)
33 | dof_loss (Tensor)
34 | points_loss (Tensor)
35 | """
36 | img_size = [
37 | (boxes_in_image.shape[0], image_shapes[i])
38 | for i, boxes_in_image in enumerate(proposals)
39 | ]
40 | img_size = list(chain.from_iterable(repeat(j, i) for i, j in img_size))
41 |
42 | labels = torch.cat(labels, dim=0)
43 | class_labels = torch.cat(class_labels, dim=0)
44 | dof_regression_targets = torch.cat(dof_regression_targets, dim=0)
45 | proposals = torch.cat(proposals, dim=0)
46 | classification_loss = F.cross_entropy(class_logits, class_labels)
47 |
48 | # get indices that correspond to the regression targets for
49 | # the corresponding ground truth labels, to be used with
50 | # advanced indexing
51 | sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
52 | labels_pos = labels[sampled_pos_inds_subset]
53 | N = dof_regression.shape[0]
54 | dof_regression = dof_regression.reshape(N, -1, 6)
55 | dof_regression = dof_regression[sampled_pos_inds_subset, labels_pos]
56 | prop_regression = proposals[sampled_pos_inds_subset]
57 |
58 | dof_regression_targets = dof_regression_targets[sampled_pos_inds_subset]
59 |
60 | all_target_calibration_points = None
61 | all_pred_calibration_points = None
62 |
63 | for i in range(prop_regression.shape[0]):
64 | (h, w) = img_size[i]
65 | global_intrinsics = torch.Tensor(
66 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
67 | ).to(proposals[0].device)
68 |
69 | threed_points = threed_points.to(proposals[0].device)
70 |
71 | h = prop_regression[i, 3] - prop_regression[i, 1]
72 | w = prop_regression[i, 2] - prop_regression[i, 0]
73 | local_intrinsics = torch.Tensor(
74 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
75 | ).to(proposals[0].device)
76 |
77 | # calibration points projection
78 | local_dof_regression = (
79 | dof_regression[i, :] * pose_stddev.to(proposals[0].device)
80 | ) + pose_mean.to(proposals[0].device)
81 |
82 | pred_calibration_points = plot_3d_landmark_torch(
83 | threed_points, local_dof_regression.float(), local_intrinsics
84 | ).unsqueeze(0)
85 |
86 | # pose convertion for pose loss
87 | dof_regression_targets[i, :] = torch.from_numpy(
88 | pose_full_image_to_bbox(
89 | dof_regression_targets[i, :].cpu().numpy(),
90 | global_intrinsics.cpu().numpy(),
91 | prop_regression[i, :].cpu().numpy(),
92 | )
93 | ).to(proposals[0].device)
94 |
95 | # target calibration points projection
96 | target_calibration_points = plot_3d_landmark_torch(
97 | threed_points, dof_regression_targets[i, :], local_intrinsics
98 | ).unsqueeze(0)
99 |
100 | if all_target_calibration_points is None:
101 | all_target_calibration_points = target_calibration_points
102 | else:
103 | all_target_calibration_points = torch.cat(
104 | (all_target_calibration_points, target_calibration_points)
105 | )
106 | if all_pred_calibration_points is None:
107 | all_pred_calibration_points = pred_calibration_points
108 | else:
109 | all_pred_calibration_points = torch.cat(
110 | (all_pred_calibration_points, pred_calibration_points)
111 | )
112 |
113 | if pose_mean is not None:
114 | dof_regression_targets[i, :] = (
115 | dof_regression_targets[i, :] - pose_mean.to(proposals[0].device)
116 | ) / pose_stddev.to(proposals[0].device)
117 |
118 | points_loss = F.l1_loss(all_target_calibration_points, all_pred_calibration_points)
119 |
120 | dof_loss = (
121 | F.mse_loss(
122 | dof_regression,
123 | dof_regression_targets,
124 | reduction="sum",
125 | )
126 | / dof_regression.shape[0]
127 | )
128 |
129 | return classification_loss, dof_loss, points_loss
130 |
--------------------------------------------------------------------------------
/model_loader.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import torch
4 |
5 | try:
6 | from utils.dist import is_main_process
7 | except Exception as e:
8 | print(e)
9 |
10 |
11 | def save_model(fpn_model, optimizer, config, val_loss=0, step=0, model_only=False):
12 | if is_main_process():
13 | save_path = config.model_path
14 |
15 | if model_only:
16 | torch.save(
17 | {"fpn_model": fpn_model.state_dict()},
18 | path.join(save_path, f"model_val_loss_{val_loss:.4f}_step_{step}.pth"),
19 | )
20 | else:
21 | torch.save(
22 | {
23 | "fpn_model": fpn_model.state_dict(),
24 | "optimizer": optimizer.state_dict(),
25 | },
26 | path.join(save_path, f"model_val_loss_{val_loss:.4f}_step_{step}.pth"),
27 | )
28 |
29 |
30 | def load_model(fpn_model, model_path, model_only=True, optimizer=None, cpu_mode=False):
31 | if cpu_mode:
32 | checkpoint = torch.load(model_path, map_location="cpu")
33 | else:
34 | checkpoint = torch.load(model_path)
35 |
36 | fpn_model.load_state_dict(checkpoint["fpn_model"])
37 |
38 | if not model_only:
39 | optimizer.load_state_dict(checkpoint["optimizer"])
40 |
--------------------------------------------------------------------------------
/pose_references/reference_3d_5_points_trans.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/reference_3d_5_points_trans.npy
--------------------------------------------------------------------------------
/pose_references/reference_3d_68_points_trans.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/reference_3d_68_points_trans.npy
--------------------------------------------------------------------------------
/pose_references/triangles.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/triangles.npy
--------------------------------------------------------------------------------
/pose_references/vertices_trans.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/vertices_trans.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.11.0
2 | argon2-cffi==20.1.0
3 | async-generator==1.10
4 | attrs==20.3.0
5 | backcall==0.2.0
6 | bleach==3.2.1
7 | cachetools==4.2.0
8 | certifi==2020.12.5
9 | cffi==1.14.4
10 | chardet==3.0.4
11 | cycler==0.10.0
12 | Cython==0.29.21
13 | dataclasses==0.6
14 | decorator==4.4.2
15 | defusedxml==0.6.0
16 | easydict==1.9
17 | entrypoints==0.3
18 | google-auth==1.24.0
19 | google-auth-oauthlib==0.4.2
20 | grpcio==1.34.0
21 | idna==2.10
22 | IProgress==0.4
23 | ipykernel==5.4.2
24 | ipython==7.19.0
25 | ipython-genutils==0.2.0
26 | ipywidgets==7.5.1
27 | jedi==0.17.2
28 | Jinja2==2.11.2
29 | jsonschema==3.2.0
30 | jupyter-client==6.1.7
31 | jupyter-core==4.7.0
32 | jupyterlab-pygments==0.1.2
33 | kiwisolver==1.3.1
34 | lmdb==1.0.0
35 | Markdown==3.3.3
36 | MarkupSafe==1.1.1
37 | matplotlib==3.3.3
38 | mistune==0.8.4
39 | msgpack==1.0.1
40 | nbclient==0.5.1
41 | nbconvert==6.0.7
42 | nbformat==5.0.8
43 | nest-asyncio==1.4.3
44 | notebook==6.1.5
45 | numpy==1.19.4
46 | oauthlib==3.1.0
47 | opencv-python==4.4.0.46
48 | packaging==20.8
49 | pandas==1.1.5
50 | pandocfilters==1.4.3
51 | parso==0.7.1
52 | pexpect==4.8.0
53 | pickleshare==0.7.5
54 | Pillow==8.0.1
55 | prometheus-client==0.9.0
56 | prompt-toolkit==3.0.8
57 | protobuf==3.14.0
58 | ptyprocess==0.6.0
59 | pyasn1==0.4.8
60 | pyasn1-modules==0.2.8
61 | pycparser==2.20
62 | Pygments==2.7.3
63 | pyparsing==2.4.7
64 | pyrsistent==0.17.3
65 | python-dateutil==2.8.1
66 | pytz==2020.4
67 | pyzmq==20.0.0
68 | requests==2.25.0
69 | requests-oauthlib==1.3.0
70 | rsa==4.6
71 | scikit-image==0.18.0
72 | scipy==1.5.4
73 | Send2Trash==1.5.0
74 | six==1.15.0
75 | tensorboard==2.4.0
76 | tensorboard-plugin-wit==1.7.0
77 | terminado==0.9.1
78 | testpath==0.4.4
79 | torch==1.7.1
80 | torchvision==0.8.2
81 | tornado==6.1
82 | tqdm==4.54.1
83 | traitlets==5.0.5
84 | typing-extensions==3.7.4.3
85 | urllib3==1.26.2
86 | wcwidth==0.2.5
87 | webencodings==0.5.1
88 | Werkzeug==1.0.1
89 | widgetsnbextension==3.5.1
90 |
--------------------------------------------------------------------------------
/run_face_alignment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from PIL import Image
7 | from torchvision import transforms
8 | from tqdm import tqdm
9 |
10 | from img2pose import img2poseModel
11 | from model_loader import load_model
12 | from utils.pose_operations import align_faces
13 |
14 |
15 | class img2pose:
16 | def __init__(self, args):
17 | self.threed_5_points = np.load(args.threed_5_points)
18 | self.threed_68_points = np.load(args.threed_68_points)
19 | self.nms_threshold = args.nms_threshold
20 |
21 | self.pose_mean = np.load(args.pose_mean)
22 | self.pose_stddev = np.load(args.pose_stddev)
23 | self.model = self.create_model(args)
24 |
25 | self.transform = transforms.Compose([transforms.ToTensor()])
26 | self.min_size = (args.min_size,)
27 | self.max_size = args.max_size
28 |
29 | self.max_faces = args.max_faces
30 | self.face_size = args.face_size
31 | self.order_method = args.order_method
32 | self.det_threshold = args.det_threshold
33 |
34 | images_path = args.images_path
35 | if os.path.isfile(images_path):
36 | self.image_list = pd.read_csv(images_path, delimiter=" ", header=None)
37 | self.image_list = np.asarray(self.image_list).squeeze()
38 | else:
39 | self.image_list = [
40 | os.path.join(images_path, img_path)
41 | for img_path in os.listdir(images_path)
42 | ]
43 |
44 | self.output_path = args.output_path
45 |
46 | def create_model(self, args):
47 | img2pose_model = img2poseModel(
48 | args.depth,
49 | args.min_size,
50 | args.max_size,
51 | pose_mean=self.pose_mean,
52 | pose_stddev=self.pose_stddev,
53 | threed_68_points=self.threed_68_points,
54 | )
55 | load_model(
56 | img2pose_model.fpn_model,
57 | args.pretrained_path,
58 | cpu_mode=str(img2pose_model.device) == "cpu",
59 | model_only=True,
60 | )
61 | img2pose_model.evaluate()
62 |
63 | return img2pose_model
64 |
65 | def align(self):
66 | for img_path in tqdm(self.image_list):
67 | image_name = os.path.split(img_path)[-1]
68 | img = Image.open(img_path).convert("RGB")
69 |
70 | res = self.model.predict([self.transform(img)])[0]
71 |
72 | all_scores = res["scores"].cpu().numpy().astype("float")
73 | all_poses = res["dofs"].cpu().numpy().astype("float")
74 |
75 | all_poses = all_poses[all_scores > self.det_threshold]
76 | all_scores = all_scores[all_scores > self.det_threshold]
77 |
78 | if len(all_poses) > 0:
79 | if self.order_method == "confidence":
80 | order = np.argsort(all_scores)[::-1]
81 |
82 | elif self.order_method == "position":
83 | distance_center = np.sqrt(
84 | all_poses[:, 3] ** 2
85 | + all_poses[:, 4] ** 2
86 | )
87 |
88 | order = np.argsort(distance_center)
89 |
90 | top_poses = all_poses[order][: self.max_faces]
91 |
92 | sub_folder = os.path.basename(
93 | os.path.normpath(os.path.split(img_path)[0])
94 | )
95 | output_path = os.path.join(args.output_path, sub_folder)
96 | if not os.path.exists(output_path):
97 | os.makedirs(output_path)
98 |
99 | for i in range(len(top_poses)):
100 | save_name = image_name
101 | if len(top_poses) > 1:
102 | name, ext = image_name.split(".")
103 | save_name = f"{name}_{i}.{ext}"
104 |
105 | aligned_face = align_faces(self.threed_5_points, img, top_poses[i])[
106 | 0
107 | ]
108 | aligned_face = aligned_face.resize((self.face_size, self.face_size))
109 | aligned_face.save(os.path.join(output_path, save_name))
110 | else:
111 | print(f"No face detected above the threshold {self.det_threshold}!")
112 |
113 |
114 | def parse_args():
115 | parser = argparse.ArgumentParser(
116 | description="Align top n faces ordering by score or distance to image center."
117 | )
118 | parser.add_argument("--max_faces", help="Top n faces to save.", default=1, type=int)
119 | parser.add_argument(
120 | "--order_method",
121 | help="How to order faces [confidence, position].",
122 | default="position",
123 | type=str,
124 | )
125 | parser.add_argument(
126 | "--face_size",
127 | help="Image size to save aligned faces [112 or 224].",
128 | default=224,
129 | type=int,
130 | )
131 | parser.add_argument("--min_size", help="Image min size", default=400, type=int)
132 | parser.add_argument("--max_size", help="Image max size", default=1400, type=int)
133 | parser.add_argument(
134 | "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int
135 | )
136 | parser.add_argument(
137 | "--pose_mean",
138 | help="Pose mean file path.",
139 | type=str,
140 | default="./models/WIDER_train_pose_mean_v1.npy",
141 | )
142 | parser.add_argument(
143 | "--pose_stddev",
144 | help="Pose stddev file path.",
145 | type=str,
146 | default="./models/WIDER_train_pose_stddev_v1.npy",
147 | )
148 |
149 | parser.add_argument(
150 | "--pretrained_path",
151 | help="Path to pretrained weights.",
152 | type=str,
153 | default="./models/img2pose_v1.pth",
154 | )
155 |
156 | parser.add_argument(
157 | "--threed_5_points",
158 | type=str,
159 | help="Reference 3D points to align the face.",
160 | default="./pose_references/reference_3d_5_points_trans.npy",
161 | )
162 |
163 | parser.add_argument(
164 | "--threed_68_points",
165 | type=str,
166 | help="Reference 3D points to project bbox.",
167 | default="./pose_references/reference_3d_68_points_trans.npy",
168 | )
169 |
170 | parser.add_argument("--nms_threshold", default=0.6, type=float)
171 | parser.add_argument(
172 | "--det_threshold", help="Detection threshold.", default=0.7, type=float
173 | )
174 | parser.add_argument("--images_path", help="Image list, or folder.", required=True)
175 | parser.add_argument("--output_path", help="Path to save predictions", required=True)
176 |
177 | args = parser.parse_args()
178 |
179 | if not os.path.exists(args.output_path):
180 | os.makedirs(args.output_path)
181 |
182 | return args
183 |
184 |
185 | if __name__ == "__main__":
186 | args = parse_args()
187 |
188 | img2pose = img2pose(args)
189 | img2pose.align()
190 |
--------------------------------------------------------------------------------
/teaser.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/teaser.jpeg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from os import path
4 |
5 | import numpy as np
6 | import torch
7 | from torch import optim
8 | from torch.optim.lr_scheduler import ReduceLROnPlateau
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from config import Config
12 | from data_loader_lmdb import LMDBDataLoader
13 | from data_loader_lmdb_augmenter import LMDBDataLoaderAugmenter
14 | from early_stop import EarlyStop
15 | from img2pose import img2poseModel
16 | from model_loader import load_model, save_model
17 | from train_logger import TrainLogger
18 | from utils.dist import init_distributed_mode, is_main_process, reduce_dict
19 |
20 |
21 | class Train:
22 | def __init__(self, config):
23 | self.config = config
24 |
25 | if is_main_process():
26 | # start tensorboard summary writer
27 | self.writer = SummaryWriter(config.log_path)
28 |
29 | # load training dataset generator
30 | if self.config.random_flip or self.config.random_crop:
31 | self.train_loader = LMDBDataLoaderAugmenter(
32 | self.config, self.config.train_source
33 | )
34 | else:
35 | self.train_loader = LMDBDataLoader(self.config, self.config.train_source)
36 | print(f"Training with {len(self.train_loader.dataset)} images.")
37 |
38 | # loads validation dataset generator if a validation dataset is given
39 | if self.config.val_source is not None:
40 | self.val_loader = LMDBDataLoader(self.config, self.config.val_source, False)
41 |
42 | # creates model
43 | self.img2pose_model = img2poseModel(
44 | depth=self.config.depth,
45 | min_size=self.config.min_size,
46 | max_size=self.config.max_size,
47 | device=self.config.device,
48 | pose_mean=self.config.pose_mean,
49 | pose_stddev=self.config.pose_stddev,
50 | distributed=self.config.distributed,
51 | gpu=self.config.gpu,
52 | threed_68_points=np.load(self.config.threed_68_points),
53 | threed_5_points=np.load(self.config.threed_5_points),
54 | )
55 | # optimizer for the backbone and heads
56 | if args.optimizer == "Adam":
57 | self.optimizer = optim.Adam(
58 | self.img2pose_model.fpn_model.parameters(),
59 | lr=self.config.lr,
60 | weight_decay=self.config.weight_decay,
61 | )
62 | elif args.optimizer == "SGD":
63 | self.optimizer = optim.SGD(
64 | self.img2pose_model.fpn_model.parameters(),
65 | lr=self.config.lr,
66 | weight_decay=self.config.weight_decay,
67 | momentum=self.config.momentum,
68 | )
69 | else:
70 | raise Exception("No optimizer founded, please select between SGD or Adam.")
71 |
72 | # loads a model with optimizer so that it can continue training where it stopped
73 | if self.config.resume_path:
74 | print(f"Resuming training from {self.config.resume_path}")
75 | load_model(
76 | self.img2pose_model.fpn_model,
77 | self.config.resume_path,
78 | model_only=False,
79 | optimizer=self.optimizer,
80 | cpu_mode=str(self.config.device) == "cpu",
81 | )
82 |
83 | # loads a pretrained model without loading the optimizer
84 | if self.config.pretrained_path:
85 | print(f"Loading pretrained weights from {self.config.pretrained_path}")
86 | load_model(
87 | self.img2pose_model.fpn_model,
88 | self.config.pretrained_path,
89 | model_only=True,
90 | cpu_mode=str(self.config.device) == "cpu",
91 | )
92 |
93 | if is_main_process():
94 | # saves configuration to file for easier retrival later
95 | print(self.config)
96 | self.save_file(self.config, "config.txt")
97 |
98 | if is_main_process():
99 | # saves optimizer config to file for easier retrival later
100 | print(self.optimizer)
101 | self.save_file(self.optimizer, "optimizer.txt")
102 |
103 | self.tensorboard_loss_every = max(len(self.train_loader) // 100, 1)
104 |
105 | # reduce learning rate when the validation loss stops to decrease
106 | if self.config.lr_plateau:
107 | self.scheduler = ReduceLROnPlateau(
108 | self.optimizer,
109 | mode="min",
110 | factor=0.1,
111 | patience=3,
112 | verbose=True,
113 | threshold=0.001,
114 | cooldown=1,
115 | min_lr=0.00001,
116 | )
117 |
118 | # stops training before the defined epochs if validation loss stops to decrease
119 | if self.config.early_stop:
120 | self.early_stop = EarlyStop(mode="min", patience=5)
121 |
122 | def run(self):
123 | self.img2pose_model.train()
124 |
125 | # accumulate running loss to log into tensorboard
126 | running_losses = {}
127 | running_losses["loss"] = 0
128 |
129 | step = 0
130 |
131 | # prints the best step and loss every time it does a validation
132 | self.best_step = 0
133 | self.best_val_loss = float("Inf")
134 |
135 | for epoch in range(self.config.epochs):
136 | train_logger = TrainLogger(
137 | self.config.batch_size, self.config.frequency_log, self.config.num_gpus
138 | )
139 | idx = 0
140 | for idx, data in enumerate(self.train_loader):
141 | imgs, targets = data
142 |
143 | imgs = [image.to(self.config.device) for image in imgs]
144 | targets = [
145 | {k: v.to(self.config.device) for k, v in t.items()} for t in targets
146 | ]
147 |
148 | self.optimizer.zero_grad()
149 |
150 | # forward pass
151 | losses = self.img2pose_model.forward(imgs, targets)
152 |
153 | loss = sum(loss for loss in losses.values())
154 |
155 | # does a backward propagation through the network
156 | loss.backward()
157 |
158 | torch.nn.utils.clip_grad_norm_(
159 | self.img2pose_model.fpn_model.parameters(), 10
160 | )
161 |
162 | self.optimizer.step()
163 |
164 | if self.config.distributed:
165 | losses = reduce_dict(losses)
166 | loss = sum(loss for loss in losses.values())
167 |
168 | for loss_name in losses.keys():
169 | if loss_name in running_losses:
170 | running_losses[loss_name] += losses[loss_name].item()
171 | else:
172 | running_losses[loss_name] = losses[loss_name].item()
173 |
174 | running_losses["loss"] += loss.item()
175 |
176 | # saves loss into tensorboard
177 | if step % self.tensorboard_loss_every == 0 and step != 0:
178 | for loss_name in running_losses.keys():
179 | if is_main_process():
180 | self.writer.add_scalar(
181 | f"train_{loss_name}",
182 | running_losses[loss_name] / self.tensorboard_loss_every,
183 | step,
184 | )
185 |
186 | running_losses[loss_name] = 0
187 |
188 | train_logger(
189 | epoch, self.config.epochs, idx, len(self.train_loader), loss.item()
190 | )
191 | step += 1
192 |
193 | # evaluate model using validation set (if set)
194 | if self.config.val_source is not None:
195 | val_loss = self.evaluate(step)
196 |
197 | else:
198 | # otherwise just save the model
199 | save_model(
200 | self.img2pose_model.fpn_model_without_ddp,
201 | self.optimizer,
202 | self.config,
203 | step=step,
204 | )
205 |
206 | # if validation loss stops decreasing, decrease lr
207 | if self.config.lr_plateau and self.config.val_source is not None:
208 | self.scheduler.step(val_loss)
209 |
210 | # early stop model to prevent overfitting
211 | if self.config.early_stop and self.config.val_source is not None:
212 | self.early_stop(val_loss)
213 | if self.early_stop.stop:
214 | print("Early stopping model...")
215 | break
216 |
217 | if self.config.val_source is not None:
218 | val_loss = self.evaluate(step)
219 |
220 | def checkpoint(self, val_loss, step):
221 | if val_loss < self.best_val_loss:
222 | self.best_val_loss = val_loss
223 | self.best_step = step
224 |
225 | save_model(
226 | self.img2pose_model.fpn_model_without_ddp,
227 | self.optimizer,
228 | self.config,
229 | val_loss,
230 | step,
231 | )
232 |
233 | def reduce_lr(self):
234 | for params in self.optimizer.param_groups:
235 | params["lr"] /= 10
236 |
237 | print("Reducing learning rate...")
238 | print(self.optimizer)
239 |
240 | def evaluate(self, step):
241 | val_losses = {}
242 | val_losses["loss"] = 0
243 |
244 | print("Evaluating model...")
245 | with torch.no_grad():
246 | for data in iter(self.val_loader):
247 | imgs, targets = data
248 |
249 | imgs = [image.to(self.config.device) for image in imgs]
250 | targets = [
251 | {k: v.to(self.config.device) for k, v in t.items()} for t in targets
252 | ]
253 |
254 | if self.config.distributed:
255 | torch.cuda.synchronize()
256 |
257 | losses = self.img2pose_model.forward(imgs, targets)
258 |
259 | if self.config.distributed:
260 | losses = reduce_dict(losses)
261 |
262 | loss = sum(loss for loss in losses.values())
263 |
264 | for loss_name in losses.keys():
265 | if loss_name in val_losses:
266 | val_losses[loss_name] += losses[loss_name].item()
267 | else:
268 | val_losses[loss_name] = losses[loss_name].item()
269 |
270 | val_losses["loss"] += loss.item()
271 |
272 | for loss_name in val_losses.keys():
273 | if is_main_process():
274 | self.writer.add_scalar(
275 | f"val_{loss_name}",
276 | round(val_losses[loss_name] / len(self.val_loader), 6),
277 | step,
278 | )
279 |
280 | val_loss = round(val_losses["loss"] / len(self.val_loader), 6)
281 | self.checkpoint(val_loss, step)
282 |
283 | print(
284 | "Current validation loss: "
285 | + f"{val_loss:.6f} at step {step}"
286 | + " - Best validation loss: "
287 | + f"{self.best_val_loss:.6f} at step {self.best_step}"
288 | )
289 |
290 | self.img2pose_model.train()
291 |
292 | return val_loss
293 |
294 | def save_file(self, string, file_name):
295 | with open(path.join(self.config.work_path, file_name), "w") as file:
296 | file.write(str(string))
297 | file.close()
298 |
299 |
300 | def parse_args():
301 | parser = argparse.ArgumentParser(
302 | description="Train a deep network to predict 3D expression and 6DOF pose."
303 | )
304 | # network and training parameters
305 | parser.add_argument(
306 | "--min_size", help="Min size", default="640, 672, 704, 736, 768, 800", type=str
307 | )
308 | parser.add_argument("--max_size", help="Max size", default=1400, type=int)
309 | parser.add_argument("--epochs", help="Number of epochs.", default=100, type=int)
310 | parser.add_argument(
311 | "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int
312 | )
313 | parser.add_argument("--lr", help="Learning rate.", default=0.001, type=float)
314 | parser.add_argument(
315 | "--optimizer", help="Optimizer (SGD or Adam).", default="SGD", type=str
316 | )
317 | parser.add_argument("--batch_size", help="Batch size.", default=2, type=int)
318 | parser.add_argument(
319 | "--lr_plateau", help="Reduce lr on plateau.", action="store_true"
320 | )
321 | parser.add_argument("--early_stop", help="Use early stop.", action="store_true")
322 | parser.add_argument("--workers", help="Workers number.", default=4, type=int)
323 | parser.add_argument(
324 | "--pose_mean", help="Pose mean file path.", type=str, required=True
325 | )
326 | parser.add_argument(
327 | "--pose_stddev", help="Pose stddev file path.", type=str, required=True
328 | )
329 |
330 | # training/validation configuration
331 | parser.add_argument(
332 | "--workspace", help="Worskspace path to save models and logs.", required=True
333 | )
334 | parser.add_argument(
335 | "--train_source", help="Path to the dataset train LMDB file.", required=True
336 | )
337 | parser.add_argument(
338 | "--val_source", help="Path to the dataset validation LMDB file."
339 | )
340 |
341 | parser.add_argument(
342 | "--prefix", help="Prefix to save the model.", type=str, required=True
343 | )
344 |
345 | # resume from or load pretrained weights
346 | parser.add_argument(
347 | "--pretrained_path", help="Path to pretrained weights.", type=str
348 | )
349 | parser.add_argument(
350 | "--resume_path", help="Path to load model to resume training.", type=str
351 | )
352 |
353 | # online augmentation
354 | parser.add_argument("--noise_augmentation", action="store_true")
355 | parser.add_argument("--contrast_augmentation", action="store_true")
356 | parser.add_argument("--random_flip", action="store_true")
357 | parser.add_argument("--random_crop", action="store_true")
358 |
359 | # distributed training parameters
360 | parser.add_argument(
361 | "--world-size", default=1, type=int, help="number of distributed processes"
362 | )
363 | parser.add_argument(
364 | "--dist-url", default="env://", help="url used to set up distributed training"
365 | )
366 | parser.add_argument(
367 | "--distributed", help="Use distributed training", action="store_true"
368 | )
369 |
370 | # reference points to create pose labels
371 | parser.add_argument(
372 | "--threed_5_points",
373 | type=str,
374 | help="Reference 3D points to compute pose.",
375 | default="./pose_references/reference_3d_5_points_trans.npy",
376 | )
377 |
378 | parser.add_argument(
379 | "--threed_68_points",
380 | type=str,
381 | help="Reference 3D points to compute pose.",
382 | default="./pose_references/reference_3d_68_points_trans.npy",
383 | )
384 |
385 | args = parser.parse_args()
386 |
387 | args.min_size = [int(item) for item in args.min_size.split(",")]
388 |
389 | return args
390 |
391 |
392 | if __name__ == "__main__":
393 | args = parse_args()
394 |
395 | if args.distributed:
396 | init_distributed_mode(args)
397 |
398 | config = Config(args)
399 |
400 | torch.manual_seed(42)
401 | np.random.seed(42)
402 | random.seed(42)
403 |
404 | train = Train(config)
405 | train.run()
406 |
--------------------------------------------------------------------------------
/train_logger.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class TrainLogger(object):
5 | def __init__(self, batch_size, frequency=50, num_gpus=1):
6 | self.num_gpus = num_gpus
7 | self.batch_size = batch_size * num_gpus
8 | self.frequency = frequency
9 | self.init = False
10 | self.tic = 0
11 | self.last_batch = 0
12 | self.running_loss = 0
13 |
14 | def __call__(self, epoch, total_epochs, batch, total, loss):
15 | if self.last_batch > batch:
16 | self.init = False
17 | self.last_batch = batch
18 |
19 | if self.init:
20 | self.running_loss += loss
21 | if batch % self.frequency == 0:
22 | speed = self.frequency * self.batch_size / (time.time() - self.tic)
23 | self.running_loss = self.running_loss / self.frequency
24 |
25 | batch, total = batch * self.num_gpus, total * self.num_gpus
26 | log = (
27 | f"Epoch: [{epoch + 1}-{total_epochs}] Batch: [{batch}-{total}] "
28 | + f"Speed: {speed:.2f} samples/sec Loss: {self.running_loss:.5f}"
29 | )
30 | print(log)
31 |
32 | self.running_loss = 0
33 | self.tic = time.time()
34 | else:
35 | self.init = True
36 | self.tic = time.time()
37 |
--------------------------------------------------------------------------------
/utils/annotate_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import sys
4 |
5 | import cv2
6 | import numpy as np
7 | import pandas as pd
8 | from tqdm import tqdm
9 |
10 | # modify to your RetinaFace path
11 | sys.path.append("../insightface/RetinaFace/")
12 | import os
13 |
14 | from retinaface import RetinaFace
15 |
16 |
17 | def face_landmark_detection(detector, args):
18 | image_paths = pd.read_csv(args.image_list, delimiter=" ", header=None)
19 | image_paths = np.asarray(image_paths).squeeze()
20 |
21 | for i in tqdm(range(len(image_paths))):
22 | img_json = {}
23 |
24 | img_path = image_paths[i]
25 | img_name = os.path.split(img_path)[-1]
26 |
27 | output_path = os.path.join(args.output_path, os.path.split(img_name)[0])
28 | file_output_path = os.path.join(
29 | output_path, f"{os.path.split(img_name)[-1][:-4]}.json"
30 | )
31 |
32 | if os.path.isfile(file_output_path):
33 | print(f"Skipping file {img_name} as it was already processed.")
34 | continue
35 |
36 | im = cv2.imread(img_path)
37 |
38 | pyramid = True
39 | do_flip = False
40 |
41 | if not pyramid:
42 | target_size = 1200
43 | max_size = 1600
44 | target_size = 1504
45 | max_size = 2000
46 | target_size = 1600
47 | max_size = 2150
48 | im_shape = im.shape
49 | im_size_min = np.min(im_shape[0:2])
50 | im_size_max = np.max(im_shape[0:2])
51 | im_scale = float(target_size) / float(im_size_min)
52 | # prevent bigger axis from being more than max_size:
53 | if np.round(im_scale * im_size_max) > max_size:
54 | im_scale = float(max_size) / float(im_size_max)
55 | scales = [im_scale]
56 | else:
57 | do_flip = True
58 | TEST_SCALES = [500, 800, 1100, 1400, 1700]
59 | target_size = 800
60 | max_size = 1200
61 | im_shape = im.shape
62 | im_size_min = np.min(im_shape[0:2])
63 | im_size_max = np.max(im_shape[0:2])
64 | im_scale = float(target_size) / float(im_size_min)
65 | # prevent bigger axis from being more than max_size:
66 | if np.round(im_scale * im_size_max) > max_size:
67 | im_scale = float(max_size) / float(im_size_max)
68 | scales = [float(scale) / target_size * im_scale for scale in TEST_SCALES]
69 |
70 | faces, landmarks = detector.detect(
71 | im, threshold=args.thresh, scales=scales, do_flip=do_flip
72 | )
73 |
74 | bboxes_list = []
75 | landmarks_list = []
76 |
77 | if faces is not None:
78 | for i in range(faces.shape[0]):
79 | bbox = faces[i].astype(np.float32)
80 | if landmarks is not None:
81 | landmark5 = landmarks[i].astype(np.float32)
82 |
83 | bboxes_list.append(bbox.tolist())
84 | landmarks_list.append(landmark5.tolist())
85 |
86 | if len(landmarks_list) > 0:
87 | img_json["image_path"] = img_path
88 | img_json["bboxes"] = bboxes_list
89 | img_json["landmarks"] = landmarks_list
90 |
91 | if not os.path.isdir(output_path):
92 | os.makedirs(output_path)
93 |
94 | with open(
95 | file_output_path,
96 | "w",
97 | ) as output_file:
98 | json.dump(img_json, output_file)
99 |
100 |
101 | def parse_args():
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument(
104 | "--image_list",
105 | type=str,
106 | required=True,
107 | help="List with path to images.",
108 | )
109 | parser.add_argument(
110 | "--output_path", type=str, required=True, help="Path to save the json files."
111 | )
112 | parser.add_argument(
113 | "--thresh", type=float, default=0.8, help="Face detection threshold."
114 | )
115 | parser.add_argument(
116 | "--model_path",
117 | type=str,
118 | default="../insightface/RetinaFace/models/R50/R50",
119 | help="Model path for detector.",
120 | )
121 |
122 | args = parser.parse_args()
123 |
124 | if not os.path.exists(args.output_path):
125 | os.makedirs(args.output_path)
126 |
127 | return args
128 |
129 |
130 | if __name__ == "__main__":
131 | args = parse_args()
132 |
133 | detector = RetinaFace(args.model_path, 0, 0, "net3", vote=False)
134 |
135 | face_landmark_detection(detector, args)
136 |
--------------------------------------------------------------------------------
/utils/augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image, ImageEnhance, ImageOps
6 |
7 |
8 | def random_crop(img, bboxes, landmarks):
9 | useable_landmarks = False
10 | for landmark in landmarks:
11 | if -1 not in landmark:
12 | useable_landmarks = True
13 | break
14 |
15 | total_attempts = 10
16 | searching = True
17 | attempt = 0
18 |
19 | while searching:
20 | crop_img = img.copy()
21 | (w, h) = crop_img.size
22 | crop_size = random.uniform(0.7, 1)
23 |
24 | if attempt == total_attempts:
25 | return img, bboxes, landmarks
26 |
27 | crop_x = int(w * crop_size)
28 | crop_y = int(h * crop_size)
29 |
30 | start_x = random.randint(0, w - crop_x)
31 | start_y = (start_x // w) * h
32 |
33 | crop_bbox = [start_x, start_y, start_x + crop_x, start_y + crop_y]
34 |
35 | new_bboxes, new_lms = _adjust_bboxes_landmarks(
36 | bboxes.copy(), landmarks.copy(), crop_bbox
37 | )
38 |
39 | if len(new_bboxes) > 0:
40 | if useable_landmarks:
41 | for lms in new_lms:
42 | if -1 not in lms:
43 | searching = False
44 | break
45 | else:
46 | searching = False
47 |
48 | if not searching:
49 | crop_img = crop_img.crop(
50 | (crop_bbox[0], crop_bbox[1], crop_bbox[2], crop_bbox[3])
51 | )
52 |
53 | attempt += 1
54 |
55 | return crop_img, new_bboxes, new_lms
56 |
57 |
58 | def _adjust_bboxes_landmarks(bboxes, landmarks, crop_bbox):
59 | new_bboxes = []
60 | new_lms = []
61 | for i in range(len(bboxes)):
62 | bbox = bboxes[i]
63 | lms = np.asarray(landmarks[i])
64 |
65 | bbox_center_x = bbox[0] + ((bbox[2] - bbox[0]) // 2)
66 | bbox_center_y = bbox[1] + ((bbox[3] - bbox[1]) // 2)
67 |
68 | if (
69 | bbox_center_x > crop_bbox[0]
70 | and bbox_center_x < crop_bbox[2]
71 | and bbox_center_y > crop_bbox[1]
72 | and bbox_center_y < crop_bbox[3]
73 | ):
74 | bbox[[0, 2]] -= crop_bbox[0]
75 | bbox[[1, 3]] -= crop_bbox[1]
76 |
77 | bbox[0] = max(bbox[0], 0)
78 | bbox[1] = max(bbox[1], 0)
79 | bbox[2] = min(bbox[2], crop_bbox[2] - crop_bbox[0])
80 | bbox[3] = min(bbox[3], crop_bbox[3] - crop_bbox[1])
81 |
82 | add_lm = True
83 | for lm in lms:
84 | if (
85 | lm[0] < crop_bbox[0]
86 | or lm[0] > crop_bbox[2]
87 | or lm[1] < crop_bbox[1]
88 | or lm[1] > crop_bbox[3]
89 | ):
90 | add_lm = False
91 | break
92 |
93 | if add_lm:
94 | lms[:, 0] -= crop_bbox[0]
95 | lms[:, 1] -= crop_bbox[1]
96 |
97 | new_lms.append(lms.tolist())
98 | new_bboxes.append(bbox)
99 |
100 | return new_bboxes, new_lms
101 |
102 |
103 | def random_flip(img, bboxes, all_landmarks):
104 | flip = random.randint(0, 1)
105 |
106 | if flip == 1:
107 | # flip image
108 | img = ImageOps.mirror(img)
109 |
110 | # flip bboxes
111 | old_bboxes = bboxes.copy()
112 | (w, h) = img.size
113 | bboxes[:, 0] = w - old_bboxes[:, 2]
114 | bboxes[:, 2] = w - old_bboxes[:, 0]
115 |
116 | for i in range(len(all_landmarks)):
117 | landmarks = np.asarray(all_landmarks[i])
118 |
119 | if -1 in landmarks:
120 | continue
121 |
122 | all_landmarks[i] = flip_landmarks(landmarks, w).tolist()
123 |
124 | return img, bboxes, all_landmarks
125 |
126 |
127 | def flip_landmarks(landmarks, w):
128 | if len(landmarks) == 5:
129 | order = [1, 0, 2, 4, 3]
130 | else:
131 | order = [
132 | 16,
133 | 15,
134 | 14,
135 | 13,
136 | 12,
137 | 11,
138 | 10,
139 | 9,
140 | 8,
141 | 7,
142 | 6,
143 | 5,
144 | 4,
145 | 3,
146 | 2,
147 | 1,
148 | 0,
149 | 26,
150 | 25,
151 | 24,
152 | 23,
153 | 22,
154 | 21,
155 | 20,
156 | 19,
157 | 18,
158 | 17,
159 | 27,
160 | 28,
161 | 29,
162 | 30,
163 | 35,
164 | 34,
165 | 33,
166 | 32,
167 | 31,
168 | 45,
169 | 44,
170 | 43,
171 | 42,
172 | 47,
173 | 46,
174 | 39,
175 | 38,
176 | 37,
177 | 36,
178 | 41,
179 | 40,
180 | 54,
181 | 53,
182 | 52,
183 | 51,
184 | 50,
185 | 49,
186 | 48,
187 | 59,
188 | 58,
189 | 57,
190 | 56,
191 | 55,
192 | 64,
193 | 63,
194 | 62,
195 | 61,
196 | 60,
197 | 67,
198 | 66,
199 | 65,
200 | ]
201 |
202 | # flip landmarks
203 | landmarks[:, 0] = w - landmarks[:, 0]
204 | flandmarks = landmarks.copy()
205 | for idx, a in enumerate(order):
206 | flandmarks[idx, :] = landmarks[a, :]
207 |
208 | return flandmarks
209 |
210 |
211 | def rotate(img, landmarks, bbox):
212 | angle = random.gauss(0, 1) * 30
213 |
214 | (h, w) = img.shape[:2]
215 | (cX, cY) = (w // 2, h // 2)
216 |
217 | # Transform Image
218 | new_img = _rotate_img(img, angle, cX, cY, h, w)
219 |
220 | # Transform Landmarks
221 | new_landmarks = _rotate_landmarks(landmarks, angle, cX, cY, h, w)
222 |
223 | # Transform Bounding Box
224 | x1, y1, x2, y2 = bbox["left"], bbox["top"], bbox["right"], bbox["bottom"]
225 | bounding_box_8pts = np.array([x1, y1, x2, y1, x2, y2, x1, y2])
226 |
227 | rot_bounding_box_8pts = _rotate_bbox(bounding_box_8pts, angle, cX, cY, h, w)
228 | rot_bounding_box = _get_enclosing_bbox(rot_bounding_box_8pts, y2 - y1, x2 - x1)[
229 | 0
230 | ].astype("float")
231 |
232 | new_bbox = {
233 | "left": rot_bounding_box[0],
234 | "top": rot_bounding_box[1],
235 | "right": rot_bounding_box[2],
236 | "bottom": rot_bounding_box[3],
237 | }
238 |
239 | # validate that bbox boundaries are within the image otherwise do not apply rotation
240 | if (
241 | new_bbox["top"] > 0
242 | and new_bbox["bottom"] < h
243 | and bbox["left"] > 0
244 | and bbox["right"] < w
245 | ):
246 | img = new_img
247 | landmarks = new_landmarks
248 | bbox = new_bbox
249 |
250 | return img, landmarks, bbox
251 |
252 |
253 | def scale(img, bbox):
254 | scale = random.uniform(0.75, 1.25)
255 | bbox = _scale_bbox(img, bbox, scale)
256 |
257 | return bbox
258 |
259 |
260 | def translate_vertical(img, bbox):
261 | bbox_height = bbox["bottom"] - bbox["top"]
262 | vtrans = random.uniform(-0.1, 0.1) * bbox_height
263 |
264 | # check if bbox boundaries are within image, otherwise do not move bbox
265 | if bbox["top"] + vtrans > 0 and bbox["bottom"] + vtrans < img.shape[0]:
266 | bbox["top"] += vtrans
267 | bbox["bottom"] += vtrans
268 |
269 | return bbox
270 |
271 |
272 | def translate_horizontal(img, bbox):
273 | bbox_width = bbox["right"] - bbox["left"]
274 | htrans = random.uniform(-0.1, 0.1) * bbox_width
275 |
276 | # check if bbox boundaries are within image, otherwise do not move bbox
277 | if bbox["left"] + htrans > 0 and bbox["right"] + htrans < img.shape[1]:
278 | bbox["left"] += htrans
279 | bbox["right"] += htrans
280 |
281 | return bbox
282 |
283 |
284 | def change_contrast(img, bboxes, landmarks):
285 | change = random.randint(0, 1)
286 | if change == 1:
287 | factor = random.uniform(0.5, 1.5)
288 | enhancer = ImageEnhance.Contrast(img)
289 | img = enhancer.enhance(factor)
290 |
291 | return img, bboxes, landmarks
292 |
293 |
294 | def add_noise(img, bboxes, landmarks):
295 | add_noise = random.randint(0, 4)
296 | if add_noise == 4:
297 | noise_types = ["gauss", "s&p", "poisson"]
298 | noise_idx = random.randint(0, 2)
299 | noise_type = noise_types[noise_idx]
300 |
301 | img = np.array(img)
302 |
303 | if noise_type == "gauss":
304 | row, col, ch = img.shape
305 | mean = 0
306 | var = 0.1
307 | sigma = var ** 0.5
308 | gauss = np.random.normal(mean, sigma, (row, col, ch))
309 | gauss = gauss.reshape(row, col, ch)
310 | img = img + gauss
311 |
312 | elif noise_type == "s&p":
313 | row, col, ch = img.shape
314 | s_vs_p = 0.5
315 | amount = 0.004
316 | # Salt mode
317 | num_salt = np.ceil(amount * (img.size / ch) * s_vs_p)
318 | coords = [
319 | np.random.randint(0, i - 1, int(num_salt)) for i in img.shape[0:2]
320 | ]
321 | img[tuple(coords)] = 255
322 | # Pepper mode
323 | num_pepper = np.ceil(amount * (img.size / ch) * (1.0 - s_vs_p))
324 | coords = [
325 | np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape[0:2]
326 | ]
327 | img[tuple(coords)] = 0
328 |
329 | elif noise_type == "poisson":
330 | vals = len(np.unique(img))
331 | vals = 2 ** np.ceil(np.log2(vals))
332 | img = np.random.poisson(img * vals) / float(vals)
333 |
334 | img = Image.fromarray(img.astype("uint8"))
335 |
336 | return img, bboxes, landmarks
337 |
338 |
339 | def _rotate_img(img, angle, cX, cY, h, w):
340 | M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
341 | cos = np.abs(M[0, 0])
342 | sin = np.abs(M[0, 1])
343 |
344 | nW = int((h * sin) + (w * cos))
345 | nH = int((h * cos) + (w * sin))
346 |
347 | M[0, 2] += (nW / 2) - cX
348 | M[1, 2] += (nH / 2) - cY
349 |
350 | img = cv2.warpAffine(img, M, (nW, nH))
351 |
352 | return img
353 |
354 |
355 | def _rotate_landmarks(landmarks, angle, cx, cy, h, w):
356 | M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
357 | cos = np.abs(M[0, 0])
358 | sin = np.abs(M[0, 1])
359 |
360 | nW = int((h * sin) + (w * cos))
361 | nH = int((h * cos) + (w * sin))
362 |
363 | M[0, 2] += (nW / 2) - cx
364 | M[1, 2] += (nH / 2) - cy
365 |
366 | landmarks = np.append(landmarks, np.ones((landmarks.shape[0], 1)), axis=1)
367 | calculated = (np.dot(M, landmarks.T)).T
368 | return calculated.astype("int")
369 |
370 |
371 | def _rotate_bbox(corners, angle, cx, cy, h, w):
372 | corners = corners.reshape(-1, 2)
373 | corners = np.hstack(
374 | (corners, np.ones((corners.shape[0], 1), dtype=type(corners[0][0])))
375 | )
376 | M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
377 |
378 | cos = np.abs(M[0, 0])
379 | sin = np.abs(M[0, 1])
380 |
381 | nW = int((h * sin) + (w * cos))
382 | nH = int((h * cos) + (w * sin))
383 | # adjust the rotation matrix to take into account translation
384 | M[0, 2] += (nW / 2) - cx
385 | M[1, 2] += (nH / 2) - cy
386 | # Prepare the vector to be transformed
387 | calculated = np.dot(M, corners.T).T
388 | calculated = calculated.reshape(-1, 8)
389 |
390 | return calculated
391 |
392 |
393 | def _get_enclosing_bbox(corners, original_height, original_width):
394 | x = corners[:, [0, 2, 4, 6]]
395 | y = corners[:, [1, 3, 5, 7]]
396 | xmin = np.min(x, 1).reshape(-1, 1)
397 | ymin = np.min(y, 1).reshape(-1, 1)
398 | xmax = np.max(x, 1).reshape(-1, 1)
399 | ymax = np.max(y, 1).reshape(-1, 1)
400 |
401 | height = ymax - ymin
402 | width = xmax - xmin
403 |
404 | diff_height = height - original_height
405 | diff_width = width - original_width
406 |
407 | ymax -= diff_height // 2
408 | ymin += diff_height // 2
409 | xmax -= diff_width // 2
410 | xmin += diff_width // 2
411 |
412 | final = np.hstack((xmin, ymin, xmax, ymax, corners[:, 8:]))
413 |
414 | return final.astype("int")
415 |
416 |
417 | def _scale_bbox(img, bbox, scale):
418 | height = bbox["bottom"] - bbox["top"]
419 | width = bbox["right"] - bbox["left"]
420 |
421 | new_height = height * scale
422 | new_width = width * scale
423 |
424 | diff_height = height - new_height
425 | diff_width = width - new_width
426 |
427 | # check if bbox boundaries are within image, otherwise do not scale bbox
428 | if (
429 | bbox["top"] + (diff_height // 2) > 0
430 | and bbox["bottom"] - (diff_height // 2) < img.shape[0]
431 | and bbox["left"] + (diff_width // 2) > 0
432 | and bbox["right"] - (diff_width // 2) < img.shape[1]
433 | ):
434 | bbox["bottom"] -= diff_height // 2
435 | bbox["top"] += diff_height // 2
436 | bbox["right"] -= diff_width // 2
437 | bbox["left"] += diff_width // 2
438 |
439 | return bbox
440 |
--------------------------------------------------------------------------------
/utils/dist.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | def init_distributed_mode(args):
8 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
9 | args.rank = int(os.environ["RANK"])
10 | args.world_size = int(os.environ["WORLD_SIZE"])
11 | args.gpu = int(os.environ["LOCAL_RANK"])
12 | elif "SLURM_PROCID" in os.environ:
13 | args.rank = int(os.environ["SLURM_PROCID"])
14 | args.gpu = args.rank % torch.cuda.device_count()
15 | else:
16 | print("Not using distributed mode")
17 | args.distributed = False
18 |
19 | return
20 |
21 | args.distributed = True
22 |
23 | torch.cuda.set_device(args.gpu)
24 | args.dist_backend = "nccl"
25 | print(
26 | "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
27 | )
28 | torch.distributed.init_process_group(
29 | backend=args.dist_backend,
30 | init_method=args.dist_url,
31 | world_size=args.world_size,
32 | rank=args.rank,
33 | )
34 | torch.distributed.barrier()
35 | setup_for_distributed(args.rank == 0)
36 |
37 | print(args)
38 |
39 |
40 | def setup_for_distributed(is_master):
41 | """
42 | This function disables printing when not in master process
43 | """
44 | import builtins as __builtin__
45 |
46 | builtin_print = __builtin__.print
47 |
48 | def print(*args, **kwargs):
49 | force = kwargs.pop("force", False)
50 | if is_master or force:
51 | builtin_print(*args, **kwargs)
52 |
53 | __builtin__.print = print
54 |
55 |
56 | def reduce_dict(input_dict, average=True):
57 | """
58 | Args:
59 | input_dict (dict): all the values will be reduced
60 | average (bool): whether to do average or sum
61 | Reduce the values in the dictionary from all processes so that all processes
62 | have the averaged results. Returns a dict with the same fields as
63 | input_dict, after reduction.
64 | """
65 | world_size = get_world_size()
66 | if world_size < 2:
67 | return input_dict
68 | with torch.no_grad():
69 | names = []
70 | values = []
71 | # sort the keys so that they are consistent across processes
72 | for k in sorted(input_dict.keys()):
73 | names.append(k)
74 | values.append(input_dict[k])
75 | values = torch.stack(values, dim=0)
76 |
77 | dist.all_reduce(values)
78 | if average:
79 | values /= world_size
80 | reduced_dict = {k: v for k, v in zip(names, values)}
81 | return reduced_dict
82 |
83 |
84 | def is_dist_avail_and_initialized():
85 | if not dist.is_available():
86 | return False
87 | if not dist.is_initialized():
88 | return False
89 |
90 | return True
91 |
92 |
93 | def get_world_size():
94 | if not is_dist_avail_and_initialized():
95 | return 1
96 | return dist.get_world_size()
97 |
98 |
99 | def get_rank():
100 | if not is_dist_avail_and_initialized():
101 | return 0
102 | return dist.get_rank()
103 |
104 |
105 | def is_main_process():
106 | return get_rank() == 0
107 |
108 |
109 | def save_on_master(*args, **kwargs):
110 | if is_main_process():
111 | torch.save(*args, **kwargs)
112 |
--------------------------------------------------------------------------------
/utils/face_align.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from
3 | https://github.com/deepinsight/insightface/blob/master/recognition/common/face_align.py
4 | """
5 |
6 | import cv2
7 | import numpy as np
8 | from skimage import transform as trans
9 |
10 | # <-- left profile
11 | src1 = np.array(
12 | [
13 | [51.642, 50.115],
14 | [57.617, 49.990],
15 | [35.740, 69.007],
16 | [51.157, 89.050],
17 | [57.025, 89.702],
18 | ],
19 | dtype=np.float32,
20 | )
21 | # <--left
22 | src2 = np.array(
23 | [
24 | [45.031, 50.118],
25 | [65.568, 50.872],
26 | [39.677, 68.111],
27 | [45.177, 86.190],
28 | [64.246, 86.758],
29 | ],
30 | dtype=np.float32,
31 | )
32 |
33 | # ---frontal
34 | src3 = np.array(
35 | [
36 | [39.730, 51.138],
37 | [72.270, 51.138],
38 | [56.000, 68.493],
39 | [42.463, 87.010],
40 | [69.537, 87.010],
41 | ],
42 | dtype=np.float32,
43 | )
44 |
45 | # -->right
46 | src4 = np.array(
47 | [
48 | [46.845, 50.872],
49 | [67.382, 50.118],
50 | [72.737, 68.111],
51 | [48.167, 86.758],
52 | [67.236, 86.190],
53 | ],
54 | dtype=np.float32,
55 | )
56 |
57 | # -->right profile
58 | src5 = np.array(
59 | [
60 | [54.796, 49.990],
61 | [60.771, 50.115],
62 | [76.673, 69.007],
63 | [55.388, 89.702],
64 | [61.257, 89.050],
65 | ],
66 | dtype=np.float32,
67 | )
68 |
69 | src = np.array([src1, src2, src3, src4, src5])
70 | src_map = {112: src, 224: src * 2}
71 |
72 |
73 | def estimate_norm(lmk, image_size=224):
74 | assert lmk.shape == (5, 2)
75 | tform = trans.SimilarityTransform()
76 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
77 | min_M = []
78 | min_index = []
79 | min_error = float("inf")
80 |
81 | src = src_map[image_size]
82 | for i in np.arange(src.shape[0]):
83 | tform.estimate(lmk, src[i])
84 | M = tform.params[0:2, :]
85 | results = np.dot(M, lmk_tran.T)
86 | results = results.T
87 | error = np.sum(np.sqrt(np.sum((results - src[i]) ** 2, axis=1)))
88 |
89 | # find the src that is most close to the projected points (predicted)
90 | if error < min_error:
91 | min_error = error
92 | min_M = M
93 | min_index = i
94 |
95 | return min_M, min_index
96 |
97 |
98 | def norm_crop(img, landmark, image_size=224):
99 | M, pose_index = estimate_norm(landmark, image_size)
100 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
101 |
102 | return warped
103 |
--------------------------------------------------------------------------------
/utils/image_operations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 |
5 |
6 | def expand_bbox_rectangle(
7 | w, h, bbox_x_factor=2.0, bbox_y_factor=2.0, lms=None, expand_forehead=0.3, roll=0
8 | ):
9 | # get a good bbox for the facial landmarks
10 | min_pt_x = np.min(lms[:, 0], axis=0)
11 | max_pt_x = np.max(lms[:, 0], axis=0)
12 |
13 | min_pt_y = np.min(lms[:, 1], axis=0)
14 | max_pt_y = np.max(lms[:, 1], axis=0)
15 |
16 | # find out the bbox of the crop region
17 | bbox_size_x = int(np.max(max_pt_x - min_pt_x) * bbox_x_factor)
18 | center_pt_x = 0.5 * min_pt_x + 0.5 * max_pt_x
19 |
20 | bbox_size_y = int(np.max(max_pt_y - min_pt_y) * bbox_y_factor)
21 | center_pt_y = 0.5 * min_pt_y + 0.5 * max_pt_y
22 |
23 | bbox_min_x, bbox_max_x = (
24 | center_pt_x - bbox_size_x * 0.5,
25 | center_pt_x + bbox_size_x * 0.5,
26 | )
27 |
28 | bbox_min_y, bbox_max_y = (
29 | center_pt_y - bbox_size_y * 0.5,
30 | center_pt_y + bbox_size_y * 0.5,
31 | )
32 |
33 | if abs(roll) > 2.5:
34 | expand_forehead_size = expand_forehead * np.max(max_pt_y - min_pt_y)
35 | bbox_max_y += expand_forehead_size
36 |
37 | elif roll > 1:
38 | expand_forehead_size = expand_forehead * np.max(max_pt_x - min_pt_x)
39 | bbox_max_x += expand_forehead_size
40 |
41 | elif roll < -1:
42 | expand_forehead_size = expand_forehead * np.max(max_pt_x - min_pt_x)
43 | bbox_min_x -= expand_forehead_size
44 |
45 | else:
46 | expand_forehead_size = expand_forehead * np.max(max_pt_y - min_pt_y)
47 | bbox_min_y -= expand_forehead_size
48 |
49 | bbox_min_x = bbox_min_x.astype(np.int32)
50 | bbox_max_x = bbox_max_x.astype(np.int32)
51 | bbox_min_y = bbox_min_y.astype(np.int32)
52 | bbox_max_y = bbox_max_y.astype(np.int32)
53 |
54 | # compute necessary padding
55 | padding_left = abs(min(bbox_min_x, 0))
56 | padding_top = abs(min(bbox_min_y, 0))
57 | padding_right = max(bbox_max_x - w, 0)
58 | padding_bottom = max(bbox_max_y - h, 0)
59 |
60 | # crop the image properly by computing proper crop bounds
61 | crop_left = 0 if padding_left > 0 else bbox_min_x
62 | crop_top = 0 if padding_top > 0 else bbox_min_y
63 | crop_right = w if padding_right > 0 else bbox_max_x
64 | crop_bottom = h if padding_bottom > 0 else bbox_max_y
65 |
66 | return np.array([crop_left, crop_top, crop_right, crop_bottom])
67 |
68 |
69 | def expand_bbox_rectangle_tensor(
70 | w, h, bbox_x_factor=2.0, bbox_y_factor=2.0, lms=None, expand_forehead=0.3, roll=0
71 | ):
72 | # get a good bbox for the facial landmarks
73 | min_pt_x = torch.min(lms[:, 0], axis=0)[0]
74 | max_pt_x = torch.max(lms[:, 0], axis=0)[0]
75 |
76 | min_pt_y = torch.min(lms[:, 1], axis=0)[0]
77 | max_pt_y = torch.max(lms[:, 1], axis=0)[0]
78 |
79 | # find out the bbox of the crop region
80 | bbox_size_x = int(torch.max(max_pt_x - min_pt_x) * bbox_x_factor)
81 | center_pt_x = 0.5 * min_pt_x + 0.5 * max_pt_x
82 |
83 | bbox_size_y = int(torch.max(max_pt_y - min_pt_y) * bbox_y_factor)
84 | center_pt_y = 0.5 * min_pt_y + 0.5 * max_pt_y
85 |
86 | bbox_min_x, bbox_max_x = (
87 | center_pt_x - bbox_size_x * 0.5,
88 | center_pt_x + bbox_size_x * 0.5,
89 | )
90 |
91 | bbox_min_y, bbox_max_y = (
92 | center_pt_y - bbox_size_y * 0.5,
93 | center_pt_y + bbox_size_y * 0.5,
94 | )
95 |
96 | if abs(roll) > 2.5:
97 | expand_forehead_size = expand_forehead * torch.max(max_pt_y - min_pt_y)
98 | bbox_max_y += expand_forehead_size
99 |
100 | elif roll > 1:
101 | expand_forehead_size = expand_forehead * torch.max(max_pt_x - min_pt_x)
102 | bbox_max_x += expand_forehead_size
103 |
104 | elif roll < -1:
105 | expand_forehead_size = expand_forehead * torch.max(max_pt_x - min_pt_x)
106 | bbox_min_x -= expand_forehead_size
107 |
108 | else:
109 | expand_forehead_size = expand_forehead * torch.max(max_pt_y - min_pt_y)
110 | bbox_min_y -= expand_forehead_size
111 |
112 | bbox_min_x = bbox_min_x.int()
113 | bbox_max_x = bbox_max_x.int()
114 | bbox_min_y = bbox_min_y.int()
115 | bbox_max_y = bbox_max_y.int()
116 |
117 | # compute necessary padding
118 | padding_left = abs(min(bbox_min_x, 0))
119 | padding_top = abs(min(bbox_min_y, 0))
120 | padding_right = max(bbox_max_x - w, 0)
121 | padding_bottom = max(bbox_max_y - h, 0)
122 |
123 | # crop the image properly by computing proper crop bounds
124 | crop_left = 0 if padding_left > 0 else bbox_min_x
125 | crop_top = 0 if padding_top > 0 else bbox_min_y
126 | crop_right = w if padding_right > 0 else bbox_max_x
127 | crop_bottom = h if padding_bottom > 0 else bbox_max_y
128 |
129 | return (
130 | torch.tensor([crop_left, crop_top, crop_right, crop_bottom])
131 | .float()
132 | .to(lms.device)
133 | )
134 |
135 |
136 | def preprocess_image_wrt_face(img, bbox_size_factor=2.0, lms=None):
137 | w, h = img.size
138 |
139 | if lms is None:
140 | return None, None
141 |
142 | # get a good bbox for the facial landmarks
143 | min_pt = np.min(lms, axis=0)
144 | max_pt = np.max(lms, axis=0)
145 |
146 | # find out the bbox of the crop region
147 | bbox_size = int(np.max(max_pt - min_pt) * bbox_size_factor)
148 | center_pt = 0.5 * min_pt + 0.5 * max_pt
149 | bbox_min, bbox_max = center_pt - bbox_size * 0.5, center_pt + bbox_size * 0.5
150 | bbox_min = bbox_min.astype(np.int32)
151 | bbox_max = bbox_max.astype(np.int32)
152 |
153 | # compute necessary padding
154 | padding_left = abs(min(bbox_min[0], 0))
155 | padding_top = abs(min(bbox_min[1], 0))
156 | padding_right = max(bbox_max[0] - w, 0)
157 | padding_bottom = max(bbox_max[1] - h, 0)
158 |
159 | # crop the image properly by computing proper crop bounds
160 | crop_left = 0 if padding_left > 0 else bbox_min[0]
161 | crop_top = 0 if padding_top > 0 else bbox_min[1]
162 | crop_right = w if padding_right > 0 else bbox_max[0]
163 | crop_bottom = h if padding_bottom > 0 else bbox_max[1]
164 |
165 | cropped_image = img.crop((crop_left, crop_top, crop_right, crop_bottom))
166 |
167 | # copy the cropped image to padded image
168 | padded_image = Image.new(img.mode, (bbox_size, bbox_size))
169 | padded_image.paste(
170 | cropped_image,
171 | (
172 | padding_left,
173 | padding_top,
174 | padding_left + crop_right - crop_left,
175 | padding_top + crop_bottom - crop_top,
176 | ),
177 | )
178 |
179 | bbox = {}
180 | bbox["left"] = crop_left - padding_left
181 | bbox["right"] = crop_right + padding_right
182 | bbox["top"] = crop_top - padding_top
183 | bbox["bottom"] = crop_bottom + padding_bottom
184 |
185 | return padded_image, lms, bbox
186 |
187 |
188 | def bbox_is_dict(bbox):
189 | # check if the bbox is a not dict and convert it if needed
190 | if not isinstance(bbox, dict):
191 | temp_bbox = {}
192 | temp_bbox["left"] = bbox[0]
193 | temp_bbox["top"] = bbox[1]
194 | temp_bbox["right"] = bbox[2]
195 | temp_bbox["bottom"] = bbox[3]
196 | bbox = temp_bbox
197 |
198 | return bbox
199 |
200 |
201 | def pad_image_no_crop(img, bbox):
202 | bbox = bbox_is_dict(bbox)
203 |
204 | w, h = img.size
205 |
206 | # checks if the bbox is going outside of the image and if so expands the image
207 | if bbox["left"] < 0 or bbox["top"] < 0 or bbox["right"] > w or bbox["bottom"] > h:
208 | padding_left = abs(min(bbox["left"], 0))
209 | padding_top = abs(min(bbox["top"], 0))
210 | padding_right = max(bbox["right"] - w, 0)
211 | padding_bottom = max(bbox["bottom"] - h, 0)
212 |
213 | height = h + padding_top + padding_bottom
214 | width = w + padding_left + padding_right
215 |
216 | padded_image = Image.new(img.mode, (width, height))
217 | padded_image.paste(
218 | img, (padding_left, padding_top, padding_left + w, padding_top + h)
219 | )
220 |
221 | img = padded_image
222 | bbox["left"] += padding_left
223 | bbox["top"] += padding_top
224 | bbox["right"] += padding_left
225 | bbox["bottom"] += padding_top
226 |
227 | return img, bbox
228 |
229 |
230 | def crop_face_bbox_expanded(img, bbox, bbox_size_factor=2.0):
231 | bbox = bbox_is_dict(bbox)
232 |
233 | # get image size
234 | w, h = img.size
235 |
236 | # transform bounding box to 8 points (4,2)
237 | x1, y1, x2, y2 = bbox["left"], bbox["top"], bbox["right"], bbox["bottom"]
238 | bounding_box_8pts = np.array([x1, y1, x2, y1, x2, y2, x1, y2])
239 | bounding_box_8pts = bounding_box_8pts.reshape(-1, 2)
240 |
241 | # get a good bbox for the facial landmarks
242 | min_pt = np.min(bounding_box_8pts, axis=0)
243 | max_pt = np.max(bounding_box_8pts, axis=0)
244 |
245 | # find out the bbox of the crop region
246 | bbox_size = int(np.max(max_pt - min_pt) * bbox_size_factor)
247 | center_pt = 0.5 * min_pt + 0.5 * max_pt
248 | bbox_min, bbox_max = center_pt - bbox_size * 0.5, center_pt + bbox_size * 0.5
249 | bbox_min = bbox_min.astype(np.int32)
250 | bbox_max = bbox_max.astype(np.int32)
251 |
252 | # compute necessary padding
253 | padding_left = abs(min(bbox_min[0], 0))
254 | padding_top = abs(min(bbox_min[1], 0))
255 | padding_right = max(bbox_max[0] - w, 0)
256 | padding_bottom = max(bbox_max[1] - h, 0)
257 |
258 | # crop the image properly by computing proper crop bounds
259 | crop_left = 0 if padding_left > 0 else bbox_min[0]
260 | crop_top = 0 if padding_top > 0 else bbox_min[1]
261 | crop_right = w if padding_right > 0 else bbox_max[0]
262 | crop_bottom = h if padding_bottom > 0 else bbox_max[1]
263 |
264 | cropped_image = img.crop((crop_left, crop_top, crop_right, crop_bottom))
265 |
266 | # copy the cropped image to padded image
267 | padded_image = Image.new(img.mode, (bbox_size, bbox_size))
268 | padded_image.paste(
269 | cropped_image,
270 | (
271 | padding_left,
272 | padding_top,
273 | padding_left + crop_right - crop_left,
274 | padding_top + crop_bottom - crop_top,
275 | ),
276 | )
277 |
278 | bbox_padded = {}
279 | bbox_padded["left"] = crop_left - padding_left
280 | bbox_padded["right"] = crop_right + padding_right
281 | bbox_padded["top"] = crop_top - padding_top
282 | bbox_padded["bottom"] = crop_bottom + padding_bottom
283 |
284 | bbox = {}
285 | bbox["left"] = crop_left
286 | bbox["right"] = crop_right
287 | bbox["top"] = crop_top
288 | bbox["bottom"] = crop_bottom
289 |
290 | return padded_image, bbox_padded, bbox
291 |
292 |
293 | def resize_image(img, min_size=600, max_size=1000):
294 | width = img.width
295 | height = img.height
296 | w, h, scale = width, height, 1.0
297 |
298 | if width < height:
299 | if width < min_size:
300 | w = min_size
301 | h = int(height * min_size / width)
302 | scale = float(min_size) / float(width)
303 | elif width > max_size:
304 | w = max_size
305 | h = int(height * max_size / width)
306 | scale = float(max_size) / float(width)
307 | else:
308 | if height < min_size:
309 | w = int(width * min_size / height)
310 | h = min_size
311 | scale = float(min_size) / float(height)
312 | elif height > max_size:
313 | w = int(width * max_size / height)
314 | h = max_size
315 | scale = float(max_size) / float(height)
316 |
317 | img_resized = img.resize((w, h))
318 | img_resize_info = [h, w, scale]
319 |
320 | return img_resized, img_resize_info
321 |
--------------------------------------------------------------------------------
/utils/json_loader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from PIL import Image
8 | from torch.utils.data import DataLoader
9 | from torchvision.datasets import ImageFolder
10 | from tqdm import tqdm
11 |
12 | from .image_operations import bbox_is_dict
13 | from .pose_operations import get_pose, pose_bbox_to_full_image
14 |
15 |
16 | class FramesJsonList(ImageFolder):
17 | def __init__(self, threed_5_points, threed_68_points, json_list, dataset_path=None):
18 | self.samples = []
19 | self.bboxes = []
20 | self.landmarks = []
21 | self.threed_5_points = threed_5_points
22 | self.threed_68_points = threed_68_points
23 | self.dataset_path = dataset_path
24 |
25 | image_paths = pd.read_csv(json_list, delimiter=" ", header=None)
26 | image_paths = np.asarray(image_paths).squeeze()
27 |
28 | print("Loading frames paths...")
29 | for image_path in tqdm(image_paths):
30 | with open(image_path) as f:
31 | image_json = json.load(f)
32 |
33 | # path to the image
34 | img_path = image_json["image_path"]
35 | # if not absolute path, append the dataset path
36 | if self.dataset_path is not None:
37 | img_path = os.path.join(self.dataset_path, img_path)
38 | self.samples.append(img_path)
39 |
40 | # landmarks used to create pose labels
41 | self.landmarks.append(image_json["landmarks"])
42 |
43 | # load bboxes
44 | self.bboxes.append(image_json["bboxes"])
45 |
46 | def __len__(self):
47 | return len(self.samples)
48 |
49 | def __getitem__(self, index):
50 | image_path = Path(self.samples[index])
51 |
52 | img = Image.open(image_path)
53 |
54 | (w, h) = img.size
55 | global_intrinsics = np.array(
56 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
57 | )
58 | bboxes = self.bboxes[index]
59 | landmarks = self.landmarks[index]
60 |
61 | bbox_labels = []
62 | landmark_labels = []
63 | pose_labels = []
64 | global_pose_labels = []
65 |
66 | for i in range(len(bboxes)):
67 | bbox = np.asarray(bboxes[i])[:4].astype(int)
68 | landmark = np.asarray(landmarks[i])[:, :2].astype(float)
69 |
70 | # remove samples that do not have height ot width or are negative
71 | if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]:
72 | continue
73 |
74 | if -1 in landmark:
75 | global_pose_labels.append([-9, -9, -9, -9, -9, -9])
76 | pose_labels.append([-9, -9, -9, -9, -9, -9])
77 |
78 | else:
79 | landmark[:, 0] -= bbox[0]
80 | landmark[:, 1] -= bbox[1]
81 |
82 | w = int(bbox[2] - bbox[0])
83 | h = int(bbox[3] - bbox[1])
84 |
85 | bbox_intrinsics = np.array(
86 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
87 | )
88 |
89 | if len(landmark) == 5:
90 | P, pose = get_pose(self.threed_5_points, landmark, bbox_intrinsics)
91 | else:
92 | P, pose = get_pose(
93 | self.threed_68_points,
94 | landmark,
95 | bbox_intrinsics,
96 | )
97 |
98 | pose_labels.append(pose.tolist())
99 |
100 | global_pose = pose_bbox_to_full_image(
101 | pose, global_intrinsics, bbox_is_dict(bbox)
102 | )
103 |
104 | global_pose_labels.append(global_pose.tolist())
105 |
106 | bbox_labels.append(bbox.tolist())
107 | landmark_labels.append(self.landmarks[index][i])
108 |
109 | with open(image_path, "rb") as f:
110 | raw_img = f.read()
111 |
112 | return (
113 | raw_img,
114 | global_pose_labels,
115 | bbox_labels,
116 | pose_labels,
117 | landmark_labels,
118 | )
119 |
120 |
121 | class JsonLoader(DataLoader):
122 | def __init__(
123 | self, workers, json_list, threed_5_points, threed_68_points, dataset_path=None
124 | ):
125 | self._dataset = FramesJsonList(
126 | threed_5_points, threed_68_points, json_list, dataset_path
127 | )
128 |
129 | super(JsonLoader, self).__init__(
130 | self._dataset, num_workers=workers, collate_fn=lambda x: x
131 | )
132 |
--------------------------------------------------------------------------------
/utils/json_loader_300wlp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from PIL import Image
8 | from scipy.spatial.transform import Rotation
9 | from torch.utils.data import DataLoader
10 | from torchvision.datasets import ImageFolder
11 | from tqdm import tqdm
12 |
13 | from .image_operations import bbox_is_dict, expand_bbox_rectangle
14 | from .pose_operations import get_pose, pose_full_image_to_bbox
15 |
16 |
17 | class FramesJsonList(ImageFolder):
18 | def __init__(self, threed_5_points, threed_68_points, json_list, dataset_path=None):
19 | self.samples = []
20 | self.bboxes = []
21 | self.landmarks = []
22 | self.poses_para = []
23 | self.threed_5_points = threed_5_points
24 | self.threed_68_points = threed_68_points
25 | self.dataset_path = dataset_path
26 |
27 | image_paths = pd.read_csv(json_list, delimiter=",", header=None)
28 | image_paths = np.asarray(image_paths).squeeze()
29 |
30 | print("Loading frames paths...")
31 | for image_path in tqdm(image_paths):
32 | with open(image_path) as f:
33 | image_json = json.load(f)
34 |
35 | # path to the image
36 | img_path = image_json["image_path"]
37 | # if not absolute path, append the dataset path
38 | if self.dataset_path is not None:
39 | img_path = os.path.join(self.dataset_path, img_path)
40 | self.samples.append(img_path)
41 |
42 | # landmarks used to create pose labels
43 | self.landmarks.append(image_json["landmarks"])
44 |
45 | # load bboxes
46 | self.bboxes.append(image_json["bboxes"])
47 |
48 | # load bboxes
49 | self.poses_para.append(image_json["pose_para"])
50 |
51 | def __len__(self):
52 | return len(self.samples)
53 |
54 | def __getitem__(self, index):
55 | image_path = Path(self.samples[index])
56 |
57 | img = Image.open(image_path)
58 |
59 | (img_w, img_h) = img.size
60 | global_intrinsics = np.array(
61 | [[img_w + img_h, 0, img_w // 2], [0, img_w + img_h, img_h // 2], [0, 0, 1]]
62 | )
63 | bboxes = self.bboxes[index]
64 | landmarks = self.landmarks[index]
65 | pose_para = self.poses_para[index]
66 |
67 | bbox_labels = []
68 | landmark_labels = []
69 | pose_labels = []
70 | global_pose_labels = []
71 |
72 | for i in range(len(bboxes)):
73 | bbox = np.asarray(bboxes[i])[:4].astype(int)
74 | landmark = np.asarray(landmarks[i])[:, :2].astype(float)
75 | pose_para = np.asarray(pose_para[i])[:3].astype(float)
76 |
77 | # remove samples that do not have height ot width or are negative
78 | if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]:
79 | continue
80 |
81 | if -1 in landmark:
82 | global_pose_labels.append([-9, -9, -9, -9, -9, -9])
83 | pose_labels.append([-9, -9, -9, -9, -9, -9])
84 |
85 | else:
86 | P, pose = get_pose(
87 | self.threed_68_points,
88 | landmark,
89 | global_intrinsics,
90 | )
91 |
92 | pose[:3] = self.convert_aflw(pose_para)
93 | global_pose_labels.append(pose.tolist())
94 |
95 | projected_bbox = expand_bbox_rectangle(
96 | img_w, img_h, 1.1, 1.1, landmark, roll=pose[2]
97 | )
98 |
99 | local_pose = pose_full_image_to_bbox(
100 | pose,
101 | global_intrinsics,
102 | bbox_is_dict(projected_bbox),
103 | )
104 |
105 | pose_labels.append(local_pose.tolist())
106 |
107 | bbox_labels.append(projected_bbox.tolist())
108 | landmark_labels.append(self.landmarks[index][i])
109 |
110 | with open(image_path, "rb") as f:
111 | raw_img = f.read()
112 |
113 | return (
114 | raw_img,
115 | global_pose_labels,
116 | bbox_labels,
117 | pose_labels,
118 | landmark_labels,
119 | )
120 |
121 | def convert_aflw(self, angle):
122 | rot_mat_1 = Rotation.from_euler(
123 | "xyz", [angle[0], -angle[1], -angle[2]], degrees=False
124 | ).as_matrix()
125 | rot_mat_2 = np.transpose(rot_mat_1)
126 | return Rotation.from_matrix(rot_mat_2).as_rotvec()
127 |
128 |
129 | class JsonLoader(DataLoader):
130 | def __init__(
131 | self, workers, json_list, threed_5_points, threed_68_points, dataset_path=None
132 | ):
133 | self._dataset = FramesJsonList(
134 | threed_5_points, threed_68_points, json_list, dataset_path
135 | )
136 |
137 | super(JsonLoader, self).__init__(
138 | self._dataset, num_workers=workers, collate_fn=lambda x: x
139 | )
140 |
--------------------------------------------------------------------------------
/utils/pose_operations.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from scipy.spatial.transform import Rotation
6 |
7 | from .face_align import norm_crop
8 | from .image_operations import bbox_is_dict, expand_bbox_rectangle
9 |
10 |
11 | def bbox_dict_to_np(bbox):
12 | bbox_np = np.zeros(shape=4)
13 | bbox_np[0] = bbox["left"]
14 | bbox_np[1] = bbox["top"]
15 | bbox_np[2] = bbox["right"]
16 | bbox_np[3] = bbox["bottom"]
17 |
18 | return bbox_np
19 |
20 |
21 | def quat_to_rotation_mat_tensor(quat):
22 | x = quat[0]
23 | y = quat[1]
24 | z = quat[2]
25 | w = quat[3]
26 | x2 = x * x
27 | y2 = y * y
28 | z2 = z * z
29 | w2 = w * w
30 | xy = x * y
31 | zw = z * w
32 | xz = x * z
33 | yw = y * w
34 | yz = y * z
35 | xw = x * w
36 | matrix = torch.zeros(3, 3).to(quat.device)
37 | matrix[0, 0] = x2 - y2 - z2 + w2
38 | matrix[1, 0] = 2 * (xy + zw)
39 | matrix[2, 0] = 2 * (xz - yw)
40 | matrix[0, 1] = 2 * (xy - zw)
41 | matrix[1, 1] = -x2 + y2 - z2 + w2
42 | matrix[2, 1] = 2 * (yz + xw)
43 | matrix[0, 2] = 2 * (xz + yw)
44 | matrix[1, 2] = 2 * (yz - xw)
45 | matrix[2, 2] = -x2 - y2 + z2 + w2
46 | return matrix
47 |
48 |
49 | def from_rotvec_tensor(rotvec):
50 | norm = torch.norm(rotvec)
51 | small_angle = norm <= 1e-3
52 | scale = 0
53 | if small_angle:
54 | scale = 0.5 - norm ** 2 / 48 + norm ** 4 / 3840
55 | else:
56 | scale = torch.sin(norm / 2) / norm
57 | quat = torch.zeros(4).to(rotvec.device)
58 | quat[0:3] = scale * rotvec
59 | quat[3] = torch.cos(norm / 2)
60 |
61 | return quat_to_rotation_mat_tensor(quat)
62 |
63 |
64 | def transform_points_tensor(points, pose):
65 | return torch.matmul(points, from_rotvec_tensor(pose[:3]).T) + pose[3:]
66 |
67 |
68 | def get_bbox_intrinsics(image_intrinsics, bbox):
69 | # crop principle point of view
70 | bbox_center_x = bbox["left"] + ((bbox["right"] - bbox["left"]) // 2)
71 | bbox_center_y = bbox["top"] + ((bbox["bottom"] - bbox["top"]) // 2)
72 |
73 | # create a camera intrinsics from the bbox center
74 | bbox_intrinsics = image_intrinsics.copy()
75 | bbox_intrinsics[0, 2] = bbox_center_x
76 | bbox_intrinsics[1, 2] = bbox_center_y
77 |
78 | return bbox_intrinsics
79 |
80 |
81 | def get_bbox_intrinsics_np(image_intrinsics, bbox):
82 | # crop principle point of view
83 | bbox_center_x = bbox[0] + ((bbox[2] - bbox[0]) // 2)
84 | bbox_center_y = bbox[1] + ((bbox[3] - bbox[1]) // 2)
85 |
86 | # create a camera intrinsics from the bbox center
87 | bbox_intrinsics = image_intrinsics.copy()
88 | bbox_intrinsics[0, 2] = bbox_center_x
89 | bbox_intrinsics[1, 2] = bbox_center_y
90 |
91 | return bbox_intrinsics
92 |
93 |
94 | def pose_full_image_to_bbox(pose, image_intrinsics, bbox):
95 | # check if bbox is np or dict
96 | bbox = bbox_is_dict(bbox)
97 |
98 | # rotation vector
99 | rvec = pose[:3].copy()
100 |
101 | # translation and scale vector
102 | tvec = pose[3:].copy()
103 |
104 | # get camera intrinsics using bbox
105 | bbox_intrinsics = get_bbox_intrinsics(image_intrinsics, bbox)
106 |
107 | # focal length
108 | focal_length = image_intrinsics[0, 0]
109 |
110 | # bbox_size
111 | bbox_width = bbox["right"] - bbox["left"]
112 | bbox_height = bbox["bottom"] - bbox["top"]
113 | bbox_size = bbox_width + bbox_height
114 |
115 | # project crop points using the full image camera intrinsics
116 | projected_point = image_intrinsics.dot(tvec.T)
117 |
118 | # reverse the projected points using the crop camera intrinsics
119 | tvec = projected_point.dot(np.linalg.inv(bbox_intrinsics.T))
120 |
121 | # adjust scale
122 | tvec[2] /= focal_length / bbox_size
123 |
124 | # same for rotation
125 | rmat = Rotation.from_rotvec(rvec).as_matrix()
126 | # project crop points using the crop camera intrinsics
127 | projected_point = image_intrinsics.dot(rmat)
128 | # reverse the projected points using the full image camera intrinsics
129 | rmat = np.linalg.inv(bbox_intrinsics).dot(projected_point)
130 | rvec = Rotation.from_matrix(rmat).as_rotvec()
131 |
132 | return np.concatenate([rvec, tvec])
133 |
134 |
135 | def pose_bbox_to_full_image(pose, image_intrinsics, bbox):
136 | # check if bbox is np or dict
137 | bbox = bbox_is_dict(bbox)
138 |
139 | # rotation vector
140 | rvec = pose[:3].copy()
141 |
142 | # translation and scale vector
143 | tvec = pose[3:].copy()
144 |
145 | # get camera intrinsics using bbox
146 | bbox_intrinsics = get_bbox_intrinsics(image_intrinsics, bbox)
147 |
148 | # focal length
149 | focal_length = image_intrinsics[0, 0]
150 |
151 | # bbox_size
152 | bbox_width = bbox["right"] - bbox["left"]
153 | bbox_height = bbox["bottom"] - bbox["top"]
154 | bbox_size = bbox_width + bbox_height
155 |
156 | # adjust scale
157 | tvec[2] *= focal_length / bbox_size
158 |
159 | # project crop points using the crop camera intrinsics
160 | projected_point = bbox_intrinsics.dot(tvec.T)
161 |
162 | # reverse the projected points using the full image camera intrinsics
163 | tvec = projected_point.dot(np.linalg.inv(image_intrinsics.T))
164 |
165 | # same for rotation
166 | rmat = Rotation.from_rotvec(rvec).as_matrix()
167 | # project crop points using the crop camera intrinsics
168 | projected_point = bbox_intrinsics.dot(rmat)
169 | # reverse the projected points using the full image camera intrinsics
170 | rmat = np.linalg.inv(image_intrinsics).dot(projected_point)
171 | rvec = Rotation.from_matrix(rmat).as_rotvec()
172 |
173 | return np.concatenate([rvec, tvec])
174 |
175 |
176 | def plot_3d_landmark(verts, campose, intrinsics):
177 | lm_3d_trans = transform_points(verts, campose)
178 |
179 | # project to image plane
180 | lms_3d_trans_proj = intrinsics.dot(lm_3d_trans.T).T
181 | lms_projected = (
182 | lms_3d_trans_proj[:, :2] / np.tile(lms_3d_trans_proj[:, 2], (2, 1)).T
183 | )
184 |
185 | return lms_projected, lms_3d_trans_proj
186 |
187 |
188 | def plot_3d_landmark_torch(verts, campose, intrinsics):
189 | lm_3d_trans = transform_points_tensor(verts, campose)
190 |
191 | # project to image plane
192 | lms_3d_trans_proj = torch.matmul(intrinsics, lm_3d_trans.T).T
193 | lms_projected = lms_3d_trans_proj[:, :2] / lms_3d_trans_proj[:, 2].repeat(2, 1).T
194 |
195 | return lms_projected
196 |
197 |
198 | def transform_points(points, pose):
199 | return points.dot(Rotation.from_rotvec(pose[:3]).as_matrix().T) + pose[3:]
200 |
201 |
202 | def get_pose(vertices, twod_landmarks, camera_intrinsics, initial_pose=None):
203 | threed_landmarks = vertices
204 | twod_landmarks = np.asarray(twod_landmarks).astype("float32")
205 |
206 | # if initial_pose is provided, use it as a guess to solve new pose
207 | if initial_pose is not None:
208 | initial_pose = np.asarray(initial_pose)
209 | retval, rvecs, tvecs = cv2.solvePnP(
210 | threed_landmarks,
211 | twod_landmarks,
212 | camera_intrinsics,
213 | None,
214 | rvec=initial_pose[:3],
215 | tvec=initial_pose[3:],
216 | flags=cv2.SOLVEPNP_EPNP,
217 | useExtrinsicGuess=True,
218 | )
219 | else:
220 | retval, rvecs, tvecs = cv2.solvePnP(
221 | threed_landmarks,
222 | twod_landmarks,
223 | camera_intrinsics,
224 | None,
225 | flags=cv2.SOLVEPNP_EPNP,
226 | )
227 |
228 | rotation_mat = np.zeros(shape=(3, 3))
229 | R = cv2.Rodrigues(rvecs, rotation_mat)[0]
230 |
231 | RT = np.column_stack((R, tvecs))
232 | P = np.matmul(camera_intrinsics, RT)
233 | dof = np.append(rvecs, tvecs)
234 |
235 | return P, dof
236 |
237 |
238 | def transform_pose_global_project_bbox(
239 | boxes,
240 | dofs,
241 | pose_mean,
242 | pose_stddev,
243 | image_shape,
244 | threed_68_points=None,
245 | bbox_x_factor=1.1,
246 | bbox_y_factor=1.1,
247 | expand_forehead=0.3,
248 | ):
249 | if len(dofs) == 0:
250 | return boxes, dofs
251 |
252 | device = dofs.device
253 |
254 | boxes = boxes.cpu().numpy()
255 | dofs = dofs.cpu().numpy()
256 |
257 | threed_68_points = threed_68_points.numpy()
258 |
259 | (h, w) = image_shape
260 | global_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])
261 |
262 | if threed_68_points is not None:
263 | threed_68_points = threed_68_points
264 |
265 | pose_mean = pose_mean.numpy()
266 | pose_stddev = pose_stddev.numpy()
267 |
268 | dof_mean = pose_mean
269 | dof_std = pose_stddev
270 | dofs = dofs * dof_std + dof_mean
271 |
272 | projected_boxes = []
273 | global_dofs = []
274 |
275 | for i in range(dofs.shape[0]):
276 | global_dof = pose_bbox_to_full_image(dofs[i], global_intrinsics, boxes[i])
277 | global_dofs.append(global_dof)
278 |
279 | if threed_68_points is not None:
280 | # project points and get bbox
281 | projected_lms, _ = plot_3d_landmark(
282 | threed_68_points, global_dof, global_intrinsics
283 | )
284 | projected_bbox = expand_bbox_rectangle(
285 | w,
286 | h,
287 | bbox_x_factor=bbox_x_factor,
288 | bbox_y_factor=bbox_y_factor,
289 | lms=projected_lms,
290 | roll=global_dof[2],
291 | expand_forehead=expand_forehead,
292 | )
293 | else:
294 | projected_bbox = boxes[i]
295 |
296 | projected_boxes.append(projected_bbox)
297 |
298 | global_dofs = torch.from_numpy(np.asarray(global_dofs)).float()
299 | projected_boxes = torch.from_numpy(np.asarray(projected_boxes)).float()
300 |
301 | return projected_boxes.to(device), global_dofs.to(device)
302 |
303 |
304 | def align_faces(threed_5_points, img, poses, face_size=224):
305 | if len(poses) == 0:
306 | return None
307 | elif np.ndim(poses) == 1:
308 | poses = poses[np.newaxis, :]
309 |
310 | (w, h) = img.size
311 | global_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])
312 |
313 | faces_aligned = []
314 |
315 | for pose in poses:
316 | proj_lms, _ = plot_3d_landmark(
317 | threed_5_points, np.asarray(pose), global_intrinsics
318 | )
319 | face_aligned = norm_crop(np.asarray(img).copy(), proj_lms, face_size)
320 | faces_aligned.append(Image.fromarray(face_aligned))
321 |
322 | return faces_aligned
323 |
--------------------------------------------------------------------------------
/utils/renderer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from Sim3DR import RenderPipeline
4 |
5 | from .pose_operations import plot_3d_landmark
6 |
7 |
8 | def _to_ctype(arr):
9 | if not arr.flags.c_contiguous:
10 | return arr.copy(order="C")
11 | return arr
12 |
13 |
14 | def get_colors(img, ver):
15 | h, w, _ = img.shape
16 | ver[0, :] = np.minimum(np.maximum(ver[0, :], 0), w - 1) # x
17 | ver[1, :] = np.minimum(np.maximum(ver[1, :], 0), h - 1) # y
18 | ind = np.round(ver).astype(np.int32)
19 | colors = img[ind[1, :], ind[0, :], :] / 255.0 # n x 3
20 |
21 | return colors.copy()
22 |
23 |
24 | class Renderer:
25 | def __init__(
26 | self,
27 | vertices_path="../pose_references/vertices_trans.npy",
28 | triangles_path="../pose_references/triangles.npy",
29 | ):
30 | self.vertices = np.load(vertices_path)
31 | self.triangles = _to_ctype(np.load(triangles_path).T)
32 | self.vertices[:, 0] *= -1
33 |
34 | self.cfg = {
35 | "intensity_ambient": 0.3,
36 | "color_ambient": (1, 1, 1),
37 | "intensity_directional": 0.6,
38 | "color_directional": (1, 1, 1),
39 | "intensity_specular": 0.1,
40 | "specular_exp": 5,
41 | "light_pos": (0, 0, 5),
42 | "view_pos": (0, 0, 5),
43 | }
44 |
45 | self.render_app = RenderPipeline(**self.cfg)
46 |
47 | def transform_vertices(self, img, poses, global_intrinsics=None):
48 | (w, h) = img.size
49 | if global_intrinsics is None:
50 | global_intrinsics = np.array(
51 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]
52 | )
53 |
54 | transformed_vertices = []
55 | for pose in poses:
56 | projected_lms = np.zeros_like(self.vertices)
57 | projected_lms[:, :2], lms_3d_trans_proj = plot_3d_landmark(
58 | self.vertices, pose, global_intrinsics
59 | )
60 | projected_lms[:, 2] = lms_3d_trans_proj[:, 2] * -1
61 |
62 | range_x = np.max(projected_lms[:, 0]) - np.min(projected_lms[:, 0])
63 | range_y = np.max(projected_lms[:, 1]) - np.min(projected_lms[:, 1])
64 |
65 | s = (h + w) / pose[5]
66 | projected_lms[:, 2] *= s
67 | projected_lms[:, 2] += (range_x + range_y) * 3
68 |
69 | transformed_vertices.append(projected_lms)
70 |
71 | return transformed_vertices
72 |
73 | def render(self, img, transformed_vertices, alpha=0.9, save_path=None):
74 | img = np.asarray(img)
75 | overlap = img.copy()
76 |
77 | for vertices in transformed_vertices:
78 | vertices = _to_ctype(vertices) # transpose
79 | overlap = self.render_app(vertices, self.triangles, overlap)
80 |
81 | res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0)
82 |
83 | if save_path is not None:
84 | cv2.imwrite(save_path, res)
85 | print(f"Save visualization result to {save_path}")
86 |
87 | return res
88 |
89 | def save_to_obj(self, img, ver_lst, height, save_path):
90 | n_obj = len(ver_lst) # count obj
91 |
92 | if n_obj <= 0:
93 | return
94 |
95 | n_vertex = ver_lst[0].T.shape[1]
96 | n_face = self.triangles.shape[0]
97 |
98 | with open(save_path, "w") as f:
99 | for i in range(n_obj):
100 | ver = ver_lst[i].T
101 | colors = get_colors(img, ver)
102 |
103 | for j in range(n_vertex):
104 | x, y, z = ver[:, j]
105 | f.write(
106 | f"v {x:.2f} {height - y:.2f} {z:.2f} {colors[j, 2]:.2f} "
107 | f"{colors[j, 1]:.2f} {colors[j, 0]:.2f}\n"
108 | )
109 |
110 | for i in range(n_obj):
111 | offset = i * n_vertex
112 | for j in range(n_face):
113 | idx1, idx2, idx3 = self.triangles[j] # m x 3
114 | f.write(
115 | f"f {idx3 + 1 + offset} {idx2 + 1 + offset} "
116 | f"{idx1 + 1 + offset}\n"
117 | )
118 |
119 | print(f"Dump tp {save_path}")
120 |
--------------------------------------------------------------------------------