├── .gitattributes ├── .gitignore ├── README.md ├── Sim3DR ├── .gitignore ├── Sim3DR.py ├── __init__.py ├── _init_paths.py ├── build_sim3dr.sh ├── lib │ ├── rasterize.h │ ├── rasterize.pyx │ └── rasterize_kernel.cpp ├── lighting.py ├── readme.md ├── setup.py └── tests │ ├── .gitignore │ ├── CMakeLists.txt │ ├── io.cpp │ ├── io.h │ └── test.cpp ├── config.py ├── convert_json_list_to_lmdb.py ├── data_loader_lmdb.py ├── data_loader_lmdb_augmenter.py ├── early_stop.py ├── evaluation ├── evaluate_wider.py └── jupyter_notebooks │ ├── AFLW2000_annotations.txt │ ├── BIWI_annotations.txt │ ├── aflw_2000_3d_evaluation.ipynb │ ├── biwi_evaluation.ipynb │ ├── test_own_images.ipynb │ └── visualize_trained_model_predictions.ipynb ├── generalized_rcnn.py ├── img2pose.py ├── license.md ├── losses.py ├── model_loader.py ├── models.py ├── pose_references ├── reference_3d_5_points_trans.npy ├── reference_3d_68_points_trans.npy ├── triangles.npy └── vertices_trans.npy ├── requirements.txt ├── rpn.py ├── run_face_alignment.py ├── teaser.jpeg ├── train.py ├── train_logger.py └── utils ├── annotate_dataset.py ├── augmentation.py ├── dist.py ├── face_align.py ├── image_operations.py ├── json_loader.py ├── json_loader_300wlp.py ├── pose_operations.py └── renderer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | pip-wheel-metadata/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | *.py,cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | db.sqlite3-journal 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 93 | __pypackages__/ 94 | 95 | # Celery stuff 96 | celerybeat-schedule 97 | celerybeat.pid 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # User 130 | **/annotations/* 131 | **/datasets/* 132 | **/workspace/* 133 | **/workspace_old/* 134 | **/results/* 135 | **/models/* 136 | *.json 137 | *.sge 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation 2 | 3 | ## Paper accepted to the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2021 4 | 5 | [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/img2pose-face-alignment-and-detection-via/head-pose-estimation-on-aflw2000)](https://paperswithcode.com/sota/head-pose-estimation-on-aflw2000?p=img2pose-face-alignment-and-detection-via) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/img2pose-face-alignment-and-detection-via/head-pose-estimation-on-biwi)](https://paperswithcode.com/sota/head-pose-estimation-on-biwi?p=img2pose-face-alignment-and-detection-via) 8 | 9 |
10 | 11 |
Figure 1: We estimate the 6DoF rigid transformation of a 3D face (rendered in silver), aligning it with even the tiniest faces, without face detection or facial landmark localization. Our estimated 3D face locations are rendered by descending distances from the camera, for coherent visualization.
12 |
13 | 14 | ## TL;DR 15 | This repository provides a novel method for six degrees of fredoom (6DoF) detection on multiple faces without the need of prior face detection. After prediction, one can visualize the detections (as show in the figure above), customize projected bounding boxes, or crop and align each face for further processing. See details below. 16 | 17 | ## Table of contents 18 | 19 | 20 | - [Paper details](#paper-details) 21 | * [Abstract](#abstract) 22 | * [Video Spotlight](#video-spotlight) 23 | * [Citation](#citation) 24 | - [Installation](#installation) 25 | - [Training](#training) 26 | * [Prepare WIDER FACE dataset](#prepare-wider-face-dataset) 27 | * [Train](#train) 28 | * [Training on your own dataset](#training-on-your-own-dataset) 29 | - [Testing](#testing) 30 | * [Visualizing trained model](#visualizing-trained-model) 31 | * [WIDER FACE dataset evaluation](#wider-face-dataset-evaluation) 32 | * [AFLW2000-3D dataset evaluation](#aflw2000-3d-dataset-evaluation) 33 | * [BIWI dataset evaluation](#biwi-dataset-evaluation) 34 | * [Testing on your own images](#testing-on-your-own-images) 35 | - [Output customization](#output-customization) 36 | - [Align faces](#align-faces) 37 | - [Resources](#resources) 38 | - [License](#license) 39 | 40 | 41 | ## Paper details 42 | 43 | [Vítor Albiero](https://vitoralbiero.netlify.app), Xingyu Chen, [Xi Yin](https://xiyinmsu.github.io/), Guan Pang, [Tal Hassner](https://talhassner.github.io/home/), "*img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation,*" CVPR, 2021, [arXiv:2012.07791](https://arxiv.org/abs/2012.07791) 44 | 45 | 46 | ### Abstract 47 | > We propose real-time, six degrees of freedom (6DoF), 3D face pose estimation without face detection or landmark localization. We observe that estimating the 6DoF rigid transformation of a face is a simpler problem than facial landmark detection, often used for 3D face alignment. In addition, 6DoF offers more information than face bounding box labels. We leverage these observations to make multiple contributions: (a) We describe an easily trained, efficient, Faster R-CNN--based model which regresses 6DoF pose for all faces in the photo, without preliminary face detection. (b) We explain how pose is converted and kept consistent between the input photo and arbitrary crops created while training and evaluating our model. (c) Finally, we show how face poses can replace detection bounding box training labels. Tests on AFLW2000-3D and BIWI show that our method runs at real-time and outperforms state of the art (SotA) face pose estimators. Remarkably, our method also surpasses SotA models of comparable complexity on the WIDER FACE detection benchmark, despite not been optimized on bounding box labels. 48 | 49 | ### Video Spotlight 50 | [CVPR 2021 Spotlight](https://youtu.be/vDGlvpnzXGo) 51 | 52 | ### Citation 53 | If you use any part of our code or data, please cite our paper. 54 | ``` 55 | @inproceedings{albiero2021img2pose, 56 | title={img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation}, 57 | author={Albiero, Vítor and Chen, Xingyu and Yin, Xi and Pang, Guan and Hassner, Tal}, 58 | booktitle={CVPR}, 59 | year={2021}, 60 | url={https://arxiv.org/abs/2012.07791}, 61 | } 62 | ``` 63 | 64 | ## Installation 65 | 66 | 67 | Install dependecies with Python 3. 68 | ``` 69 | pip install -r requirements.txt 70 | ``` 71 | Install the renderer, which is used to visualize predictions. The renderer implementation is forked from [here](https://github.com/cleardusk/3DDFA_V2/tree/master/Sim3DR). 72 | ``` 73 | cd Sim3DR 74 | sh build_sim3dr.sh 75 | ``` 76 | 77 | ## Training 78 | ### Prepare WIDER FACE dataset 79 | First, download our annotations as instructed in [Annotations](https://github.com/vitoralbiero/img2pose/wiki/Annotations). 80 | 81 | Download [WIDER FACE](http://shuoyang1213.me/WIDERFACE/) dataset and extract to datasets/WIDER_Face. 82 | 83 | Then, to create the train and validation files (LMDB), run the following scripts. 84 | 85 | ``` 86 | python3 convert_json_list_to_lmdb.py \ 87 | --json_list ./annotations/WIDER_train_annotations.txt \ 88 | --dataset_path ./datasets/WIDER_Face/WIDER_train/images/ \ 89 | --dest ./datasets/lmdb/ \ 90 | --train 91 | ``` 92 | This first script will generate a LMDB dataset, which contains the training images along with annotations. It will also output a pose mean and std deviation files, which will be used for training and testing. 93 | ``` 94 | python3 convert_json_list_to_lmdb.py \ 95 | --json_list ./annotations/WIDER_val_annotations.txt \ 96 | --dataset_path ./datasets/WIDER_Face/WIDER_val/images/ \ 97 | --dest ./datasets/lmdb 98 | ``` 99 | This second script will create a LMDB containing the validation images along with annotations. 100 | 101 | ### Train 102 | Once the LMDB train/val files are created, to start training simple run the script below. 103 | ``` 104 | CUDA_VISIBLE_DEVICES=0 python3 train.py \ 105 | --pose_mean ./datasets/lmdb/WIDER_train_annotations_pose_mean.npy \ 106 | --pose_stddev ./datasets/lmdb/WIDER_train_annotations_pose_stddev.npy \ 107 | --workspace ./workspace/ \ 108 | --train_source ./datasets/lmdb/WIDER_train_annotations.lmdb \ 109 | --val_source ./datasets/lmdb/WIDER_val_annotations.lmdb \ 110 | --prefix trial_1 \ 111 | --batch_size 2 \ 112 | --lr_plateau \ 113 | --early_stop \ 114 | --random_flip \ 115 | --random_crop \ 116 | --max_size 1400 117 | ``` 118 | To train with multiple GPUs (in the example below 4 GPUs), use the script below. 119 | ``` 120 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env train.py \ 121 | --pose_mean ./datasets/lmdb/WIDER_train_annotations_pose_mean.npy \ 122 | --pose_stddev ./datasets/lmdb/WIDER_train_annotations_pose_stddev.npy \ 123 | --workspace ./workspace/ \ 124 | --train_source ./datasets/lmdb/WIDER_train_annotations.lmdb \ 125 | --val_source ./datasets/lmdb/WIDER_val_annotations.lmdb \ 126 | --prefix trial_1 \ 127 | --batch_size 2 \ 128 | --lr_plateau \ 129 | --early_stop \ 130 | --random_flip \ 131 | --random_crop \ 132 | --max_size 1400 \ 133 | --distributed 134 | ``` 135 | 136 | ### Training on your own dataset 137 | If your dataset has facial landmarks and bounding boxes already annotated, store them into JSON files following the same format as in the [WIDER FACE annotations](https://github.com/vitoralbiero/img2pose/wiki/Annotations). 138 | 139 | If not, run the script below to annotate your dataset. You will need a detector and import it inside the script. 140 | ``` 141 | python3 utils/annotate_dataset.py 142 | --image_list list_of_images.txt 143 | --output_path ./annotations/dataset_name 144 | ``` 145 | After the dataset is annotated, create a list pointing to the JSON files there were saved. Then, follow the steps in [Prepare WIDER FACE dataset](https://github.com/vitoralbiero/img2pose#prepare-wider-face-dataset) replacing the WIDER annotations with your own dataset annotations. Once the LMDB and pose files are created, follow the steps in [Train](https://github.com/vitoralbiero/img2pose#train) replacing the WIDER LMDB and pose files with your dataset own files. 146 | 147 | ## Testing 148 | To evaluate with the pretrained model, download the model from [Model Zoo](https://github.com/vitoralbiero/img2pose/wiki/Model-Zoo), and extract it to the main folder. It will create a folder called models, which contains the model weights and the pose mean and std dev that was used for training. 149 | 150 | If evaluating with own trained model, change the pose mean and standard deviation to the ones trained with. 151 | 152 | ### Visualizing trained model 153 | To visualize a trained model on the WIDER FACE validation set run the notebook [visualize_trained_model_predictions](evaluation/jupyter_notebooks/visualize_trained_model_predictions.ipynb). 154 | 155 | ### WIDER FACE dataset evaluation 156 | If you haven't done already, download the [WIDER FACE](http://shuoyang1213.me/WIDERFACE/) dataset and extract to datasets/WIDER_Face. 157 | 158 | Download the [pre-trained model](https://drive.google.com/file/d/1OvnZ7OUQFg2bAgFADhT7UnCkSaXst10O/view?usp=sharing). 159 | 160 | ``` 161 | python3 evaluation/evaluate_wider.py \ 162 | --dataset_path datasets/WIDER_Face/WIDER_val/images/ \ 163 | --dataset_list datasets/WIDER_Face/wider_face_split/wider_face_val_bbx_gt.txt \ 164 | --pose_mean models/WIDER_train_pose_mean_v1.npy \ 165 | --pose_stddev models/WIDER_train_pose_stddev_v1.npy \ 166 | --pretrained_path models/img2pose_v1.pth \ 167 | --output_path results/WIDER_FACE/Val/ 168 | ``` 169 | 170 | To check mAP and plot curves, download the [eval tools](http://shuoyang1213.me/WIDERFACE/) and point to results/WIDER_FACE/Val. 171 | 172 | ### AFLW2000-3D dataset evaluation 173 | Download the [AFLW2000-3D](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/Database/AFLW2000-3D.zip) dataset and unzip to datasets/AFLW2000. 174 | 175 | Download the [fine-tuned model](https://drive.google.com/file/d/1wSqPr9h1x_TOaxuN-Nu3OlTmhqnuf6rZ/view?usp=sharing). 176 | 177 | Run the notebook [aflw_2000_3d_evaluation](./evaluation/jupyter_notebooks/aflw_2000_3d_evaluation.ipynb). 178 | 179 | ### BIWI dataset evaluation 180 | Download the [BIWI](http://data.vision.ee.ethz.ch/cvl/gfanelli/kinect_head_pose_db.tgz) dataset and unzip to datasets/BIWI. 181 | 182 | Download the [fine-tuned model](https://drive.google.com/file/d/1wSqPr9h1x_TOaxuN-Nu3OlTmhqnuf6rZ/view?usp=sharing). 183 | 184 | Run the notebook [biwi_evaluation](./evaluation/jupyter_notebooks/biwi_evaluation.ipynb). 185 | 186 | ### Testing on your own images 187 | 188 | Run the notebook [test_own_images](./evaluation/jupyter_notebooks/test_own_images.ipynb). 189 | 190 | ## Output customization 191 | 192 | For every face detected, the model outputs by default: 193 | - Pose: rx, ry, rz, tx, ty, tz 194 | - Projected bounding boxes: left, top, right, bottom 195 | - Face scores: 0 to 1 196 | 197 | Since the projected bounding box without expansion ends at the start of the forehead, we provide a way of expanding the forehead invidually, along with default x and y expansion. 198 | 199 | To customize the size of the projected bounding boxes, when creating the model change any of the bounding box expansion variables as shown below (a complete example can be seen at [visualize_trained_model_predictions](evaluation/jupyter_notebooks/visualize_trained_model_predictions.ipynb)). 200 | ```python 201 | # how much to expand in width 202 | bbox_x_factor = 1.1 203 | # how much to expand in height 204 | bbox_y_factor = 1.1 205 | # how much to expand in the forehead 206 | expand_forehead = 0.3 207 | 208 | img2pose_model = img2poseModel( 209 | ..., 210 | bbox_x_factor=bbox_x_factor, 211 | bbox_y_factor=bbox_y_factor, 212 | expand_forehead=expand_forehead, 213 | ) 214 | ``` 215 | 216 | ## Align faces 217 | To detect and align faces, simply run the command below, passing the path to the images you want to detect and align and the path to save them. 218 | ``` 219 | python3 run_face_alignment.py \ 220 | --pose_mean models/WIDER_train_pose_mean_v1.npy \ 221 | --pose_stddev models/WIDER_train_pose_stddev_v1.npy \ 222 | --pretrained_path models/img2pose_v1.pth \ 223 | --images_path image_path_or_list \ 224 | --output_path path_to_save_aligned_faces 225 | ``` 226 | 227 | ## Resources 228 | [Model Zoo](https://github.com/vitoralbiero/img2pose/wiki/Model-Zoo) 229 | 230 | [Annotations](https://github.com/vitoralbiero/img2pose/wiki/Annotations) 231 | 232 | [Data Zoo](https://github.com/vitoralbiero/img2pose/wiki/Data-Zoo) 233 | 234 | ## License 235 | Check [license](./license.md) for license details. 236 | -------------------------------------------------------------------------------- /Sim3DR/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | cmake-build-debug/ 3 | .idea/ 4 | build/ 5 | *.so 6 | data/ 7 | 8 | lib/rasterize.cpp -------------------------------------------------------------------------------- /Sim3DR/Sim3DR.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from . import _init_paths 4 | import numpy as np 5 | import Sim3DR_Cython 6 | 7 | 8 | def get_normal(vertices, triangles): 9 | normal = np.zeros_like(vertices, dtype=np.float32) 10 | Sim3DR_Cython.get_normal(normal, vertices, triangles, vertices.shape[0], triangles.shape[0]) 11 | return normal 12 | 13 | 14 | def rasterize(vertices, triangles, colors, bg=None, 15 | height=None, width=None, channel=None, 16 | reverse=False): 17 | if bg is not None: 18 | height, width, channel = bg.shape 19 | else: 20 | assert height is not None and width is not None and channel is not None 21 | bg = np.zeros((height, width, channel), dtype=np.uint8) 22 | 23 | buffer = np.zeros((height, width), dtype=np.float32) - 1e8 24 | 25 | if colors.dtype != np.float32: 26 | colors = colors.astype(np.float32) 27 | Sim3DR_Cython.rasterize(bg, vertices, triangles, colors, buffer, triangles.shape[0], height, width, channel, 28 | reverse=reverse) 29 | return bg 30 | -------------------------------------------------------------------------------- /Sim3DR/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from .Sim3DR import get_normal, rasterize 4 | from .lighting import RenderPipeline 5 | -------------------------------------------------------------------------------- /Sim3DR/_init_paths.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os.path as osp 4 | import sys 5 | 6 | 7 | def add_path(path): 8 | if path not in sys.path: 9 | sys.path.insert(0, path) 10 | 11 | 12 | this_dir = osp.dirname(__file__) 13 | lib_path = osp.join(this_dir, '.') 14 | add_path(lib_path) 15 | -------------------------------------------------------------------------------- /Sim3DR/build_sim3dr.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py build_ext --inplace -------------------------------------------------------------------------------- /Sim3DR/lib/rasterize.h: -------------------------------------------------------------------------------- 1 | #ifndef MESH_CORE_HPP_ 2 | #define MESH_CORE_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace std; 12 | 13 | class Point3D { 14 | public: 15 | float x; 16 | float y; 17 | float z; 18 | 19 | public: 20 | Point3D() : x(0.f), y(0.f), z(0.f) {} 21 | Point3D(float x_, float y_, float z_) : x(x_), y(y_), z(z_) {} 22 | 23 | void initialize(float x_, float y_, float z_){ 24 | this->x = x_; this->y = y_; this->z = z_; 25 | } 26 | 27 | Point3D cross(Point3D &p){ 28 | Point3D c; 29 | c.x = this->y * p.z - this->z * p.y; 30 | c.y = this->z * p.x - this->x * p.z; 31 | c.z = this->x * p.y - this->y * p.x; 32 | return c; 33 | } 34 | 35 | float dot(Point3D &p) { 36 | return this->x * p.x + this->y * p.y + this->z * p.z; 37 | } 38 | 39 | Point3D operator-(const Point3D &p) { 40 | Point3D np; 41 | np.x = this->x - p.x; 42 | np.y = this->y - p.y; 43 | np.z = this->z - p.z; 44 | return np; 45 | } 46 | 47 | }; 48 | 49 | class Point { 50 | public: 51 | float x; 52 | float y; 53 | 54 | public: 55 | Point() : x(0.f), y(0.f) {} 56 | Point(float x_, float y_) : x(x_), y(y_) {} 57 | float dot(Point p) { 58 | return this->x * p.x + this->y * p.y; 59 | } 60 | 61 | Point operator-(const Point &p) { 62 | Point np; 63 | np.x = this->x - p.x; 64 | np.y = this->y - p.y; 65 | return np; 66 | } 67 | 68 | Point operator+(const Point &p) { 69 | Point np; 70 | np.x = this->x + p.x; 71 | np.y = this->y + p.y; 72 | return np; 73 | } 74 | 75 | Point operator*(float s) { 76 | Point np; 77 | np.x = s * this->x; 78 | np.y = s * this->y; 79 | return np; 80 | } 81 | }; 82 | 83 | 84 | bool is_point_in_tri(Point p, Point p0, Point p1, Point p2); 85 | 86 | void get_point_weight(float *weight, Point p, Point p0, Point p1, Point p2); 87 | 88 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int ntri, bool norm_flg); 89 | 90 | void _get_ver_normal(float *ver_normal, float *tri_normal, int *triangles, int nver, int ntri); 91 | 92 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri); 93 | 94 | void _rasterize_triangles( 95 | float *vertices, int *triangles, float *depth_buffer, int *triangle_buffer, float *barycentric_weight, 96 | int ntri, int h, int w); 97 | 98 | void _rasterize( 99 | unsigned char *image, float *vertices, int *triangles, float *colors, 100 | float *depth_buffer, int ntri, int h, int w, int c, float alpha, bool reverse); 101 | 102 | void _render_texture_core( 103 | float *image, float *vertices, int *triangles, 104 | float *texture, float *tex_coords, int *tex_triangles, 105 | float *depth_buffer, 106 | int nver, int tex_nver, int ntri, 107 | int h, int w, int c, 108 | int tex_h, int tex_w, int tex_c, 109 | int mapping_type); 110 | 111 | void _write_obj_with_colors_texture(string filename, string mtl_name, 112 | float *vertices, int *triangles, float *colors, float *uv_coords, 113 | int nver, int ntri, int ntexver); 114 | 115 | #endif -------------------------------------------------------------------------------- /Sim3DR/lib/rasterize.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | # from libcpp.string cimport string 4 | cimport cython 5 | from libcpp cimport bool 6 | 7 | # from cpython import bool 8 | 9 | # use the Numpy-C-API from Cython 10 | np.import_array() 11 | 12 | # cdefine the signature of our c function 13 | cdef extern from "rasterize.h": 14 | void _rasterize_triangles( 15 | float*vertices, int*triangles, float*depth_buffer, int*triangle_buffer, float*barycentric_weight, 16 | int ntri, int h, int w 17 | ) 18 | 19 | void _rasterize( 20 | unsigned char*image, float*vertices, int*triangles, float*colors, float*depth_buffer, 21 | int ntri, int h, int w, int c, float alpha, bool reverse 22 | ) 23 | 24 | # void _render_texture_core( 25 | # float* image, float* vertices, int* triangles, 26 | # float* texture, float* tex_coords, int* tex_triangles, 27 | # float* depth_buffer, 28 | # int nver, int tex_nver, int ntri, 29 | # int h, int w, int c, 30 | # int tex_h, int tex_w, int tex_c, 31 | # int mapping_type) 32 | 33 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int nver, bool norm_flg) 34 | void _get_ver_normal(float *ver_normal, float*tri_normal, int*triangles, int nver, int ntri) 35 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri) 36 | 37 | 38 | # void _write_obj_with_colors_texture(string filename, string mtl_name, 39 | # float* vertices, int* triangles, float* colors, float* uv_coords, 40 | # int nver, int ntri, int ntexver) 41 | 42 | @cython.boundscheck(False) 43 | @cython.wraparound(False) 44 | def get_tri_normal(np.ndarray[float, ndim=2, mode="c"] tri_normal not None, 45 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 46 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 47 | int ntri, bool norm_flg = False): 48 | _get_tri_normal( np.PyArray_DATA(tri_normal), np.PyArray_DATA(vertices), 49 | np.PyArray_DATA(triangles), ntri, norm_flg) 50 | 51 | @cython.boundscheck(False) # turn off bounds-checking for entire function 52 | @cython.wraparound(False) # turn off negative index wrapping for entire function 53 | def get_ver_normal(np.ndarray[float, ndim=2, mode = "c"] ver_normal not None, 54 | np.ndarray[float, ndim=2, mode = "c"] tri_normal not None, 55 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 56 | int nver, int ntri): 57 | _get_ver_normal( 58 | np.PyArray_DATA(ver_normal), np.PyArray_DATA(tri_normal), np.PyArray_DATA(triangles), 59 | nver, ntri) 60 | 61 | @cython.boundscheck(False) # turn off bounds-checking for entire function 62 | @cython.wraparound(False) # turn off negative index wrapping for entire function 63 | def get_normal(np.ndarray[float, ndim=2, mode = "c"] ver_normal not None, 64 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 65 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 66 | int nver, int ntri): 67 | _get_normal( 68 | np.PyArray_DATA(ver_normal), np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), 69 | nver, ntri) 70 | 71 | 72 | @cython.boundscheck(False) # turn off bounds-checking for entire function 73 | @cython.wraparound(False) # turn off negative index wrapping for entire function 74 | def rasterize_triangles( 75 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 76 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 77 | np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None, 78 | np.ndarray[int, ndim=2, mode = "c"] triangle_buffer not None, 79 | np.ndarray[float, ndim=2, mode = "c"] barycentric_weight not None, 80 | int ntri, int h, int w 81 | ): 82 | _rasterize_triangles( 83 | np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), 84 | np.PyArray_DATA(depth_buffer), np.PyArray_DATA(triangle_buffer), 85 | np.PyArray_DATA(barycentric_weight), 86 | ntri, h, w) 87 | 88 | @cython.boundscheck(False) # turn off bounds-checking for entire function 89 | @cython.wraparound(False) # turn off negative index wrapping for entire function 90 | def rasterize(np.ndarray[unsigned char, ndim=3, mode = "c"] image not None, 91 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 92 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 93 | np.ndarray[float, ndim=2, mode = "c"] colors not None, 94 | np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None, 95 | int ntri, int h, int w, int c, float alpha = 1, bool reverse = False 96 | ): 97 | _rasterize( 98 | np.PyArray_DATA(image), np.PyArray_DATA(vertices), 99 | np.PyArray_DATA(triangles), 100 | np.PyArray_DATA(colors), 101 | np.PyArray_DATA(depth_buffer), 102 | ntri, h, w, c, alpha, reverse) 103 | 104 | # def render_texture_core(np.ndarray[float, ndim=3, mode = "c"] image not None, 105 | # np.ndarray[float, ndim=2, mode = "c"] vertices not None, 106 | # np.ndarray[int, ndim=2, mode="c"] triangles not None, 107 | # np.ndarray[float, ndim=3, mode = "c"] texture not None, 108 | # np.ndarray[float, ndim=2, mode = "c"] tex_coords not None, 109 | # np.ndarray[int, ndim=2, mode="c"] tex_triangles not None, 110 | # np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None, 111 | # int nver, int tex_nver, int ntri, 112 | # int h, int w, int c, 113 | # int tex_h, int tex_w, int tex_c, 114 | # int mapping_type 115 | # ): 116 | # _render_texture_core( 117 | # np.PyArray_DATA(image), np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), 118 | # np.PyArray_DATA(texture), np.PyArray_DATA(tex_coords), np.PyArray_DATA(tex_triangles), 119 | # np.PyArray_DATA(depth_buffer), 120 | # nver, tex_nver, ntri, 121 | # h, w, c, 122 | # tex_h, tex_w, tex_c, 123 | # mapping_type) 124 | # 125 | # def write_obj_with_colors_texture_core(string filename, string mtl_name, 126 | # np.ndarray[float, ndim=2, mode = "c"] vertices not None, 127 | # np.ndarray[int, ndim=2, mode="c"] triangles not None, 128 | # np.ndarray[float, ndim=2, mode = "c"] colors not None, 129 | # np.ndarray[float, ndim=2, mode = "c"] uv_coords not None, 130 | # int nver, int ntri, int ntexver 131 | # ): 132 | # _write_obj_with_colors_texture(filename, mtl_name, 133 | # np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), np.PyArray_DATA(colors), np.PyArray_DATA(uv_coords), 134 | # nver, ntri, ntexver) 135 | -------------------------------------------------------------------------------- /Sim3DR/lighting.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | from .Sim3DR import get_normal, rasterize 5 | 6 | _norm = lambda arr: arr / np.sqrt(np.sum(arr ** 2, axis=1))[:, None] 7 | 8 | 9 | def norm_vertices(vertices): 10 | vertices -= vertices.min(0)[None, :] 11 | vertices /= vertices.max() 12 | vertices *= 2 13 | vertices -= vertices.max(0)[None, :] / 2 14 | return vertices 15 | 16 | 17 | def convert_type(obj): 18 | if isinstance(obj, tuple) or isinstance(obj, list): 19 | return np.array(obj, dtype=np.float32)[None, :] 20 | return obj 21 | 22 | 23 | class RenderPipeline(object): 24 | def __init__(self, **kwargs): 25 | self.intensity_ambient = convert_type(kwargs.get('intensity_ambient', 0.3)) 26 | self.intensity_directional = convert_type(kwargs.get('intensity_directional', 0.6)) 27 | self.intensity_specular = convert_type(kwargs.get('intensity_specular', 0.1)) 28 | self.specular_exp = kwargs.get('specular_exp', 5) 29 | self.color_ambient = convert_type(kwargs.get('color_ambient', (1, 1, 1))) 30 | self.color_directional = convert_type(kwargs.get('color_directional', (1, 1, 1))) 31 | self.light_pos = convert_type(kwargs.get('light_pos', (0, 0, 5))) 32 | self.view_pos = convert_type(kwargs.get('view_pos', (0, 0, 5))) 33 | 34 | def update_light_pos(self, light_pos): 35 | self.light_pos = convert_type(light_pos) 36 | 37 | def __call__(self, vertices, triangles, bg, texture=None): 38 | normal = get_normal(vertices, triangles) 39 | 40 | # 2. lighting 41 | light = np.zeros_like(vertices, dtype=np.float32) 42 | # ambient component 43 | if self.intensity_ambient > 0: 44 | light += self.intensity_ambient * self.color_ambient 45 | 46 | vertices_n = norm_vertices(vertices.copy()) 47 | if self.intensity_directional > 0: 48 | # diffuse component 49 | direction = _norm(self.light_pos - vertices_n) 50 | cos = np.sum(normal * direction, axis=1)[:, None] 51 | # cos = np.clip(cos, 0, 1) 52 | # todo: check below 53 | light += self.intensity_directional * (self.color_directional * np.clip(cos, 0, 1)) 54 | 55 | # specular component 56 | if self.intensity_specular > 0: 57 | v2v = _norm(self.view_pos - vertices_n) 58 | reflection = 2 * cos * normal - direction 59 | spe = np.sum((v2v * reflection) ** self.specular_exp, axis=1)[:, None] 60 | spe = np.where(cos != 0, np.clip(spe, 0, 1), np.zeros_like(spe)) 61 | light += self.intensity_specular * self.color_directional * np.clip(spe, 0, 1) 62 | light = np.clip(light, 0, 1) 63 | 64 | # 2. rasterization, [0, 1] 65 | if texture is None: 66 | render_img = rasterize(vertices, triangles, light, bg=bg) 67 | return render_img 68 | else: 69 | texture *= light 70 | render_img = rasterize(vertices, triangles, texture, bg=bg) 71 | return render_img 72 | 73 | 74 | def main(): 75 | pass 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /Sim3DR/readme.md: -------------------------------------------------------------------------------- 1 | ## Forked from https://github.com/cleardusk/3DDFA_V2/tree/master/Sim3DR 2 | 3 | ## Sim3DR 4 | This is a simple 3D render, written by c++ and cython. 5 | 6 | ### Build Sim3DR 7 | 8 | ```shell script 9 | python3 setup.py build_ext --inplace 10 | ``` 11 | 12 | -------------------------------------------------------------------------------- /Sim3DR/setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python setup.py build_ext -i 3 | to compile 4 | ''' 5 | 6 | from distutils.core import setup, Extension 7 | from Cython.Build import cythonize 8 | from Cython.Distutils import build_ext 9 | import numpy 10 | 11 | setup( 12 | name='Sim3DR_Cython', # not the package name 13 | cmdclass={'build_ext': build_ext}, 14 | ext_modules=[Extension("Sim3DR_Cython", 15 | sources=["lib/rasterize.pyx", "lib/rasterize_kernel.cpp"], 16 | language='c++', 17 | include_dirs=[numpy.get_include()], 18 | extra_compile_args=["-std=c++11"])], 19 | ) 20 | -------------------------------------------------------------------------------- /Sim3DR/tests/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /Sim3DR/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | 3 | set(TARGET test) 4 | project(${TARGET}) 5 | 6 | #find_package( OpenCV REQUIRED ) 7 | #include_directories( ${OpenCV_INCLUDE_DIRS} ) 8 | 9 | #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -O3") 10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -std=c++11") 11 | add_executable(${TARGET} test.cpp rasterize_kernel.cpp io.cpp) 12 | target_include_directories(${TARGET} PRIVATE ${PROJECT_SOURCE_DIR}) 13 | -------------------------------------------------------------------------------- /Sim3DR/tests/io.cpp: -------------------------------------------------------------------------------- 1 | #include "io.h" 2 | 3 | //void load_obj(const string obj_fp, float* vertices, float* colors, float* triangles){ 4 | // string line; 5 | // ifstream in(obj_fp); 6 | // 7 | // if(in.is_open()){ 8 | // while (getline(in, line)){ 9 | // stringstream ss(line); 10 | // 11 | // char t; // type: v, f 12 | // ss >> t; 13 | // if (t == 'v'){ 14 | // 15 | // } 16 | // } 17 | // } 18 | //} 19 | 20 | void load_obj(const char *obj_fp, float *vertices, float *colors, int *triangles, int nver, int ntri) { 21 | FILE *fp; 22 | fp = fopen(obj_fp, "r"); 23 | 24 | char t; // type: v or f 25 | if (fp != nullptr) { 26 | for (int i = 0; i < nver; ++i) { 27 | fscanf(fp, "%c", &t); 28 | for (int j = 0; j < 3; ++j) 29 | fscanf(fp, " %f", &vertices[3 * i + j]); 30 | for (int j = 0; j < 3; ++j) 31 | fscanf(fp, " %f", &colors[3 * i + j]); 32 | fscanf(fp, "\n"); 33 | } 34 | // fscanf(fp, "%c", &t); 35 | for (int i = 0; i < ntri; ++i) { 36 | fscanf(fp, "%c", &t); 37 | for (int j = 0; j < 3; ++j) { 38 | fscanf(fp, " %d", &triangles[3 * i + j]); 39 | triangles[3 * i + j] -= 1; 40 | } 41 | fscanf(fp, "\n"); 42 | } 43 | 44 | fclose(fp); 45 | } 46 | } 47 | 48 | void load_ply(const char *ply_fp, float *vertices, int *triangles, int nver, int ntri) { 49 | FILE *fp; 50 | fp = fopen(ply_fp, "r"); 51 | 52 | // char s[256]; 53 | char t; 54 | if (fp != nullptr) { 55 | // for (int i = 0; i < 9; ++i) 56 | // fscanf(fp, "%s", s); 57 | for (int i = 0; i < nver; ++i) 58 | fscanf(fp, "%f %f %f\n", &vertices[3 * i], &vertices[3 * i + 1], &vertices[3 * i + 2]); 59 | 60 | for (int i = 0; i < ntri; ++i) 61 | fscanf(fp, "%c %d %d %d\n", &t, &triangles[3 * i], &triangles[3 * i + 1], &triangles[3 * i + 2]); 62 | 63 | fclose(fp); 64 | } 65 | } 66 | 67 | void write_ppm(const char *filename, unsigned char *img, int h, int w, int c) { 68 | FILE *fp; 69 | //open file for output 70 | fp = fopen(filename, "wb"); 71 | if (!fp) { 72 | fprintf(stderr, "Unable to open file '%s'\n", filename); 73 | exit(1); 74 | } 75 | 76 | //write the header file 77 | //image format 78 | fprintf(fp, "P6\n"); 79 | 80 | //image size 81 | fprintf(fp, "%d %d\n", w, h); 82 | 83 | // rgb component depth 84 | fprintf(fp, "%d\n", MAX_PXL_VALUE); 85 | 86 | // pixel data 87 | fwrite(img, sizeof(unsigned char), size_t(h * w * c), fp); 88 | fclose(fp); 89 | } -------------------------------------------------------------------------------- /Sim3DR/tests/io.h: -------------------------------------------------------------------------------- 1 | #ifndef IO_H_ 2 | #define IO_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | using namespace std; 11 | 12 | #define MAX_PXL_VALUE 255 13 | 14 | void load_obj(const char* obj_fp, float* vertices, float* colors, int* triangles, int nver, int ntri); 15 | void load_ply(const char* ply_fp, float* vertices, int* triangles, int nver, int ntri); 16 | 17 | 18 | void write_ppm(const char *filename, unsigned char *img, int h, int w, int c); 19 | 20 | #endif -------------------------------------------------------------------------------- /Sim3DR/tests/test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Tesing cases 3 | */ 4 | 5 | #include 6 | #include 7 | #include "rasterize.h" 8 | #include "io.h" 9 | 10 | void test_isPointInTri() { 11 | Point p0(0, 0); 12 | Point p1(1, 0); 13 | Point p2(1, 1); 14 | 15 | Point p(0.2, 0.2); 16 | 17 | if (is_point_in_tri(p, p0, p1, p2)) 18 | std::cout << "In"; 19 | else 20 | std::cout << "Out"; 21 | std::cout << std::endl; 22 | } 23 | 24 | void test_getPointWeight() { 25 | Point p0(0, 0); 26 | Point p1(1, 0); 27 | Point p2(1, 1); 28 | 29 | Point p(0.2, 0.2); 30 | 31 | float weight[3]; 32 | get_point_weight(weight, p, p0, p1, p2); 33 | std::cout << weight[0] << " " << weight[1] << " " << weight[2] << std::endl; 34 | } 35 | 36 | void test_get_tri_normal() { 37 | float tri_normal[3]; 38 | // float vertices[9] = {1, 0, 0, 0, 0, 0, 0, 1, 0}; 39 | float vertices[9] = {1, 1.1, 0, 0, 0, 0, 0, 0.6, 0.7}; 40 | int triangles[3] = {0, 1, 2}; 41 | int ntri = 1; 42 | 43 | _get_tri_normal(tri_normal, vertices, triangles, ntri); 44 | 45 | for (int i = 0; i < 3; ++i) 46 | std::cout << tri_normal[i] << ", "; 47 | std::cout << std::endl; 48 | } 49 | 50 | void test_load_obj() { 51 | const char *fp = "../data/vd005_mesh.obj"; 52 | int nver = 35709; 53 | int ntri = 70789; 54 | 55 | auto *vertices = new float[nver]; 56 | auto *colors = new float[nver]; 57 | auto *triangles = new int[ntri]; 58 | load_obj(fp, vertices, colors, triangles, nver, ntri); 59 | 60 | delete[] vertices; 61 | delete[] colors; 62 | delete[] triangles; 63 | } 64 | 65 | void test_render() { 66 | // 1. loading obj 67 | // const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/vd005_mesh.obj"; 68 | const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/face1.obj"; 69 | int nver = 35709; //53215; //35709; 70 | int ntri = 70789; //105840;//70789; 71 | 72 | auto *vertices = new float[3 * nver]; 73 | auto *colors = new float[3 * nver]; 74 | auto *triangles = new int[3 * ntri]; 75 | load_obj(fp, vertices, colors, triangles, nver, ntri); 76 | 77 | // 2. rendering 78 | int h = 224, w = 224, c = 3; 79 | 80 | // enlarging 81 | int scale = 4; 82 | h *= scale; 83 | w *= scale; 84 | for (int i = 0; i < nver * 3; ++i) vertices[i] *= scale; 85 | 86 | auto *image = new unsigned char[h * w * c](); 87 | auto *depth_buffer = new float[h * w](); 88 | 89 | for (int i = 0; i < h * w; ++i) depth_buffer[i] = -999999; 90 | 91 | clock_t t; 92 | t = clock(); 93 | 94 | _rasterize(image, vertices, triangles, colors, depth_buffer, ntri, h, w, c, true); 95 | t = clock() - t; 96 | double time_taken = ((double) t) / CLOCKS_PER_SEC; // in seconds 97 | printf("Render took %f seconds to execute \n", time_taken); 98 | 99 | 100 | // auto *image_char = new u_char[h * w * c](); 101 | // for (int i = 0; i < h * w * c; ++i) 102 | // image_char[i] = u_char(255 * image[i]); 103 | write_ppm("res.ppm", image, h, w, c); 104 | 105 | // delete[] image_char; 106 | delete[] vertices; 107 | delete[] colors; 108 | delete[] triangles; 109 | delete[] image; 110 | delete[] depth_buffer; 111 | } 112 | 113 | void test_light() { 114 | // 1. loading obj 115 | const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/emma_input_0_noheader.ply"; 116 | int nver = 53215; //35709; 117 | int ntri = 105840; //70789; 118 | 119 | auto *vertices = new float[3 * nver]; 120 | auto *colors = new float[3 * nver]; 121 | auto *triangles = new int[3 * ntri]; 122 | load_ply(fp, vertices, triangles, nver, ntri); 123 | 124 | // 2. rendering 125 | // int h = 1901, w = 3913, c = 3; 126 | int h = 2000, w = 4000, c = 3; 127 | 128 | // enlarging 129 | // int scale = 1; 130 | // h *= scale; 131 | // w *= scale; 132 | // for (int i = 0; i < nver * 3; ++i) vertices[i] *= scale; 133 | 134 | auto *image = new unsigned char[h * w * c](); 135 | auto *depth_buffer = new float[h * w](); 136 | 137 | for (int i = 0; i < h * w; ++i) depth_buffer[i] = -999999; 138 | for (int i = 0; i < 3 * nver; ++i) colors[i] = 0.8; 139 | 140 | clock_t t; 141 | t = clock(); 142 | 143 | _rasterize(image, vertices, triangles, colors, depth_buffer, ntri, h, w, c, true); 144 | t = clock() - t; 145 | double time_taken = ((double) t) / CLOCKS_PER_SEC; // in seconds 146 | printf("Render took %f seconds to execute \n", time_taken); 147 | 148 | 149 | // auto *image_char = new u_char[h * w * c](); 150 | // for (int i = 0; i < h * w * c; ++i) 151 | // image_char[i] = u_char(255 * image[i]); 152 | write_ppm("emma.ppm", image, h, w, c); 153 | 154 | // delete[] image_char; 155 | delete[] vertices; 156 | delete[] colors; 157 | delete[] triangles; 158 | delete[] image; 159 | delete[] depth_buffer; 160 | } 161 | 162 | int main(int argc, char *argv[]) { 163 | // std::cout << "Hello CMake!" << std::endl; 164 | 165 | // test_isPointInTri(); 166 | // test_getPointWeight(); 167 | // test_get_tri_normal(); 168 | // test_load_obj(); 169 | // test_render(); 170 | test_light(); 171 | return 0; 172 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from easydict import EasyDict 6 | from torch.nn import MSELoss 7 | 8 | 9 | class Config(EasyDict): 10 | def __init__(self, args): 11 | # workspace configuration 12 | self.prefix = args.prefix 13 | self.work_path = os.path.join(args.workspace, self.prefix) 14 | self.model_path = os.path.join(self.work_path, "models") 15 | try: 16 | self.create_path(self.model_path) 17 | except Exception as e: 18 | print(e) 19 | 20 | self.log_path = os.path.join(self.work_path, "log") 21 | try: 22 | self.create_path(self.log_path) 23 | except Exception as e: 24 | print(e) 25 | 26 | self.frequency_log = 20 27 | 28 | # training/validation configuration 29 | self.train_source = args.train_source 30 | self.val_source = args.val_source 31 | 32 | # network and training parameters 33 | self.pose_loss = MSELoss(reduction="sum") 34 | self.pose_mean = np.load(args.pose_mean) 35 | self.pose_stddev = np.load(args.pose_stddev) 36 | self.depth = args.depth 37 | self.lr = args.lr 38 | self.lr_plateau = args.lr_plateau 39 | self.early_stop = args.early_stop 40 | self.batch_size = args.batch_size 41 | self.workers = args.workers 42 | self.epochs = args.epochs 43 | self.min_size = args.min_size 44 | self.max_size = args.max_size 45 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | self.weight_decay = 5e-4 47 | self.momentum = 0.9 48 | self.pin_memory = True 49 | 50 | # resume from or load pretrained weights 51 | self.pretrained_path = args.pretrained_path 52 | self.resume_path = args.resume_path 53 | 54 | # online augmentation 55 | self.noise_augmentation = args.noise_augmentation 56 | self.contrast_augmentation = args.contrast_augmentation 57 | self.random_flip = args.random_flip 58 | self.random_crop = args.random_crop 59 | 60 | # 3d reference points to compute pose 61 | self.threed_5_points = args.threed_5_points 62 | self.threed_68_points = args.threed_68_points 63 | 64 | # distributed 65 | self.distributed = args.distributed 66 | if not args.distributed: 67 | self.gpu = 0 68 | else: 69 | self.gpu = args.gpu 70 | 71 | self.num_gpus = args.world_size 72 | 73 | def create_path(self, file_path): 74 | if not os.path.exists(file_path): 75 | os.makedirs(file_path) 76 | -------------------------------------------------------------------------------- /convert_json_list_to_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import lmdb 5 | import msgpack 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from utils.json_loader import JsonLoader 10 | 11 | 12 | def json_list_to_lmdb(args): 13 | cpu_available = os.cpu_count() 14 | if args.num_workers > cpu_available: 15 | args.num_workers = cpu_available 16 | 17 | threed_5_points = np.load(args.threed_5_points) 18 | threed_68_points = np.load(args.threed_68_points) 19 | 20 | print("Loading dataset from %s" % args.json_list) 21 | data_loader = JsonLoader( 22 | args.num_workers, 23 | args.json_list, 24 | threed_5_points, 25 | threed_68_points, 26 | args.dataset_path, 27 | ) 28 | 29 | name = f"{os.path.split(args.json_list)[1][:-4]}.lmdb" 30 | lmdb_path = os.path.join(args.dest, name) 31 | isdir = os.path.isdir(lmdb_path) 32 | 33 | if os.path.isfile(lmdb_path): 34 | os.remove(lmdb_path) 35 | 36 | print(f"Generate LMDB to {lmdb_path}") 37 | 38 | size = len(data_loader) * 1200 * 1200 * 3 39 | print(f"LMDB max size: {size}") 40 | 41 | db = lmdb.open( 42 | lmdb_path, 43 | subdir=isdir, 44 | map_size=size * 2, 45 | readonly=False, 46 | meminit=False, 47 | map_async=True, 48 | ) 49 | 50 | print(f"Total number of samples: {len(data_loader)}") 51 | 52 | all_pose_labels = [] 53 | 54 | txn = db.begin(write=True) 55 | 56 | total_samples = 0 57 | 58 | for idx, data in tqdm(enumerate(data_loader)): 59 | image, global_pose_labels, bboxes, pose_labels, landmarks = data[0] 60 | 61 | if len(bboxes) == 0: 62 | continue 63 | 64 | has_pose = False 65 | for pose_label in pose_labels: 66 | if pose_label[0] != -9: 67 | all_pose_labels.append(pose_label) 68 | has_pose = True 69 | 70 | if not has_pose: 71 | continue 72 | 73 | txn.put( 74 | "{}".format(total_samples).encode("ascii"), 75 | msgpack.dumps((image, global_pose_labels, bboxes, pose_labels, landmarks)), 76 | ) 77 | if idx % args.write_frequency == 0: 78 | print(f"[{idx}/{len(data_loader)}]") 79 | txn.commit() 80 | txn = db.begin(write=True) 81 | 82 | total_samples += 1 83 | 84 | print(total_samples) 85 | 86 | txn.commit() 87 | keys = ["{}".format(k).encode("ascii") for k in range(total_samples)] 88 | with db.begin(write=True) as txn: 89 | txn.put(b"__keys__", msgpack.dumps(keys)) 90 | txn.put(b"__len__", msgpack.dumps(len(keys))) 91 | 92 | print("Flushing database ...") 93 | db.sync() 94 | db.close() 95 | 96 | if args.train: 97 | print("Saving pose mean and std dev.") 98 | all_pose_labels = np.asarray(all_pose_labels) 99 | pose_mean = np.mean(all_pose_labels, axis=0) 100 | pose_stddev = np.std(all_pose_labels, axis=0) 101 | 102 | save_file_path = os.path.join(args.dest, os.path.split(args.json_list)[1][:-4]) 103 | np.save(f"{save_file_path}_pose_mean.npy", pose_mean) 104 | np.save(f"{save_file_path}_pose_stddev.npy", pose_stddev) 105 | 106 | 107 | def parse_args(): 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument( 110 | "--json_list", 111 | type=str, 112 | required=True, 113 | help="List of json files that contain frames annotations", 114 | ) 115 | parser.add_argument( 116 | "--dataset_path", 117 | type=str, 118 | help="Path to the dataset images", 119 | ) 120 | parser.add_argument("--num_workers", default=16, type=int) 121 | parser.add_argument( 122 | "--write_frequency", help="Frequency to save to file.", type=int, default=5000 123 | ) 124 | parser.add_argument( 125 | "--dest", type=str, required=True, help="Path to save the lmdb file." 126 | ) 127 | parser.add_argument( 128 | "--train", action="store_true", help="Dataset will be used for training." 129 | ) 130 | parser.add_argument( 131 | "--threed_5_points", 132 | type=str, 133 | help="Reference 3D points to compute pose.", 134 | default="./pose_references/reference_3d_5_points_trans.npy", 135 | ) 136 | 137 | parser.add_argument( 138 | "--threed_68_points", 139 | type=str, 140 | help="Reference 3D points to compute pose.", 141 | default="./pose_references/reference_3d_68_points_trans.npy", 142 | ) 143 | 144 | args = parser.parse_args() 145 | 146 | if not os.path.exists(args.dest): 147 | os.makedirs(args.dest) 148 | 149 | return args 150 | 151 | 152 | if __name__ == "__main__": 153 | args = parse_args() 154 | 155 | json_list_to_lmdb(args) 156 | -------------------------------------------------------------------------------- /data_loader_lmdb.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import lmdb 4 | import msgpack 5 | import numpy as np 6 | import six 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import BatchSampler, DataLoader, Dataset 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torchvision import transforms 12 | 13 | import utils.augmentation as augmentation 14 | from utils.image_operations import expand_bbox_rectangle 15 | from utils.pose_operations import plot_3d_landmark, pose_bbox_to_full_image 16 | 17 | 18 | class LMDB(Dataset): 19 | def __init__( 20 | self, 21 | config, 22 | db_path, 23 | transform=None, 24 | pose_label_transform=None, 25 | augmentation_methods=None, 26 | ): 27 | self.config = config 28 | self.env = lmdb.open( 29 | db_path, 30 | subdir=path.isdir(db_path), 31 | readonly=True, 32 | lock=False, 33 | readahead=False, 34 | meminit=False, 35 | ) 36 | 37 | with self.env.begin(write=False) as txn: 38 | self.length = msgpack.loads(txn.get(b"__len__")) 39 | self.keys = msgpack.loads(txn.get(b"__keys__")) 40 | 41 | self.transform = transform 42 | self.pose_label_transform = pose_label_transform 43 | self.augmentation_methods = augmentation_methods 44 | self.threed_68_points = np.load(self.config.threed_68_points) 45 | 46 | def __getitem__(self, index): 47 | img, target = None, None 48 | env = self.env 49 | with env.begin(write=False) as txn: 50 | byteflow = txn.get(self.keys[index]) 51 | data = msgpack.loads(byteflow) 52 | 53 | # load image 54 | imgbuf = data[0] 55 | buf = six.BytesIO() 56 | buf.write(imgbuf) 57 | buf.seek(0) 58 | img = Image.open(buf).convert("RGB") 59 | 60 | # load local pose label 61 | pose_labels = np.asarray(data[3]) 62 | 63 | # load bbox 64 | bbox_labels = np.asarray(data[2]) 65 | 66 | # load landmarks label 67 | landmark_labels = data[4] 68 | 69 | # apply augmentations that are provided from the parent class 70 | for augmentation_method in self.augmentation_methods: 71 | img, _, _ = augmentation_method(img, None, None) 72 | 73 | # create global intrinsics 74 | (w, h) = img.size 75 | global_intrinsics = np.array( 76 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 77 | ) 78 | 79 | img = np.array(img) 80 | projected_bbox_labels = [] 81 | new_pose_labels = [] 82 | # get projected bboxes 83 | for i in range(len(pose_labels)): 84 | pose_label = pose_labels[i] 85 | bbox = bbox_labels[i] 86 | 87 | lms = np.asarray(landmark_labels[i]) 88 | 89 | # black out faces that do not have pose annotation 90 | if -1 in lms: 91 | img[int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2]), :] = 0 92 | continue 93 | 94 | # convert to global image 95 | pose_label = pose_bbox_to_full_image(pose_label, global_intrinsics, bbox) 96 | 97 | # project points and get bbox 98 | projected_lms, _ = plot_3d_landmark( 99 | self.threed_68_points, pose_label, global_intrinsics 100 | ) 101 | projected_bbox = expand_bbox_rectangle( 102 | w, h, 1.1, 1.1, projected_lms, roll=pose_label[2] 103 | ) 104 | 105 | projected_bbox_labels.append(projected_bbox) 106 | 107 | new_pose_labels.append(pose_label) 108 | 109 | img = Image.fromarray(img) 110 | 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | 114 | target = { 115 | "dofs": torch.from_numpy(np.asarray(new_pose_labels)).float(), 116 | "boxes": torch.from_numpy(np.asarray(projected_bbox_labels)).float(), 117 | "labels": torch.ones((len(projected_bbox_labels),), dtype=torch.int64), 118 | } 119 | 120 | return img, target 121 | 122 | def __len__(self): 123 | return self.length 124 | 125 | 126 | class LMDBDataLoader(DataLoader): 127 | def __init__(self, config, lmdb_path, train=True): 128 | self.config = config 129 | 130 | transform = transforms.Compose([transforms.ToTensor()]) 131 | 132 | augmentation_methods = [] 133 | 134 | if train: 135 | if self.config.noise_augmentation: 136 | augmentation_methods.append(augmentation.add_noise) 137 | 138 | if self.config.contrast_augmentation: 139 | augmentation_methods.append(augmentation.change_contrast) 140 | 141 | if self.config.pose_mean is not None: 142 | pose_label_transform = self.normalize_pose_labels 143 | else: 144 | pose_label_transform = None 145 | 146 | self._dataset = LMDB( 147 | config, lmdb_path, transform, pose_label_transform, augmentation_methods 148 | ) 149 | 150 | if config.distributed: 151 | self._sampler = DistributedSampler(self._dataset, shuffle=False) 152 | 153 | if train: 154 | self._sampler = BatchSampler( 155 | self._sampler, config.batch_size, drop_last=True 156 | ) 157 | 158 | super(LMDBDataLoader, self).__init__( 159 | self._dataset, 160 | batch_sampler=self._sampler, 161 | pin_memory=config.pin_memory, 162 | num_workers=config.workers, 163 | collate_fn=collate_fn, 164 | ) 165 | else: 166 | super(LMDBDataLoader, self).__init__( 167 | self._dataset, 168 | config.batch_size, 169 | drop_last=False, 170 | sampler=self._sampler, 171 | pin_memory=config.pin_memory, 172 | num_workers=config.workers, 173 | collate_fn=collate_fn, 174 | ) 175 | 176 | else: 177 | super(LMDBDataLoader, self).__init__( 178 | self._dataset, 179 | batch_size=config.batch_size, 180 | shuffle=train, 181 | pin_memory=config.pin_memory, 182 | num_workers=config.workers, 183 | drop_last=True, 184 | collate_fn=collate_fn, 185 | ) 186 | 187 | def normalize_pose_labels(self, pose_labels): 188 | for i in range(len(pose_labels)): 189 | pose_labels[i] = ( 190 | pose_labels[i] - self.config.pose_mean 191 | ) / self.config.pose_stddev 192 | 193 | return pose_labels 194 | 195 | 196 | def collate_fn(batch): 197 | return tuple(zip(*batch)) 198 | -------------------------------------------------------------------------------- /data_loader_lmdb_augmenter.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import lmdb 4 | import msgpack 5 | import numpy as np 6 | import six 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import BatchSampler, DataLoader, Dataset 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torchvision import transforms 12 | 13 | import utils.augmentation as augmentation 14 | from utils.image_operations import expand_bbox_rectangle 15 | from utils.pose_operations import (get_pose, plot_3d_landmark, 16 | pose_bbox_to_full_image) 17 | 18 | 19 | class LMDB(Dataset): 20 | def __init__( 21 | self, 22 | config, 23 | db_path, 24 | transform=None, 25 | pose_label_transform=None, 26 | augmentation_methods=None, 27 | ): 28 | self.config = config 29 | 30 | self.env = lmdb.open( 31 | db_path, 32 | subdir=path.isdir(db_path), 33 | readonly=True, 34 | lock=False, 35 | readahead=False, 36 | meminit=False, 37 | ) 38 | 39 | with self.env.begin(write=False) as txn: 40 | self.length = msgpack.loads(txn.get(b"__len__")) 41 | self.keys = msgpack.loads(txn.get(b"__keys__")) 42 | 43 | self.transform = transform 44 | self.pose_label_transform = pose_label_transform 45 | self.augmentation_methods = augmentation_methods 46 | 47 | self.threed_5_points = np.load(self.config.threed_5_points) 48 | self.threed_68_points = np.load(self.config.threed_68_points) 49 | 50 | def __getitem__(self, index): 51 | img, target = None, None 52 | env = self.env 53 | with env.begin(write=False) as txn: 54 | byteflow = txn.get(self.keys[index]) 55 | data = msgpack.loads(byteflow) 56 | 57 | # load image 58 | imgbuf = data[0] 59 | buf = six.BytesIO() 60 | buf.write(imgbuf) 61 | buf.seek(0) 62 | img = Image.open(buf).convert("RGB") 63 | 64 | # load landmarks label 65 | landmark_labels = data[4] 66 | 67 | # load bbox 68 | bbox_labels = np.asarray(data[2]) 69 | 70 | # apply augmentations that are provided from the parent class in creation order 71 | for augmentation_method in self.augmentation_methods: 72 | img, bbox_labels, landmark_labels = augmentation_method( 73 | img, bbox_labels, landmark_labels 74 | ) 75 | 76 | # create global intrinsics 77 | (img_w, img_h) = img.size 78 | global_intrinsics = np.array( 79 | [[img_w + img_h, 0, img_w // 2], [0, img_w + img_h, img_h // 2], [0, 0, 1]] 80 | ) 81 | 82 | projected_bbox_labels = [] 83 | pose_labels = [] 84 | 85 | img = np.array(img) 86 | 87 | # get pose labels 88 | for i in range(len(bbox_labels)): 89 | bbox = bbox_labels[i] 90 | lms = np.asarray(landmark_labels[i]) 91 | 92 | # black out faces that do not have pose annotation 93 | if -1 in lms: 94 | img[int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2]), :] = 0 95 | continue 96 | 97 | # convert landmarks to bbox 98 | bbox_lms = lms.copy() 99 | bbox_lms[:, 0] -= bbox[0] 100 | bbox_lms[:, 1] -= bbox[1] 101 | 102 | # create bbox intrinsincs 103 | w = int(bbox[2] - bbox[0]) 104 | h = int(bbox[3] - bbox[1]) 105 | 106 | bbox_intrinsics = np.array( 107 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 108 | ) 109 | 110 | # get pose between gt points and 3D reference 111 | if len(bbox_lms) == 5: 112 | P, pose = get_pose(self.threed_5_points, bbox_lms, bbox_intrinsics) 113 | else: 114 | P, pose = get_pose(self.threed_68_points, bbox_lms, bbox_intrinsics) 115 | 116 | # convert to global image 117 | pose_label = pose_bbox_to_full_image(pose, global_intrinsics, bbox) 118 | 119 | # project points and get bbox 120 | projected_lms, _ = plot_3d_landmark( 121 | self.threed_68_points, pose_label, global_intrinsics 122 | ) 123 | projected_bbox = expand_bbox_rectangle( 124 | img_w, img_h, 1.1, 1.1, projected_lms, roll=pose_label[2] 125 | ) 126 | 127 | pose_labels.append(pose_label) 128 | projected_bbox_labels.append(projected_bbox) 129 | 130 | pose_labels = np.asarray(pose_labels) 131 | 132 | img = Image.fromarray(img) 133 | 134 | if self.transform is not None: 135 | img = self.transform(img) 136 | 137 | target = { 138 | "dofs": torch.from_numpy(pose_labels).float(), 139 | "boxes": torch.from_numpy(np.asarray(projected_bbox_labels)).float(), 140 | "labels": torch.ones((len(bbox_labels),), dtype=torch.int64), 141 | } 142 | 143 | return img, target 144 | 145 | def __len__(self): 146 | return self.length 147 | 148 | 149 | class LMDBDataLoaderAugmenter(DataLoader): 150 | def __init__(self, config, lmdb_path, train=True): 151 | self.config = config 152 | 153 | transform = transforms.Compose([transforms.ToTensor()]) 154 | 155 | augmentation_methods = [] 156 | 157 | if train: 158 | if self.config.random_flip: 159 | augmentation_methods.append(augmentation.random_flip) 160 | 161 | if self.config.random_crop: 162 | augmentation_methods.append(augmentation.random_crop) 163 | 164 | if self.config.noise_augmentation: 165 | augmentation_methods.append(augmentation.add_noise) 166 | 167 | if self.config.contrast_augmentation: 168 | augmentation_methods.append(augmentation.change_contrast) 169 | 170 | if self.config.pose_mean is not None: 171 | pose_label_transform = self.normalize_pose_labels 172 | else: 173 | pose_label_transform = None 174 | 175 | self._dataset = LMDB( 176 | config, 177 | lmdb_path, 178 | transform, 179 | pose_label_transform, 180 | augmentation_methods, 181 | ) 182 | 183 | if config.distributed: 184 | self._sampler = DistributedSampler(self._dataset, shuffle=False) 185 | 186 | if train: 187 | self._sampler = BatchSampler( 188 | self._sampler, config.batch_size, drop_last=True 189 | ) 190 | 191 | super(LMDBDataLoaderAugmenter, self).__init__( 192 | self._dataset, 193 | batch_sampler=self._sampler, 194 | pin_memory=config.pin_memory, 195 | num_workers=config.workers, 196 | collate_fn=collate_fn, 197 | ) 198 | else: 199 | super(LMDBDataLoaderAugmenter, self).__init__( 200 | self._dataset, 201 | config.batch_size, 202 | drop_last=False, 203 | sampler=self._sampler, 204 | pin_memory=config.pin_memory, 205 | num_workers=config.workers, 206 | collate_fn=collate_fn, 207 | ) 208 | 209 | else: 210 | super(LMDBDataLoaderAugmenter, self).__init__( 211 | self._dataset, 212 | batch_size=config.batch_size, 213 | shuffle=train, 214 | pin_memory=config.pin_memory, 215 | num_workers=config.workers, 216 | drop_last=True, 217 | collate_fn=collate_fn, 218 | ) 219 | 220 | def normalize_pose_labels(self, pose_labels): 221 | for i in range(len(pose_labels)): 222 | pose_labels[i] = ( 223 | pose_labels[i] - self.config.pose_mean 224 | ) / self.config.pose_stddev 225 | 226 | return pose_labels 227 | 228 | 229 | def collate_fn(batch): 230 | return tuple(zip(*batch)) 231 | -------------------------------------------------------------------------------- /early_stop.py: -------------------------------------------------------------------------------- 1 | class EarlyStop: 2 | def __init__(self, patience=5, mode="max", threshold=0): 3 | self.patience = patience 4 | self.counter = 0 5 | self.best_score = None 6 | self.stop = False 7 | self.mode = mode 8 | self.threshold = threshold 9 | self.val_score = float("Inf") 10 | if mode == "max": 11 | self.val_score *= -1 12 | 13 | def __call__(self, val_score): 14 | if self.best_score is None: 15 | self.best_score = val_score 16 | 17 | # if val score did not improve, add to early stop counter 18 | elif (val_score < self.best_score + self.threshold and self.mode == "max") or ( 19 | val_score > self.best_score + self.threshold and self.mode == "min" 20 | ): 21 | self.counter += 1 22 | print(f"Early stop counter: {self.counter} out of {self.patience}") 23 | 24 | # if not improve for patience times, stop training earlier 25 | if self.counter >= self.patience: 26 | self.stop = True 27 | else: 28 | self.best_score = val_score 29 | self.counter = 0 30 | -------------------------------------------------------------------------------- /evaluation/evaluate_wider.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | from PIL import Image, ImageOps 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | 11 | sys.path.append("./") 12 | from img2pose import img2poseModel 13 | from model_loader import load_model 14 | 15 | 16 | class WIDER_Eval: 17 | def __init__(self, args): 18 | self.threed_68_points = np.load(args.threed_68_points) 19 | self.nms_threshold = args.nms_threshold 20 | self.pose_mean = np.load(args.pose_mean) 21 | self.pose_stddev = np.load(args.pose_stddev) 22 | 23 | self.test_dataset = self.get_dataset(args) 24 | self.dataset_path = args.dataset_path 25 | self.img2pose_model = self.create_model(args) 26 | 27 | self.transform = transforms.Compose([transforms.ToTensor()]) 28 | self.min_size = args.min_size 29 | self.max_size = args.max_size 30 | self.flip = len(args.min_size) > 1 31 | self.output_path = args.output_path 32 | 33 | def create_model(self, args): 34 | img2pose_model = img2poseModel( 35 | args.depth, 36 | args.min_size[-1], 37 | args.max_size, 38 | pose_mean=self.pose_mean, 39 | pose_stddev=self.pose_stddev, 40 | threed_68_points=self.threed_68_points, 41 | ) 42 | load_model( 43 | img2pose_model.fpn_model, 44 | args.pretrained_path, 45 | cpu_mode=str(img2pose_model.device) == "cpu", 46 | model_only=True, 47 | ) 48 | img2pose_model.evaluate() 49 | 50 | return img2pose_model 51 | 52 | def get_dataset(self, args): 53 | annotations = open(args.dataset_list) 54 | lines = annotations.readlines() 55 | 56 | test_dataset = [] 57 | 58 | for i in range(len(lines)): 59 | lines[i] = str(lines[i].rstrip("\n")) 60 | if "--" in lines[i]: 61 | test_dataset.append(lines[i]) 62 | 63 | return test_dataset 64 | 65 | def bbox_voting(self, bboxes, iou_thresh=0.6): 66 | # bboxes: a numpy array of N*5 size representing N boxes; 67 | # for each box, it is represented as [x1, y1, x2, y2, s] 68 | # iou_thresh: group bounding boxes if their overlap is > threshold. 69 | 70 | order = bboxes[:, 4].ravel().argsort()[::-1] 71 | bboxes = bboxes[order, :] 72 | areas = (bboxes[:, 2] - bboxes[:, 0] + 1) * (bboxes[:, 3] - bboxes[:, 1] + 1) 73 | voted_bboxes = np.zeros([0, 5]) 74 | while bboxes.shape[0] > 0: 75 | xx1 = np.maximum(bboxes[0, 0], bboxes[:, 0]) 76 | yy1 = np.maximum(bboxes[0, 1], bboxes[:, 1]) 77 | xx2 = np.minimum(bboxes[0, 2], bboxes[:, 2]) 78 | yy2 = np.minimum(bboxes[0, 3], bboxes[:, 3]) 79 | w = np.maximum(0.0, xx2 - xx1 + 1) 80 | h = np.maximum(0.0, yy2 - yy1 + 1) 81 | inter = w * h 82 | overlaps = inter / (areas[0] + areas[:] - inter) 83 | merge_indexs = np.where(overlaps >= iou_thresh)[0] 84 | if merge_indexs.shape[0] == 0: 85 | bboxes = np.delete(bboxes, np.array([0]), 0) 86 | areas = np.delete(areas, np.array([0]), 0) 87 | continue 88 | bboxes_accu = bboxes[merge_indexs, :] 89 | bboxes = np.delete(bboxes, merge_indexs, 0) 90 | areas = np.delete(areas, merge_indexs, 0) 91 | # generate a new box by score voting and box voting 92 | bbox = np.zeros((1, 5)) 93 | box_weights = (bboxes_accu[:, -1] / max(bboxes_accu[:, -1])) * overlaps[ 94 | merge_indexs 95 | ] 96 | bboxes_accu[:, 0:4] = bboxes_accu[:, 0:4] * np.tile( 97 | box_weights.reshape((-1, 1)), (1, 4) 98 | ) 99 | bbox[:, 0:4] = np.sum(bboxes_accu[:, 0:4], axis=0) / (np.sum(box_weights)) 100 | bbox[0, 4] = np.sum(bboxes_accu[:, 4] * box_weights) 101 | voted_bboxes = np.row_stack((voted_bboxes, bbox)) 102 | 103 | return voted_bboxes 104 | 105 | def get_scales(self, im): 106 | im_shape = im.size 107 | im_size_min = np.min(im_shape[0:2]) 108 | 109 | scales = [float(scale) / im_size_min for scale in self.min_size] 110 | 111 | return scales 112 | 113 | def test(self): 114 | times = [] 115 | 116 | for img_path in tqdm(self.test_dataset): 117 | img_full_path = os.path.join(self.dataset_path, img_path) 118 | img = Image.open(img_full_path).convert("RGB") 119 | 120 | bboxes = [] 121 | 122 | (w, h) = img.size 123 | scales = self.get_scales(img) 124 | 125 | if self.flip: 126 | flip_list = (False, True) 127 | else: 128 | flip_list = (False,) 129 | 130 | for flip in flip_list: 131 | for scale in scales: 132 | run_img = img.copy() 133 | 134 | if flip: 135 | run_img = ImageOps.mirror(run_img) 136 | 137 | new_w = int(run_img.size[0] * scale) 138 | new_h = int(run_img.size[1] * scale) 139 | 140 | min_size = min(new_w, new_h) 141 | max_size = max(new_w, new_h) 142 | 143 | if len(scales) > 1: 144 | self.img2pose_model.fpn_model.module.set_max_min_size( 145 | max_size, min_size 146 | ) 147 | 148 | time1 = time.time() 149 | 150 | res = self.img2pose_model.predict([self.transform(run_img)]) 151 | 152 | time2 = time.time() 153 | times.append(time2 - time1) 154 | 155 | res = res[0] 156 | 157 | for i in range(len(res["scores"])): 158 | bbox = res["boxes"].cpu().numpy()[i].astype("int") 159 | score = res["scores"].cpu().numpy()[i] 160 | pose = res["dofs"].cpu().numpy()[i] 161 | 162 | if flip: 163 | bbox_copy = bbox.copy() 164 | bbox[0] = w - bbox_copy[2] 165 | bbox[2] = w - bbox_copy[0] 166 | pose[1:4] *= -1 167 | 168 | bboxes.append(np.append(bbox, score)) 169 | 170 | bboxes = np.asarray(bboxes) 171 | 172 | if len(self.min_size) > 1: 173 | bboxes = self.bbox_voting(bboxes, self.nms_threshold) 174 | 175 | if np.ndim(bboxes) == 1 and len(bboxes) > 0: 176 | bboxes = bboxes[np.newaxis, :] 177 | 178 | output_path = os.path.join(self.output_path, os.path.split(img_path)[0]) 179 | 180 | if not os.path.exists(output_path): 181 | os.makedirs(output_path) 182 | 183 | file_name = os.path.split(img_path)[-1] 184 | f = open(os.path.join(output_path, file_name[:-4] + ".txt"), "w") 185 | f.write(file_name + "\n") 186 | f.write(str(len(bboxes)) + "\n") 187 | for i in range(len(bboxes)): 188 | bbox = bboxes[i] 189 | 190 | f.write( 191 | f"{bbox[0]} {bbox[1]} {bbox[2]-bbox[0]} {bbox[3]-bbox[1]} {bbox[4]}\n" 192 | ) 193 | f.close() 194 | 195 | print(f"Average time forward pass: {np.mean(np.asarray(times))}") 196 | 197 | 198 | def parse_args(): 199 | parser = argparse.ArgumentParser( 200 | description="Train a deep network to predict 3D expression and 6DOF pose." 201 | ) 202 | parser.add_argument( 203 | "--min_size", 204 | help="Min size", 205 | default="200, 300, 500, 800, 1100, 1400, 1700", 206 | type=str, 207 | ) 208 | parser.add_argument("--max_size", help="Max size", default=1400, type=int) 209 | parser.add_argument( 210 | "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int 211 | ) 212 | parser.add_argument( 213 | "--pose_mean", 214 | help="Pose mean file path.", 215 | type=str, 216 | default="./models/WIDER_train_pose_mean_v1.npy", 217 | ) 218 | parser.add_argument( 219 | "--pose_stddev", 220 | help="Pose stddev file path.", 221 | type=str, 222 | default="./models/WIDER_train_pose_stddev_v1.npy", 223 | ) 224 | 225 | # training/validation configuration 226 | parser.add_argument("--output_path", help="Path to save predictions", required=True) 227 | parser.add_argument("--dataset_path", help="Path to the dataset", required=True) 228 | parser.add_argument("--dataset_list", help="Dataset list.") 229 | 230 | # resume from or load pretrained weights 231 | parser.add_argument( 232 | "--pretrained_path", help="Path to pretrained weights.", type=str 233 | ) 234 | parser.add_argument("--nms_threshold", default=0.6, type=float) 235 | 236 | parser.add_argument( 237 | "--threed_68_points", 238 | type=str, 239 | help="Reference 3D points to compute pose.", 240 | default="./pose_references/reference_3d_68_points_trans.npy", 241 | ) 242 | 243 | args = parser.parse_args() 244 | 245 | args.min_size = [int(item) for item in args.min_size.split(", ")] 246 | 247 | return args 248 | 249 | 250 | if __name__ == "__main__": 251 | args = parse_args() 252 | 253 | wider_eval = WIDER_Eval(args) 254 | wider_eval.test() 255 | -------------------------------------------------------------------------------- /evaluation/jupyter_notebooks/aflw_2000_3d_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2021-03-18T19:35:06.774927Z", 16 | "start_time": "2021-03-18T19:35:05.706994Z" 17 | }, 18 | "scrolled": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import sys\n", 23 | "sys.path.append('../../')\n", 24 | "import numpy as np\n", 25 | "import torch\n", 26 | "from torchvision import transforms\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "from tqdm.notebook import tqdm\n", 29 | "from PIL import Image, ImageOps\n", 30 | "from scipy.spatial.transform import Rotation\n", 31 | "import pandas as pd\n", 32 | "from scipy.spatial import distance\n", 33 | "import time\n", 34 | "import os\n", 35 | "import math\n", 36 | "import scipy.io as sio\n", 37 | "from utils.renderer import Renderer\n", 38 | "from utils.image_operations import expand_bbox_rectangle\n", 39 | "from utils.pose_operations import get_pose\n", 40 | "from img2pose import img2poseModel\n", 41 | "from model_loader import load_model\n", 42 | "\n", 43 | "np.set_printoptions(suppress=True)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## Load dataset annotations " 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": { 57 | "ExecuteTime": { 58 | "end_time": "2021-03-18T19:35:06.783194Z", 59 | "start_time": "2021-03-18T19:35:06.776810Z" 60 | }, 61 | "scrolled": true 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "dataset_path = \"AFLW2000_annotations.txt\"\n", 66 | "test_dataset = pd.read_csv(dataset_path, delimiter=\" \", header=None)\n", 67 | "test_dataset = np.asarray(test_dataset).squeeze() " 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Declare useful functions" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": { 81 | "ExecuteTime": { 82 | "end_time": "2021-03-18T19:35:06.797090Z", 83 | "start_time": "2021-03-18T19:35:06.785309Z" 84 | } 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "def bb_intersection_over_union(boxA, boxB):\n", 89 | " xA = max(boxA[0], boxB[0])\n", 90 | " yA = max(boxA[1], boxB[1])\n", 91 | " xB = min(boxA[2], boxB[2])\n", 92 | " yB = min(boxA[3], boxB[3])\n", 93 | "\n", 94 | " interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)\n", 95 | " boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)\n", 96 | " boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)\n", 97 | " iou = interArea / float(boxAArea + boxBArea - interArea)\n", 98 | " return iou\n", 99 | "\n", 100 | "def render_plot(img, pose_pred):\n", 101 | " (w, h) = img.size\n", 102 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n", 103 | "\n", 104 | " trans_vertices = renderer.transform_vertices(img, [pose_pred])\n", 105 | " img = renderer.render(img, trans_vertices, alpha=1) \n", 106 | "\n", 107 | " plt.figure(figsize=(8, 8)) \n", 108 | "\n", 109 | " plt.imshow(img) \n", 110 | " plt.show()\n", 111 | " \n", 112 | "def convert_to_aflw(rotvec, is_rotvec=True):\n", 113 | " if is_rotvec:\n", 114 | " rotvec = Rotation.from_rotvec(rotvec).as_matrix()\n", 115 | " rot_mat_2 = np.transpose(rotvec)\n", 116 | " angle = Rotation.from_matrix(rot_mat_2).as_euler('xyz', degrees=True)\n", 117 | " \n", 118 | " return np.array([angle[0], -angle[1], -angle[2]])" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Create the renderer for visualization" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 4, 131 | "metadata": { 132 | "ExecuteTime": { 133 | "end_time": "2021-03-18T19:35:06.816811Z", 134 | "start_time": "2021-03-18T19:35:06.798493Z" 135 | } 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "renderer = Renderer(\n", 140 | " vertices_path=\"../../pose_references/vertices_trans.npy\", \n", 141 | " triangles_path=\"../../pose_references/triangles.npy\"\n", 142 | ")\n", 143 | "\n", 144 | "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## Load model weights and pose mean and std deviation\n", 152 | "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "metadata": { 159 | "ExecuteTime": { 160 | "end_time": "2021-03-18T19:35:09.656797Z", 161 | "start_time": "2021-03-18T19:35:06.818766Z" 162 | } 163 | }, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "Model will use 1 GPUs!\n" 170 | ] 171 | } 172 | ], 173 | "source": [ 174 | "transform = transforms.Compose([transforms.ToTensor()])\n", 175 | "\n", 176 | "DEPTH = 18\n", 177 | "MAX_SIZE = 1400\n", 178 | "MIN_SIZE = 400\n", 179 | "\n", 180 | "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n", 181 | "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n", 182 | "MODEL_PATH = \"../../models/img2pose_v1_ft_300w_lp.pth\"\n", 183 | "\n", 184 | "\n", 185 | "pose_mean = np.load(POSE_MEAN)\n", 186 | "pose_stddev = np.load(POSE_STDDEV)\n", 187 | "\n", 188 | "\n", 189 | "img2pose_model = img2poseModel(\n", 190 | " DEPTH, MIN_SIZE, MAX_SIZE, \n", 191 | " pose_mean=pose_mean, pose_stddev=pose_stddev,\n", 192 | " threed_68_points=threed_points,\n", 193 | " rpn_pre_nms_top_n_test=500,\n", 194 | " rpn_post_nms_top_n_test=10,\n", 195 | ")\n", 196 | "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n", 197 | "img2pose_model.evaluate()" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "## Run AFLW2000-3D evaluation\n", 205 | "To visualize the predictions, change visualize to True and change total_imgs to the amount of images desired." 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 7, 211 | "metadata": { 212 | "ExecuteTime": { 213 | "end_time": "2021-03-18T19:36:41.278410Z", 214 | "start_time": "2021-03-18T19:35:34.696305Z" 215 | }, 216 | "scrolled": false 217 | }, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "application/vnd.jupyter.widget-view+json": { 222 | "model_id": "098a3413f3414f4db135d05411b36713", 223 | "version_major": 2, 224 | "version_minor": 0 225 | }, 226 | "text/plain": [ 227 | "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))" 228 | ] 229 | }, 230 | "metadata": {}, 231 | "output_type": "display_data" 232 | }, 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "\n", 238 | "Model failed on 0 images\n", 239 | "Yaw: 3.426 Pitch: 5.034 Roll: 3.278 MAE: 3.913\n", 240 | "H. Trans.: 0.028 V. Trans.: 0.038 Scale: 0.230 MAE: 0.099\n", 241 | "Average time 0.02748371853690392\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "visualize = False\n", 247 | "total_imgs = len(test_dataset)\n", 248 | "\n", 249 | "threshold = 0.0\n", 250 | "targets = []\n", 251 | "predictions = []\n", 252 | "\n", 253 | "total_failures = 0\n", 254 | "times = []\n", 255 | "\n", 256 | "for img_path in tqdm(test_dataset[:total_imgs]):\n", 257 | " img = Image.open(img_path).convert(\"RGB\")\n", 258 | "\n", 259 | " image_name = os.path.split(img_path)[1]\n", 260 | "\n", 261 | " ori_img = img.copy()\n", 262 | "\n", 263 | " (w, h) = ori_img.size\n", 264 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n", 265 | "\n", 266 | " mat_contents = sio.loadmat(img_path[:-4] + \".mat\")\n", 267 | " target_points = np.asarray(mat_contents['pt3d_68']).T[:, :2]\n", 268 | "\n", 269 | " _, pose_target = get_pose(threed_points, target_points, image_intrinsics)\n", 270 | "\n", 271 | " target_bbox = expand_bbox_rectangle(w, h, 1.1, 1.1, target_points, roll=pose_target[2])\n", 272 | "\n", 273 | " pose_para = np.asarray(mat_contents['Pose_Para'])[0][:3]\n", 274 | " pose_para_degrees = pose_para[:3] * (180 / math.pi)\n", 275 | "\n", 276 | " if np.any(np.abs(pose_para_degrees) > 99):\n", 277 | " continue \n", 278 | "\n", 279 | " run_time = 0\n", 280 | " time1 = time.time()\n", 281 | " res = img2pose_model.predict([transform(img)])\n", 282 | " time2 = time.time()\n", 283 | " run_time += (time2 - time1)\n", 284 | "\n", 285 | " res = res[0]\n", 286 | "\n", 287 | " bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n", 288 | " max_iou = 0\n", 289 | " best_index = -1\n", 290 | "\n", 291 | " for i in range(len(bboxes)):\n", 292 | " if res[\"scores\"][i] > threshold:\n", 293 | " bbox = bboxes[i]\n", 294 | " pose_pred = res[\"dofs\"].cpu().numpy()[i].astype('float')\n", 295 | " pose_pred = np.asarray(pose_pred.squeeze()) \n", 296 | " iou = bb_intersection_over_union(bbox, target_bbox)\n", 297 | "\n", 298 | " if iou > max_iou:\n", 299 | " max_iou = iou\n", 300 | " best_index = i \n", 301 | "\n", 302 | " if best_index >= 0:\n", 303 | " bbox = bboxes[best_index]\n", 304 | " pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n", 305 | " pose_pred = np.asarray(pose_pred.squeeze()) \n", 306 | "\n", 307 | " \n", 308 | " if visualize and best_index >= 0: \n", 309 | " render_plot(ori_img.copy(), pose_pred)\n", 310 | "\n", 311 | " if len(bboxes) == 0:\n", 312 | " total_failures += 1\n", 313 | "\n", 314 | " continue\n", 315 | "\n", 316 | " times.append(run_time)\n", 317 | "\n", 318 | " pose_target[:3] = pose_para_degrees \n", 319 | " pose_pred[:3] = convert_to_aflw(pose_pred[:3])\n", 320 | "\n", 321 | " targets.append(pose_target)\n", 322 | " predictions.append(pose_pred)\n", 323 | "\n", 324 | "pose_mae = np.mean(abs(np.asarray(predictions) - np.asarray(targets)), axis=0)\n", 325 | "threed_pose = pose_mae[:3]\n", 326 | "trans_pose = pose_mae[3:]\n", 327 | "\n", 328 | "print(f\"Model failed on {total_failures} images\")\n", 329 | "print(f\"Yaw: {threed_pose[1]:.3f} Pitch: {threed_pose[0]:.3f} Roll: {threed_pose[2]:.3f} MAE: {np.mean(threed_pose):.3f}\")\n", 330 | "print(f\"H. Trans.: {trans_pose[0]:.3f} V. Trans.: {trans_pose[1]:.3f} Scale: {trans_pose[2]:.3f} MAE: {np.mean(trans_pose):.3f}\")\n", 331 | "print(f\"Average time {np.mean(np.asarray(times))}\")" 332 | ] 333 | } 334 | ], 335 | "metadata": { 336 | "bento_stylesheets": { 337 | "bento/extensions/flow/main.css": true, 338 | "bento/extensions/kernel_selector/main.css": true, 339 | "bento/extensions/kernel_ui/main.css": true, 340 | "bento/extensions/new_kernel/main.css": true, 341 | "bento/extensions/system_usage/main.css": true, 342 | "bento/extensions/theme/main.css": true 343 | }, 344 | "disseminate_notebook_id": { 345 | "notebook_id": "336305087387084" 346 | }, 347 | "disseminate_notebook_info": { 348 | "bento_version": "20200629-000305", 349 | "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n", 350 | "hide_code": false, 351 | "hipster_group": "", 352 | "kernel_build_info": { 353 | "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found." 354 | }, 355 | "no_uii": true, 356 | "notebook_number": "296232", 357 | "others_can_edit": false, 358 | "reviewers": "", 359 | "revision_id": "1126967097689153", 360 | "tags": "", 361 | "tasks": "", 362 | "title": "Updated Model Pose and Expression Qualitative Results" 363 | }, 364 | "kernelspec": { 365 | "display_name": "pytorch", 366 | "language": "python", 367 | "name": "pytorch" 368 | }, 369 | "language_info": { 370 | "codemirror_mode": { 371 | "name": "ipython", 372 | "version": 3 373 | }, 374 | "file_extension": ".py", 375 | "mimetype": "text/x-python", 376 | "name": "python", 377 | "nbconvert_exporter": "python", 378 | "pygments_lexer": "ipython3", 379 | "version": "3.8.1" 380 | }, 381 | "notify_time": "30" 382 | }, 383 | "nbformat": 4, 384 | "nbformat_minor": 2 385 | } 386 | -------------------------------------------------------------------------------- /evaluation/jupyter_notebooks/biwi_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2021-03-18T19:09:02.700914Z", 16 | "start_time": "2021-03-18T19:09:01.585915Z" 17 | }, 18 | "scrolled": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import sys\n", 23 | "sys.path.append('../../')\n", 24 | "import numpy as np\n", 25 | "import torch\n", 26 | "from torchvision import transforms\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "from tqdm.notebook import tqdm\n", 29 | "from PIL import Image, ImageOps\n", 30 | "from scipy.spatial.transform import Rotation\n", 31 | "import pandas as pd\n", 32 | "from scipy.spatial import distance\n", 33 | "import time\n", 34 | "import os\n", 35 | "import math\n", 36 | "import scipy.io as sio\n", 37 | "from utils.renderer import Renderer\n", 38 | "from img2pose import img2poseModel\n", 39 | "from model_loader import load_model\n", 40 | "\n", 41 | "np.set_printoptions(suppress=True)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Declare useful functions" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": { 55 | "ExecuteTime": { 56 | "end_time": "2021-03-18T19:09:02.712136Z", 57 | "start_time": "2021-03-18T19:09:02.702535Z" 58 | } 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "def bb_intersection_over_union(boxA, boxB):\n", 63 | " xA = max(boxA[0], boxB[0])\n", 64 | " yA = max(boxA[1], boxB[1])\n", 65 | " xB = min(boxA[2], boxB[2])\n", 66 | " yB = min(boxA[3], boxB[3])\n", 67 | "\n", 68 | " interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)\n", 69 | " boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)\n", 70 | " boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)\n", 71 | " iou = interArea / float(boxAArea + boxBArea - interArea)\n", 72 | " return iou\n", 73 | "\n", 74 | "def render_plot(img, pose_pred):\n", 75 | " (w, h) = img.size\n", 76 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n", 77 | "\n", 78 | " trans_vertices = renderer.transform_vertices(img, [pose_pred])\n", 79 | " img = renderer.render(img, trans_vertices, alpha=1) \n", 80 | "\n", 81 | " plt.figure(figsize=(16, 16)) \n", 82 | "\n", 83 | " plt.imshow(img) \n", 84 | " plt.show()\n", 85 | " \n", 86 | "def convert_to_aflw(rotvec, is_rotvec=True):\n", 87 | " if is_rotvec:\n", 88 | " rotvec = Rotation.from_rotvec(rotvec).as_matrix()\n", 89 | " rot_mat_2 = np.transpose(rotvec)\n", 90 | " angle = Rotation.from_matrix(rot_mat_2).as_euler('xyz', degrees=True)\n", 91 | " \n", 92 | " return np.array([angle[0], -angle[1], -angle[2]])" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Load BIWI dataset annotations " 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2021-03-18T19:09:09.822126Z", 108 | "start_time": "2021-03-18T19:09:02.715404Z" 109 | } 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "dataset_path = \"./BIWI_annotations.txt\"\n", 114 | "dataset = pd.read_csv(dataset_path, delimiter=\" \", header=None)\n", 115 | "dataset = np.asarray(dataset).squeeze()\n", 116 | "\n", 117 | "pose_targets = []\n", 118 | "test_dataset = []\n", 119 | "\n", 120 | "for sample in dataset:\n", 121 | " img_path = sample[0]\n", 122 | " \n", 123 | " annotations = open(img_path.replace(\"_rgb.png\", \"_pose.txt\"))\n", 124 | " lines = annotations.readlines()\n", 125 | " \n", 126 | " pose_target = []\n", 127 | " for i in range(3):\n", 128 | " lines[i] = str(lines[i].rstrip(\"\\n\")) \n", 129 | " pose_target.append(lines[i].split(\" \")[:3])\n", 130 | " \n", 131 | " pose_target = np.asarray(pose_target).astype(float) \n", 132 | " pose_target = convert_to_aflw(pose_target, False)\n", 133 | " pose_targets.append(pose_target)\n", 134 | " \n", 135 | " test_dataset.append(img_path)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## Create the renderer for visualization" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 4, 148 | "metadata": { 149 | "ExecuteTime": { 150 | "end_time": "2021-03-18T19:09:09.833558Z", 151 | "start_time": "2021-03-18T19:09:09.825145Z" 152 | } 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "renderer = Renderer(\n", 157 | " vertices_path=\"../../pose_references/vertices_trans.npy\", \n", 158 | " triangles_path=\"../../pose_references/triangles.npy\"\n", 159 | ")\n", 160 | "\n", 161 | "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## Load model weights and pose mean and std deviation\n", 169 | "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 5, 175 | "metadata": { 176 | "ExecuteTime": { 177 | "end_time": "2021-03-18T19:09:12.607511Z", 178 | "start_time": "2021-03-18T19:09:09.835246Z" 179 | }, 180 | "scrolled": true 181 | }, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "Model will use 1 GPUs!\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "transform = transforms.Compose([transforms.ToTensor()])\n", 193 | "\n", 194 | "DEPTH = 18\n", 195 | "MAX_SIZE = 1400\n", 196 | "MIN_SIZE = 700\n", 197 | "\n", 198 | "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n", 199 | "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n", 200 | "MODEL_PATH = \"../../models/img2pose_v1_ft_300w_lp.pth\"\n", 201 | "pose_mean = np.load(POSE_MEAN)\n", 202 | "pose_stddev = np.load(POSE_STDDEV)\n", 203 | "\n", 204 | "img2pose_model = img2poseModel(\n", 205 | " DEPTH, MIN_SIZE, MAX_SIZE, \n", 206 | " pose_mean=pose_mean, pose_stddev=pose_stddev,\n", 207 | " threed_68_points=threed_points,\n", 208 | " rpn_pre_nms_top_n_test=500,\n", 209 | " rpn_post_nms_top_n_test=10,\n", 210 | ")\n", 211 | "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n", 212 | "img2pose_model.evaluate()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "## Run BIWI evaluation\n", 220 | "To visualize the predictions, change visualize to True and change total_imgs to the amount of images desired." 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 7, 226 | "metadata": { 227 | "ExecuteTime": { 228 | "end_time": "2021-03-18T19:20:56.102388Z", 229 | "start_time": "2021-03-18T19:09:16.285597Z" 230 | }, 231 | "scrolled": false 232 | }, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "application/vnd.jupyter.widget-view+json": { 237 | "model_id": "097fef45c4534f7b818555e66dfad113", 238 | "version_major": 2, 239 | "version_minor": 0 240 | }, 241 | "text/plain": [ 242 | "HBox(children=(FloatProgress(value=0.0, max=13219.0), HTML(value='')))" 243 | ] 244 | }, 245 | "metadata": {}, 246 | "output_type": "display_data" 247 | }, 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "\n", 253 | "Model failed on 0 images\n", 254 | "Yaw: 4.567 Pitch: 3.546 Roll: 3.244 MAE: 3.786\n", 255 | "Average time 0.034422722154255576\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "visualize = False\n", 261 | "total_imgs = len(test_dataset)\n", 262 | "threshold = 0.9\n", 263 | "\n", 264 | "predictions = []\n", 265 | "targets = []\n", 266 | "\n", 267 | "total_failures = 0\n", 268 | "times = []\n", 269 | "\n", 270 | "for j in tqdm(range(total_imgs)):\n", 271 | " img = Image.open(test_dataset[j]).convert(\"RGB\")\n", 272 | " (w, h) = img.size\n", 273 | " pose_target = pose_targets[j]\n", 274 | " ori_img = img.copy()\n", 275 | " \n", 276 | " time1 = time.time()\n", 277 | " res = img2pose_model.predict([transform(img)])\n", 278 | " time2 = time.time()\n", 279 | " times.append(time2 - time1)\n", 280 | "\n", 281 | " res = res[0]\n", 282 | "\n", 283 | " bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n", 284 | "\n", 285 | " min_dist_center = float(\"Inf\")\n", 286 | " best_index = 0\n", 287 | "\n", 288 | " if len(bboxes) == 0:\n", 289 | " total_failures += 1 \n", 290 | " continue\n", 291 | "\n", 292 | " for i in range(len(bboxes)):\n", 293 | " if res[\"scores\"][i] > threshold:\n", 294 | " bbox = bboxes[i]\n", 295 | " bbox_center_x = bbox[0] + ((bbox[2] - bbox[0]) // 2)\n", 296 | " bbox_center_y = bbox[1] + ((bbox[3] - bbox[1]) // 2)\n", 297 | "\n", 298 | " dist_center = distance.euclidean([bbox_center_x, bbox_center_y], [w // 2, h // 2])\n", 299 | "\n", 300 | " if dist_center < min_dist_center:\n", 301 | " min_dist_center = dist_center\n", 302 | " best_index = i \n", 303 | "\n", 304 | " bbox = bboxes[best_index]\n", 305 | " pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n", 306 | " pose_pred = np.asarray(pose_pred.squeeze())\n", 307 | "\n", 308 | " if best_index >= 0:\n", 309 | " bbox = bboxes[best_index]\n", 310 | " pose_pred = res[\"dofs\"].cpu().numpy()[best_index].astype('float')\n", 311 | " pose_pred = np.asarray(pose_pred.squeeze())\n", 312 | " \n", 313 | "\n", 314 | " if visualize and best_index >= 0: \n", 315 | " render_plot(ori_img.copy(), pose_pred)\n", 316 | "\n", 317 | " if len(bboxes) == 0:\n", 318 | " total_failures += 1\n", 319 | "\n", 320 | " continue\n", 321 | " \n", 322 | " pose_pred = convert_to_aflw(pose_pred[:3])\n", 323 | " \n", 324 | " predictions.append(pose_pred[:3])\n", 325 | " targets.append(pose_target[:3])\n", 326 | "\n", 327 | "pose_mae = np.mean(abs(np.asarray(predictions) - np.asarray(targets)), axis=0)\n", 328 | "threed_pose = pose_mae[:3]\n", 329 | "\n", 330 | "print(f\"Model failed on {total_failures} images\")\n", 331 | "print(f\"Yaw: {threed_pose[1]:.3f} Pitch: {threed_pose[0]:.3f} Roll: {threed_pose[2]:.3f} MAE: {np.mean(threed_pose):.3f}\")\n", 332 | "print(f\"Average time {np.mean(np.asarray(times))}\")" 333 | ] 334 | } 335 | ], 336 | "metadata": { 337 | "bento_stylesheets": { 338 | "bento/extensions/flow/main.css": true, 339 | "bento/extensions/kernel_selector/main.css": true, 340 | "bento/extensions/kernel_ui/main.css": true, 341 | "bento/extensions/new_kernel/main.css": true, 342 | "bento/extensions/system_usage/main.css": true, 343 | "bento/extensions/theme/main.css": true 344 | }, 345 | "disseminate_notebook_id": { 346 | "notebook_id": "336305087387084" 347 | }, 348 | "disseminate_notebook_info": { 349 | "bento_version": "20200629-000305", 350 | "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n", 351 | "hide_code": false, 352 | "hipster_group": "", 353 | "kernel_build_info": { 354 | "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found." 355 | }, 356 | "no_uii": true, 357 | "notebook_number": "296232", 358 | "others_can_edit": false, 359 | "reviewers": "", 360 | "revision_id": "1126967097689153", 361 | "tags": "", 362 | "tasks": "", 363 | "title": "Updated Model Pose and Expression Qualitative Results" 364 | }, 365 | "kernelspec": { 366 | "display_name": "pytorch", 367 | "language": "python", 368 | "name": "pytorch" 369 | }, 370 | "language_info": { 371 | "codemirror_mode": { 372 | "name": "ipython", 373 | "version": 3 374 | }, 375 | "file_extension": ".py", 376 | "mimetype": "text/x-python", 377 | "name": "python", 378 | "nbconvert_exporter": "python", 379 | "pygments_lexer": "ipython3", 380 | "version": "3.8.1" 381 | }, 382 | "notify_time": "30" 383 | }, 384 | "nbformat": 4, 385 | "nbformat_minor": 2 386 | } 387 | -------------------------------------------------------------------------------- /evaluation/jupyter_notebooks/test_own_images.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2020-12-17T00:06:41.333846Z", 16 | "start_time": "2020-12-17T00:06:40.238883Z" 17 | }, 18 | "scrolled": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import sys\n", 23 | "sys.path.append('../../')\n", 24 | "import numpy as np\n", 25 | "import torch\n", 26 | "from torchvision import transforms\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "from tqdm.notebook import tqdm\n", 29 | "from PIL import Image, ImageOps\n", 30 | "import matplotlib.patches as patches\n", 31 | "from scipy.spatial.transform import Rotation\n", 32 | "import pandas as pd\n", 33 | "from scipy.spatial import distance\n", 34 | "import time\n", 35 | "import os\n", 36 | "import math\n", 37 | "import scipy.io as sio\n", 38 | "from utils.renderer import Renderer\n", 39 | "from utils.image_operations import expand_bbox_rectangle\n", 40 | "from utils.pose_operations import get_pose\n", 41 | "from img2pose import img2poseModel\n", 42 | "from model_loader import load_model\n", 43 | "\n", 44 | "np.set_printoptions(suppress=True)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Declare useful functions" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "ExecuteTime": { 59 | "end_time": "2020-12-17T00:06:41.340777Z", 60 | "start_time": "2020-12-17T00:06:41.335946Z" 61 | } 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "def render_plot(img, poses, bboxes):\n", 66 | " (w, h) = img.size\n", 67 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n", 68 | " \n", 69 | " trans_vertices = renderer.transform_vertices(img, poses)\n", 70 | " img = renderer.render(img, trans_vertices, alpha=1) \n", 71 | "\n", 72 | " plt.figure(figsize=(8, 8)) \n", 73 | " \n", 74 | " for bbox in bboxes:\n", 75 | " plt.gca().add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1],linewidth=3,edgecolor='b',facecolor='none')) \n", 76 | " \n", 77 | " plt.imshow(img) \n", 78 | " plt.show()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "## Create the renderer for visualization" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "ExecuteTime": { 93 | "end_time": "2020-12-17T00:06:41.357319Z", 94 | "start_time": "2020-12-17T00:06:41.342763Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "renderer = Renderer(\n", 100 | " vertices_path=\"../../pose_references/vertices_trans.npy\", \n", 101 | " triangles_path=\"../../pose_references/triangles.npy\"\n", 102 | ")\n", 103 | "\n", 104 | "threed_points = np.load('../../pose_references/reference_3d_68_points_trans.npy')" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "## Load model weights and pose mean and std deviation\n", 112 | "To test other models, change MODEL_PATH along the the POSE_MEAN and POSE_STDDEV used for training" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "ExecuteTime": { 120 | "end_time": "2020-12-17T00:06:44.297174Z", 121 | "start_time": "2020-12-17T00:06:41.359153Z" 122 | }, 123 | "scrolled": true 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "transform = transforms.Compose([transforms.ToTensor()])\n", 128 | "\n", 129 | "DEPTH = 18\n", 130 | "MAX_SIZE = 1400\n", 131 | "MIN_SIZE = 600\n", 132 | "\n", 133 | "POSE_MEAN = \"../../models/WIDER_train_pose_mean_v1.npy\"\n", 134 | "POSE_STDDEV = \"../../models/WIDER_train_pose_stddev_v1.npy\"\n", 135 | "MODEL_PATH = \"../../models/img2pose_v1.pth\"\n", 136 | "\n", 137 | "pose_mean = np.load(POSE_MEAN)\n", 138 | "pose_stddev = np.load(POSE_STDDEV)\n", 139 | "\n", 140 | "img2pose_model = img2poseModel(\n", 141 | " DEPTH, MIN_SIZE, MAX_SIZE, \n", 142 | " pose_mean=pose_mean, pose_stddev=pose_stddev,\n", 143 | " threed_68_points=threed_points,\n", 144 | ")\n", 145 | "load_model(img2pose_model.fpn_model, MODEL_PATH, cpu_mode=str(img2pose_model.device) == \"cpu\", model_only=True)\n", 146 | "img2pose_model.evaluate()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "## Run on a folder or an image list\n", 154 | "Give it a list with images paths, or a folder containing images" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "ExecuteTime": { 162 | "end_time": "2020-12-17T00:06:54.309068Z", 163 | "start_time": "2020-12-17T00:06:51.244841Z" 164 | }, 165 | "scrolled": false 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "# change to a folder with images, or another list containing image paths\n", 170 | "images_path = \"your_own_image_folder/list\"\n", 171 | "\n", 172 | "threshold = 0.9\n", 173 | "\n", 174 | "if os.path.isfile(images_path):\n", 175 | " img_paths = pd.read_csv(images_path, delimiter=\" \", header=None)\n", 176 | " img_paths = np.asarray(img_paths).squeeze()\n", 177 | "else:\n", 178 | " img_paths = [os.path.join(images_path, img_path) for img_path in os.listdir(images_path)]\n", 179 | "\n", 180 | "for img_path in tqdm(img_paths):\n", 181 | " img = Image.open(img_path).convert(\"RGB\")\n", 182 | " \n", 183 | " image_name = os.path.split(img_path)[1]\n", 184 | " \n", 185 | " (w, h) = img.size\n", 186 | " image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])\n", 187 | " \n", 188 | " res = img2pose_model.predict([transform(img)])[0]\n", 189 | "\n", 190 | " all_bboxes = res[\"boxes\"].cpu().numpy().astype('float')\n", 191 | "\n", 192 | " poses = []\n", 193 | " bboxes = []\n", 194 | " for i in range(len(all_bboxes)):\n", 195 | " if res[\"scores\"][i] > threshold:\n", 196 | " bbox = all_bboxes[i]\n", 197 | " pose_pred = res[\"dofs\"].cpu().numpy()[i].astype('float')\n", 198 | " pose_pred = pose_pred.squeeze()\n", 199 | "\n", 200 | " poses.append(pose_pred) \n", 201 | " bboxes.append(bbox)\n", 202 | "\n", 203 | " render_plot(img.copy(), poses, bboxes)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [] 212 | } 213 | ], 214 | "metadata": { 215 | "bento_stylesheets": { 216 | "bento/extensions/flow/main.css": true, 217 | "bento/extensions/kernel_selector/main.css": true, 218 | "bento/extensions/kernel_ui/main.css": true, 219 | "bento/extensions/new_kernel/main.css": true, 220 | "bento/extensions/system_usage/main.css": true, 221 | "bento/extensions/theme/main.css": true 222 | }, 223 | "disseminate_notebook_id": { 224 | "notebook_id": "336305087387084" 225 | }, 226 | "disseminate_notebook_info": { 227 | "bento_version": "20200629-000305", 228 | "description": "Visualization of smaller model pose and expression qualitative results (trial 4).\nResNet-18 with sum of squared errors weighted equally for both pose and expression.\nFC layers for both pose and expression are fc1 512x512 and fc2 512 x output (output is either 6 or 72).\n", 229 | "hide_code": false, 230 | "hipster_group": "", 231 | "kernel_build_info": { 232 | "error": "The file located at '/data/users/valbiero/fbsource/fbcode/bento/kernels/local/deep_3d_face_modeling/TARGETS' could not be found." 233 | }, 234 | "no_uii": true, 235 | "notebook_number": "296232", 236 | "others_can_edit": false, 237 | "reviewers": "", 238 | "revision_id": "1126967097689153", 239 | "tags": "", 240 | "tasks": "", 241 | "title": "Updated Model Pose and Expression Qualitative Results" 242 | }, 243 | "kernelspec": { 244 | "display_name": "test", 245 | "language": "python", 246 | "name": "test" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.9.1" 259 | }, 260 | "notify_time": "30" 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 2 264 | } 265 | -------------------------------------------------------------------------------- /generalized_rcnn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.jit.annotations import Dict, List, Optional, Tuple 7 | 8 | 9 | class GeneralizedRCNN(nn.Module): 10 | """ 11 | Main class for Generalized R-CNN. 12 | 13 | Arguments: 14 | backbone (nn.Module): 15 | rpn (nn.Module): 16 | roi_heads (nn.Module): takes the features + the proposals from the RPN 17 | and computes detections / masks from it. 18 | transform (nn.Module): performs the data transformation from the inputs 19 | to feed into the model 20 | """ 21 | 22 | def __init__(self, backbone, rpn, roi_heads, transform): 23 | super(GeneralizedRCNN, self).__init__() 24 | self.transform = transform 25 | self.backbone = backbone 26 | self.rpn = rpn 27 | self.roi_heads = roi_heads 28 | # used only on torchscript mode 29 | self._has_warned = False 30 | 31 | @torch.jit.unused 32 | def eager_outputs(self, losses, detections, evaluating): 33 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) 34 | # -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 35 | if evaluating: 36 | return losses 37 | 38 | return detections 39 | 40 | def forward(self, images, targets=None): 41 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) 42 | # -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 43 | """ 44 | Arguments: 45 | images (list[Tensor]): images to be processed 46 | targets (list[Dict[Tensor]]): ground-truth (optional) 47 | 48 | Returns: 49 | result (list[BoxList] or dict[Tensor]): the output from the model. 50 | During training, it returns a dict[Tensor] which contains the losses. 51 | During testing, it returns list[BoxList] contains additional fields 52 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 53 | 54 | """ 55 | if self.training and targets is None: 56 | raise ValueError("In training mode, targets should be passed") 57 | if self.training or targets is not None: 58 | assert targets is not None 59 | for target in targets: 60 | boxes = target["boxes"] 61 | if isinstance(boxes, torch.Tensor): 62 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4: 63 | raise ValueError( 64 | "Expected target boxes to be a tensor" 65 | "of shape [N, 4], got {:}.".format(boxes.shape) 66 | ) 67 | else: 68 | raise ValueError( 69 | "Expected target boxes to be of type " 70 | "Tensor, got {:}.".format(type(boxes)) 71 | ) 72 | 73 | original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) 74 | for img in images: 75 | val = img.shape[-2:] 76 | assert len(val) == 2 77 | original_image_sizes.append((val[0], val[1])) 78 | 79 | images, targets = self.transform(images, targets) 80 | 81 | # Check for degenerate boxes 82 | # TODO: Move this to a function 83 | if targets is not None: 84 | for target_idx, target in enumerate(targets): 85 | boxes = target["boxes"] 86 | degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] 87 | if degenerate_boxes.any(): 88 | # print the first degenrate box 89 | bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] 90 | degen_bb: List[float] = boxes[bb_idx].tolist() 91 | raise ValueError( 92 | "All bounding boxes should have positive height and width." 93 | " Found invaid box {} for target at index {}.".format( 94 | degen_bb, target_idx 95 | ) 96 | ) 97 | 98 | features = self.backbone(images.tensors) 99 | if isinstance(features, torch.Tensor): 100 | features = OrderedDict([("0", features)]) 101 | proposals, proposal_losses = self.rpn(images, features, targets) 102 | detections, detector_losses = self.roi_heads( 103 | features, proposals, images.image_sizes, targets 104 | ) 105 | detections = self.transform.postprocess( 106 | detections, images.image_sizes, original_image_sizes 107 | ) 108 | 109 | losses = {} 110 | losses.update(detector_losses) 111 | losses.update(proposal_losses) 112 | 113 | if torch.jit.is_scripting(): 114 | if not self._has_warned: 115 | warnings.warn( 116 | "RCNN always returns a (Losses, Detections) tuple in scripting" 117 | ) 118 | self._has_warned = True 119 | return (losses, detections) 120 | else: 121 | return self.eager_outputs(losses, detections, targets is not None) 122 | -------------------------------------------------------------------------------- /img2pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import DataParallel, Module 3 | from torch.nn.parallel import DistributedDataParallel 4 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 5 | 6 | from model_loader import load_model 7 | from models import FasterDoFRCNN 8 | 9 | 10 | class WrappedModel(Module): 11 | def __init__(self, module): 12 | super(WrappedModel, self).__init__() 13 | self.module = module 14 | 15 | def forward(self, images, targets=None): 16 | return self.module(images, targets) 17 | 18 | 19 | class img2poseModel: 20 | def __init__( 21 | self, 22 | depth, 23 | min_size, 24 | max_size, 25 | model_path=None, 26 | device=None, 27 | pose_mean=None, 28 | pose_stddev=None, 29 | distributed=False, 30 | gpu=0, 31 | threed_68_points=None, 32 | threed_5_points=None, 33 | rpn_pre_nms_top_n_test=6000, 34 | rpn_post_nms_top_n_test=1000, 35 | bbox_x_factor=1.1, 36 | bbox_y_factor=1.1, 37 | expand_forehead=0.3, 38 | ): 39 | self.depth = depth 40 | self.min_size = min_size 41 | self.max_size = max_size 42 | self.model_path = model_path 43 | self.distributed = distributed 44 | self.gpu = gpu 45 | 46 | if device is None: 47 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 48 | else: 49 | self.device = device 50 | 51 | # create network backbone 52 | backbone = resnet_fpn_backbone(f"resnet{self.depth}", pretrained=True) 53 | 54 | if pose_mean is not None: 55 | pose_mean = torch.tensor(pose_mean) 56 | pose_stddev = torch.tensor(pose_stddev) 57 | 58 | if threed_68_points is not None: 59 | threed_68_points = torch.tensor(threed_68_points) 60 | 61 | if threed_5_points is not None: 62 | threed_5_points = torch.tensor(threed_5_points) 63 | 64 | # create the feature pyramid network 65 | self.fpn_model = FasterDoFRCNN( 66 | backbone, 67 | 2, 68 | min_size=self.min_size, 69 | max_size=self.max_size, 70 | pose_mean=pose_mean, 71 | pose_stddev=pose_stddev, 72 | threed_68_points=threed_68_points, 73 | threed_5_points=threed_5_points, 74 | rpn_pre_nms_top_n_test=rpn_pre_nms_top_n_test, 75 | rpn_post_nms_top_n_test=rpn_post_nms_top_n_test, 76 | bbox_x_factor=bbox_x_factor, 77 | bbox_y_factor=bbox_y_factor, 78 | expand_forehead=expand_forehead, 79 | ) 80 | 81 | # if using cpu, remove the parallel modules from the saved model 82 | self.fpn_model_without_ddp = self.fpn_model 83 | 84 | if self.distributed: 85 | self.fpn_model = self.fpn_model.to(self.device) 86 | self.fpn_model = DistributedDataParallel( 87 | self.fpn_model, device_ids=[self.gpu] 88 | ) 89 | self.fpn_model_without_ddp = self.fpn_model.module 90 | 91 | print("Model will use distributed mode!") 92 | 93 | elif str(self.device) == "cpu": 94 | self.fpn_model = WrappedModel(self.fpn_model) 95 | self.fpn_model_without_ddp = self.fpn_model 96 | 97 | print("Model will run on CPU!") 98 | 99 | else: 100 | self.fpn_model = DataParallel(self.fpn_model) 101 | self.fpn_model = self.fpn_model.to(self.device) 102 | self.fpn_model_without_ddp = self.fpn_model 103 | 104 | print(f"Model will use {torch.cuda.device_count()} GPUs!") 105 | 106 | if self.model_path is not None: 107 | self.load_saved_model(self.model_path) 108 | self.evaluate() 109 | 110 | def load_saved_model(self, model_path): 111 | load_model( 112 | self.fpn_model_without_ddp, model_path, cpu_mode=str(self.device) == "cpu" 113 | ) 114 | 115 | def evaluate(self): 116 | self.fpn_model.eval() 117 | 118 | def train(self): 119 | self.fpn_model.train() 120 | 121 | def run_model(self, imgs, targets=None): 122 | outputs = self.fpn_model(imgs, targets) 123 | 124 | return outputs 125 | 126 | def forward(self, imgs, targets): 127 | losses = self.run_model(imgs, targets) 128 | 129 | return losses 130 | 131 | def predict(self, imgs): 132 | assert self.fpn_model.training is False 133 | 134 | with torch.no_grad(): 135 | predictions = self.run_model(imgs) 136 | 137 | return predictions 138 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from itertools import chain, repeat 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from utils.pose_operations import plot_3d_landmark_torch, pose_full_image_to_bbox 7 | 8 | 9 | def fastrcnn_loss( 10 | class_logits, 11 | class_labels, 12 | dof_regression, 13 | labels, 14 | dof_regression_targets, 15 | proposals, 16 | image_shapes, 17 | pose_mean=None, 18 | pose_stddev=None, 19 | threed_points=None, 20 | ): 21 | # # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] 22 | """ 23 | Computes the loss for Faster R-CNN. 24 | 25 | Arguments: 26 | class_logits (Tensor) 27 | dof_regression (Tensor) 28 | labels (list[BoxList]) 29 | regression_targets (Tensor) 30 | 31 | Returns: 32 | classification_loss (Tensor) 33 | dof_loss (Tensor) 34 | points_loss (Tensor) 35 | """ 36 | img_size = [ 37 | (boxes_in_image.shape[0], image_shapes[i]) 38 | for i, boxes_in_image in enumerate(proposals) 39 | ] 40 | img_size = list(chain.from_iterable(repeat(j, i) for i, j in img_size)) 41 | 42 | labels = torch.cat(labels, dim=0) 43 | class_labels = torch.cat(class_labels, dim=0) 44 | dof_regression_targets = torch.cat(dof_regression_targets, dim=0) 45 | proposals = torch.cat(proposals, dim=0) 46 | classification_loss = F.cross_entropy(class_logits, class_labels) 47 | 48 | # get indices that correspond to the regression targets for 49 | # the corresponding ground truth labels, to be used with 50 | # advanced indexing 51 | sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) 52 | labels_pos = labels[sampled_pos_inds_subset] 53 | N = dof_regression.shape[0] 54 | dof_regression = dof_regression.reshape(N, -1, 6) 55 | dof_regression = dof_regression[sampled_pos_inds_subset, labels_pos] 56 | prop_regression = proposals[sampled_pos_inds_subset] 57 | 58 | dof_regression_targets = dof_regression_targets[sampled_pos_inds_subset] 59 | 60 | all_target_calibration_points = None 61 | all_pred_calibration_points = None 62 | 63 | for i in range(prop_regression.shape[0]): 64 | (h, w) = img_size[i] 65 | global_intrinsics = torch.Tensor( 66 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 67 | ).to(proposals[0].device) 68 | 69 | threed_points = threed_points.to(proposals[0].device) 70 | 71 | h = prop_regression[i, 3] - prop_regression[i, 1] 72 | w = prop_regression[i, 2] - prop_regression[i, 0] 73 | local_intrinsics = torch.Tensor( 74 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 75 | ).to(proposals[0].device) 76 | 77 | # calibration points projection 78 | local_dof_regression = ( 79 | dof_regression[i, :] * pose_stddev.to(proposals[0].device) 80 | ) + pose_mean.to(proposals[0].device) 81 | 82 | pred_calibration_points = plot_3d_landmark_torch( 83 | threed_points, local_dof_regression.float(), local_intrinsics 84 | ).unsqueeze(0) 85 | 86 | # pose convertion for pose loss 87 | dof_regression_targets[i, :] = torch.from_numpy( 88 | pose_full_image_to_bbox( 89 | dof_regression_targets[i, :].cpu().numpy(), 90 | global_intrinsics.cpu().numpy(), 91 | prop_regression[i, :].cpu().numpy(), 92 | ) 93 | ).to(proposals[0].device) 94 | 95 | # target calibration points projection 96 | target_calibration_points = plot_3d_landmark_torch( 97 | threed_points, dof_regression_targets[i, :], local_intrinsics 98 | ).unsqueeze(0) 99 | 100 | if all_target_calibration_points is None: 101 | all_target_calibration_points = target_calibration_points 102 | else: 103 | all_target_calibration_points = torch.cat( 104 | (all_target_calibration_points, target_calibration_points) 105 | ) 106 | if all_pred_calibration_points is None: 107 | all_pred_calibration_points = pred_calibration_points 108 | else: 109 | all_pred_calibration_points = torch.cat( 110 | (all_pred_calibration_points, pred_calibration_points) 111 | ) 112 | 113 | if pose_mean is not None: 114 | dof_regression_targets[i, :] = ( 115 | dof_regression_targets[i, :] - pose_mean.to(proposals[0].device) 116 | ) / pose_stddev.to(proposals[0].device) 117 | 118 | points_loss = F.l1_loss(all_target_calibration_points, all_pred_calibration_points) 119 | 120 | dof_loss = ( 121 | F.mse_loss( 122 | dof_regression, 123 | dof_regression_targets, 124 | reduction="sum", 125 | ) 126 | / dof_regression.shape[0] 127 | ) 128 | 129 | return classification_loss, dof_loss, points_loss 130 | -------------------------------------------------------------------------------- /model_loader.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import torch 4 | 5 | try: 6 | from utils.dist import is_main_process 7 | except Exception as e: 8 | print(e) 9 | 10 | 11 | def save_model(fpn_model, optimizer, config, val_loss=0, step=0, model_only=False): 12 | if is_main_process(): 13 | save_path = config.model_path 14 | 15 | if model_only: 16 | torch.save( 17 | {"fpn_model": fpn_model.state_dict()}, 18 | path.join(save_path, f"model_val_loss_{val_loss:.4f}_step_{step}.pth"), 19 | ) 20 | else: 21 | torch.save( 22 | { 23 | "fpn_model": fpn_model.state_dict(), 24 | "optimizer": optimizer.state_dict(), 25 | }, 26 | path.join(save_path, f"model_val_loss_{val_loss:.4f}_step_{step}.pth"), 27 | ) 28 | 29 | 30 | def load_model(fpn_model, model_path, model_only=True, optimizer=None, cpu_mode=False): 31 | if cpu_mode: 32 | checkpoint = torch.load(model_path, map_location="cpu") 33 | else: 34 | checkpoint = torch.load(model_path) 35 | 36 | fpn_model.load_state_dict(checkpoint["fpn_model"]) 37 | 38 | if not model_only: 39 | optimizer.load_state_dict(checkpoint["optimizer"]) 40 | -------------------------------------------------------------------------------- /pose_references/reference_3d_5_points_trans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/reference_3d_5_points_trans.npy -------------------------------------------------------------------------------- /pose_references/reference_3d_68_points_trans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/reference_3d_68_points_trans.npy -------------------------------------------------------------------------------- /pose_references/triangles.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/triangles.npy -------------------------------------------------------------------------------- /pose_references/vertices_trans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/pose_references/vertices_trans.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | argon2-cffi==20.1.0 3 | async-generator==1.10 4 | attrs==20.3.0 5 | backcall==0.2.0 6 | bleach==3.2.1 7 | cachetools==4.2.0 8 | certifi==2020.12.5 9 | cffi==1.14.4 10 | chardet==3.0.4 11 | cycler==0.10.0 12 | Cython==0.29.21 13 | dataclasses==0.6 14 | decorator==4.4.2 15 | defusedxml==0.6.0 16 | easydict==1.9 17 | entrypoints==0.3 18 | google-auth==1.24.0 19 | google-auth-oauthlib==0.4.2 20 | grpcio==1.34.0 21 | idna==2.10 22 | IProgress==0.4 23 | ipykernel==5.4.2 24 | ipython==7.19.0 25 | ipython-genutils==0.2.0 26 | ipywidgets==7.5.1 27 | jedi==0.17.2 28 | Jinja2==2.11.2 29 | jsonschema==3.2.0 30 | jupyter-client==6.1.7 31 | jupyter-core==4.7.0 32 | jupyterlab-pygments==0.1.2 33 | kiwisolver==1.3.1 34 | lmdb==1.0.0 35 | Markdown==3.3.3 36 | MarkupSafe==1.1.1 37 | matplotlib==3.3.3 38 | mistune==0.8.4 39 | msgpack==1.0.1 40 | nbclient==0.5.1 41 | nbconvert==6.0.7 42 | nbformat==5.0.8 43 | nest-asyncio==1.4.3 44 | notebook==6.1.5 45 | numpy==1.19.4 46 | oauthlib==3.1.0 47 | opencv-python==4.4.0.46 48 | packaging==20.8 49 | pandas==1.1.5 50 | pandocfilters==1.4.3 51 | parso==0.7.1 52 | pexpect==4.8.0 53 | pickleshare==0.7.5 54 | Pillow==8.0.1 55 | prometheus-client==0.9.0 56 | prompt-toolkit==3.0.8 57 | protobuf==3.14.0 58 | ptyprocess==0.6.0 59 | pyasn1==0.4.8 60 | pyasn1-modules==0.2.8 61 | pycparser==2.20 62 | Pygments==2.7.3 63 | pyparsing==2.4.7 64 | pyrsistent==0.17.3 65 | python-dateutil==2.8.1 66 | pytz==2020.4 67 | pyzmq==20.0.0 68 | requests==2.25.0 69 | requests-oauthlib==1.3.0 70 | rsa==4.6 71 | scikit-image==0.18.0 72 | scipy==1.5.4 73 | Send2Trash==1.5.0 74 | six==1.15.0 75 | tensorboard==2.4.0 76 | tensorboard-plugin-wit==1.7.0 77 | terminado==0.9.1 78 | testpath==0.4.4 79 | torch==1.7.1 80 | torchvision==0.8.2 81 | tornado==6.1 82 | tqdm==4.54.1 83 | traitlets==5.0.5 84 | typing-extensions==3.7.4.3 85 | urllib3==1.26.2 86 | wcwidth==0.2.5 87 | webencodings==0.5.1 88 | Werkzeug==1.0.1 89 | widgetsnbextension==3.5.1 90 | -------------------------------------------------------------------------------- /run_face_alignment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | from img2pose import img2poseModel 11 | from model_loader import load_model 12 | from utils.pose_operations import align_faces 13 | 14 | 15 | class img2pose: 16 | def __init__(self, args): 17 | self.threed_5_points = np.load(args.threed_5_points) 18 | self.threed_68_points = np.load(args.threed_68_points) 19 | self.nms_threshold = args.nms_threshold 20 | 21 | self.pose_mean = np.load(args.pose_mean) 22 | self.pose_stddev = np.load(args.pose_stddev) 23 | self.model = self.create_model(args) 24 | 25 | self.transform = transforms.Compose([transforms.ToTensor()]) 26 | self.min_size = (args.min_size,) 27 | self.max_size = args.max_size 28 | 29 | self.max_faces = args.max_faces 30 | self.face_size = args.face_size 31 | self.order_method = args.order_method 32 | self.det_threshold = args.det_threshold 33 | 34 | images_path = args.images_path 35 | if os.path.isfile(images_path): 36 | self.image_list = pd.read_csv(images_path, delimiter=" ", header=None) 37 | self.image_list = np.asarray(self.image_list).squeeze() 38 | else: 39 | self.image_list = [ 40 | os.path.join(images_path, img_path) 41 | for img_path in os.listdir(images_path) 42 | ] 43 | 44 | self.output_path = args.output_path 45 | 46 | def create_model(self, args): 47 | img2pose_model = img2poseModel( 48 | args.depth, 49 | args.min_size, 50 | args.max_size, 51 | pose_mean=self.pose_mean, 52 | pose_stddev=self.pose_stddev, 53 | threed_68_points=self.threed_68_points, 54 | ) 55 | load_model( 56 | img2pose_model.fpn_model, 57 | args.pretrained_path, 58 | cpu_mode=str(img2pose_model.device) == "cpu", 59 | model_only=True, 60 | ) 61 | img2pose_model.evaluate() 62 | 63 | return img2pose_model 64 | 65 | def align(self): 66 | for img_path in tqdm(self.image_list): 67 | image_name = os.path.split(img_path)[-1] 68 | img = Image.open(img_path).convert("RGB") 69 | 70 | res = self.model.predict([self.transform(img)])[0] 71 | 72 | all_scores = res["scores"].cpu().numpy().astype("float") 73 | all_poses = res["dofs"].cpu().numpy().astype("float") 74 | 75 | all_poses = all_poses[all_scores > self.det_threshold] 76 | all_scores = all_scores[all_scores > self.det_threshold] 77 | 78 | if len(all_poses) > 0: 79 | if self.order_method == "confidence": 80 | order = np.argsort(all_scores)[::-1] 81 | 82 | elif self.order_method == "position": 83 | distance_center = np.sqrt( 84 | all_poses[:, 3] ** 2 85 | + all_poses[:, 4] ** 2 86 | ) 87 | 88 | order = np.argsort(distance_center) 89 | 90 | top_poses = all_poses[order][: self.max_faces] 91 | 92 | sub_folder = os.path.basename( 93 | os.path.normpath(os.path.split(img_path)[0]) 94 | ) 95 | output_path = os.path.join(args.output_path, sub_folder) 96 | if not os.path.exists(output_path): 97 | os.makedirs(output_path) 98 | 99 | for i in range(len(top_poses)): 100 | save_name = image_name 101 | if len(top_poses) > 1: 102 | name, ext = image_name.split(".") 103 | save_name = f"{name}_{i}.{ext}" 104 | 105 | aligned_face = align_faces(self.threed_5_points, img, top_poses[i])[ 106 | 0 107 | ] 108 | aligned_face = aligned_face.resize((self.face_size, self.face_size)) 109 | aligned_face.save(os.path.join(output_path, save_name)) 110 | else: 111 | print(f"No face detected above the threshold {self.det_threshold}!") 112 | 113 | 114 | def parse_args(): 115 | parser = argparse.ArgumentParser( 116 | description="Align top n faces ordering by score or distance to image center." 117 | ) 118 | parser.add_argument("--max_faces", help="Top n faces to save.", default=1, type=int) 119 | parser.add_argument( 120 | "--order_method", 121 | help="How to order faces [confidence, position].", 122 | default="position", 123 | type=str, 124 | ) 125 | parser.add_argument( 126 | "--face_size", 127 | help="Image size to save aligned faces [112 or 224].", 128 | default=224, 129 | type=int, 130 | ) 131 | parser.add_argument("--min_size", help="Image min size", default=400, type=int) 132 | parser.add_argument("--max_size", help="Image max size", default=1400, type=int) 133 | parser.add_argument( 134 | "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int 135 | ) 136 | parser.add_argument( 137 | "--pose_mean", 138 | help="Pose mean file path.", 139 | type=str, 140 | default="./models/WIDER_train_pose_mean_v1.npy", 141 | ) 142 | parser.add_argument( 143 | "--pose_stddev", 144 | help="Pose stddev file path.", 145 | type=str, 146 | default="./models/WIDER_train_pose_stddev_v1.npy", 147 | ) 148 | 149 | parser.add_argument( 150 | "--pretrained_path", 151 | help="Path to pretrained weights.", 152 | type=str, 153 | default="./models/img2pose_v1.pth", 154 | ) 155 | 156 | parser.add_argument( 157 | "--threed_5_points", 158 | type=str, 159 | help="Reference 3D points to align the face.", 160 | default="./pose_references/reference_3d_5_points_trans.npy", 161 | ) 162 | 163 | parser.add_argument( 164 | "--threed_68_points", 165 | type=str, 166 | help="Reference 3D points to project bbox.", 167 | default="./pose_references/reference_3d_68_points_trans.npy", 168 | ) 169 | 170 | parser.add_argument("--nms_threshold", default=0.6, type=float) 171 | parser.add_argument( 172 | "--det_threshold", help="Detection threshold.", default=0.7, type=float 173 | ) 174 | parser.add_argument("--images_path", help="Image list, or folder.", required=True) 175 | parser.add_argument("--output_path", help="Path to save predictions", required=True) 176 | 177 | args = parser.parse_args() 178 | 179 | if not os.path.exists(args.output_path): 180 | os.makedirs(args.output_path) 181 | 182 | return args 183 | 184 | 185 | if __name__ == "__main__": 186 | args = parse_args() 187 | 188 | img2pose = img2pose(args) 189 | img2pose.align() 190 | -------------------------------------------------------------------------------- /teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vitoralbiero/img2pose/fd5473efb83d78c530afc7db12b7c2aa631ea5cb/teaser.jpeg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from os import path 4 | 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from config import Config 12 | from data_loader_lmdb import LMDBDataLoader 13 | from data_loader_lmdb_augmenter import LMDBDataLoaderAugmenter 14 | from early_stop import EarlyStop 15 | from img2pose import img2poseModel 16 | from model_loader import load_model, save_model 17 | from train_logger import TrainLogger 18 | from utils.dist import init_distributed_mode, is_main_process, reduce_dict 19 | 20 | 21 | class Train: 22 | def __init__(self, config): 23 | self.config = config 24 | 25 | if is_main_process(): 26 | # start tensorboard summary writer 27 | self.writer = SummaryWriter(config.log_path) 28 | 29 | # load training dataset generator 30 | if self.config.random_flip or self.config.random_crop: 31 | self.train_loader = LMDBDataLoaderAugmenter( 32 | self.config, self.config.train_source 33 | ) 34 | else: 35 | self.train_loader = LMDBDataLoader(self.config, self.config.train_source) 36 | print(f"Training with {len(self.train_loader.dataset)} images.") 37 | 38 | # loads validation dataset generator if a validation dataset is given 39 | if self.config.val_source is not None: 40 | self.val_loader = LMDBDataLoader(self.config, self.config.val_source, False) 41 | 42 | # creates model 43 | self.img2pose_model = img2poseModel( 44 | depth=self.config.depth, 45 | min_size=self.config.min_size, 46 | max_size=self.config.max_size, 47 | device=self.config.device, 48 | pose_mean=self.config.pose_mean, 49 | pose_stddev=self.config.pose_stddev, 50 | distributed=self.config.distributed, 51 | gpu=self.config.gpu, 52 | threed_68_points=np.load(self.config.threed_68_points), 53 | threed_5_points=np.load(self.config.threed_5_points), 54 | ) 55 | # optimizer for the backbone and heads 56 | if args.optimizer == "Adam": 57 | self.optimizer = optim.Adam( 58 | self.img2pose_model.fpn_model.parameters(), 59 | lr=self.config.lr, 60 | weight_decay=self.config.weight_decay, 61 | ) 62 | elif args.optimizer == "SGD": 63 | self.optimizer = optim.SGD( 64 | self.img2pose_model.fpn_model.parameters(), 65 | lr=self.config.lr, 66 | weight_decay=self.config.weight_decay, 67 | momentum=self.config.momentum, 68 | ) 69 | else: 70 | raise Exception("No optimizer founded, please select between SGD or Adam.") 71 | 72 | # loads a model with optimizer so that it can continue training where it stopped 73 | if self.config.resume_path: 74 | print(f"Resuming training from {self.config.resume_path}") 75 | load_model( 76 | self.img2pose_model.fpn_model, 77 | self.config.resume_path, 78 | model_only=False, 79 | optimizer=self.optimizer, 80 | cpu_mode=str(self.config.device) == "cpu", 81 | ) 82 | 83 | # loads a pretrained model without loading the optimizer 84 | if self.config.pretrained_path: 85 | print(f"Loading pretrained weights from {self.config.pretrained_path}") 86 | load_model( 87 | self.img2pose_model.fpn_model, 88 | self.config.pretrained_path, 89 | model_only=True, 90 | cpu_mode=str(self.config.device) == "cpu", 91 | ) 92 | 93 | if is_main_process(): 94 | # saves configuration to file for easier retrival later 95 | print(self.config) 96 | self.save_file(self.config, "config.txt") 97 | 98 | if is_main_process(): 99 | # saves optimizer config to file for easier retrival later 100 | print(self.optimizer) 101 | self.save_file(self.optimizer, "optimizer.txt") 102 | 103 | self.tensorboard_loss_every = max(len(self.train_loader) // 100, 1) 104 | 105 | # reduce learning rate when the validation loss stops to decrease 106 | if self.config.lr_plateau: 107 | self.scheduler = ReduceLROnPlateau( 108 | self.optimizer, 109 | mode="min", 110 | factor=0.1, 111 | patience=3, 112 | verbose=True, 113 | threshold=0.001, 114 | cooldown=1, 115 | min_lr=0.00001, 116 | ) 117 | 118 | # stops training before the defined epochs if validation loss stops to decrease 119 | if self.config.early_stop: 120 | self.early_stop = EarlyStop(mode="min", patience=5) 121 | 122 | def run(self): 123 | self.img2pose_model.train() 124 | 125 | # accumulate running loss to log into tensorboard 126 | running_losses = {} 127 | running_losses["loss"] = 0 128 | 129 | step = 0 130 | 131 | # prints the best step and loss every time it does a validation 132 | self.best_step = 0 133 | self.best_val_loss = float("Inf") 134 | 135 | for epoch in range(self.config.epochs): 136 | train_logger = TrainLogger( 137 | self.config.batch_size, self.config.frequency_log, self.config.num_gpus 138 | ) 139 | idx = 0 140 | for idx, data in enumerate(self.train_loader): 141 | imgs, targets = data 142 | 143 | imgs = [image.to(self.config.device) for image in imgs] 144 | targets = [ 145 | {k: v.to(self.config.device) for k, v in t.items()} for t in targets 146 | ] 147 | 148 | self.optimizer.zero_grad() 149 | 150 | # forward pass 151 | losses = self.img2pose_model.forward(imgs, targets) 152 | 153 | loss = sum(loss for loss in losses.values()) 154 | 155 | # does a backward propagation through the network 156 | loss.backward() 157 | 158 | torch.nn.utils.clip_grad_norm_( 159 | self.img2pose_model.fpn_model.parameters(), 10 160 | ) 161 | 162 | self.optimizer.step() 163 | 164 | if self.config.distributed: 165 | losses = reduce_dict(losses) 166 | loss = sum(loss for loss in losses.values()) 167 | 168 | for loss_name in losses.keys(): 169 | if loss_name in running_losses: 170 | running_losses[loss_name] += losses[loss_name].item() 171 | else: 172 | running_losses[loss_name] = losses[loss_name].item() 173 | 174 | running_losses["loss"] += loss.item() 175 | 176 | # saves loss into tensorboard 177 | if step % self.tensorboard_loss_every == 0 and step != 0: 178 | for loss_name in running_losses.keys(): 179 | if is_main_process(): 180 | self.writer.add_scalar( 181 | f"train_{loss_name}", 182 | running_losses[loss_name] / self.tensorboard_loss_every, 183 | step, 184 | ) 185 | 186 | running_losses[loss_name] = 0 187 | 188 | train_logger( 189 | epoch, self.config.epochs, idx, len(self.train_loader), loss.item() 190 | ) 191 | step += 1 192 | 193 | # evaluate model using validation set (if set) 194 | if self.config.val_source is not None: 195 | val_loss = self.evaluate(step) 196 | 197 | else: 198 | # otherwise just save the model 199 | save_model( 200 | self.img2pose_model.fpn_model_without_ddp, 201 | self.optimizer, 202 | self.config, 203 | step=step, 204 | ) 205 | 206 | # if validation loss stops decreasing, decrease lr 207 | if self.config.lr_plateau and self.config.val_source is not None: 208 | self.scheduler.step(val_loss) 209 | 210 | # early stop model to prevent overfitting 211 | if self.config.early_stop and self.config.val_source is not None: 212 | self.early_stop(val_loss) 213 | if self.early_stop.stop: 214 | print("Early stopping model...") 215 | break 216 | 217 | if self.config.val_source is not None: 218 | val_loss = self.evaluate(step) 219 | 220 | def checkpoint(self, val_loss, step): 221 | if val_loss < self.best_val_loss: 222 | self.best_val_loss = val_loss 223 | self.best_step = step 224 | 225 | save_model( 226 | self.img2pose_model.fpn_model_without_ddp, 227 | self.optimizer, 228 | self.config, 229 | val_loss, 230 | step, 231 | ) 232 | 233 | def reduce_lr(self): 234 | for params in self.optimizer.param_groups: 235 | params["lr"] /= 10 236 | 237 | print("Reducing learning rate...") 238 | print(self.optimizer) 239 | 240 | def evaluate(self, step): 241 | val_losses = {} 242 | val_losses["loss"] = 0 243 | 244 | print("Evaluating model...") 245 | with torch.no_grad(): 246 | for data in iter(self.val_loader): 247 | imgs, targets = data 248 | 249 | imgs = [image.to(self.config.device) for image in imgs] 250 | targets = [ 251 | {k: v.to(self.config.device) for k, v in t.items()} for t in targets 252 | ] 253 | 254 | if self.config.distributed: 255 | torch.cuda.synchronize() 256 | 257 | losses = self.img2pose_model.forward(imgs, targets) 258 | 259 | if self.config.distributed: 260 | losses = reduce_dict(losses) 261 | 262 | loss = sum(loss for loss in losses.values()) 263 | 264 | for loss_name in losses.keys(): 265 | if loss_name in val_losses: 266 | val_losses[loss_name] += losses[loss_name].item() 267 | else: 268 | val_losses[loss_name] = losses[loss_name].item() 269 | 270 | val_losses["loss"] += loss.item() 271 | 272 | for loss_name in val_losses.keys(): 273 | if is_main_process(): 274 | self.writer.add_scalar( 275 | f"val_{loss_name}", 276 | round(val_losses[loss_name] / len(self.val_loader), 6), 277 | step, 278 | ) 279 | 280 | val_loss = round(val_losses["loss"] / len(self.val_loader), 6) 281 | self.checkpoint(val_loss, step) 282 | 283 | print( 284 | "Current validation loss: " 285 | + f"{val_loss:.6f} at step {step}" 286 | + " - Best validation loss: " 287 | + f"{self.best_val_loss:.6f} at step {self.best_step}" 288 | ) 289 | 290 | self.img2pose_model.train() 291 | 292 | return val_loss 293 | 294 | def save_file(self, string, file_name): 295 | with open(path.join(self.config.work_path, file_name), "w") as file: 296 | file.write(str(string)) 297 | file.close() 298 | 299 | 300 | def parse_args(): 301 | parser = argparse.ArgumentParser( 302 | description="Train a deep network to predict 3D expression and 6DOF pose." 303 | ) 304 | # network and training parameters 305 | parser.add_argument( 306 | "--min_size", help="Min size", default="640, 672, 704, 736, 768, 800", type=str 307 | ) 308 | parser.add_argument("--max_size", help="Max size", default=1400, type=int) 309 | parser.add_argument("--epochs", help="Number of epochs.", default=100, type=int) 310 | parser.add_argument( 311 | "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int 312 | ) 313 | parser.add_argument("--lr", help="Learning rate.", default=0.001, type=float) 314 | parser.add_argument( 315 | "--optimizer", help="Optimizer (SGD or Adam).", default="SGD", type=str 316 | ) 317 | parser.add_argument("--batch_size", help="Batch size.", default=2, type=int) 318 | parser.add_argument( 319 | "--lr_plateau", help="Reduce lr on plateau.", action="store_true" 320 | ) 321 | parser.add_argument("--early_stop", help="Use early stop.", action="store_true") 322 | parser.add_argument("--workers", help="Workers number.", default=4, type=int) 323 | parser.add_argument( 324 | "--pose_mean", help="Pose mean file path.", type=str, required=True 325 | ) 326 | parser.add_argument( 327 | "--pose_stddev", help="Pose stddev file path.", type=str, required=True 328 | ) 329 | 330 | # training/validation configuration 331 | parser.add_argument( 332 | "--workspace", help="Worskspace path to save models and logs.", required=True 333 | ) 334 | parser.add_argument( 335 | "--train_source", help="Path to the dataset train LMDB file.", required=True 336 | ) 337 | parser.add_argument( 338 | "--val_source", help="Path to the dataset validation LMDB file." 339 | ) 340 | 341 | parser.add_argument( 342 | "--prefix", help="Prefix to save the model.", type=str, required=True 343 | ) 344 | 345 | # resume from or load pretrained weights 346 | parser.add_argument( 347 | "--pretrained_path", help="Path to pretrained weights.", type=str 348 | ) 349 | parser.add_argument( 350 | "--resume_path", help="Path to load model to resume training.", type=str 351 | ) 352 | 353 | # online augmentation 354 | parser.add_argument("--noise_augmentation", action="store_true") 355 | parser.add_argument("--contrast_augmentation", action="store_true") 356 | parser.add_argument("--random_flip", action="store_true") 357 | parser.add_argument("--random_crop", action="store_true") 358 | 359 | # distributed training parameters 360 | parser.add_argument( 361 | "--world-size", default=1, type=int, help="number of distributed processes" 362 | ) 363 | parser.add_argument( 364 | "--dist-url", default="env://", help="url used to set up distributed training" 365 | ) 366 | parser.add_argument( 367 | "--distributed", help="Use distributed training", action="store_true" 368 | ) 369 | 370 | # reference points to create pose labels 371 | parser.add_argument( 372 | "--threed_5_points", 373 | type=str, 374 | help="Reference 3D points to compute pose.", 375 | default="./pose_references/reference_3d_5_points_trans.npy", 376 | ) 377 | 378 | parser.add_argument( 379 | "--threed_68_points", 380 | type=str, 381 | help="Reference 3D points to compute pose.", 382 | default="./pose_references/reference_3d_68_points_trans.npy", 383 | ) 384 | 385 | args = parser.parse_args() 386 | 387 | args.min_size = [int(item) for item in args.min_size.split(",")] 388 | 389 | return args 390 | 391 | 392 | if __name__ == "__main__": 393 | args = parse_args() 394 | 395 | if args.distributed: 396 | init_distributed_mode(args) 397 | 398 | config = Config(args) 399 | 400 | torch.manual_seed(42) 401 | np.random.seed(42) 402 | random.seed(42) 403 | 404 | train = Train(config) 405 | train.run() 406 | -------------------------------------------------------------------------------- /train_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TrainLogger(object): 5 | def __init__(self, batch_size, frequency=50, num_gpus=1): 6 | self.num_gpus = num_gpus 7 | self.batch_size = batch_size * num_gpus 8 | self.frequency = frequency 9 | self.init = False 10 | self.tic = 0 11 | self.last_batch = 0 12 | self.running_loss = 0 13 | 14 | def __call__(self, epoch, total_epochs, batch, total, loss): 15 | if self.last_batch > batch: 16 | self.init = False 17 | self.last_batch = batch 18 | 19 | if self.init: 20 | self.running_loss += loss 21 | if batch % self.frequency == 0: 22 | speed = self.frequency * self.batch_size / (time.time() - self.tic) 23 | self.running_loss = self.running_loss / self.frequency 24 | 25 | batch, total = batch * self.num_gpus, total * self.num_gpus 26 | log = ( 27 | f"Epoch: [{epoch + 1}-{total_epochs}] Batch: [{batch}-{total}] " 28 | + f"Speed: {speed:.2f} samples/sec Loss: {self.running_loss:.5f}" 29 | ) 30 | print(log) 31 | 32 | self.running_loss = 0 33 | self.tic = time.time() 34 | else: 35 | self.init = True 36 | self.tic = time.time() 37 | -------------------------------------------------------------------------------- /utils/annotate_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | # modify to your RetinaFace path 11 | sys.path.append("../insightface/RetinaFace/") 12 | import os 13 | 14 | from retinaface import RetinaFace 15 | 16 | 17 | def face_landmark_detection(detector, args): 18 | image_paths = pd.read_csv(args.image_list, delimiter=" ", header=None) 19 | image_paths = np.asarray(image_paths).squeeze() 20 | 21 | for i in tqdm(range(len(image_paths))): 22 | img_json = {} 23 | 24 | img_path = image_paths[i] 25 | img_name = os.path.split(img_path)[-1] 26 | 27 | output_path = os.path.join(args.output_path, os.path.split(img_name)[0]) 28 | file_output_path = os.path.join( 29 | output_path, f"{os.path.split(img_name)[-1][:-4]}.json" 30 | ) 31 | 32 | if os.path.isfile(file_output_path): 33 | print(f"Skipping file {img_name} as it was already processed.") 34 | continue 35 | 36 | im = cv2.imread(img_path) 37 | 38 | pyramid = True 39 | do_flip = False 40 | 41 | if not pyramid: 42 | target_size = 1200 43 | max_size = 1600 44 | target_size = 1504 45 | max_size = 2000 46 | target_size = 1600 47 | max_size = 2150 48 | im_shape = im.shape 49 | im_size_min = np.min(im_shape[0:2]) 50 | im_size_max = np.max(im_shape[0:2]) 51 | im_scale = float(target_size) / float(im_size_min) 52 | # prevent bigger axis from being more than max_size: 53 | if np.round(im_scale * im_size_max) > max_size: 54 | im_scale = float(max_size) / float(im_size_max) 55 | scales = [im_scale] 56 | else: 57 | do_flip = True 58 | TEST_SCALES = [500, 800, 1100, 1400, 1700] 59 | target_size = 800 60 | max_size = 1200 61 | im_shape = im.shape 62 | im_size_min = np.min(im_shape[0:2]) 63 | im_size_max = np.max(im_shape[0:2]) 64 | im_scale = float(target_size) / float(im_size_min) 65 | # prevent bigger axis from being more than max_size: 66 | if np.round(im_scale * im_size_max) > max_size: 67 | im_scale = float(max_size) / float(im_size_max) 68 | scales = [float(scale) / target_size * im_scale for scale in TEST_SCALES] 69 | 70 | faces, landmarks = detector.detect( 71 | im, threshold=args.thresh, scales=scales, do_flip=do_flip 72 | ) 73 | 74 | bboxes_list = [] 75 | landmarks_list = [] 76 | 77 | if faces is not None: 78 | for i in range(faces.shape[0]): 79 | bbox = faces[i].astype(np.float32) 80 | if landmarks is not None: 81 | landmark5 = landmarks[i].astype(np.float32) 82 | 83 | bboxes_list.append(bbox.tolist()) 84 | landmarks_list.append(landmark5.tolist()) 85 | 86 | if len(landmarks_list) > 0: 87 | img_json["image_path"] = img_path 88 | img_json["bboxes"] = bboxes_list 89 | img_json["landmarks"] = landmarks_list 90 | 91 | if not os.path.isdir(output_path): 92 | os.makedirs(output_path) 93 | 94 | with open( 95 | file_output_path, 96 | "w", 97 | ) as output_file: 98 | json.dump(img_json, output_file) 99 | 100 | 101 | def parse_args(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument( 104 | "--image_list", 105 | type=str, 106 | required=True, 107 | help="List with path to images.", 108 | ) 109 | parser.add_argument( 110 | "--output_path", type=str, required=True, help="Path to save the json files." 111 | ) 112 | parser.add_argument( 113 | "--thresh", type=float, default=0.8, help="Face detection threshold." 114 | ) 115 | parser.add_argument( 116 | "--model_path", 117 | type=str, 118 | default="../insightface/RetinaFace/models/R50/R50", 119 | help="Model path for detector.", 120 | ) 121 | 122 | args = parser.parse_args() 123 | 124 | if not os.path.exists(args.output_path): 125 | os.makedirs(args.output_path) 126 | 127 | return args 128 | 129 | 130 | if __name__ == "__main__": 131 | args = parse_args() 132 | 133 | detector = RetinaFace(args.model_path, 0, 0, "net3", vote=False) 134 | 135 | face_landmark_detection(detector, args) 136 | -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image, ImageEnhance, ImageOps 6 | 7 | 8 | def random_crop(img, bboxes, landmarks): 9 | useable_landmarks = False 10 | for landmark in landmarks: 11 | if -1 not in landmark: 12 | useable_landmarks = True 13 | break 14 | 15 | total_attempts = 10 16 | searching = True 17 | attempt = 0 18 | 19 | while searching: 20 | crop_img = img.copy() 21 | (w, h) = crop_img.size 22 | crop_size = random.uniform(0.7, 1) 23 | 24 | if attempt == total_attempts: 25 | return img, bboxes, landmarks 26 | 27 | crop_x = int(w * crop_size) 28 | crop_y = int(h * crop_size) 29 | 30 | start_x = random.randint(0, w - crop_x) 31 | start_y = (start_x // w) * h 32 | 33 | crop_bbox = [start_x, start_y, start_x + crop_x, start_y + crop_y] 34 | 35 | new_bboxes, new_lms = _adjust_bboxes_landmarks( 36 | bboxes.copy(), landmarks.copy(), crop_bbox 37 | ) 38 | 39 | if len(new_bboxes) > 0: 40 | if useable_landmarks: 41 | for lms in new_lms: 42 | if -1 not in lms: 43 | searching = False 44 | break 45 | else: 46 | searching = False 47 | 48 | if not searching: 49 | crop_img = crop_img.crop( 50 | (crop_bbox[0], crop_bbox[1], crop_bbox[2], crop_bbox[3]) 51 | ) 52 | 53 | attempt += 1 54 | 55 | return crop_img, new_bboxes, new_lms 56 | 57 | 58 | def _adjust_bboxes_landmarks(bboxes, landmarks, crop_bbox): 59 | new_bboxes = [] 60 | new_lms = [] 61 | for i in range(len(bboxes)): 62 | bbox = bboxes[i] 63 | lms = np.asarray(landmarks[i]) 64 | 65 | bbox_center_x = bbox[0] + ((bbox[2] - bbox[0]) // 2) 66 | bbox_center_y = bbox[1] + ((bbox[3] - bbox[1]) // 2) 67 | 68 | if ( 69 | bbox_center_x > crop_bbox[0] 70 | and bbox_center_x < crop_bbox[2] 71 | and bbox_center_y > crop_bbox[1] 72 | and bbox_center_y < crop_bbox[3] 73 | ): 74 | bbox[[0, 2]] -= crop_bbox[0] 75 | bbox[[1, 3]] -= crop_bbox[1] 76 | 77 | bbox[0] = max(bbox[0], 0) 78 | bbox[1] = max(bbox[1], 0) 79 | bbox[2] = min(bbox[2], crop_bbox[2] - crop_bbox[0]) 80 | bbox[3] = min(bbox[3], crop_bbox[3] - crop_bbox[1]) 81 | 82 | add_lm = True 83 | for lm in lms: 84 | if ( 85 | lm[0] < crop_bbox[0] 86 | or lm[0] > crop_bbox[2] 87 | or lm[1] < crop_bbox[1] 88 | or lm[1] > crop_bbox[3] 89 | ): 90 | add_lm = False 91 | break 92 | 93 | if add_lm: 94 | lms[:, 0] -= crop_bbox[0] 95 | lms[:, 1] -= crop_bbox[1] 96 | 97 | new_lms.append(lms.tolist()) 98 | new_bboxes.append(bbox) 99 | 100 | return new_bboxes, new_lms 101 | 102 | 103 | def random_flip(img, bboxes, all_landmarks): 104 | flip = random.randint(0, 1) 105 | 106 | if flip == 1: 107 | # flip image 108 | img = ImageOps.mirror(img) 109 | 110 | # flip bboxes 111 | old_bboxes = bboxes.copy() 112 | (w, h) = img.size 113 | bboxes[:, 0] = w - old_bboxes[:, 2] 114 | bboxes[:, 2] = w - old_bboxes[:, 0] 115 | 116 | for i in range(len(all_landmarks)): 117 | landmarks = np.asarray(all_landmarks[i]) 118 | 119 | if -1 in landmarks: 120 | continue 121 | 122 | all_landmarks[i] = flip_landmarks(landmarks, w).tolist() 123 | 124 | return img, bboxes, all_landmarks 125 | 126 | 127 | def flip_landmarks(landmarks, w): 128 | if len(landmarks) == 5: 129 | order = [1, 0, 2, 4, 3] 130 | else: 131 | order = [ 132 | 16, 133 | 15, 134 | 14, 135 | 13, 136 | 12, 137 | 11, 138 | 10, 139 | 9, 140 | 8, 141 | 7, 142 | 6, 143 | 5, 144 | 4, 145 | 3, 146 | 2, 147 | 1, 148 | 0, 149 | 26, 150 | 25, 151 | 24, 152 | 23, 153 | 22, 154 | 21, 155 | 20, 156 | 19, 157 | 18, 158 | 17, 159 | 27, 160 | 28, 161 | 29, 162 | 30, 163 | 35, 164 | 34, 165 | 33, 166 | 32, 167 | 31, 168 | 45, 169 | 44, 170 | 43, 171 | 42, 172 | 47, 173 | 46, 174 | 39, 175 | 38, 176 | 37, 177 | 36, 178 | 41, 179 | 40, 180 | 54, 181 | 53, 182 | 52, 183 | 51, 184 | 50, 185 | 49, 186 | 48, 187 | 59, 188 | 58, 189 | 57, 190 | 56, 191 | 55, 192 | 64, 193 | 63, 194 | 62, 195 | 61, 196 | 60, 197 | 67, 198 | 66, 199 | 65, 200 | ] 201 | 202 | # flip landmarks 203 | landmarks[:, 0] = w - landmarks[:, 0] 204 | flandmarks = landmarks.copy() 205 | for idx, a in enumerate(order): 206 | flandmarks[idx, :] = landmarks[a, :] 207 | 208 | return flandmarks 209 | 210 | 211 | def rotate(img, landmarks, bbox): 212 | angle = random.gauss(0, 1) * 30 213 | 214 | (h, w) = img.shape[:2] 215 | (cX, cY) = (w // 2, h // 2) 216 | 217 | # Transform Image 218 | new_img = _rotate_img(img, angle, cX, cY, h, w) 219 | 220 | # Transform Landmarks 221 | new_landmarks = _rotate_landmarks(landmarks, angle, cX, cY, h, w) 222 | 223 | # Transform Bounding Box 224 | x1, y1, x2, y2 = bbox["left"], bbox["top"], bbox["right"], bbox["bottom"] 225 | bounding_box_8pts = np.array([x1, y1, x2, y1, x2, y2, x1, y2]) 226 | 227 | rot_bounding_box_8pts = _rotate_bbox(bounding_box_8pts, angle, cX, cY, h, w) 228 | rot_bounding_box = _get_enclosing_bbox(rot_bounding_box_8pts, y2 - y1, x2 - x1)[ 229 | 0 230 | ].astype("float") 231 | 232 | new_bbox = { 233 | "left": rot_bounding_box[0], 234 | "top": rot_bounding_box[1], 235 | "right": rot_bounding_box[2], 236 | "bottom": rot_bounding_box[3], 237 | } 238 | 239 | # validate that bbox boundaries are within the image otherwise do not apply rotation 240 | if ( 241 | new_bbox["top"] > 0 242 | and new_bbox["bottom"] < h 243 | and bbox["left"] > 0 244 | and bbox["right"] < w 245 | ): 246 | img = new_img 247 | landmarks = new_landmarks 248 | bbox = new_bbox 249 | 250 | return img, landmarks, bbox 251 | 252 | 253 | def scale(img, bbox): 254 | scale = random.uniform(0.75, 1.25) 255 | bbox = _scale_bbox(img, bbox, scale) 256 | 257 | return bbox 258 | 259 | 260 | def translate_vertical(img, bbox): 261 | bbox_height = bbox["bottom"] - bbox["top"] 262 | vtrans = random.uniform(-0.1, 0.1) * bbox_height 263 | 264 | # check if bbox boundaries are within image, otherwise do not move bbox 265 | if bbox["top"] + vtrans > 0 and bbox["bottom"] + vtrans < img.shape[0]: 266 | bbox["top"] += vtrans 267 | bbox["bottom"] += vtrans 268 | 269 | return bbox 270 | 271 | 272 | def translate_horizontal(img, bbox): 273 | bbox_width = bbox["right"] - bbox["left"] 274 | htrans = random.uniform(-0.1, 0.1) * bbox_width 275 | 276 | # check if bbox boundaries are within image, otherwise do not move bbox 277 | if bbox["left"] + htrans > 0 and bbox["right"] + htrans < img.shape[1]: 278 | bbox["left"] += htrans 279 | bbox["right"] += htrans 280 | 281 | return bbox 282 | 283 | 284 | def change_contrast(img, bboxes, landmarks): 285 | change = random.randint(0, 1) 286 | if change == 1: 287 | factor = random.uniform(0.5, 1.5) 288 | enhancer = ImageEnhance.Contrast(img) 289 | img = enhancer.enhance(factor) 290 | 291 | return img, bboxes, landmarks 292 | 293 | 294 | def add_noise(img, bboxes, landmarks): 295 | add_noise = random.randint(0, 4) 296 | if add_noise == 4: 297 | noise_types = ["gauss", "s&p", "poisson"] 298 | noise_idx = random.randint(0, 2) 299 | noise_type = noise_types[noise_idx] 300 | 301 | img = np.array(img) 302 | 303 | if noise_type == "gauss": 304 | row, col, ch = img.shape 305 | mean = 0 306 | var = 0.1 307 | sigma = var ** 0.5 308 | gauss = np.random.normal(mean, sigma, (row, col, ch)) 309 | gauss = gauss.reshape(row, col, ch) 310 | img = img + gauss 311 | 312 | elif noise_type == "s&p": 313 | row, col, ch = img.shape 314 | s_vs_p = 0.5 315 | amount = 0.004 316 | # Salt mode 317 | num_salt = np.ceil(amount * (img.size / ch) * s_vs_p) 318 | coords = [ 319 | np.random.randint(0, i - 1, int(num_salt)) for i in img.shape[0:2] 320 | ] 321 | img[tuple(coords)] = 255 322 | # Pepper mode 323 | num_pepper = np.ceil(amount * (img.size / ch) * (1.0 - s_vs_p)) 324 | coords = [ 325 | np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape[0:2] 326 | ] 327 | img[tuple(coords)] = 0 328 | 329 | elif noise_type == "poisson": 330 | vals = len(np.unique(img)) 331 | vals = 2 ** np.ceil(np.log2(vals)) 332 | img = np.random.poisson(img * vals) / float(vals) 333 | 334 | img = Image.fromarray(img.astype("uint8")) 335 | 336 | return img, bboxes, landmarks 337 | 338 | 339 | def _rotate_img(img, angle, cX, cY, h, w): 340 | M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0) 341 | cos = np.abs(M[0, 0]) 342 | sin = np.abs(M[0, 1]) 343 | 344 | nW = int((h * sin) + (w * cos)) 345 | nH = int((h * cos) + (w * sin)) 346 | 347 | M[0, 2] += (nW / 2) - cX 348 | M[1, 2] += (nH / 2) - cY 349 | 350 | img = cv2.warpAffine(img, M, (nW, nH)) 351 | 352 | return img 353 | 354 | 355 | def _rotate_landmarks(landmarks, angle, cx, cy, h, w): 356 | M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) 357 | cos = np.abs(M[0, 0]) 358 | sin = np.abs(M[0, 1]) 359 | 360 | nW = int((h * sin) + (w * cos)) 361 | nH = int((h * cos) + (w * sin)) 362 | 363 | M[0, 2] += (nW / 2) - cx 364 | M[1, 2] += (nH / 2) - cy 365 | 366 | landmarks = np.append(landmarks, np.ones((landmarks.shape[0], 1)), axis=1) 367 | calculated = (np.dot(M, landmarks.T)).T 368 | return calculated.astype("int") 369 | 370 | 371 | def _rotate_bbox(corners, angle, cx, cy, h, w): 372 | corners = corners.reshape(-1, 2) 373 | corners = np.hstack( 374 | (corners, np.ones((corners.shape[0], 1), dtype=type(corners[0][0]))) 375 | ) 376 | M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) 377 | 378 | cos = np.abs(M[0, 0]) 379 | sin = np.abs(M[0, 1]) 380 | 381 | nW = int((h * sin) + (w * cos)) 382 | nH = int((h * cos) + (w * sin)) 383 | # adjust the rotation matrix to take into account translation 384 | M[0, 2] += (nW / 2) - cx 385 | M[1, 2] += (nH / 2) - cy 386 | # Prepare the vector to be transformed 387 | calculated = np.dot(M, corners.T).T 388 | calculated = calculated.reshape(-1, 8) 389 | 390 | return calculated 391 | 392 | 393 | def _get_enclosing_bbox(corners, original_height, original_width): 394 | x = corners[:, [0, 2, 4, 6]] 395 | y = corners[:, [1, 3, 5, 7]] 396 | xmin = np.min(x, 1).reshape(-1, 1) 397 | ymin = np.min(y, 1).reshape(-1, 1) 398 | xmax = np.max(x, 1).reshape(-1, 1) 399 | ymax = np.max(y, 1).reshape(-1, 1) 400 | 401 | height = ymax - ymin 402 | width = xmax - xmin 403 | 404 | diff_height = height - original_height 405 | diff_width = width - original_width 406 | 407 | ymax -= diff_height // 2 408 | ymin += diff_height // 2 409 | xmax -= diff_width // 2 410 | xmin += diff_width // 2 411 | 412 | final = np.hstack((xmin, ymin, xmax, ymax, corners[:, 8:])) 413 | 414 | return final.astype("int") 415 | 416 | 417 | def _scale_bbox(img, bbox, scale): 418 | height = bbox["bottom"] - bbox["top"] 419 | width = bbox["right"] - bbox["left"] 420 | 421 | new_height = height * scale 422 | new_width = width * scale 423 | 424 | diff_height = height - new_height 425 | diff_width = width - new_width 426 | 427 | # check if bbox boundaries are within image, otherwise do not scale bbox 428 | if ( 429 | bbox["top"] + (diff_height // 2) > 0 430 | and bbox["bottom"] - (diff_height // 2) < img.shape[0] 431 | and bbox["left"] + (diff_width // 2) > 0 432 | and bbox["right"] - (diff_width // 2) < img.shape[1] 433 | ): 434 | bbox["bottom"] -= diff_height // 2 435 | bbox["top"] += diff_height // 2 436 | bbox["right"] -= diff_width // 2 437 | bbox["left"] += diff_width // 2 438 | 439 | return bbox 440 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | def init_distributed_mode(args): 8 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 9 | args.rank = int(os.environ["RANK"]) 10 | args.world_size = int(os.environ["WORLD_SIZE"]) 11 | args.gpu = int(os.environ["LOCAL_RANK"]) 12 | elif "SLURM_PROCID" in os.environ: 13 | args.rank = int(os.environ["SLURM_PROCID"]) 14 | args.gpu = args.rank % torch.cuda.device_count() 15 | else: 16 | print("Not using distributed mode") 17 | args.distributed = False 18 | 19 | return 20 | 21 | args.distributed = True 22 | 23 | torch.cuda.set_device(args.gpu) 24 | args.dist_backend = "nccl" 25 | print( 26 | "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True 27 | ) 28 | torch.distributed.init_process_group( 29 | backend=args.dist_backend, 30 | init_method=args.dist_url, 31 | world_size=args.world_size, 32 | rank=args.rank, 33 | ) 34 | torch.distributed.barrier() 35 | setup_for_distributed(args.rank == 0) 36 | 37 | print(args) 38 | 39 | 40 | def setup_for_distributed(is_master): 41 | """ 42 | This function disables printing when not in master process 43 | """ 44 | import builtins as __builtin__ 45 | 46 | builtin_print = __builtin__.print 47 | 48 | def print(*args, **kwargs): 49 | force = kwargs.pop("force", False) 50 | if is_master or force: 51 | builtin_print(*args, **kwargs) 52 | 53 | __builtin__.print = print 54 | 55 | 56 | def reduce_dict(input_dict, average=True): 57 | """ 58 | Args: 59 | input_dict (dict): all the values will be reduced 60 | average (bool): whether to do average or sum 61 | Reduce the values in the dictionary from all processes so that all processes 62 | have the averaged results. Returns a dict with the same fields as 63 | input_dict, after reduction. 64 | """ 65 | world_size = get_world_size() 66 | if world_size < 2: 67 | return input_dict 68 | with torch.no_grad(): 69 | names = [] 70 | values = [] 71 | # sort the keys so that they are consistent across processes 72 | for k in sorted(input_dict.keys()): 73 | names.append(k) 74 | values.append(input_dict[k]) 75 | values = torch.stack(values, dim=0) 76 | 77 | dist.all_reduce(values) 78 | if average: 79 | values /= world_size 80 | reduced_dict = {k: v for k, v in zip(names, values)} 81 | return reduced_dict 82 | 83 | 84 | def is_dist_avail_and_initialized(): 85 | if not dist.is_available(): 86 | return False 87 | if not dist.is_initialized(): 88 | return False 89 | 90 | return True 91 | 92 | 93 | def get_world_size(): 94 | if not is_dist_avail_and_initialized(): 95 | return 1 96 | return dist.get_world_size() 97 | 98 | 99 | def get_rank(): 100 | if not is_dist_avail_and_initialized(): 101 | return 0 102 | return dist.get_rank() 103 | 104 | 105 | def is_main_process(): 106 | return get_rank() == 0 107 | 108 | 109 | def save_on_master(*args, **kwargs): 110 | if is_main_process(): 111 | torch.save(*args, **kwargs) 112 | -------------------------------------------------------------------------------- /utils/face_align.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from 3 | https://github.com/deepinsight/insightface/blob/master/recognition/common/face_align.py 4 | """ 5 | 6 | import cv2 7 | import numpy as np 8 | from skimage import transform as trans 9 | 10 | # <-- left profile 11 | src1 = np.array( 12 | [ 13 | [51.642, 50.115], 14 | [57.617, 49.990], 15 | [35.740, 69.007], 16 | [51.157, 89.050], 17 | [57.025, 89.702], 18 | ], 19 | dtype=np.float32, 20 | ) 21 | # <--left 22 | src2 = np.array( 23 | [ 24 | [45.031, 50.118], 25 | [65.568, 50.872], 26 | [39.677, 68.111], 27 | [45.177, 86.190], 28 | [64.246, 86.758], 29 | ], 30 | dtype=np.float32, 31 | ) 32 | 33 | # ---frontal 34 | src3 = np.array( 35 | [ 36 | [39.730, 51.138], 37 | [72.270, 51.138], 38 | [56.000, 68.493], 39 | [42.463, 87.010], 40 | [69.537, 87.010], 41 | ], 42 | dtype=np.float32, 43 | ) 44 | 45 | # -->right 46 | src4 = np.array( 47 | [ 48 | [46.845, 50.872], 49 | [67.382, 50.118], 50 | [72.737, 68.111], 51 | [48.167, 86.758], 52 | [67.236, 86.190], 53 | ], 54 | dtype=np.float32, 55 | ) 56 | 57 | # -->right profile 58 | src5 = np.array( 59 | [ 60 | [54.796, 49.990], 61 | [60.771, 50.115], 62 | [76.673, 69.007], 63 | [55.388, 89.702], 64 | [61.257, 89.050], 65 | ], 66 | dtype=np.float32, 67 | ) 68 | 69 | src = np.array([src1, src2, src3, src4, src5]) 70 | src_map = {112: src, 224: src * 2} 71 | 72 | 73 | def estimate_norm(lmk, image_size=224): 74 | assert lmk.shape == (5, 2) 75 | tform = trans.SimilarityTransform() 76 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1) 77 | min_M = [] 78 | min_index = [] 79 | min_error = float("inf") 80 | 81 | src = src_map[image_size] 82 | for i in np.arange(src.shape[0]): 83 | tform.estimate(lmk, src[i]) 84 | M = tform.params[0:2, :] 85 | results = np.dot(M, lmk_tran.T) 86 | results = results.T 87 | error = np.sum(np.sqrt(np.sum((results - src[i]) ** 2, axis=1))) 88 | 89 | # find the src that is most close to the projected points (predicted) 90 | if error < min_error: 91 | min_error = error 92 | min_M = M 93 | min_index = i 94 | 95 | return min_M, min_index 96 | 97 | 98 | def norm_crop(img, landmark, image_size=224): 99 | M, pose_index = estimate_norm(landmark, image_size) 100 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 101 | 102 | return warped 103 | -------------------------------------------------------------------------------- /utils/image_operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | def expand_bbox_rectangle( 7 | w, h, bbox_x_factor=2.0, bbox_y_factor=2.0, lms=None, expand_forehead=0.3, roll=0 8 | ): 9 | # get a good bbox for the facial landmarks 10 | min_pt_x = np.min(lms[:, 0], axis=0) 11 | max_pt_x = np.max(lms[:, 0], axis=0) 12 | 13 | min_pt_y = np.min(lms[:, 1], axis=0) 14 | max_pt_y = np.max(lms[:, 1], axis=0) 15 | 16 | # find out the bbox of the crop region 17 | bbox_size_x = int(np.max(max_pt_x - min_pt_x) * bbox_x_factor) 18 | center_pt_x = 0.5 * min_pt_x + 0.5 * max_pt_x 19 | 20 | bbox_size_y = int(np.max(max_pt_y - min_pt_y) * bbox_y_factor) 21 | center_pt_y = 0.5 * min_pt_y + 0.5 * max_pt_y 22 | 23 | bbox_min_x, bbox_max_x = ( 24 | center_pt_x - bbox_size_x * 0.5, 25 | center_pt_x + bbox_size_x * 0.5, 26 | ) 27 | 28 | bbox_min_y, bbox_max_y = ( 29 | center_pt_y - bbox_size_y * 0.5, 30 | center_pt_y + bbox_size_y * 0.5, 31 | ) 32 | 33 | if abs(roll) > 2.5: 34 | expand_forehead_size = expand_forehead * np.max(max_pt_y - min_pt_y) 35 | bbox_max_y += expand_forehead_size 36 | 37 | elif roll > 1: 38 | expand_forehead_size = expand_forehead * np.max(max_pt_x - min_pt_x) 39 | bbox_max_x += expand_forehead_size 40 | 41 | elif roll < -1: 42 | expand_forehead_size = expand_forehead * np.max(max_pt_x - min_pt_x) 43 | bbox_min_x -= expand_forehead_size 44 | 45 | else: 46 | expand_forehead_size = expand_forehead * np.max(max_pt_y - min_pt_y) 47 | bbox_min_y -= expand_forehead_size 48 | 49 | bbox_min_x = bbox_min_x.astype(np.int32) 50 | bbox_max_x = bbox_max_x.astype(np.int32) 51 | bbox_min_y = bbox_min_y.astype(np.int32) 52 | bbox_max_y = bbox_max_y.astype(np.int32) 53 | 54 | # compute necessary padding 55 | padding_left = abs(min(bbox_min_x, 0)) 56 | padding_top = abs(min(bbox_min_y, 0)) 57 | padding_right = max(bbox_max_x - w, 0) 58 | padding_bottom = max(bbox_max_y - h, 0) 59 | 60 | # crop the image properly by computing proper crop bounds 61 | crop_left = 0 if padding_left > 0 else bbox_min_x 62 | crop_top = 0 if padding_top > 0 else bbox_min_y 63 | crop_right = w if padding_right > 0 else bbox_max_x 64 | crop_bottom = h if padding_bottom > 0 else bbox_max_y 65 | 66 | return np.array([crop_left, crop_top, crop_right, crop_bottom]) 67 | 68 | 69 | def expand_bbox_rectangle_tensor( 70 | w, h, bbox_x_factor=2.0, bbox_y_factor=2.0, lms=None, expand_forehead=0.3, roll=0 71 | ): 72 | # get a good bbox for the facial landmarks 73 | min_pt_x = torch.min(lms[:, 0], axis=0)[0] 74 | max_pt_x = torch.max(lms[:, 0], axis=0)[0] 75 | 76 | min_pt_y = torch.min(lms[:, 1], axis=0)[0] 77 | max_pt_y = torch.max(lms[:, 1], axis=0)[0] 78 | 79 | # find out the bbox of the crop region 80 | bbox_size_x = int(torch.max(max_pt_x - min_pt_x) * bbox_x_factor) 81 | center_pt_x = 0.5 * min_pt_x + 0.5 * max_pt_x 82 | 83 | bbox_size_y = int(torch.max(max_pt_y - min_pt_y) * bbox_y_factor) 84 | center_pt_y = 0.5 * min_pt_y + 0.5 * max_pt_y 85 | 86 | bbox_min_x, bbox_max_x = ( 87 | center_pt_x - bbox_size_x * 0.5, 88 | center_pt_x + bbox_size_x * 0.5, 89 | ) 90 | 91 | bbox_min_y, bbox_max_y = ( 92 | center_pt_y - bbox_size_y * 0.5, 93 | center_pt_y + bbox_size_y * 0.5, 94 | ) 95 | 96 | if abs(roll) > 2.5: 97 | expand_forehead_size = expand_forehead * torch.max(max_pt_y - min_pt_y) 98 | bbox_max_y += expand_forehead_size 99 | 100 | elif roll > 1: 101 | expand_forehead_size = expand_forehead * torch.max(max_pt_x - min_pt_x) 102 | bbox_max_x += expand_forehead_size 103 | 104 | elif roll < -1: 105 | expand_forehead_size = expand_forehead * torch.max(max_pt_x - min_pt_x) 106 | bbox_min_x -= expand_forehead_size 107 | 108 | else: 109 | expand_forehead_size = expand_forehead * torch.max(max_pt_y - min_pt_y) 110 | bbox_min_y -= expand_forehead_size 111 | 112 | bbox_min_x = bbox_min_x.int() 113 | bbox_max_x = bbox_max_x.int() 114 | bbox_min_y = bbox_min_y.int() 115 | bbox_max_y = bbox_max_y.int() 116 | 117 | # compute necessary padding 118 | padding_left = abs(min(bbox_min_x, 0)) 119 | padding_top = abs(min(bbox_min_y, 0)) 120 | padding_right = max(bbox_max_x - w, 0) 121 | padding_bottom = max(bbox_max_y - h, 0) 122 | 123 | # crop the image properly by computing proper crop bounds 124 | crop_left = 0 if padding_left > 0 else bbox_min_x 125 | crop_top = 0 if padding_top > 0 else bbox_min_y 126 | crop_right = w if padding_right > 0 else bbox_max_x 127 | crop_bottom = h if padding_bottom > 0 else bbox_max_y 128 | 129 | return ( 130 | torch.tensor([crop_left, crop_top, crop_right, crop_bottom]) 131 | .float() 132 | .to(lms.device) 133 | ) 134 | 135 | 136 | def preprocess_image_wrt_face(img, bbox_size_factor=2.0, lms=None): 137 | w, h = img.size 138 | 139 | if lms is None: 140 | return None, None 141 | 142 | # get a good bbox for the facial landmarks 143 | min_pt = np.min(lms, axis=0) 144 | max_pt = np.max(lms, axis=0) 145 | 146 | # find out the bbox of the crop region 147 | bbox_size = int(np.max(max_pt - min_pt) * bbox_size_factor) 148 | center_pt = 0.5 * min_pt + 0.5 * max_pt 149 | bbox_min, bbox_max = center_pt - bbox_size * 0.5, center_pt + bbox_size * 0.5 150 | bbox_min = bbox_min.astype(np.int32) 151 | bbox_max = bbox_max.astype(np.int32) 152 | 153 | # compute necessary padding 154 | padding_left = abs(min(bbox_min[0], 0)) 155 | padding_top = abs(min(bbox_min[1], 0)) 156 | padding_right = max(bbox_max[0] - w, 0) 157 | padding_bottom = max(bbox_max[1] - h, 0) 158 | 159 | # crop the image properly by computing proper crop bounds 160 | crop_left = 0 if padding_left > 0 else bbox_min[0] 161 | crop_top = 0 if padding_top > 0 else bbox_min[1] 162 | crop_right = w if padding_right > 0 else bbox_max[0] 163 | crop_bottom = h if padding_bottom > 0 else bbox_max[1] 164 | 165 | cropped_image = img.crop((crop_left, crop_top, crop_right, crop_bottom)) 166 | 167 | # copy the cropped image to padded image 168 | padded_image = Image.new(img.mode, (bbox_size, bbox_size)) 169 | padded_image.paste( 170 | cropped_image, 171 | ( 172 | padding_left, 173 | padding_top, 174 | padding_left + crop_right - crop_left, 175 | padding_top + crop_bottom - crop_top, 176 | ), 177 | ) 178 | 179 | bbox = {} 180 | bbox["left"] = crop_left - padding_left 181 | bbox["right"] = crop_right + padding_right 182 | bbox["top"] = crop_top - padding_top 183 | bbox["bottom"] = crop_bottom + padding_bottom 184 | 185 | return padded_image, lms, bbox 186 | 187 | 188 | def bbox_is_dict(bbox): 189 | # check if the bbox is a not dict and convert it if needed 190 | if not isinstance(bbox, dict): 191 | temp_bbox = {} 192 | temp_bbox["left"] = bbox[0] 193 | temp_bbox["top"] = bbox[1] 194 | temp_bbox["right"] = bbox[2] 195 | temp_bbox["bottom"] = bbox[3] 196 | bbox = temp_bbox 197 | 198 | return bbox 199 | 200 | 201 | def pad_image_no_crop(img, bbox): 202 | bbox = bbox_is_dict(bbox) 203 | 204 | w, h = img.size 205 | 206 | # checks if the bbox is going outside of the image and if so expands the image 207 | if bbox["left"] < 0 or bbox["top"] < 0 or bbox["right"] > w or bbox["bottom"] > h: 208 | padding_left = abs(min(bbox["left"], 0)) 209 | padding_top = abs(min(bbox["top"], 0)) 210 | padding_right = max(bbox["right"] - w, 0) 211 | padding_bottom = max(bbox["bottom"] - h, 0) 212 | 213 | height = h + padding_top + padding_bottom 214 | width = w + padding_left + padding_right 215 | 216 | padded_image = Image.new(img.mode, (width, height)) 217 | padded_image.paste( 218 | img, (padding_left, padding_top, padding_left + w, padding_top + h) 219 | ) 220 | 221 | img = padded_image 222 | bbox["left"] += padding_left 223 | bbox["top"] += padding_top 224 | bbox["right"] += padding_left 225 | bbox["bottom"] += padding_top 226 | 227 | return img, bbox 228 | 229 | 230 | def crop_face_bbox_expanded(img, bbox, bbox_size_factor=2.0): 231 | bbox = bbox_is_dict(bbox) 232 | 233 | # get image size 234 | w, h = img.size 235 | 236 | # transform bounding box to 8 points (4,2) 237 | x1, y1, x2, y2 = bbox["left"], bbox["top"], bbox["right"], bbox["bottom"] 238 | bounding_box_8pts = np.array([x1, y1, x2, y1, x2, y2, x1, y2]) 239 | bounding_box_8pts = bounding_box_8pts.reshape(-1, 2) 240 | 241 | # get a good bbox for the facial landmarks 242 | min_pt = np.min(bounding_box_8pts, axis=0) 243 | max_pt = np.max(bounding_box_8pts, axis=0) 244 | 245 | # find out the bbox of the crop region 246 | bbox_size = int(np.max(max_pt - min_pt) * bbox_size_factor) 247 | center_pt = 0.5 * min_pt + 0.5 * max_pt 248 | bbox_min, bbox_max = center_pt - bbox_size * 0.5, center_pt + bbox_size * 0.5 249 | bbox_min = bbox_min.astype(np.int32) 250 | bbox_max = bbox_max.astype(np.int32) 251 | 252 | # compute necessary padding 253 | padding_left = abs(min(bbox_min[0], 0)) 254 | padding_top = abs(min(bbox_min[1], 0)) 255 | padding_right = max(bbox_max[0] - w, 0) 256 | padding_bottom = max(bbox_max[1] - h, 0) 257 | 258 | # crop the image properly by computing proper crop bounds 259 | crop_left = 0 if padding_left > 0 else bbox_min[0] 260 | crop_top = 0 if padding_top > 0 else bbox_min[1] 261 | crop_right = w if padding_right > 0 else bbox_max[0] 262 | crop_bottom = h if padding_bottom > 0 else bbox_max[1] 263 | 264 | cropped_image = img.crop((crop_left, crop_top, crop_right, crop_bottom)) 265 | 266 | # copy the cropped image to padded image 267 | padded_image = Image.new(img.mode, (bbox_size, bbox_size)) 268 | padded_image.paste( 269 | cropped_image, 270 | ( 271 | padding_left, 272 | padding_top, 273 | padding_left + crop_right - crop_left, 274 | padding_top + crop_bottom - crop_top, 275 | ), 276 | ) 277 | 278 | bbox_padded = {} 279 | bbox_padded["left"] = crop_left - padding_left 280 | bbox_padded["right"] = crop_right + padding_right 281 | bbox_padded["top"] = crop_top - padding_top 282 | bbox_padded["bottom"] = crop_bottom + padding_bottom 283 | 284 | bbox = {} 285 | bbox["left"] = crop_left 286 | bbox["right"] = crop_right 287 | bbox["top"] = crop_top 288 | bbox["bottom"] = crop_bottom 289 | 290 | return padded_image, bbox_padded, bbox 291 | 292 | 293 | def resize_image(img, min_size=600, max_size=1000): 294 | width = img.width 295 | height = img.height 296 | w, h, scale = width, height, 1.0 297 | 298 | if width < height: 299 | if width < min_size: 300 | w = min_size 301 | h = int(height * min_size / width) 302 | scale = float(min_size) / float(width) 303 | elif width > max_size: 304 | w = max_size 305 | h = int(height * max_size / width) 306 | scale = float(max_size) / float(width) 307 | else: 308 | if height < min_size: 309 | w = int(width * min_size / height) 310 | h = min_size 311 | scale = float(min_size) / float(height) 312 | elif height > max_size: 313 | w = int(width * max_size / height) 314 | h = max_size 315 | scale = float(max_size) / float(height) 316 | 317 | img_resized = img.resize((w, h)) 318 | img_resize_info = [h, w, scale] 319 | 320 | return img_resized, img_resize_info 321 | -------------------------------------------------------------------------------- /utils/json_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from torchvision.datasets import ImageFolder 10 | from tqdm import tqdm 11 | 12 | from .image_operations import bbox_is_dict 13 | from .pose_operations import get_pose, pose_bbox_to_full_image 14 | 15 | 16 | class FramesJsonList(ImageFolder): 17 | def __init__(self, threed_5_points, threed_68_points, json_list, dataset_path=None): 18 | self.samples = [] 19 | self.bboxes = [] 20 | self.landmarks = [] 21 | self.threed_5_points = threed_5_points 22 | self.threed_68_points = threed_68_points 23 | self.dataset_path = dataset_path 24 | 25 | image_paths = pd.read_csv(json_list, delimiter=" ", header=None) 26 | image_paths = np.asarray(image_paths).squeeze() 27 | 28 | print("Loading frames paths...") 29 | for image_path in tqdm(image_paths): 30 | with open(image_path) as f: 31 | image_json = json.load(f) 32 | 33 | # path to the image 34 | img_path = image_json["image_path"] 35 | # if not absolute path, append the dataset path 36 | if self.dataset_path is not None: 37 | img_path = os.path.join(self.dataset_path, img_path) 38 | self.samples.append(img_path) 39 | 40 | # landmarks used to create pose labels 41 | self.landmarks.append(image_json["landmarks"]) 42 | 43 | # load bboxes 44 | self.bboxes.append(image_json["bboxes"]) 45 | 46 | def __len__(self): 47 | return len(self.samples) 48 | 49 | def __getitem__(self, index): 50 | image_path = Path(self.samples[index]) 51 | 52 | img = Image.open(image_path) 53 | 54 | (w, h) = img.size 55 | global_intrinsics = np.array( 56 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 57 | ) 58 | bboxes = self.bboxes[index] 59 | landmarks = self.landmarks[index] 60 | 61 | bbox_labels = [] 62 | landmark_labels = [] 63 | pose_labels = [] 64 | global_pose_labels = [] 65 | 66 | for i in range(len(bboxes)): 67 | bbox = np.asarray(bboxes[i])[:4].astype(int) 68 | landmark = np.asarray(landmarks[i])[:, :2].astype(float) 69 | 70 | # remove samples that do not have height ot width or are negative 71 | if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]: 72 | continue 73 | 74 | if -1 in landmark: 75 | global_pose_labels.append([-9, -9, -9, -9, -9, -9]) 76 | pose_labels.append([-9, -9, -9, -9, -9, -9]) 77 | 78 | else: 79 | landmark[:, 0] -= bbox[0] 80 | landmark[:, 1] -= bbox[1] 81 | 82 | w = int(bbox[2] - bbox[0]) 83 | h = int(bbox[3] - bbox[1]) 84 | 85 | bbox_intrinsics = np.array( 86 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 87 | ) 88 | 89 | if len(landmark) == 5: 90 | P, pose = get_pose(self.threed_5_points, landmark, bbox_intrinsics) 91 | else: 92 | P, pose = get_pose( 93 | self.threed_68_points, 94 | landmark, 95 | bbox_intrinsics, 96 | ) 97 | 98 | pose_labels.append(pose.tolist()) 99 | 100 | global_pose = pose_bbox_to_full_image( 101 | pose, global_intrinsics, bbox_is_dict(bbox) 102 | ) 103 | 104 | global_pose_labels.append(global_pose.tolist()) 105 | 106 | bbox_labels.append(bbox.tolist()) 107 | landmark_labels.append(self.landmarks[index][i]) 108 | 109 | with open(image_path, "rb") as f: 110 | raw_img = f.read() 111 | 112 | return ( 113 | raw_img, 114 | global_pose_labels, 115 | bbox_labels, 116 | pose_labels, 117 | landmark_labels, 118 | ) 119 | 120 | 121 | class JsonLoader(DataLoader): 122 | def __init__( 123 | self, workers, json_list, threed_5_points, threed_68_points, dataset_path=None 124 | ): 125 | self._dataset = FramesJsonList( 126 | threed_5_points, threed_68_points, json_list, dataset_path 127 | ) 128 | 129 | super(JsonLoader, self).__init__( 130 | self._dataset, num_workers=workers, collate_fn=lambda x: x 131 | ) 132 | -------------------------------------------------------------------------------- /utils/json_loader_300wlp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from PIL import Image 8 | from scipy.spatial.transform import Rotation 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import ImageFolder 11 | from tqdm import tqdm 12 | 13 | from .image_operations import bbox_is_dict, expand_bbox_rectangle 14 | from .pose_operations import get_pose, pose_full_image_to_bbox 15 | 16 | 17 | class FramesJsonList(ImageFolder): 18 | def __init__(self, threed_5_points, threed_68_points, json_list, dataset_path=None): 19 | self.samples = [] 20 | self.bboxes = [] 21 | self.landmarks = [] 22 | self.poses_para = [] 23 | self.threed_5_points = threed_5_points 24 | self.threed_68_points = threed_68_points 25 | self.dataset_path = dataset_path 26 | 27 | image_paths = pd.read_csv(json_list, delimiter=",", header=None) 28 | image_paths = np.asarray(image_paths).squeeze() 29 | 30 | print("Loading frames paths...") 31 | for image_path in tqdm(image_paths): 32 | with open(image_path) as f: 33 | image_json = json.load(f) 34 | 35 | # path to the image 36 | img_path = image_json["image_path"] 37 | # if not absolute path, append the dataset path 38 | if self.dataset_path is not None: 39 | img_path = os.path.join(self.dataset_path, img_path) 40 | self.samples.append(img_path) 41 | 42 | # landmarks used to create pose labels 43 | self.landmarks.append(image_json["landmarks"]) 44 | 45 | # load bboxes 46 | self.bboxes.append(image_json["bboxes"]) 47 | 48 | # load bboxes 49 | self.poses_para.append(image_json["pose_para"]) 50 | 51 | def __len__(self): 52 | return len(self.samples) 53 | 54 | def __getitem__(self, index): 55 | image_path = Path(self.samples[index]) 56 | 57 | img = Image.open(image_path) 58 | 59 | (img_w, img_h) = img.size 60 | global_intrinsics = np.array( 61 | [[img_w + img_h, 0, img_w // 2], [0, img_w + img_h, img_h // 2], [0, 0, 1]] 62 | ) 63 | bboxes = self.bboxes[index] 64 | landmarks = self.landmarks[index] 65 | pose_para = self.poses_para[index] 66 | 67 | bbox_labels = [] 68 | landmark_labels = [] 69 | pose_labels = [] 70 | global_pose_labels = [] 71 | 72 | for i in range(len(bboxes)): 73 | bbox = np.asarray(bboxes[i])[:4].astype(int) 74 | landmark = np.asarray(landmarks[i])[:, :2].astype(float) 75 | pose_para = np.asarray(pose_para[i])[:3].astype(float) 76 | 77 | # remove samples that do not have height ot width or are negative 78 | if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]: 79 | continue 80 | 81 | if -1 in landmark: 82 | global_pose_labels.append([-9, -9, -9, -9, -9, -9]) 83 | pose_labels.append([-9, -9, -9, -9, -9, -9]) 84 | 85 | else: 86 | P, pose = get_pose( 87 | self.threed_68_points, 88 | landmark, 89 | global_intrinsics, 90 | ) 91 | 92 | pose[:3] = self.convert_aflw(pose_para) 93 | global_pose_labels.append(pose.tolist()) 94 | 95 | projected_bbox = expand_bbox_rectangle( 96 | img_w, img_h, 1.1, 1.1, landmark, roll=pose[2] 97 | ) 98 | 99 | local_pose = pose_full_image_to_bbox( 100 | pose, 101 | global_intrinsics, 102 | bbox_is_dict(projected_bbox), 103 | ) 104 | 105 | pose_labels.append(local_pose.tolist()) 106 | 107 | bbox_labels.append(projected_bbox.tolist()) 108 | landmark_labels.append(self.landmarks[index][i]) 109 | 110 | with open(image_path, "rb") as f: 111 | raw_img = f.read() 112 | 113 | return ( 114 | raw_img, 115 | global_pose_labels, 116 | bbox_labels, 117 | pose_labels, 118 | landmark_labels, 119 | ) 120 | 121 | def convert_aflw(self, angle): 122 | rot_mat_1 = Rotation.from_euler( 123 | "xyz", [angle[0], -angle[1], -angle[2]], degrees=False 124 | ).as_matrix() 125 | rot_mat_2 = np.transpose(rot_mat_1) 126 | return Rotation.from_matrix(rot_mat_2).as_rotvec() 127 | 128 | 129 | class JsonLoader(DataLoader): 130 | def __init__( 131 | self, workers, json_list, threed_5_points, threed_68_points, dataset_path=None 132 | ): 133 | self._dataset = FramesJsonList( 134 | threed_5_points, threed_68_points, json_list, dataset_path 135 | ) 136 | 137 | super(JsonLoader, self).__init__( 138 | self._dataset, num_workers=workers, collate_fn=lambda x: x 139 | ) 140 | -------------------------------------------------------------------------------- /utils/pose_operations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from scipy.spatial.transform import Rotation 6 | 7 | from .face_align import norm_crop 8 | from .image_operations import bbox_is_dict, expand_bbox_rectangle 9 | 10 | 11 | def bbox_dict_to_np(bbox): 12 | bbox_np = np.zeros(shape=4) 13 | bbox_np[0] = bbox["left"] 14 | bbox_np[1] = bbox["top"] 15 | bbox_np[2] = bbox["right"] 16 | bbox_np[3] = bbox["bottom"] 17 | 18 | return bbox_np 19 | 20 | 21 | def quat_to_rotation_mat_tensor(quat): 22 | x = quat[0] 23 | y = quat[1] 24 | z = quat[2] 25 | w = quat[3] 26 | x2 = x * x 27 | y2 = y * y 28 | z2 = z * z 29 | w2 = w * w 30 | xy = x * y 31 | zw = z * w 32 | xz = x * z 33 | yw = y * w 34 | yz = y * z 35 | xw = x * w 36 | matrix = torch.zeros(3, 3).to(quat.device) 37 | matrix[0, 0] = x2 - y2 - z2 + w2 38 | matrix[1, 0] = 2 * (xy + zw) 39 | matrix[2, 0] = 2 * (xz - yw) 40 | matrix[0, 1] = 2 * (xy - zw) 41 | matrix[1, 1] = -x2 + y2 - z2 + w2 42 | matrix[2, 1] = 2 * (yz + xw) 43 | matrix[0, 2] = 2 * (xz + yw) 44 | matrix[1, 2] = 2 * (yz - xw) 45 | matrix[2, 2] = -x2 - y2 + z2 + w2 46 | return matrix 47 | 48 | 49 | def from_rotvec_tensor(rotvec): 50 | norm = torch.norm(rotvec) 51 | small_angle = norm <= 1e-3 52 | scale = 0 53 | if small_angle: 54 | scale = 0.5 - norm ** 2 / 48 + norm ** 4 / 3840 55 | else: 56 | scale = torch.sin(norm / 2) / norm 57 | quat = torch.zeros(4).to(rotvec.device) 58 | quat[0:3] = scale * rotvec 59 | quat[3] = torch.cos(norm / 2) 60 | 61 | return quat_to_rotation_mat_tensor(quat) 62 | 63 | 64 | def transform_points_tensor(points, pose): 65 | return torch.matmul(points, from_rotvec_tensor(pose[:3]).T) + pose[3:] 66 | 67 | 68 | def get_bbox_intrinsics(image_intrinsics, bbox): 69 | # crop principle point of view 70 | bbox_center_x = bbox["left"] + ((bbox["right"] - bbox["left"]) // 2) 71 | bbox_center_y = bbox["top"] + ((bbox["bottom"] - bbox["top"]) // 2) 72 | 73 | # create a camera intrinsics from the bbox center 74 | bbox_intrinsics = image_intrinsics.copy() 75 | bbox_intrinsics[0, 2] = bbox_center_x 76 | bbox_intrinsics[1, 2] = bbox_center_y 77 | 78 | return bbox_intrinsics 79 | 80 | 81 | def get_bbox_intrinsics_np(image_intrinsics, bbox): 82 | # crop principle point of view 83 | bbox_center_x = bbox[0] + ((bbox[2] - bbox[0]) // 2) 84 | bbox_center_y = bbox[1] + ((bbox[3] - bbox[1]) // 2) 85 | 86 | # create a camera intrinsics from the bbox center 87 | bbox_intrinsics = image_intrinsics.copy() 88 | bbox_intrinsics[0, 2] = bbox_center_x 89 | bbox_intrinsics[1, 2] = bbox_center_y 90 | 91 | return bbox_intrinsics 92 | 93 | 94 | def pose_full_image_to_bbox(pose, image_intrinsics, bbox): 95 | # check if bbox is np or dict 96 | bbox = bbox_is_dict(bbox) 97 | 98 | # rotation vector 99 | rvec = pose[:3].copy() 100 | 101 | # translation and scale vector 102 | tvec = pose[3:].copy() 103 | 104 | # get camera intrinsics using bbox 105 | bbox_intrinsics = get_bbox_intrinsics(image_intrinsics, bbox) 106 | 107 | # focal length 108 | focal_length = image_intrinsics[0, 0] 109 | 110 | # bbox_size 111 | bbox_width = bbox["right"] - bbox["left"] 112 | bbox_height = bbox["bottom"] - bbox["top"] 113 | bbox_size = bbox_width + bbox_height 114 | 115 | # project crop points using the full image camera intrinsics 116 | projected_point = image_intrinsics.dot(tvec.T) 117 | 118 | # reverse the projected points using the crop camera intrinsics 119 | tvec = projected_point.dot(np.linalg.inv(bbox_intrinsics.T)) 120 | 121 | # adjust scale 122 | tvec[2] /= focal_length / bbox_size 123 | 124 | # same for rotation 125 | rmat = Rotation.from_rotvec(rvec).as_matrix() 126 | # project crop points using the crop camera intrinsics 127 | projected_point = image_intrinsics.dot(rmat) 128 | # reverse the projected points using the full image camera intrinsics 129 | rmat = np.linalg.inv(bbox_intrinsics).dot(projected_point) 130 | rvec = Rotation.from_matrix(rmat).as_rotvec() 131 | 132 | return np.concatenate([rvec, tvec]) 133 | 134 | 135 | def pose_bbox_to_full_image(pose, image_intrinsics, bbox): 136 | # check if bbox is np or dict 137 | bbox = bbox_is_dict(bbox) 138 | 139 | # rotation vector 140 | rvec = pose[:3].copy() 141 | 142 | # translation and scale vector 143 | tvec = pose[3:].copy() 144 | 145 | # get camera intrinsics using bbox 146 | bbox_intrinsics = get_bbox_intrinsics(image_intrinsics, bbox) 147 | 148 | # focal length 149 | focal_length = image_intrinsics[0, 0] 150 | 151 | # bbox_size 152 | bbox_width = bbox["right"] - bbox["left"] 153 | bbox_height = bbox["bottom"] - bbox["top"] 154 | bbox_size = bbox_width + bbox_height 155 | 156 | # adjust scale 157 | tvec[2] *= focal_length / bbox_size 158 | 159 | # project crop points using the crop camera intrinsics 160 | projected_point = bbox_intrinsics.dot(tvec.T) 161 | 162 | # reverse the projected points using the full image camera intrinsics 163 | tvec = projected_point.dot(np.linalg.inv(image_intrinsics.T)) 164 | 165 | # same for rotation 166 | rmat = Rotation.from_rotvec(rvec).as_matrix() 167 | # project crop points using the crop camera intrinsics 168 | projected_point = bbox_intrinsics.dot(rmat) 169 | # reverse the projected points using the full image camera intrinsics 170 | rmat = np.linalg.inv(image_intrinsics).dot(projected_point) 171 | rvec = Rotation.from_matrix(rmat).as_rotvec() 172 | 173 | return np.concatenate([rvec, tvec]) 174 | 175 | 176 | def plot_3d_landmark(verts, campose, intrinsics): 177 | lm_3d_trans = transform_points(verts, campose) 178 | 179 | # project to image plane 180 | lms_3d_trans_proj = intrinsics.dot(lm_3d_trans.T).T 181 | lms_projected = ( 182 | lms_3d_trans_proj[:, :2] / np.tile(lms_3d_trans_proj[:, 2], (2, 1)).T 183 | ) 184 | 185 | return lms_projected, lms_3d_trans_proj 186 | 187 | 188 | def plot_3d_landmark_torch(verts, campose, intrinsics): 189 | lm_3d_trans = transform_points_tensor(verts, campose) 190 | 191 | # project to image plane 192 | lms_3d_trans_proj = torch.matmul(intrinsics, lm_3d_trans.T).T 193 | lms_projected = lms_3d_trans_proj[:, :2] / lms_3d_trans_proj[:, 2].repeat(2, 1).T 194 | 195 | return lms_projected 196 | 197 | 198 | def transform_points(points, pose): 199 | return points.dot(Rotation.from_rotvec(pose[:3]).as_matrix().T) + pose[3:] 200 | 201 | 202 | def get_pose(vertices, twod_landmarks, camera_intrinsics, initial_pose=None): 203 | threed_landmarks = vertices 204 | twod_landmarks = np.asarray(twod_landmarks).astype("float32") 205 | 206 | # if initial_pose is provided, use it as a guess to solve new pose 207 | if initial_pose is not None: 208 | initial_pose = np.asarray(initial_pose) 209 | retval, rvecs, tvecs = cv2.solvePnP( 210 | threed_landmarks, 211 | twod_landmarks, 212 | camera_intrinsics, 213 | None, 214 | rvec=initial_pose[:3], 215 | tvec=initial_pose[3:], 216 | flags=cv2.SOLVEPNP_EPNP, 217 | useExtrinsicGuess=True, 218 | ) 219 | else: 220 | retval, rvecs, tvecs = cv2.solvePnP( 221 | threed_landmarks, 222 | twod_landmarks, 223 | camera_intrinsics, 224 | None, 225 | flags=cv2.SOLVEPNP_EPNP, 226 | ) 227 | 228 | rotation_mat = np.zeros(shape=(3, 3)) 229 | R = cv2.Rodrigues(rvecs, rotation_mat)[0] 230 | 231 | RT = np.column_stack((R, tvecs)) 232 | P = np.matmul(camera_intrinsics, RT) 233 | dof = np.append(rvecs, tvecs) 234 | 235 | return P, dof 236 | 237 | 238 | def transform_pose_global_project_bbox( 239 | boxes, 240 | dofs, 241 | pose_mean, 242 | pose_stddev, 243 | image_shape, 244 | threed_68_points=None, 245 | bbox_x_factor=1.1, 246 | bbox_y_factor=1.1, 247 | expand_forehead=0.3, 248 | ): 249 | if len(dofs) == 0: 250 | return boxes, dofs 251 | 252 | device = dofs.device 253 | 254 | boxes = boxes.cpu().numpy() 255 | dofs = dofs.cpu().numpy() 256 | 257 | threed_68_points = threed_68_points.numpy() 258 | 259 | (h, w) = image_shape 260 | global_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]) 261 | 262 | if threed_68_points is not None: 263 | threed_68_points = threed_68_points 264 | 265 | pose_mean = pose_mean.numpy() 266 | pose_stddev = pose_stddev.numpy() 267 | 268 | dof_mean = pose_mean 269 | dof_std = pose_stddev 270 | dofs = dofs * dof_std + dof_mean 271 | 272 | projected_boxes = [] 273 | global_dofs = [] 274 | 275 | for i in range(dofs.shape[0]): 276 | global_dof = pose_bbox_to_full_image(dofs[i], global_intrinsics, boxes[i]) 277 | global_dofs.append(global_dof) 278 | 279 | if threed_68_points is not None: 280 | # project points and get bbox 281 | projected_lms, _ = plot_3d_landmark( 282 | threed_68_points, global_dof, global_intrinsics 283 | ) 284 | projected_bbox = expand_bbox_rectangle( 285 | w, 286 | h, 287 | bbox_x_factor=bbox_x_factor, 288 | bbox_y_factor=bbox_y_factor, 289 | lms=projected_lms, 290 | roll=global_dof[2], 291 | expand_forehead=expand_forehead, 292 | ) 293 | else: 294 | projected_bbox = boxes[i] 295 | 296 | projected_boxes.append(projected_bbox) 297 | 298 | global_dofs = torch.from_numpy(np.asarray(global_dofs)).float() 299 | projected_boxes = torch.from_numpy(np.asarray(projected_boxes)).float() 300 | 301 | return projected_boxes.to(device), global_dofs.to(device) 302 | 303 | 304 | def align_faces(threed_5_points, img, poses, face_size=224): 305 | if len(poses) == 0: 306 | return None 307 | elif np.ndim(poses) == 1: 308 | poses = poses[np.newaxis, :] 309 | 310 | (w, h) = img.size 311 | global_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]]) 312 | 313 | faces_aligned = [] 314 | 315 | for pose in poses: 316 | proj_lms, _ = plot_3d_landmark( 317 | threed_5_points, np.asarray(pose), global_intrinsics 318 | ) 319 | face_aligned = norm_crop(np.asarray(img).copy(), proj_lms, face_size) 320 | faces_aligned.append(Image.fromarray(face_aligned)) 321 | 322 | return faces_aligned 323 | -------------------------------------------------------------------------------- /utils/renderer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from Sim3DR import RenderPipeline 4 | 5 | from .pose_operations import plot_3d_landmark 6 | 7 | 8 | def _to_ctype(arr): 9 | if not arr.flags.c_contiguous: 10 | return arr.copy(order="C") 11 | return arr 12 | 13 | 14 | def get_colors(img, ver): 15 | h, w, _ = img.shape 16 | ver[0, :] = np.minimum(np.maximum(ver[0, :], 0), w - 1) # x 17 | ver[1, :] = np.minimum(np.maximum(ver[1, :], 0), h - 1) # y 18 | ind = np.round(ver).astype(np.int32) 19 | colors = img[ind[1, :], ind[0, :], :] / 255.0 # n x 3 20 | 21 | return colors.copy() 22 | 23 | 24 | class Renderer: 25 | def __init__( 26 | self, 27 | vertices_path="../pose_references/vertices_trans.npy", 28 | triangles_path="../pose_references/triangles.npy", 29 | ): 30 | self.vertices = np.load(vertices_path) 31 | self.triangles = _to_ctype(np.load(triangles_path).T) 32 | self.vertices[:, 0] *= -1 33 | 34 | self.cfg = { 35 | "intensity_ambient": 0.3, 36 | "color_ambient": (1, 1, 1), 37 | "intensity_directional": 0.6, 38 | "color_directional": (1, 1, 1), 39 | "intensity_specular": 0.1, 40 | "specular_exp": 5, 41 | "light_pos": (0, 0, 5), 42 | "view_pos": (0, 0, 5), 43 | } 44 | 45 | self.render_app = RenderPipeline(**self.cfg) 46 | 47 | def transform_vertices(self, img, poses, global_intrinsics=None): 48 | (w, h) = img.size 49 | if global_intrinsics is None: 50 | global_intrinsics = np.array( 51 | [[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]] 52 | ) 53 | 54 | transformed_vertices = [] 55 | for pose in poses: 56 | projected_lms = np.zeros_like(self.vertices) 57 | projected_lms[:, :2], lms_3d_trans_proj = plot_3d_landmark( 58 | self.vertices, pose, global_intrinsics 59 | ) 60 | projected_lms[:, 2] = lms_3d_trans_proj[:, 2] * -1 61 | 62 | range_x = np.max(projected_lms[:, 0]) - np.min(projected_lms[:, 0]) 63 | range_y = np.max(projected_lms[:, 1]) - np.min(projected_lms[:, 1]) 64 | 65 | s = (h + w) / pose[5] 66 | projected_lms[:, 2] *= s 67 | projected_lms[:, 2] += (range_x + range_y) * 3 68 | 69 | transformed_vertices.append(projected_lms) 70 | 71 | return transformed_vertices 72 | 73 | def render(self, img, transformed_vertices, alpha=0.9, save_path=None): 74 | img = np.asarray(img) 75 | overlap = img.copy() 76 | 77 | for vertices in transformed_vertices: 78 | vertices = _to_ctype(vertices) # transpose 79 | overlap = self.render_app(vertices, self.triangles, overlap) 80 | 81 | res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0) 82 | 83 | if save_path is not None: 84 | cv2.imwrite(save_path, res) 85 | print(f"Save visualization result to {save_path}") 86 | 87 | return res 88 | 89 | def save_to_obj(self, img, ver_lst, height, save_path): 90 | n_obj = len(ver_lst) # count obj 91 | 92 | if n_obj <= 0: 93 | return 94 | 95 | n_vertex = ver_lst[0].T.shape[1] 96 | n_face = self.triangles.shape[0] 97 | 98 | with open(save_path, "w") as f: 99 | for i in range(n_obj): 100 | ver = ver_lst[i].T 101 | colors = get_colors(img, ver) 102 | 103 | for j in range(n_vertex): 104 | x, y, z = ver[:, j] 105 | f.write( 106 | f"v {x:.2f} {height - y:.2f} {z:.2f} {colors[j, 2]:.2f} " 107 | f"{colors[j, 1]:.2f} {colors[j, 0]:.2f}\n" 108 | ) 109 | 110 | for i in range(n_obj): 111 | offset = i * n_vertex 112 | for j in range(n_face): 113 | idx1, idx2, idx3 = self.triangles[j] # m x 3 114 | f.write( 115 | f"f {idx3 + 1 + offset} {idx2 + 1 + offset} " 116 | f"{idx1 + 1 + offset}\n" 117 | ) 118 | 119 | print(f"Dump tp {save_path}") 120 | --------------------------------------------------------------------------------