├── BFM ├── .gitkeep ├── BFM_exp_idx.mat ├── BFM_front_idx.mat ├── facemodel_info.mat ├── select_vertex_id.mat ├── similarity_Lm3D_all.mat └── std_exp.txt ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── flist_dataset.py ├── image_folder.py └── template_dataset.py ├── data_preparation.py ├── datasets └── examples │ ├── 000002.jpg │ ├── 000006.jpg │ ├── 000007.jpg │ ├── 000031.jpg │ ├── 000033.jpg │ ├── 000037.jpg │ ├── 000050.jpg │ ├── 000055.jpg │ ├── 000114.jpg │ ├── 000125.jpg │ ├── 000126.jpg │ ├── 015259.jpg │ ├── 015270.jpg │ ├── 015309.jpg │ ├── 015310.jpg │ ├── 015316.jpg │ ├── 015384.jpg │ ├── detections │ ├── 000002.txt │ ├── 000006.txt │ ├── 000007.txt │ ├── 000031.txt │ ├── 000033.txt │ ├── 000037.txt │ ├── 000050.txt │ ├── 000055.txt │ ├── 000114.txt │ ├── 000125.txt │ ├── 000126.txt │ ├── 015259.txt │ ├── 015270.txt │ ├── 015309.txt │ ├── 015310.txt │ ├── 015316.txt │ ├── 015384.txt │ ├── vd006.txt │ ├── vd025.txt │ ├── vd026.txt │ ├── vd034.txt │ ├── vd051.txt │ ├── vd070.txt │ ├── vd092.txt │ └── vd102.txt │ ├── vd006.png │ ├── vd025.png │ ├── vd026.png │ ├── vd034.png │ ├── vd051.png │ ├── vd070.png │ ├── vd092.png │ └── vd102.png ├── environment.yml ├── images ├── 20230425_compare.png ├── compare.png └── example.gif ├── models ├── __init__.py ├── base_model.py ├── bfm.py ├── facerecon_model.py ├── losses.py ├── networks.py └── template_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── test.py ├── train.py └── util ├── BBRegressorParam_r.mat ├── __init__.py ├── detect_lm68.py ├── generate_list.py ├── html.py ├── load_mats.py ├── nvdiffrast.py ├── preprocess.py ├── skin_mask.py ├── test_mean_face.txt ├── util.py └── visualizer.py /BFM/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/.gitkeep -------------------------------------------------------------------------------- /BFM/BFM_exp_idx.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/BFM_exp_idx.mat -------------------------------------------------------------------------------- /BFM/BFM_front_idx.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/BFM_front_idx.mat -------------------------------------------------------------------------------- /BFM/facemodel_info.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/facemodel_info.mat -------------------------------------------------------------------------------- /BFM/select_vertex_id.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/select_vertex_id.mat -------------------------------------------------------------------------------- /BFM/similarity_Lm3D_all.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/similarity_Lm3D_all.mat -------------------------------------------------------------------------------- /BFM/std_exp.txt: -------------------------------------------------------------------------------- 1 | 453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sicheng Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set —— PyTorch implementation ## 2 | 3 |

4 | 5 |

6 | 7 | This is an unofficial official pytorch implementation of the following paper: 8 | 9 | Y. Deng, J. Yang, S. Xu, D. Chen, Y. Jia, and X. Tong, [Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set](https://arxiv.org/abs/1903.08527), IEEE Computer Vision and Pattern Recognition Workshop (CVPRW) on Analysis and Modeling of Faces and Gestures (AMFG), 2019. (**_Best Paper Award!_**) 10 | 11 | The method enforces a hybrid-level weakly-supervised training for CNN-based 3D face reconstruction. It is fast, accurate, and robust to pose and occlussions. It achieves state-of-the-art performance on multiple datasets such as FaceWarehouse, MICC Florence and NoW Challenge. 12 | 13 | 14 | For the original tensorflow implementation, check this [repo](https://github.com/microsoft/Deep3DFaceReconstruction). 15 | 16 | This implementation is written by S. Xu. 17 | ## 04/25/2023 Update 18 | We updated a new model to improve the results on "closed eye" images. We collected ~2K facial images with closed eyes and included them in the training data. The updated model has similar reconstruction accuracy as the previous one on the benchmarks, but has better results for faces with closed eyes (see below). Here's the [link (google drive)](https://drive.google.com/drive/folders/1grs8J4vu7gOhEClyKjWU-SNxfonGue5F?usp=share_link) to the new model. 19 | ### ● Reconstruction accuracy 20 | 21 | |Method|FaceWareHouse|MICC Florence 22 | |:----:|:-----------:|:-----------:| 23 | |Deep3DFace_PyTorch_20230425|1.60±0.44|1.54±0.49| 24 | 25 | ### ● Visual quality 26 |

27 | 28 |

29 | 30 | ## Performance 31 | 32 | ### ● Reconstruction accuracy 33 | 34 | The pytorch implementation achieves lower shape reconstruction error (9% improvement) compare to the [original tensorflow implementation](https://github.com/microsoft/Deep3DFaceReconstruction). Quantitative evaluation (average shape errors in mm) on several benchmarks is as follows: 35 | 36 | |Method|FaceWareHouse|MICC Florence | NoW Challenge | 37 | |:----:|:-----------:|:-----------:|:-----------:| 38 | |Deep3DFace Tensorflow | 1.81±0.50 | 1.67±0.50 | 1.54±1.29 | 39 | |**Deep3DFace PyTorch** |**1.64±0.50**|**1.53±0.45**| **1.41±1.21** | 40 | 41 | The comparison result with state-of-the-art public 3D face reconstruction methods on the NoW face benchmark is as follows: 42 | |Rank|Method|Median(mm) | Mean(mm) | Std(mm) | 43 | |:----:|:-----------:|:-----------:|:-----------:|:-----------:| 44 | | 1. | [DECA\[Feng et al., SIGGRAPH 2021\]](https://github.com/YadiraF/DECA)|1.09|1.38|1.18| 45 | | **2.** | **Deep3DFace PyTorch**|**1.11**|**1.41**|**1.21**| 46 | | 3. | [RingNet [Sanyal et al., CVPR 2019]](https://github.com/soubhiksanyal/RingNet) | 1.21 | 1.53 | 1.31 | 47 | | 4. | [Deep3DFace [Deng et al., CVPRW 2019]](https://github.com/microsoft/Deep3DFaceReconstruction) | 1.23 | 1.54 | 1.29 | 48 | | 5. | [3DDFA-V2 [Guo et al., ECCV 2020]](https://github.com/cleardusk/3DDFA_V2) | 1.23 | 1.57 | 1.39 | 49 | | 6. | [MGCNet [Shang et al., ECCV 2020]](https://github.com/jiaxiangshang/MGCNet) | 1.31 | 1.87 | 2.63 | 50 | | 7. | [PRNet [Feng et al., ECCV 2018]](https://github.com/YadiraF/PRNet) | 1.50 | 1.98 | 1.88 | 51 | | 8. | [3DMM-CNN [Tran et al., CVPR 2017]](https://github.com/anhttran/3dmm_cnn) | 1.84 | 2.33 | 2.05 | 52 | 53 | For more details about the evaluation, check [Now Challenge](https://ringnet.is.tue.mpg.de/challenge.html) website. 54 | 55 | **_A recent benchmark [REALY](https://www.realy3dface.com/) indicates that our method still has the SOTA performance! You can check their paper and website for more details._** 56 | 57 | ### ● Visual quality 58 | The pytorch implementation achieves better visual consistency with the input images compare to the original tensorflow version. 59 | 60 |

61 | 62 |

63 | 64 | ### ● Speed 65 | The training speed is on par with the original tensorflow implementation. For more information, see [here](https://github.com/sicxu/Deep3DFaceRecon_pytorch#train-the-face-reconstruction-network). 66 | 67 | ## Major changes 68 | 69 | ### ● Differentiable renderer 70 | 71 | We use [Nvdiffrast](https://nvlabs.github.io/nvdiffrast/) which is a pytorch library that provides high-performance primitive operations for rasterization-based differentiable rendering. The original tensorflow implementation used [tf_mesh_renderer](https://github.com/google/tf_mesh_renderer) instead. 72 | 73 | ### ● Face recognition model 74 | 75 | We use [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch), a state-of-the-art face recognition model, for perceptual loss computation. By contrast, the original tensorflow implementation used [Facenet](https://github.com/davidsandberg/facenet). 76 | 77 | ### ● Training configuration 78 | 79 | Data augmentation is used in the training process which contains random image shifting, scaling, rotation, and flipping. We also enlarge the training batchsize from 5 to 32 to stablize the training process. 80 | 81 | ### ● Training data 82 | 83 | We use an extra high quality face image dataset [FFHQ](https://github.com/NVlabs/ffhq-dataset) to increase the diversity of training data. 84 | 85 | ## Requirements 86 | **This implementation is only tested under Ubuntu environment with Nvidia GPUs and CUDA installed.** But it should also work on Windows with proper lib configures. 87 | 88 | ## Installation 89 | 1. Clone the repository and set up a conda environment with all dependencies as follows: 90 | ``` 91 | git clone https://github.com/sicxu/Deep3DFaceRecon_pytorch.git 92 | cd Deep3DFaceRecon_pytorch 93 | conda env create -f environment.yml 94 | source activate deep3d_pytorch 95 | ``` 96 | 97 | 2. Install Nvdiffrast library: 98 | ``` 99 | git clone -b 0.3.0 https://github.com/NVlabs/nvdiffrast 100 | cd nvdiffrast # ./Deep3DFaceRecon_pytorch/nvdiffrast 101 | pip install . 102 | ``` 103 | 104 | 3. Install Arcface Pytorch: 105 | ``` 106 | cd .. # ./Deep3DFaceRecon_pytorch 107 | git clone https://github.com/deepinsight/insightface.git 108 | cp -r ./insightface/recognition/arcface_torch ./models/ 109 | ``` 110 | ## Inference with a pre-trained model 111 | 112 | ### Prepare prerequisite models 113 | 1. Our method uses [Basel Face Model 2009 (BFM09)](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-0&id=basel_face_model) to represent 3d faces. Get access to BFM09 using this [link](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). After getting the access, download "01_MorphableModel.mat". In addition, we use an Expression Basis provided by [Guo et al.](https://github.com/Juyong/3DFace). Download the Expression Basis (Exp_Pca.bin) using this [link (google drive)](https://drive.google.com/file/d/1bw5Xf8C12pWmcMhNEu6PtsYVZkVucEN6/view?usp=sharing). Organize all files into the following structure: 114 | ``` 115 | Deep3DFaceRecon_pytorch 116 | │ 117 | └─── BFM 118 | │ 119 | └─── 01_MorphableModel.mat 120 | │ 121 | └─── Exp_Pca.bin 122 | | 123 | └─── ... 124 | ``` 125 | 2. We provide a model trained on a combination of [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), 126 | [LFW](http://vis-www.cs.umass.edu/lfw/), [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm), 127 | [IJB-A](https://www.nist.gov/programs-projects/face-challenges), [LS3D-W](https://www.adrianbulat.com/face-alignment), and [FFHQ](https://github.com/NVlabs/ffhq-dataset) datasets. Download the pre-trained model using this [link (google drive)](https://drive.google.com/drive/folders/1liaIxn9smpudjjqMaWWRpP0mXRW_qRPP?usp=sharing) and organize the directory into the following structure: 128 | ``` 129 | Deep3DFaceRecon_pytorch 130 | │ 131 | └─── checkpoints 132 | │ 133 | └─── 134 | │ 135 | └─── epoch_20.pth 136 | 137 | ``` 138 | 139 | ### Test with custom images 140 | To reconstruct 3d faces from test images, organize the test image folder as follows: 141 | ``` 142 | Deep3DFaceRecon_pytorch 143 | │ 144 | └─── 145 | │ 146 | └─── *.jpg/*.png 147 | | 148 | └─── detections 149 | | 150 | └─── *.txt 151 | ``` 152 | The \*.jpg/\*.png files are test images. The \*.txt files are detected 5 facial landmarks with a shape of 5x2, and have the same name as the corresponding images. Check [./datasets/examples](datasets/examples) for a reference. 153 | 154 | Then, run the test script: 155 | ``` 156 | # get reconstruction results of your custom images 157 | python test.py --name= --epoch=20 --img_folder= 158 | 159 | # get reconstruction results of example images 160 | python test.py --name= --epoch=20 --img_folder=./datasets/examples 161 | ``` 162 | **_Following [#108](https://github.com/sicxu/Deep3DFaceRecon_pytorch/issues/108), if you don't have OpenGL environment, you can simply add "--use_opengl False" to use CUDA context. Make sure you have updated the nvdiffrast to the latest version._** 163 | 164 | Results will be saved into ./checkpoints//results/, which contain the following files: 165 | | \*.png | A combination of cropped input image, reconstructed image, and visualization of projected landmarks. 166 | |:----|:-----------| 167 | | \*.obj | Reconstructed 3d face mesh with predicted color (texture+illumination) in the world coordinate space. Best viewed in Meshlab. | 168 | | \*.mat | Predicted 257-dimensional coefficients and 68 projected 2d facial landmarks. Best viewed in Matlab. 169 | 170 | ## Training a model from scratch 171 | ### Prepare prerequisite models 172 | 1. We rely on [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch) to extract identity features for loss computation. Download the pre-trained model from Arcface using this [link](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch#ms1mv3). By default, we use the resnet50 backbone ([ms1mv3_arcface_r50_fp16](https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215583&cid=4A83B6B633B029CC)), organize the download files into the following structure: 173 | ``` 174 | Deep3DFaceRecon_pytorch 175 | │ 176 | └─── checkpoints 177 | │ 178 | └─── recog_model 179 | │ 180 | └─── ms1mv3_arcface_r50_fp16 181 | | 182 | └─── backbone.pth 183 | ``` 184 | 2. We initialize R-Net using the weights trained on [ImageNet](https://image-net.org/). Download the weights provided by PyTorch using this [link](https://download.pytorch.org/models/resnet50-0676ba61.pth), and organize the file as the following structure: 185 | ``` 186 | Deep3DFaceRecon_pytorch 187 | │ 188 | └─── checkpoints 189 | │ 190 | └─── init_model 191 | │ 192 | └─── resnet50-0676ba61.pth 193 | ``` 194 | 3. We provide a landmark detector (tensorflow model) to extract 68 facial landmarks for loss computation. The detector is trained on [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm), [LFW](http://vis-www.cs.umass.edu/lfw/), and [LS3D-W](https://www.adrianbulat.com/face-alignment) datasets. Download the trained model using this [link (google drive)](https://drive.google.com/file/d/1Jl1yy2v7lIJLTRVIpgg2wvxYITI8Dkmw/view?usp=sharing) and organize the file as follows: 195 | ``` 196 | Deep3DFaceRecon_pytorch 197 | │ 198 | └─── checkpoints 199 | │ 200 | └─── lm_model 201 | │ 202 | └─── 68lm_detector.pb 203 | ``` 204 | ### Data preparation 205 | 1. To train a model with custom images,5 facial landmarks of each image are needed in advance for an image pre-alignment process. We recommend using [dlib](http://dlib.net/) or [MTCNN](https://github.com/ipazc/mtcnn) to detect these landmarks. Then, organize all files into the following structure: 206 | ``` 207 | Deep3DFaceRecon_pytorch 208 | │ 209 | └─── datasets 210 | │ 211 | └─── 212 | │ 213 | └─── *.png/*.jpg 214 | | 215 | └─── detections 216 | | 217 | └─── *.txt 218 | ``` 219 | The \*.txt files contain 5 facial landmarks with a shape of 5x2, and should have the same name with their corresponding images. 220 | 221 | 2. Generate 68 landmarks and skin attention mask for images using the following script: 222 | ``` 223 | # preprocess training images 224 | python data_preparation.py --img_folder 225 | 226 | # alternatively, you can preprocess multiple image folders simultaneously 227 | python data_preparation.py --img_folder 228 | 229 | # preprocess validation images 230 | python data_preparation.py --img_folder --mode=val 231 | ``` 232 | The script will generate files of landmarks and skin masks, and save them into ./datasets/. In addition, it also generates a file containing the path of all training data into ./datalist which will then be used in the training script. 233 | 234 | ### Train the face reconstruction network 235 | Run the following script to train a face reconstruction model using the pre-processed data: 236 | ``` 237 | # train with single GPU 238 | python train.py --name= --gpu_ids=0 239 | 240 | # train with multiple GPUs 241 | python train.py --name= --gpu_ids=0,1 242 | 243 | # train with other custom settings 244 | python train.py --name= --gpu_ids=0 --batch_size=32 --n_epochs=20 245 | ``` 246 | Training logs and model parameters will be saved into ./checkpoints/. 247 | 248 | By default, the script uses a batchsize of 32 and will train the model with 20 epochs. For reference, the pre-trained model in this repo is trained with the default setting on a image collection of 300k images. A single iteration takes 0.8~0.9s on a single Tesla M40 GPU. The total training process takes around two days. 249 | 250 | To use a trained model, see [Inference](https://github.com/sicxu/Deep3DFaceRecon_pytorch#inference-with-a-pre-trained-model) section. 251 | ## Contact 252 | If you have any questions, please contact the paper authors. 253 | 254 | ## Citation 255 | 256 | Please cite the following paper if this model helps your research: 257 | 258 | @inproceedings{deng2019accurate, 259 | title={Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set}, 260 | author={Yu Deng and Jiaolong Yang and Sicheng Xu and Dong Chen and Yunde Jia and Xin Tong}, 261 | booktitle={IEEE Computer Vision and Pattern Recognition Workshops}, 262 | year={2019} 263 | } 264 | ## 265 | The face images on this page are from the public [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset released by MMLab, CUHK. 266 | 267 | Part of the code in this implementation takes [CUT](https://github.com/taesungp/contrastive-unpaired-translation) as a reference. 268 | 269 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import numpy as np 14 | import importlib 15 | import torch.utils.data 16 | from data.base_dataset import BaseDataset 17 | 18 | 19 | def find_dataset_using_name(dataset_name): 20 | """Import the module "data/[dataset_name]_dataset.py". 21 | 22 | In the file, the class called DatasetNameDataset() will 23 | be instantiated. It has to be a subclass of BaseDataset, 24 | and it is case-insensitive. 25 | """ 26 | dataset_filename = "data." + dataset_name + "_dataset" 27 | datasetlib = importlib.import_module(dataset_filename) 28 | 29 | dataset = None 30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 31 | for name, cls in datasetlib.__dict__.items(): 32 | if name.lower() == target_dataset_name.lower() \ 33 | and issubclass(cls, BaseDataset): 34 | dataset = cls 35 | 36 | if dataset is None: 37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 38 | 39 | return dataset 40 | 41 | 42 | def get_option_setter(dataset_name): 43 | """Return the static method of the dataset class.""" 44 | dataset_class = find_dataset_using_name(dataset_name) 45 | return dataset_class.modify_commandline_options 46 | 47 | 48 | def create_dataset(opt, rank=0): 49 | """Create a dataset given the option. 50 | 51 | This function wraps the class CustomDatasetDataLoader. 52 | This is the main interface between this package and 'train.py'/'test.py' 53 | 54 | Example: 55 | >>> from data import create_dataset 56 | >>> dataset = create_dataset(opt) 57 | """ 58 | data_loader = CustomDatasetDataLoader(opt, rank=rank) 59 | dataset = data_loader.load_data() 60 | return dataset 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt, rank=0): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | self.sampler = None 75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) 76 | if opt.use_ddp and opt.isTrain: 77 | world_size = opt.world_size 78 | self.sampler = torch.utils.data.distributed.DistributedSampler( 79 | self.dataset, 80 | num_replicas=world_size, 81 | rank=rank, 82 | shuffle=not opt.serial_batches 83 | ) 84 | self.dataloader = torch.utils.data.DataLoader( 85 | self.dataset, 86 | sampler=self.sampler, 87 | num_workers=int(opt.num_threads / world_size), 88 | batch_size=int(opt.batch_size / world_size), 89 | drop_last=True) 90 | else: 91 | self.dataloader = torch.utils.data.DataLoader( 92 | self.dataset, 93 | batch_size=opt.batch_size, 94 | shuffle=(not opt.serial_batches) and opt.isTrain, 95 | num_workers=int(opt.num_threads), 96 | drop_last=True 97 | ) 98 | 99 | def set_epoch(self, epoch): 100 | self.dataset.current_epoch = epoch 101 | if self.sampler is not None: 102 | self.sampler.set_epoch(epoch) 103 | 104 | def load_data(self): 105 | return self 106 | 107 | def __len__(self): 108 | """Return the number of data in the dataset""" 109 | return min(len(self.dataset), self.opt.max_dataset_size) 110 | 111 | def __iter__(self): 112 | """Return a batch of data""" 113 | for i, data in enumerate(self.dataloader): 114 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 115 | break 116 | yield data 117 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | try: 10 | from PIL.Image import Resampling 11 | RESAMPLING_METHOD = Resampling.BICUBIC 12 | except ImportError: 13 | from PIL.Image import BICUBIC 14 | RESAMPLING_METHOD = BICUBIC 15 | import torchvision.transforms as transforms 16 | from abc import ABC, abstractmethod 17 | 18 | 19 | class BaseDataset(data.Dataset, ABC): 20 | """This class is an abstract base class (ABC) for datasets. 21 | 22 | To create a subclass, you need to implement the following four functions: 23 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 24 | -- <__len__>: return the size of dataset. 25 | -- <__getitem__>: get a data point. 26 | -- : (optionally) add dataset-specific options and set default options. 27 | """ 28 | 29 | def __init__(self, opt): 30 | """Initialize the class; save the options in the class 31 | 32 | Parameters: 33 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 34 | """ 35 | self.opt = opt 36 | # self.root = opt.dataroot 37 | self.current_epoch = 0 38 | 39 | @staticmethod 40 | def modify_commandline_options(parser, is_train): 41 | """Add new dataset-specific options, and rewrite default values for existing options. 42 | 43 | Parameters: 44 | parser -- original option parser 45 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 46 | 47 | Returns: 48 | the modified parser. 49 | """ 50 | return parser 51 | 52 | @abstractmethod 53 | def __len__(self): 54 | """Return the total number of images in the dataset.""" 55 | return 0 56 | 57 | @abstractmethod 58 | def __getitem__(self, index): 59 | """Return a data point and its metadata information. 60 | 61 | Parameters: 62 | index - - a random integer for data indexing 63 | 64 | Returns: 65 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 66 | """ 67 | pass 68 | 69 | 70 | def get_transform(grayscale=False): 71 | transform_list = [] 72 | if grayscale: 73 | transform_list.append(transforms.Grayscale(1)) 74 | transform_list += [transforms.ToTensor()] 75 | return transforms.Compose(transform_list) 76 | 77 | def get_affine_mat(opt, size): 78 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False 79 | w, h = size 80 | 81 | if 'shift' in opt.preprocess: 82 | shift_pixs = int(opt.shift_pixs) 83 | shift_x = random.randint(-shift_pixs, shift_pixs) 84 | shift_y = random.randint(-shift_pixs, shift_pixs) 85 | if 'scale' in opt.preprocess: 86 | scale = 1 + opt.scale_delta * (2 * random.random() - 1) 87 | if 'rot' in opt.preprocess: 88 | rot_angle = opt.rot_angle * (2 * random.random() - 1) 89 | rot_rad = -rot_angle * np.pi/180 90 | if 'flip' in opt.preprocess: 91 | flip = random.random() > 0.5 92 | 93 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) 94 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) 95 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) 96 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) 97 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) 98 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) 99 | 100 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin 101 | affine_inv = np.linalg.inv(affine) 102 | return affine, affine_inv, flip 103 | 104 | def apply_img_affine(img, affine_inv, method=RESAMPLING_METHOD): 105 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=RESAMPLING_METHOD) 106 | 107 | def apply_lm_affine(landmark, affine, flip, size): 108 | _, h = size 109 | lm = landmark.copy() 110 | lm[:, 1] = h - 1 - lm[:, 1] 111 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) 112 | lm = lm @ np.transpose(affine) 113 | lm[:, :2] = lm[:, :2] / lm[:, 2:] 114 | lm = lm[:, :2] 115 | lm[:, 1] = h - 1 - lm[:, 1] 116 | if flip: 117 | lm_ = lm.copy() 118 | lm_[:17] = lm[16::-1] 119 | lm_[17:22] = lm[26:21:-1] 120 | lm_[22:27] = lm[21:16:-1] 121 | lm_[31:36] = lm[35:30:-1] 122 | lm_[36:40] = lm[45:41:-1] 123 | lm_[40:42] = lm[47:45:-1] 124 | lm_[42:46] = lm[39:35:-1] 125 | lm_[46:48] = lm[41:39:-1] 126 | lm_[48:55] = lm[54:47:-1] 127 | lm_[55:60] = lm[59:54:-1] 128 | lm_[60:65] = lm[64:59:-1] 129 | lm_[65:68] = lm[67:64:-1] 130 | lm = lm_ 131 | return lm 132 | -------------------------------------------------------------------------------- /data/flist_dataset.py: -------------------------------------------------------------------------------- 1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os.path 5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | import random 9 | import util.util as util 10 | import numpy as np 11 | import json 12 | import torch 13 | from scipy.io import loadmat, savemat 14 | import pickle 15 | from util.preprocess import align_img, estimate_norm 16 | from util.load_mats import load_lm3d 17 | 18 | 19 | def default_flist_reader(flist): 20 | """ 21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 22 | """ 23 | imlist = [] 24 | with open(flist, 'r') as rf: 25 | for line in rf.readlines(): 26 | impath = line.strip() 27 | imlist.append(impath) 28 | 29 | return imlist 30 | 31 | def jason_flist_reader(flist): 32 | with open(flist, 'r') as fp: 33 | info = json.load(fp) 34 | return info 35 | 36 | def parse_label(label): 37 | return torch.tensor(np.array(label).astype(np.float32)) 38 | 39 | 40 | class FlistDataset(BaseDataset): 41 | """ 42 | It requires one directories to host training images '/path/to/data/train' 43 | You can train the model with the dataset flag '--dataroot /path/to/data'. 44 | """ 45 | 46 | def __init__(self, opt): 47 | """Initialize this dataset class. 48 | 49 | Parameters: 50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 51 | """ 52 | BaseDataset.__init__(self, opt) 53 | 54 | self.lm3d_std = load_lm3d(opt.bfm_folder) 55 | 56 | msk_names = default_flist_reader(opt.flist) 57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] 58 | 59 | self.size = len(self.msk_paths) 60 | self.opt = opt 61 | 62 | self.name = 'train' if opt.isTrain else 'val' 63 | if '_' in opt.flist: 64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] 65 | 66 | 67 | def __getitem__(self, index): 68 | """Return a data point and its metadata information. 69 | 70 | Parameters: 71 | index (int) -- a random integer for data indexing 72 | 73 | Returns a dictionary that contains A, B, A_paths and B_paths 74 | img (tensor) -- an image in the input domain 75 | msk (tensor) -- its corresponding attention mask 76 | lm (tensor) -- its corresponding 3d landmarks 77 | im_paths (str) -- image paths 78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented 79 | """ 80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range 81 | img_path = msk_path.replace('mask/', '') 82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' 83 | 84 | raw_img = Image.open(img_path).convert('RGB') 85 | raw_msk = Image.open(msk_path).convert('RGB') 86 | raw_lm = np.loadtxt(lm_path).astype(np.float32) 87 | 88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) 89 | 90 | aug_flag = self.opt.use_aug and self.opt.isTrain 91 | if aug_flag: 92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk) 93 | 94 | _, H = img.size 95 | M = estimate_norm(lm, H) 96 | transform = get_transform() 97 | img_tensor = transform(img) 98 | msk_tensor = transform(msk)[:1, ...] 99 | lm_tensor = parse_label(lm) 100 | M_tensor = parse_label(M) 101 | 102 | 103 | return {'imgs': img_tensor, 104 | 'lms': lm_tensor, 105 | 'msks': msk_tensor, 106 | 'M': M_tensor, 107 | 'im_paths': img_path, 108 | 'aug_flag': aug_flag, 109 | 'dataset': self.name} 110 | 111 | def _augmentation(self, img, lm, opt, msk=None): 112 | affine, affine_inv, flip = get_affine_mat(opt, img.size) 113 | img = apply_img_affine(img, affine_inv) 114 | lm = apply_lm_affine(lm, affine, flip, img.size) 115 | if msk is not None: 116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) 117 | return img, lm, msk 118 | 119 | 120 | 121 | 122 | def __len__(self): 123 | """Return the total number of images in the dataset. 124 | """ 125 | return self.size 126 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | import numpy as np 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /data_preparation.py: -------------------------------------------------------------------------------- 1 | """This script is the data preparation script for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | import numpy as np 6 | import argparse 7 | from util.detect_lm68 import detect_68p,load_lm_graph 8 | from util.skin_mask import get_skin_mask 9 | from util.generate_list import check_list, write_list 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data') 15 | parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images') 16 | parser.add_argument('--mode', type=str, default='train', help='train or val') 17 | opt = parser.parse_args() 18 | 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 20 | 21 | def data_prepare(folder_list,mode): 22 | 23 | lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector 24 | 25 | for img_folder in folder_list: 26 | detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images 27 | get_skin_mask(img_folder) # generate skin attention mask for images 28 | 29 | # create files that record path to all training data 30 | msks_list = [] 31 | for img_folder in folder_list: 32 | path = os.path.join(img_folder, 'mask') 33 | msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or 34 | 'png' in i or 'jpeg' in i or 'PNG' in i] 35 | 36 | imgs_list = [i.replace('mask/', '') for i in msks_list] 37 | lms_list = [i.replace('mask', 'landmarks') for i in msks_list] 38 | lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list] 39 | 40 | lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid 41 | write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files 42 | 43 | if __name__ == '__main__': 44 | print('Datasets:',opt.img_folder) 45 | data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode) 46 | -------------------------------------------------------------------------------- /datasets/examples/000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000002.jpg -------------------------------------------------------------------------------- /datasets/examples/000006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000006.jpg -------------------------------------------------------------------------------- /datasets/examples/000007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000007.jpg -------------------------------------------------------------------------------- /datasets/examples/000031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000031.jpg -------------------------------------------------------------------------------- /datasets/examples/000033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000033.jpg -------------------------------------------------------------------------------- /datasets/examples/000037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000037.jpg -------------------------------------------------------------------------------- /datasets/examples/000050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000050.jpg -------------------------------------------------------------------------------- /datasets/examples/000055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000055.jpg -------------------------------------------------------------------------------- /datasets/examples/000114.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000114.jpg -------------------------------------------------------------------------------- /datasets/examples/000125.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000125.jpg -------------------------------------------------------------------------------- /datasets/examples/000126.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000126.jpg -------------------------------------------------------------------------------- /datasets/examples/015259.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015259.jpg -------------------------------------------------------------------------------- /datasets/examples/015270.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015270.jpg -------------------------------------------------------------------------------- /datasets/examples/015309.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015309.jpg -------------------------------------------------------------------------------- /datasets/examples/015310.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015310.jpg -------------------------------------------------------------------------------- /datasets/examples/015316.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015316.jpg -------------------------------------------------------------------------------- /datasets/examples/015384.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015384.jpg -------------------------------------------------------------------------------- /datasets/examples/detections/000002.txt: -------------------------------------------------------------------------------- 1 | 142.84 207.18 2 | 222.02 203.9 3 | 159.24 253.57 4 | 146.59 290.93 5 | 227.52 284.74 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000006.txt: -------------------------------------------------------------------------------- 1 | 199.93 158.28 2 | 255.34 166.54 3 | 236.08 198.92 4 | 198.83 229.24 5 | 245.23 234.52 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000007.txt: -------------------------------------------------------------------------------- 1 | 129.36 198.28 2 | 204.47 191.47 3 | 164.42 240.51 4 | 140.74 277.77 5 | 205.4 270.9 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000031.txt: -------------------------------------------------------------------------------- 1 | 151.23 240.71 2 | 274.05 235.52 3 | 217.37 305.99 4 | 158.03 346.06 5 | 272.17 341.09 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000033.txt: -------------------------------------------------------------------------------- 1 | 119.09 94.291 2 | 158.31 96.472 3 | 136.76 121.4 4 | 119.33 134.49 5 | 154.66 136.68 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000037.txt: -------------------------------------------------------------------------------- 1 | 147.37 159.39 2 | 196.94 163.26 3 | 190.68 194.36 4 | 153.72 228.44 5 | 193.94 229.7 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000050.txt: -------------------------------------------------------------------------------- 1 | 150.4 94.799 2 | 205.14 102.07 3 | 179.54 131.16 4 | 144.45 147.42 5 | 193.39 154.14 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000055.txt: -------------------------------------------------------------------------------- 1 | 114.26 193.42 2 | 205.8 190.27 3 | 154.15 244.02 4 | 124.69 295.22 5 | 200.88 292.69 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000114.txt: -------------------------------------------------------------------------------- 1 | 217.52 152.95 2 | 281.48 147.14 3 | 253.02 196.03 4 | 225.79 221.6 5 | 288.25 214.44 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000125.txt: -------------------------------------------------------------------------------- 1 | 90.928 99.858 2 | 146.87 100.33 3 | 114.22 130.36 4 | 91.579 153.32 5 | 143.63 153.56 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/000126.txt: -------------------------------------------------------------------------------- 1 | 307.56 166.54 2 | 387.06 159.62 3 | 335.52 222.26 4 | 319.3 248.85 5 | 397.71 239.14 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/015259.txt: -------------------------------------------------------------------------------- 1 | 226.38 193.65 2 | 319.12 208.97 3 | 279.99 245.88 4 | 213.79 290.55 5 | 303.03 302.1 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/015270.txt: -------------------------------------------------------------------------------- 1 | 208.4 410.08 2 | 364.41 388.68 3 | 291.6 503.57 4 | 244.82 572.86 5 | 383.18 553.49 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/015309.txt: -------------------------------------------------------------------------------- 1 | 284.61 496.57 2 | 562.77 550.78 3 | 395.85 712.84 4 | 238.92 786.8 5 | 495.61 827.22 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/015310.txt: -------------------------------------------------------------------------------- 1 | 153.95 153.43 2 | 211.13 161.54 3 | 197.28 190.26 4 | 150.82 215.98 5 | 202.32 223.12 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/015316.txt: -------------------------------------------------------------------------------- 1 | 481.31 396.88 2 | 667.75 392.43 3 | 557.81 440.55 4 | 490.44 586.28 5 | 640.56 583.2 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/015384.txt: -------------------------------------------------------------------------------- 1 | 191.79 143.97 2 | 271.86 151.23 3 | 191.25 210.29 4 | 187.82 257.12 5 | 258.82 261.96 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd006.txt: -------------------------------------------------------------------------------- 1 | 123.12 117.58 2 | 176.59 122.09 3 | 126.99 144.68 4 | 117.61 183.43 5 | 163.94 186.41 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd025.txt: -------------------------------------------------------------------------------- 1 | 180.12 116.13 2 | 263.18 98.397 3 | 230.48 154.72 4 | 201.37 199.01 5 | 279.18 182.56 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd026.txt: -------------------------------------------------------------------------------- 1 | 171.27 263.54 2 | 286.58 263.88 3 | 203.35 333.02 4 | 170.6 389.42 5 | 281.73 386.84 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd034.txt: -------------------------------------------------------------------------------- 1 | 136.01 167.83 2 | 195.25 151.71 3 | 152.89 191.45 4 | 149.85 235.5 5 | 201.16 222.8 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd051.txt: -------------------------------------------------------------------------------- 1 | 161.92 292.04 2 | 254.21 283.81 3 | 212.75 342.06 4 | 170.78 387.28 5 | 254.6 379.82 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd070.txt: -------------------------------------------------------------------------------- 1 | 276.53 290.35 2 | 383.38 294.75 3 | 314.48 354.66 4 | 275.08 407.72 5 | 364.94 411.48 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd092.txt: -------------------------------------------------------------------------------- 1 | 108.59 149.07 2 | 157.35 143.85 3 | 134.4 173.2 4 | 117.88 200.79 5 | 159.56 196.36 6 | -------------------------------------------------------------------------------- /datasets/examples/detections/vd102.txt: -------------------------------------------------------------------------------- 1 | 121.62 225.96 2 | 186.73 223.07 3 | 162.99 269.82 4 | 132.12 302.62 5 | 186.42 299.21 6 | -------------------------------------------------------------------------------- /datasets/examples/vd006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd006.png -------------------------------------------------------------------------------- /datasets/examples/vd025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd025.png -------------------------------------------------------------------------------- /datasets/examples/vd026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd026.png -------------------------------------------------------------------------------- /datasets/examples/vd034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd034.png -------------------------------------------------------------------------------- /datasets/examples/vd051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd051.png -------------------------------------------------------------------------------- /datasets/examples/vd070.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd070.png -------------------------------------------------------------------------------- /datasets/examples/vd092.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd092.png -------------------------------------------------------------------------------- /datasets/examples/vd102.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd102.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deep3d_pytorch 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.6 8 | - pytorch=1.6.0 9 | - torchvision=0.7.0 10 | - numpy=1.18.1 11 | - scikit-image=0.16.2 12 | - scipy=1.4.1 13 | - pillow=6.2.1 14 | - pip=20.0.2 15 | - ipython=7.13.0 16 | - yaml=0.1.7 17 | - pip: 18 | - matplotlib==2.2.5 19 | - opencv-python==3.4.9.33 20 | - tensorboard==1.15.0 21 | - tensorflow==1.15.0 22 | - kornia==0.5.5 23 | - dominate==2.6.0 24 | - trimesh==3.9.20 -------------------------------------------------------------------------------- /images/20230425_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/images/20230425_compare.png -------------------------------------------------------------------------------- /images/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/images/compare.png -------------------------------------------------------------------------------- /images/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/images/example.gif -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | """This script defines the base network model for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | import numpy as np 6 | import torch 7 | from collections import OrderedDict 8 | from abc import ABC, abstractmethod 9 | from . import networks 10 | 11 | 12 | class BaseModel(ABC): 13 | """This class is an abstract base class (ABC) for models. 14 | To create a subclass, you need to implement the following five functions: 15 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 16 | -- : unpack data from dataset and apply preprocessing. 17 | -- : produce intermediate results. 18 | -- : calculate losses, gradients, and update network weights. 19 | -- : (optionally) add model-specific options and set default options. 20 | """ 21 | 22 | def __init__(self, opt): 23 | """Initialize the BaseModel class. 24 | 25 | Parameters: 26 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 27 | 28 | When creating your custom class, you need to implement your own initialization. 29 | In this fucntion, you should first call 30 | Then, you need to define four lists: 31 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 32 | -- self.model_names (str list): specify the images that you want to display and save. 33 | -- self.visual_names (str list): define networks used in our training. 34 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 35 | """ 36 | self.opt = opt 37 | self.isTrain = opt.isTrain 38 | self.device = torch.device('cpu') 39 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 40 | self.loss_names = [] 41 | self.model_names = [] 42 | self.visual_names = [] 43 | self.parallel_names = [] 44 | self.optimizers = [] 45 | self.image_paths = [] 46 | self.metric = 0 # used for learning rate policy 'plateau' 47 | 48 | @staticmethod 49 | def dict_grad_hook_factory(add_func=lambda x: x): 50 | saved_dict = dict() 51 | 52 | def hook_gen(name): 53 | def grad_hook(grad): 54 | saved_vals = add_func(grad) 55 | saved_dict[name] = saved_vals 56 | return grad_hook 57 | return hook_gen, saved_dict 58 | 59 | @staticmethod 60 | def modify_commandline_options(parser, is_train): 61 | """Add new model-specific options, and rewrite default values for existing options. 62 | 63 | Parameters: 64 | parser -- original option parser 65 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 66 | 67 | Returns: 68 | the modified parser. 69 | """ 70 | return parser 71 | 72 | @abstractmethod 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input (dict): includes the data itself and its metadata information. 78 | """ 79 | pass 80 | 81 | @abstractmethod 82 | def forward(self): 83 | """Run forward pass; called by both functions and .""" 84 | pass 85 | 86 | @abstractmethod 87 | def optimize_parameters(self): 88 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 89 | pass 90 | 91 | def setup(self, opt): 92 | """Load and print networks; create schedulers 93 | 94 | Parameters: 95 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 96 | """ 97 | if self.isTrain: 98 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 99 | 100 | if not self.isTrain or opt.continue_train: 101 | load_suffix = opt.epoch 102 | self.load_networks(load_suffix) 103 | 104 | 105 | # self.print_networks(opt.verbose) 106 | 107 | def parallelize(self, convert_sync_batchnorm=True): 108 | if not self.opt.use_ddp: 109 | for name in self.parallel_names: 110 | if isinstance(name, str): 111 | module = getattr(self, name) 112 | setattr(self, name, module.to(self.device)) 113 | else: 114 | for name in self.model_names: 115 | if isinstance(name, str): 116 | module = getattr(self, name) 117 | if convert_sync_batchnorm: 118 | module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) 119 | setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), 120 | device_ids=[self.device.index], 121 | find_unused_parameters=True, broadcast_buffers=True)) 122 | 123 | # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. 124 | for name in self.parallel_names: 125 | if isinstance(name, str) and name not in self.model_names: 126 | module = getattr(self, name) 127 | setattr(self, name, module.to(self.device)) 128 | 129 | # put state_dict of optimizer to gpu device 130 | if self.opt.phase != 'test': 131 | if self.opt.continue_train: 132 | for optim in self.optimizers: 133 | for state in optim.state.values(): 134 | for k, v in state.items(): 135 | if isinstance(v, torch.Tensor): 136 | state[k] = v.to(self.device) 137 | 138 | def data_dependent_initialize(self, data): 139 | pass 140 | 141 | def train(self): 142 | """Make models train mode""" 143 | for name in self.model_names: 144 | if isinstance(name, str): 145 | net = getattr(self, name) 146 | net.train() 147 | 148 | def eval(self): 149 | """Make models eval mode""" 150 | for name in self.model_names: 151 | if isinstance(name, str): 152 | net = getattr(self, name) 153 | net.eval() 154 | 155 | def test(self): 156 | """Forward function used in test time. 157 | 158 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 159 | It also calls to produce additional visualization results 160 | """ 161 | with torch.no_grad(): 162 | self.forward() 163 | self.compute_visuals() 164 | 165 | def compute_visuals(self): 166 | """Calculate additional output images for visdom and HTML visualization""" 167 | pass 168 | 169 | def get_image_paths(self, name='A'): 170 | """ Return image paths that are used to load current data""" 171 | return self.image_paths if name =='A' else self.image_paths_B 172 | 173 | def update_learning_rate(self): 174 | """Update learning rates for all the networks; called at the end of every epoch""" 175 | for scheduler in self.schedulers: 176 | if self.opt.lr_policy == 'plateau': 177 | scheduler.step(self.metric) 178 | else: 179 | scheduler.step() 180 | 181 | lr = self.optimizers[0].param_groups[0]['lr'] 182 | print('learning rate = %.7f' % lr) 183 | 184 | def get_current_visuals(self): 185 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 186 | visual_ret = OrderedDict() 187 | for name in self.visual_names: 188 | if isinstance(name, str): 189 | visual_ret[name] = getattr(self, name)[:, :3, ...] 190 | return visual_ret 191 | 192 | def get_current_losses(self): 193 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 194 | errors_ret = OrderedDict() 195 | for name in self.loss_names: 196 | if isinstance(name, str): 197 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 198 | return errors_ret 199 | 200 | def save_networks(self, epoch): 201 | """Save all the networks to the disk. 202 | 203 | Parameters: 204 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 205 | """ 206 | if not os.path.isdir(self.save_dir): 207 | os.makedirs(self.save_dir) 208 | 209 | save_filename = 'epoch_%s.pth' % (epoch) 210 | save_path = os.path.join(self.save_dir, save_filename) 211 | 212 | save_dict = {} 213 | for name in self.model_names: 214 | if isinstance(name, str): 215 | net = getattr(self, name) 216 | if isinstance(net, torch.nn.DataParallel) or isinstance(net, 217 | torch.nn.parallel.DistributedDataParallel): 218 | net = net.module 219 | save_dict[name] = net.state_dict() 220 | 221 | 222 | for i, optim in enumerate(self.optimizers): 223 | save_dict['opt_%02d'%i] = optim.state_dict() 224 | 225 | for i, sched in enumerate(self.schedulers): 226 | save_dict['sched_%02d'%i] = sched.state_dict() 227 | 228 | torch.save(save_dict, save_path) 229 | 230 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 231 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 232 | key = keys[i] 233 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 234 | if module.__class__.__name__.startswith('InstanceNorm') and \ 235 | (key == 'running_mean' or key == 'running_var'): 236 | if getattr(module, key) is None: 237 | state_dict.pop('.'.join(keys)) 238 | if module.__class__.__name__.startswith('InstanceNorm') and \ 239 | (key == 'num_batches_tracked'): 240 | state_dict.pop('.'.join(keys)) 241 | else: 242 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 243 | 244 | def load_networks(self, epoch): 245 | """Load all the networks from the disk. 246 | 247 | Parameters: 248 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 249 | """ 250 | if self.opt.isTrain and self.opt.pretrained_name is not None: 251 | load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) 252 | else: 253 | load_dir = self.save_dir 254 | load_filename = 'epoch_%s.pth' % (epoch) 255 | load_path = os.path.join(load_dir, load_filename) 256 | state_dict = torch.load(load_path, map_location=self.device) 257 | print('loading the model from %s' % load_path) 258 | 259 | for name in self.model_names: 260 | if isinstance(name, str): 261 | net = getattr(self, name) 262 | if isinstance(net, torch.nn.DataParallel): 263 | net = net.module 264 | net.load_state_dict(state_dict[name]) 265 | 266 | if self.opt.phase != 'test': 267 | if self.opt.continue_train: 268 | print('loading the optim from %s' % load_path) 269 | for i, optim in enumerate(self.optimizers): 270 | optim.load_state_dict(state_dict['opt_%02d'%i]) 271 | 272 | try: 273 | print('loading the sched from %s' % load_path) 274 | for i, sched in enumerate(self.schedulers): 275 | sched.load_state_dict(state_dict['sched_%02d'%i]) 276 | except: 277 | print('Failed to load schedulers, set schedulers according to epoch count manually') 278 | for i, sched in enumerate(self.schedulers): 279 | sched.last_epoch = self.opt.epoch_count - 1 280 | 281 | 282 | 283 | 284 | def print_networks(self, verbose): 285 | """Print the total number of parameters in the network and (if verbose) network architecture 286 | 287 | Parameters: 288 | verbose (bool) -- if verbose: print the network architecture 289 | """ 290 | print('---------- Networks initialized -------------') 291 | for name in self.model_names: 292 | if isinstance(name, str): 293 | net = getattr(self, name) 294 | num_params = 0 295 | for param in net.parameters(): 296 | num_params += param.numel() 297 | if verbose: 298 | print(net) 299 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 300 | print('-----------------------------------------------') 301 | 302 | def set_requires_grad(self, nets, requires_grad=False): 303 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 304 | Parameters: 305 | nets (network list) -- a list of networks 306 | requires_grad (bool) -- whether the networks require gradients or not 307 | """ 308 | if not isinstance(nets, list): 309 | nets = [nets] 310 | for net in nets: 311 | if net is not None: 312 | for param in net.parameters(): 313 | param.requires_grad = requires_grad 314 | 315 | def generate_visuals_for_evaluation(self, data, mode): 316 | return {} 317 | -------------------------------------------------------------------------------- /models/bfm.py: -------------------------------------------------------------------------------- 1 | """This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scipy.io import loadmat 8 | from util.load_mats import transferBFM09 9 | import os 10 | 11 | def perspective_projection(focal, center): 12 | # return p.T (N, 3) @ (3, 3) 13 | return np.array([ 14 | focal, 0, center, 15 | 0, focal, center, 16 | 0, 0, 1 17 | ]).reshape([3, 3]).astype(np.float32).transpose() 18 | 19 | class SH: 20 | def __init__(self): 21 | self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] 22 | self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] 23 | 24 | 25 | 26 | class ParametricFaceModel: 27 | def __init__(self, 28 | bfm_folder='./BFM', 29 | recenter=True, 30 | camera_distance=10., 31 | init_lit=np.array([ 32 | 0.8, 0, 0, 0, 0, 0, 0, 0, 0 33 | ]), 34 | focal=1015., 35 | center=112., 36 | is_train=True, 37 | default_name='BFM_model_front.mat'): 38 | 39 | if not os.path.isfile(os.path.join(bfm_folder, default_name)): 40 | transferBFM09(bfm_folder) 41 | model = loadmat(os.path.join(bfm_folder, default_name)) 42 | # mean face shape. [3*N,1] 43 | self.mean_shape = model['meanshape'].astype(np.float32) 44 | # identity basis. [3*N,80] 45 | self.id_base = model['idBase'].astype(np.float32) 46 | # expression basis. [3*N,64] 47 | self.exp_base = model['exBase'].astype(np.float32) 48 | # mean face texture. [3*N,1] (0-255) 49 | self.mean_tex = model['meantex'].astype(np.float32) 50 | # texture basis. [3*N,80] 51 | self.tex_base = model['texBase'].astype(np.float32) 52 | # face indices for each vertex that lies in. starts from 0. [N,8] 53 | self.point_buf = model['point_buf'].astype(np.int64) - 1 54 | # vertex indices for each face. starts from 0. [F,3] 55 | self.face_buf = model['tri'].astype(np.int64) - 1 56 | # vertex indices for 68 landmarks. starts from 0. [68,1] 57 | self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 58 | 59 | if is_train: 60 | # vertex indices for small face region to compute photometric error. starts from 0. 61 | self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 62 | # vertex indices for each face from small face region. starts from 0. [f,3] 63 | self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 64 | # vertex indices for pre-defined skin region to compute reflectance loss 65 | self.skin_mask = np.squeeze(model['skinmask']) 66 | 67 | if recenter: 68 | mean_shape = self.mean_shape.reshape([-1, 3]) 69 | mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) 70 | self.mean_shape = mean_shape.reshape([-1, 1]) 71 | 72 | self.persc_proj = perspective_projection(focal, center) 73 | self.device = 'cpu' 74 | self.camera_distance = camera_distance 75 | self.SH = SH() 76 | self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) 77 | 78 | 79 | def to(self, device): 80 | self.device = device 81 | for key, value in self.__dict__.items(): 82 | if type(value).__module__ == np.__name__: 83 | setattr(self, key, torch.tensor(value).to(device)) 84 | 85 | 86 | def compute_shape(self, id_coeff, exp_coeff): 87 | """ 88 | Return: 89 | face_shape -- torch.tensor, size (B, N, 3) 90 | 91 | Parameters: 92 | id_coeff -- torch.tensor, size (B, 80), identity coeffs 93 | exp_coeff -- torch.tensor, size (B, 64), expression coeffs 94 | """ 95 | batch_size = id_coeff.shape[0] 96 | id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) 97 | exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) 98 | face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) 99 | return face_shape.reshape([batch_size, -1, 3]) 100 | 101 | 102 | def compute_texture(self, tex_coeff, normalize=True): 103 | """ 104 | Return: 105 | face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) 106 | 107 | Parameters: 108 | tex_coeff -- torch.tensor, size (B, 80) 109 | """ 110 | batch_size = tex_coeff.shape[0] 111 | face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex 112 | if normalize: 113 | face_texture = face_texture / 255. 114 | return face_texture.reshape([batch_size, -1, 3]) 115 | 116 | 117 | def compute_norm(self, face_shape): 118 | """ 119 | Return: 120 | vertex_norm -- torch.tensor, size (B, N, 3) 121 | 122 | Parameters: 123 | face_shape -- torch.tensor, size (B, N, 3) 124 | """ 125 | 126 | v1 = face_shape[:, self.face_buf[:, 0]] 127 | v2 = face_shape[:, self.face_buf[:, 1]] 128 | v3 = face_shape[:, self.face_buf[:, 2]] 129 | e1 = v1 - v2 130 | e2 = v2 - v3 131 | face_norm = torch.cross(e1, e2, dim=-1) 132 | face_norm = F.normalize(face_norm, dim=-1, p=2) 133 | face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) 134 | 135 | vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) 136 | vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) 137 | return vertex_norm 138 | 139 | 140 | def compute_color(self, face_texture, face_norm, gamma): 141 | """ 142 | Return: 143 | face_color -- torch.tensor, size (B, N, 3), range (0, 1.) 144 | 145 | Parameters: 146 | face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) 147 | face_norm -- torch.tensor, size (B, N, 3), rotated face normal 148 | gamma -- torch.tensor, size (B, 27), SH coeffs 149 | """ 150 | batch_size = gamma.shape[0] 151 | v_num = face_texture.shape[1] 152 | a, c = self.SH.a, self.SH.c 153 | gamma = gamma.reshape([batch_size, 3, 9]) 154 | gamma = gamma + self.init_lit 155 | gamma = gamma.permute(0, 2, 1) 156 | Y = torch.cat([ 157 | a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), 158 | -a[1] * c[1] * face_norm[..., 1:2], 159 | a[1] * c[1] * face_norm[..., 2:], 160 | -a[1] * c[1] * face_norm[..., :1], 161 | a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], 162 | -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], 163 | 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), 164 | -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], 165 | 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) 166 | ], dim=-1) 167 | r = Y @ gamma[..., :1] 168 | g = Y @ gamma[..., 1:2] 169 | b = Y @ gamma[..., 2:] 170 | face_color = torch.cat([r, g, b], dim=-1) * face_texture 171 | return face_color 172 | 173 | 174 | def compute_rotation(self, angles): 175 | """ 176 | Return: 177 | rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat 178 | 179 | Parameters: 180 | angles -- torch.tensor, size (B, 3), radian 181 | """ 182 | 183 | batch_size = angles.shape[0] 184 | ones = torch.ones([batch_size, 1]).to(self.device) 185 | zeros = torch.zeros([batch_size, 1]).to(self.device) 186 | x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], 187 | 188 | rot_x = torch.cat([ 189 | ones, zeros, zeros, 190 | zeros, torch.cos(x), -torch.sin(x), 191 | zeros, torch.sin(x), torch.cos(x) 192 | ], dim=1).reshape([batch_size, 3, 3]) 193 | 194 | rot_y = torch.cat([ 195 | torch.cos(y), zeros, torch.sin(y), 196 | zeros, ones, zeros, 197 | -torch.sin(y), zeros, torch.cos(y) 198 | ], dim=1).reshape([batch_size, 3, 3]) 199 | 200 | rot_z = torch.cat([ 201 | torch.cos(z), -torch.sin(z), zeros, 202 | torch.sin(z), torch.cos(z), zeros, 203 | zeros, zeros, ones 204 | ], dim=1).reshape([batch_size, 3, 3]) 205 | 206 | rot = rot_z @ rot_y @ rot_x 207 | return rot.permute(0, 2, 1) 208 | 209 | 210 | def to_camera(self, face_shape): 211 | face_shape[..., -1] = self.camera_distance - face_shape[..., -1] 212 | return face_shape 213 | 214 | def to_image(self, face_shape): 215 | """ 216 | Return: 217 | face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction 218 | 219 | Parameters: 220 | face_shape -- torch.tensor, size (B, N, 3) 221 | """ 222 | # to image_plane 223 | face_proj = face_shape @ self.persc_proj 224 | face_proj = face_proj[..., :2] / face_proj[..., 2:] 225 | 226 | return face_proj 227 | 228 | 229 | def transform(self, face_shape, rot, trans): 230 | """ 231 | Return: 232 | face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans 233 | 234 | Parameters: 235 | face_shape -- torch.tensor, size (B, N, 3) 236 | rot -- torch.tensor, size (B, 3, 3) 237 | trans -- torch.tensor, size (B, 3) 238 | """ 239 | return face_shape @ rot + trans.unsqueeze(1) 240 | 241 | 242 | def get_landmarks(self, face_proj): 243 | """ 244 | Return: 245 | face_lms -- torch.tensor, size (B, 68, 2) 246 | 247 | Parameters: 248 | face_proj -- torch.tensor, size (B, N, 2) 249 | """ 250 | return face_proj[:, self.keypoints] 251 | 252 | def split_coeff(self, coeffs): 253 | """ 254 | Return: 255 | coeffs_dict -- a dict of torch.tensors 256 | 257 | Parameters: 258 | coeffs -- torch.tensor, size (B, 256) 259 | """ 260 | id_coeffs = coeffs[:, :80] 261 | exp_coeffs = coeffs[:, 80: 144] 262 | tex_coeffs = coeffs[:, 144: 224] 263 | angles = coeffs[:, 224: 227] 264 | gammas = coeffs[:, 227: 254] 265 | translations = coeffs[:, 254:] 266 | return { 267 | 'id': id_coeffs, 268 | 'exp': exp_coeffs, 269 | 'tex': tex_coeffs, 270 | 'angle': angles, 271 | 'gamma': gammas, 272 | 'trans': translations 273 | } 274 | def compute_for_render(self, coeffs): 275 | """ 276 | Return: 277 | face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate 278 | face_color -- torch.tensor, size (B, N, 3), in RGB order 279 | landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction 280 | Parameters: 281 | coeffs -- torch.tensor, size (B, 257) 282 | """ 283 | coef_dict = self.split_coeff(coeffs) 284 | face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) 285 | rotation = self.compute_rotation(coef_dict['angle']) 286 | 287 | 288 | face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) 289 | face_vertex = self.to_camera(face_shape_transformed) 290 | 291 | face_proj = self.to_image(face_vertex) 292 | landmark = self.get_landmarks(face_proj) 293 | 294 | face_texture = self.compute_texture(coef_dict['tex']) 295 | face_norm = self.compute_norm(face_shape) 296 | face_norm_roted = face_norm @ rotation 297 | face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) 298 | 299 | return face_vertex, face_texture, face_color, landmark 300 | -------------------------------------------------------------------------------- /models/facerecon_model.py: -------------------------------------------------------------------------------- 1 | """This script defines the face reconstruction model for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | from .base_model import BaseModel 7 | from . import networks 8 | from .bfm import ParametricFaceModel 9 | from .losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss 10 | from util import util 11 | from util.nvdiffrast import MeshRenderer 12 | from util.preprocess import estimate_norm_torch 13 | 14 | import trimesh 15 | from scipy.io import savemat 16 | 17 | class FaceReconModel(BaseModel): 18 | 19 | @staticmethod 20 | def modify_commandline_options(parser, is_train=True): 21 | """ Configures options specific for CUT model 22 | """ 23 | # net structure and parameters 24 | parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') 25 | parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth') 26 | parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') 27 | parser.add_argument('--bfm_folder', type=str, default='BFM') 28 | parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') 29 | 30 | # renderer parameters 31 | parser.add_argument('--focal', type=float, default=1015.) 32 | parser.add_argument('--center', type=float, default=112.) 33 | parser.add_argument('--camera_d', type=float, default=10.) 34 | parser.add_argument('--z_near', type=float, default=5.) 35 | parser.add_argument('--z_far', type=float, default=15.) 36 | parser.add_argument('--use_opengl', type=util.str2bool, nargs='?', const=True, default=True, help='use opengl context or not') 37 | 38 | if is_train: 39 | # training parameters 40 | parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') 41 | parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') 42 | parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') 43 | parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') 44 | 45 | 46 | # augmentation parameters 47 | parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') 48 | parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') 49 | parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') 50 | 51 | # loss weights 52 | parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') 53 | parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') 54 | parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') 55 | parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') 56 | parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') 57 | parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') 58 | parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') 59 | parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') 60 | parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') 61 | 62 | 63 | 64 | opt, _ = parser.parse_known_args() 65 | parser.set_defaults( 66 | focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. 67 | ) 68 | if is_train: 69 | parser.set_defaults( 70 | use_crop_face=True, use_predef_M=False 71 | ) 72 | return parser 73 | 74 | def __init__(self, opt): 75 | """Initialize this model class. 76 | 77 | Parameters: 78 | opt -- training/test options 79 | 80 | A few things can be done here. 81 | - (required) call the initialization function of BaseModel 82 | - define loss function, visualization images, model names, and optimizers 83 | """ 84 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel 85 | 86 | self.visual_names = ['output_vis'] 87 | self.model_names = ['net_recon'] 88 | self.parallel_names = self.model_names + ['renderer'] 89 | 90 | self.net_recon = networks.define_net_recon( 91 | net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path 92 | ) 93 | 94 | self.facemodel = ParametricFaceModel( 95 | bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, 96 | is_train=self.isTrain, default_name=opt.bfm_model 97 | ) 98 | 99 | fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi 100 | self.renderer = MeshRenderer( 101 | rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center), use_opengl=opt.use_opengl 102 | ) 103 | 104 | if self.isTrain: 105 | self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] 106 | 107 | self.net_recog = networks.define_net_recog( 108 | net_recog=opt.net_recog, pretrained_path=opt.net_recog_path 109 | ) 110 | # loss func name: (compute_%s_loss) % loss_name 111 | self.compute_feat_loss = perceptual_loss 112 | self.comupte_color_loss = photo_loss 113 | self.compute_lm_loss = landmark_loss 114 | self.compute_reg_loss = reg_loss 115 | self.compute_reflc_loss = reflectance_loss 116 | 117 | self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) 118 | self.optimizers = [self.optimizer] 119 | self.parallel_names += ['net_recog'] 120 | # Our program will automatically call to define schedulers, load networks, and print networks 121 | 122 | def set_input(self, input): 123 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 124 | 125 | Parameters: 126 | input: a dictionary that contains the data itself and its metadata information. 127 | """ 128 | self.input_img = input['imgs'].to(self.device) 129 | self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None 130 | self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None 131 | self.trans_m = input['M'].to(self.device) if 'M' in input else None 132 | self.image_paths = input['im_paths'] if 'im_paths' in input else None 133 | 134 | def forward(self): 135 | output_coeff = self.net_recon(self.input_img) 136 | self.facemodel.to(self.device) 137 | self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ 138 | self.facemodel.compute_for_render(output_coeff) 139 | self.pred_mask, _, self.pred_face = self.renderer( 140 | self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) 141 | 142 | self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) 143 | 144 | 145 | def compute_losses(self): 146 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 147 | 148 | assert self.net_recog.training == False 149 | trans_m = self.trans_m 150 | if not self.opt.use_predef_M: 151 | trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) 152 | 153 | pred_feat = self.net_recog(self.pred_face, trans_m) 154 | gt_feat = self.net_recog(self.input_img, self.trans_m) 155 | self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) 156 | 157 | face_mask = self.pred_mask 158 | if self.opt.use_crop_face: 159 | face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) 160 | 161 | face_mask = face_mask.detach() 162 | self.loss_color = self.opt.w_color * self.comupte_color_loss( 163 | self.pred_face, self.input_img, self.atten_mask * face_mask) 164 | 165 | loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) 166 | self.loss_reg = self.opt.w_reg * loss_reg 167 | self.loss_gamma = self.opt.w_gamma * loss_gamma 168 | 169 | self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) 170 | 171 | self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) 172 | 173 | self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ 174 | + self.loss_lm + self.loss_reflc 175 | 176 | 177 | def optimize_parameters(self, isTrain=True): 178 | self.forward() 179 | self.compute_losses() 180 | """Update network weights; it will be called in every training iteration.""" 181 | if isTrain: 182 | self.optimizer.zero_grad() 183 | self.loss_all.backward() 184 | self.optimizer.step() 185 | 186 | def compute_visuals(self): 187 | with torch.no_grad(): 188 | input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() 189 | output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img 190 | output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() 191 | 192 | if self.gt_lm is not None: 193 | gt_lm_numpy = self.gt_lm.cpu().numpy() 194 | pred_lm_numpy = self.pred_lm.detach().cpu().numpy() 195 | output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') 196 | output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') 197 | 198 | output_vis_numpy = np.concatenate((input_img_numpy, 199 | output_vis_numpy_raw, output_vis_numpy), axis=-2) 200 | else: 201 | output_vis_numpy = np.concatenate((input_img_numpy, 202 | output_vis_numpy_raw), axis=-2) 203 | 204 | self.output_vis = torch.tensor( 205 | output_vis_numpy / 255., dtype=torch.float32 206 | ).permute(0, 3, 1, 2).to(self.device) 207 | 208 | def save_mesh(self, name): 209 | 210 | recon_shape = self.pred_vertex # get reconstructed shape 211 | recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space 212 | recon_shape = recon_shape.cpu().numpy()[0] 213 | recon_color = self.pred_color 214 | recon_color = recon_color.cpu().numpy()[0] 215 | tri = self.facemodel.face_buf.cpu().numpy() 216 | mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8), process=False) 217 | mesh.export(name) 218 | 219 | def save_coeff(self,name): 220 | 221 | pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} 222 | pred_lm = self.pred_lm.cpu().numpy() 223 | pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate 224 | pred_coeffs['lm68'] = pred_lm 225 | savemat(name,pred_coeffs) 226 | 227 | 228 | 229 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from kornia.geometry import warp_affine 5 | import torch.nn.functional as F 6 | 7 | def resize_n_crop(image, M, dsize=112): 8 | # image: (b, c, h, w) 9 | # M : (b, 2, 3) 10 | return warp_affine(image, M, dsize=(dsize, dsize)) 11 | 12 | ### perceptual level loss 13 | class PerceptualLoss(nn.Module): 14 | def __init__(self, recog_net, input_size=112): 15 | super(PerceptualLoss, self).__init__() 16 | self.recog_net = recog_net 17 | self.preprocess = lambda x: 2 * x - 1 18 | self.input_size=input_size 19 | def forward(imageA, imageB, M): 20 | """ 21 | 1 - cosine distance 22 | Parameters: 23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order 24 | imageB --same as imageA 25 | """ 26 | 27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) 28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) 29 | 30 | # freeze bn 31 | self.recog_net.eval() 32 | 33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) 34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) 35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 36 | # assert torch.sum((cosine_d > 1).float()) == 0 37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 38 | 39 | def perceptual_loss(id_featureA, id_featureB): 40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 41 | # assert torch.sum((cosine_d > 1).float()) == 0 42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 43 | 44 | ### image level loss 45 | def photo_loss(imageA, imageB, mask, eps=1e-6): 46 | """ 47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) 48 | Parameters: 49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order 50 | imageB --same as imageA 51 | """ 52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask 53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) 54 | return loss 55 | 56 | def landmark_loss(predict_lm, gt_lm, weight=None): 57 | """ 58 | weighted mse loss 59 | Parameters: 60 | predict_lm --torch.tensor (B, 68, 2) 61 | gt_lm --torch.tensor (B, 68, 2) 62 | weight --numpy.array (1, 68) 63 | """ 64 | if not weight: 65 | weight = np.ones([68]) 66 | weight[28:31] = 20 67 | weight[-8:] = 20 68 | weight = np.expand_dims(weight, 0) 69 | weight = torch.tensor(weight).to(predict_lm.device) 70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight 71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) 72 | return loss 73 | 74 | 75 | ### regulization 76 | def reg_loss(coeffs_dict, opt=None): 77 | """ 78 | l2 norm without the sqrt, from yu's implementation (mse) 79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss 80 | Parameters: 81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans 82 | 83 | """ 84 | # coefficient regularization to ensure plausible 3d faces 85 | if opt: 86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex 87 | else: 88 | w_id, w_exp, w_tex = 1, 1, 1, 1 89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ 90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ 91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2) 92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0] 93 | 94 | # gamma regularization to ensure a nearly-monochromatic light 95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) 96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True) 97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2) 98 | 99 | return creg_loss, gamma_loss 100 | 101 | def reflectance_loss(texture, mask): 102 | """ 103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo 104 | Parameters: 105 | texture --torch.tensor, (B, N, 3) 106 | mask --torch.tensor, (N), 1 or 0 107 | 108 | """ 109 | mask = mask.reshape([1, mask.shape[0], 1]) 110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) 111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) 112 | return loss 113 | 114 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | """This script defines deep neural networks for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | import functools 9 | from torch.optim import lr_scheduler 10 | import torch 11 | from torch import Tensor 12 | import torch.nn as nn 13 | try: 14 | from torch.hub import load_state_dict_from_url 15 | except ImportError: 16 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 17 | from typing import Type, Any, Callable, Union, List, Optional 18 | from .arcface_torch.backbones import get_model 19 | from kornia.geometry import warp_affine 20 | 21 | def resize_n_crop(image, M, dsize=112): 22 | # image: (b, c, h, w) 23 | # M : (b, 2, 3) 24 | return warp_affine(image, M, dsize=(dsize, dsize)) 25 | 26 | def filter_state_dict(state_dict, remove_name='fc'): 27 | new_state_dict = {} 28 | for key in state_dict: 29 | if remove_name in key: 30 | continue 31 | new_state_dict[key] = state_dict[key] 32 | return new_state_dict 33 | 34 | def get_scheduler(optimizer, opt): 35 | """Return a learning rate scheduler 36 | 37 | Parameters: 38 | optimizer -- the optimizer of the network 39 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  40 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 41 | 42 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 43 | See https://pytorch.org/docs/stable/optim.html for more details. 44 | """ 45 | if opt.lr_policy == 'linear': 46 | def lambda_rule(epoch): 47 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) 48 | return lr_l 49 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 50 | elif opt.lr_policy == 'step': 51 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) 52 | elif opt.lr_policy == 'plateau': 53 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 54 | elif opt.lr_policy == 'cosine': 55 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 56 | else: 57 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 58 | return scheduler 59 | 60 | 61 | def define_net_recon(net_recon, use_last_fc=False, init_path=None): 62 | return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) 63 | 64 | def define_net_recog(net_recog, pretrained_path=None): 65 | net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) 66 | net.eval() 67 | return net 68 | 69 | class ReconNetWrapper(nn.Module): 70 | fc_dim=257 71 | def __init__(self, net_recon, use_last_fc=False, init_path=None): 72 | super(ReconNetWrapper, self).__init__() 73 | self.use_last_fc = use_last_fc 74 | if net_recon not in func_dict: 75 | return NotImplementedError('network [%s] is not implemented', net_recon) 76 | func, last_dim = func_dict[net_recon] 77 | backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) 78 | if init_path and os.path.isfile(init_path): 79 | state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) 80 | backbone.load_state_dict(state_dict) 81 | print("loading init net_recon %s from %s" %(net_recon, init_path)) 82 | self.backbone = backbone 83 | if not use_last_fc: 84 | self.final_layers = nn.ModuleList([ 85 | conv1x1(last_dim, 80, bias=True), # id layer 86 | conv1x1(last_dim, 64, bias=True), # exp layer 87 | conv1x1(last_dim, 80, bias=True), # tex layer 88 | conv1x1(last_dim, 3, bias=True), # angle layer 89 | conv1x1(last_dim, 27, bias=True), # gamma layer 90 | conv1x1(last_dim, 2, bias=True), # tx, ty 91 | conv1x1(last_dim, 1, bias=True) # tz 92 | ]) 93 | for m in self.final_layers: 94 | nn.init.constant_(m.weight, 0.) 95 | nn.init.constant_(m.bias, 0.) 96 | 97 | def forward(self, x): 98 | x = self.backbone(x) 99 | if not self.use_last_fc: 100 | output = [] 101 | for layer in self.final_layers: 102 | output.append(layer(x)) 103 | x = torch.flatten(torch.cat(output, dim=1), 1) 104 | return x 105 | 106 | 107 | class RecogNetWrapper(nn.Module): 108 | def __init__(self, net_recog, pretrained_path=None, input_size=112): 109 | super(RecogNetWrapper, self).__init__() 110 | net = get_model(name=net_recog, fp16=False) 111 | if pretrained_path: 112 | state_dict = torch.load(pretrained_path, map_location='cpu') 113 | net.load_state_dict(state_dict) 114 | print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) 115 | for param in net.parameters(): 116 | param.requires_grad = False 117 | self.net = net 118 | self.preprocess = lambda x: 2 * x - 1 119 | self.input_size=input_size 120 | 121 | def forward(self, image, M): 122 | image = self.preprocess(resize_n_crop(image, M, self.input_size)) 123 | id_feature = F.normalize(self.net(image), dim=-1, p=2) 124 | return id_feature 125 | 126 | 127 | # adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py 128 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 129 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 130 | 'wide_resnet50_2', 'wide_resnet101_2'] 131 | 132 | 133 | model_urls = { 134 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 135 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 136 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 137 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 138 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 139 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 140 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 141 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 142 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 143 | } 144 | 145 | 146 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 147 | """3x3 convolution with padding""" 148 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 149 | padding=dilation, groups=groups, bias=False, dilation=dilation) 150 | 151 | 152 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: 153 | """1x1 convolution""" 154 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) 155 | 156 | 157 | class BasicBlock(nn.Module): 158 | expansion: int = 1 159 | 160 | def __init__( 161 | self, 162 | inplanes: int, 163 | planes: int, 164 | stride: int = 1, 165 | downsample: Optional[nn.Module] = None, 166 | groups: int = 1, 167 | base_width: int = 64, 168 | dilation: int = 1, 169 | norm_layer: Optional[Callable[..., nn.Module]] = None 170 | ) -> None: 171 | super(BasicBlock, self).__init__() 172 | if norm_layer is None: 173 | norm_layer = nn.BatchNorm2d 174 | if groups != 1 or base_width != 64: 175 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 176 | if dilation > 1: 177 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 178 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 179 | self.conv1 = conv3x3(inplanes, planes, stride) 180 | self.bn1 = norm_layer(planes) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.conv2 = conv3x3(planes, planes) 183 | self.bn2 = norm_layer(planes) 184 | self.downsample = downsample 185 | self.stride = stride 186 | 187 | def forward(self, x: Tensor) -> Tensor: 188 | identity = x 189 | 190 | out = self.conv1(x) 191 | out = self.bn1(out) 192 | out = self.relu(out) 193 | 194 | out = self.conv2(out) 195 | out = self.bn2(out) 196 | 197 | if self.downsample is not None: 198 | identity = self.downsample(x) 199 | 200 | out += identity 201 | out = self.relu(out) 202 | 203 | return out 204 | 205 | 206 | class Bottleneck(nn.Module): 207 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 208 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 209 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 210 | # This variant is also known as ResNet V1.5 and improves accuracy according to 211 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 212 | 213 | expansion: int = 4 214 | 215 | def __init__( 216 | self, 217 | inplanes: int, 218 | planes: int, 219 | stride: int = 1, 220 | downsample: Optional[nn.Module] = None, 221 | groups: int = 1, 222 | base_width: int = 64, 223 | dilation: int = 1, 224 | norm_layer: Optional[Callable[..., nn.Module]] = None 225 | ) -> None: 226 | super(Bottleneck, self).__init__() 227 | if norm_layer is None: 228 | norm_layer = nn.BatchNorm2d 229 | width = int(planes * (base_width / 64.)) * groups 230 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 231 | self.conv1 = conv1x1(inplanes, width) 232 | self.bn1 = norm_layer(width) 233 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 234 | self.bn2 = norm_layer(width) 235 | self.conv3 = conv1x1(width, planes * self.expansion) 236 | self.bn3 = norm_layer(planes * self.expansion) 237 | self.relu = nn.ReLU(inplace=True) 238 | self.downsample = downsample 239 | self.stride = stride 240 | 241 | def forward(self, x: Tensor) -> Tensor: 242 | identity = x 243 | 244 | out = self.conv1(x) 245 | out = self.bn1(out) 246 | out = self.relu(out) 247 | 248 | out = self.conv2(out) 249 | out = self.bn2(out) 250 | out = self.relu(out) 251 | 252 | out = self.conv3(out) 253 | out = self.bn3(out) 254 | 255 | if self.downsample is not None: 256 | identity = self.downsample(x) 257 | 258 | out += identity 259 | out = self.relu(out) 260 | 261 | return out 262 | 263 | 264 | class ResNet(nn.Module): 265 | 266 | def __init__( 267 | self, 268 | block: Type[Union[BasicBlock, Bottleneck]], 269 | layers: List[int], 270 | num_classes: int = 1000, 271 | zero_init_residual: bool = False, 272 | use_last_fc: bool = False, 273 | groups: int = 1, 274 | width_per_group: int = 64, 275 | replace_stride_with_dilation: Optional[List[bool]] = None, 276 | norm_layer: Optional[Callable[..., nn.Module]] = None 277 | ) -> None: 278 | super(ResNet, self).__init__() 279 | if norm_layer is None: 280 | norm_layer = nn.BatchNorm2d 281 | self._norm_layer = norm_layer 282 | 283 | self.inplanes = 64 284 | self.dilation = 1 285 | if replace_stride_with_dilation is None: 286 | # each element in the tuple indicates if we should replace 287 | # the 2x2 stride with a dilated convolution instead 288 | replace_stride_with_dilation = [False, False, False] 289 | if len(replace_stride_with_dilation) != 3: 290 | raise ValueError("replace_stride_with_dilation should be None " 291 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 292 | self.use_last_fc = use_last_fc 293 | self.groups = groups 294 | self.base_width = width_per_group 295 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 296 | bias=False) 297 | self.bn1 = norm_layer(self.inplanes) 298 | self.relu = nn.ReLU(inplace=True) 299 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 300 | self.layer1 = self._make_layer(block, 64, layers[0]) 301 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 302 | dilate=replace_stride_with_dilation[0]) 303 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 304 | dilate=replace_stride_with_dilation[1]) 305 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 306 | dilate=replace_stride_with_dilation[2]) 307 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 308 | 309 | if self.use_last_fc: 310 | self.fc = nn.Linear(512 * block.expansion, num_classes) 311 | 312 | for m in self.modules(): 313 | if isinstance(m, nn.Conv2d): 314 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 315 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 316 | nn.init.constant_(m.weight, 1) 317 | nn.init.constant_(m.bias, 0) 318 | 319 | 320 | 321 | # Zero-initialize the last BN in each residual branch, 322 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 323 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 324 | if zero_init_residual: 325 | for m in self.modules(): 326 | if isinstance(m, Bottleneck): 327 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 328 | elif isinstance(m, BasicBlock): 329 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 330 | 331 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 332 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 333 | norm_layer = self._norm_layer 334 | downsample = None 335 | previous_dilation = self.dilation 336 | if dilate: 337 | self.dilation *= stride 338 | stride = 1 339 | if stride != 1 or self.inplanes != planes * block.expansion: 340 | downsample = nn.Sequential( 341 | conv1x1(self.inplanes, planes * block.expansion, stride), 342 | norm_layer(planes * block.expansion), 343 | ) 344 | 345 | layers = [] 346 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 347 | self.base_width, previous_dilation, norm_layer)) 348 | self.inplanes = planes * block.expansion 349 | for _ in range(1, blocks): 350 | layers.append(block(self.inplanes, planes, groups=self.groups, 351 | base_width=self.base_width, dilation=self.dilation, 352 | norm_layer=norm_layer)) 353 | 354 | return nn.Sequential(*layers) 355 | 356 | def _forward_impl(self, x: Tensor) -> Tensor: 357 | # See note [TorchScript super()] 358 | x = self.conv1(x) 359 | x = self.bn1(x) 360 | x = self.relu(x) 361 | x = self.maxpool(x) 362 | 363 | x = self.layer1(x) 364 | x = self.layer2(x) 365 | x = self.layer3(x) 366 | x = self.layer4(x) 367 | 368 | x = self.avgpool(x) 369 | if self.use_last_fc: 370 | x = torch.flatten(x, 1) 371 | x = self.fc(x) 372 | return x 373 | 374 | def forward(self, x: Tensor) -> Tensor: 375 | return self._forward_impl(x) 376 | 377 | 378 | def _resnet( 379 | arch: str, 380 | block: Type[Union[BasicBlock, Bottleneck]], 381 | layers: List[int], 382 | pretrained: bool, 383 | progress: bool, 384 | **kwargs: Any 385 | ) -> ResNet: 386 | model = ResNet(block, layers, **kwargs) 387 | if pretrained: 388 | state_dict = load_state_dict_from_url(model_urls[arch], 389 | progress=progress) 390 | model.load_state_dict(state_dict) 391 | return model 392 | 393 | 394 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 395 | r"""ResNet-18 model from 396 | `"Deep Residual Learning for Image Recognition" `_. 397 | 398 | Args: 399 | pretrained (bool): If True, returns a model pre-trained on ImageNet 400 | progress (bool): If True, displays a progress bar of the download to stderr 401 | """ 402 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 403 | **kwargs) 404 | 405 | 406 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 407 | r"""ResNet-34 model from 408 | `"Deep Residual Learning for Image Recognition" `_. 409 | 410 | Args: 411 | pretrained (bool): If True, returns a model pre-trained on ImageNet 412 | progress (bool): If True, displays a progress bar of the download to stderr 413 | """ 414 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 415 | **kwargs) 416 | 417 | 418 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 419 | r"""ResNet-50 model from 420 | `"Deep Residual Learning for Image Recognition" `_. 421 | 422 | Args: 423 | pretrained (bool): If True, returns a model pre-trained on ImageNet 424 | progress (bool): If True, displays a progress bar of the download to stderr 425 | """ 426 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 427 | **kwargs) 428 | 429 | 430 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 431 | r"""ResNet-101 model from 432 | `"Deep Residual Learning for Image Recognition" `_. 433 | 434 | Args: 435 | pretrained (bool): If True, returns a model pre-trained on ImageNet 436 | progress (bool): If True, displays a progress bar of the download to stderr 437 | """ 438 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 439 | **kwargs) 440 | 441 | 442 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 443 | r"""ResNet-152 model from 444 | `"Deep Residual Learning for Image Recognition" `_. 445 | 446 | Args: 447 | pretrained (bool): If True, returns a model pre-trained on ImageNet 448 | progress (bool): If True, displays a progress bar of the download to stderr 449 | """ 450 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 451 | **kwargs) 452 | 453 | 454 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 455 | r"""ResNeXt-50 32x4d model from 456 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 457 | 458 | Args: 459 | pretrained (bool): If True, returns a model pre-trained on ImageNet 460 | progress (bool): If True, displays a progress bar of the download to stderr 461 | """ 462 | kwargs['groups'] = 32 463 | kwargs['width_per_group'] = 4 464 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 465 | pretrained, progress, **kwargs) 466 | 467 | 468 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 469 | r"""ResNeXt-101 32x8d model from 470 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 471 | 472 | Args: 473 | pretrained (bool): If True, returns a model pre-trained on ImageNet 474 | progress (bool): If True, displays a progress bar of the download to stderr 475 | """ 476 | kwargs['groups'] = 32 477 | kwargs['width_per_group'] = 8 478 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 479 | pretrained, progress, **kwargs) 480 | 481 | 482 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 483 | r"""Wide ResNet-50-2 model from 484 | `"Wide Residual Networks" `_. 485 | 486 | The model is the same as ResNet except for the bottleneck number of channels 487 | which is twice larger in every block. The number of channels in outer 1x1 488 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 489 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 490 | 491 | Args: 492 | pretrained (bool): If True, returns a model pre-trained on ImageNet 493 | progress (bool): If True, displays a progress bar of the download to stderr 494 | """ 495 | kwargs['width_per_group'] = 64 * 2 496 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 497 | pretrained, progress, **kwargs) 498 | 499 | 500 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 501 | r"""Wide ResNet-101-2 model from 502 | `"Wide Residual Networks" `_. 503 | 504 | The model is the same as ResNet except for the bottleneck number of channels 505 | which is twice larger in every block. The number of channels in outer 1x1 506 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 507 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 508 | 509 | Args: 510 | pretrained (bool): If True, returns a model pre-trained on ImageNet 511 | progress (bool): If True, displays a progress bar of the download to stderr 512 | """ 513 | kwargs['width_per_group'] = 64 * 2 514 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 515 | pretrained, progress, **kwargs) 516 | 517 | 518 | func_dict = { 519 | 'resnet18': (resnet18, 512), 520 | 'resnet50': (resnet50, 2048) 521 | } 522 | -------------------------------------------------------------------------------- /models/template_model.py: -------------------------------------------------------------------------------- 1 | """Model class template 2 | 3 | This module provides a template for users to implement custom models. 4 | You can specify '--model template' to use this model. 5 | The class name should be consistent with both the filename and its model option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | It implements a simple image-to-image translation baseline based on regression loss. 9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: 10 | min_ ||netG(data_A) - data_B||_1 11 | You need to implement the following functions: 12 | : Add model-specific options and rewrite default values for existing options. 13 | <__init__>: Initialize this model class. 14 | : Unpack input data and perform data pre-processing. 15 | : Run forward pass. This will be called by both and . 16 | : Update network weights; it will be called in every training iteration. 17 | """ 18 | import numpy as np 19 | import torch 20 | from .base_model import BaseModel 21 | from . import networks 22 | 23 | 24 | class TemplateModel(BaseModel): 25 | @staticmethod 26 | def modify_commandline_options(parser, is_train=True): 27 | """Add new model-specific options and rewrite default values for existing options. 28 | 29 | Parameters: 30 | parser -- the option parser 31 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. 32 | 33 | Returns: 34 | the modified parser. 35 | """ 36 | parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. 37 | if is_train: 38 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. 39 | 40 | return parser 41 | 42 | def __init__(self, opt): 43 | """Initialize this model class. 44 | 45 | Parameters: 46 | opt -- training/test options 47 | 48 | A few things can be done here. 49 | - (required) call the initialization function of BaseModel 50 | - define loss function, visualization images, model names, and optimizers 51 | """ 52 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel 53 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. 54 | self.loss_names = ['loss_G'] 55 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. 56 | self.visual_names = ['data_A', 'data_B', 'output'] 57 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. 58 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. 59 | self.model_names = ['G'] 60 | # define networks; you can use opt.isTrain to specify different behaviors for training and test. 61 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) 62 | if self.isTrain: # only defined during training time 63 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. 64 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) 65 | self.criterionLoss = torch.nn.L1Loss() 66 | # define and initialize optimizers. You can define one optimizer for each network. 67 | # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 68 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizers = [self.optimizer] 70 | 71 | # Our program will automatically call to define schedulers, load networks, and print networks 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input: a dictionary that contains the data itself and its metadata information. 78 | """ 79 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B 80 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A 81 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B 82 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths 83 | 84 | def forward(self): 85 | """Run forward pass. This will be called by both functions and .""" 86 | self.output = self.netG(self.data_A) # generate output image given the input data_A 87 | 88 | def backward(self): 89 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 90 | # caculate the intermediate results if necessary; here self.output has been computed during function 91 | # calculate loss given the input and intermediate results 92 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression 93 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G 94 | 95 | def optimize_parameters(self): 96 | """Update network weights; it will be called in every training iteration.""" 97 | self.forward() # first call forward to calculate intermediate results 98 | self.optimizer.zero_grad() # clear network G's existing gradients 99 | self.backward() # calculate gradients for network G 100 | self.optimizer.step() # update gradients for network G 101 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | """This script contains base options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import argparse 5 | import os 6 | from util import util 7 | import numpy as np 8 | import torch 9 | import models 10 | import data 11 | 12 | 13 | class BaseOptions(): 14 | """This class defines options used during both training and test time. 15 | 16 | It also implements several helper functions such as parsing, printing, and saving the options. 17 | It also gathers additional options defined in functions in both dataset class and model class. 18 | """ 19 | 20 | def __init__(self, cmd_line=None): 21 | """Reset the class; indicates the class hasn't been initailized""" 22 | self.initialized = False 23 | self.cmd_line = None 24 | if cmd_line is not None: 25 | self.cmd_line = cmd_line.split() 26 | 27 | def initialize(self, parser): 28 | """Define the common options that are used in both training and test.""" 29 | # basic parameters 30 | parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') 31 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 32 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 33 | parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') 34 | parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') 35 | parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') 36 | parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') 37 | parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') 38 | parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') 39 | parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') 40 | 41 | # model parameters 42 | parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') 43 | 44 | # additional parameters 45 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 46 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 47 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 48 | 49 | self.initialized = True 50 | return parser 51 | 52 | def gather_options(self): 53 | """Initialize our parser with basic options(only once). 54 | Add additional model-specific and dataset-specific options. 55 | These options are defined in the function 56 | in model and dataset classes. 57 | """ 58 | if not self.initialized: # check if it has been initialized 59 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 60 | parser = self.initialize(parser) 61 | 62 | # get the basic options 63 | if self.cmd_line is None: 64 | opt, _ = parser.parse_known_args() 65 | else: 66 | opt, _ = parser.parse_known_args(self.cmd_line) 67 | 68 | # set cuda visible devices 69 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 70 | 71 | # modify model-related parser options 72 | model_name = opt.model 73 | model_option_setter = models.get_option_setter(model_name) 74 | parser = model_option_setter(parser, self.isTrain) 75 | if self.cmd_line is None: 76 | opt, _ = parser.parse_known_args() # parse again with new defaults 77 | else: 78 | opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults 79 | 80 | # modify dataset-related parser options 81 | if opt.dataset_mode: 82 | dataset_name = opt.dataset_mode 83 | dataset_option_setter = data.get_option_setter(dataset_name) 84 | parser = dataset_option_setter(parser, self.isTrain) 85 | 86 | # save and return the parser 87 | self.parser = parser 88 | if self.cmd_line is None: 89 | return parser.parse_args() 90 | else: 91 | return parser.parse_args(self.cmd_line) 92 | 93 | def print_options(self, opt): 94 | """Print and save options 95 | 96 | It will print both current options and default values(if different). 97 | It will save options into a text file / [checkpoints_dir] / opt.txt 98 | """ 99 | message = '' 100 | message += '----------------- Options ---------------\n' 101 | for k, v in sorted(vars(opt).items()): 102 | comment = '' 103 | default = self.parser.get_default(k) 104 | if v != default: 105 | comment = '\t[default: %s]' % str(default) 106 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 107 | message += '----------------- End -------------------' 108 | print(message) 109 | 110 | # save to the disk 111 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 112 | util.mkdirs(expr_dir) 113 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 114 | try: 115 | with open(file_name, 'wt') as opt_file: 116 | opt_file.write(message) 117 | opt_file.write('\n') 118 | except PermissionError as error: 119 | print("permission error {}".format(error)) 120 | pass 121 | 122 | def parse(self): 123 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 124 | opt = self.gather_options() 125 | opt.isTrain = self.isTrain # train or test 126 | 127 | # process opt.suffix 128 | if opt.suffix: 129 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 130 | opt.name = opt.name + suffix 131 | 132 | 133 | # set gpu ids 134 | str_ids = opt.gpu_ids.split(',') 135 | gpu_ids = [] 136 | for str_id in str_ids: 137 | id = int(str_id) 138 | if id >= 0: 139 | gpu_ids.append(id) 140 | opt.world_size = len(gpu_ids) 141 | # if len(opt.gpu_ids) > 0: 142 | # torch.cuda.set_device(gpu_ids[0]) 143 | if opt.world_size == 1: 144 | opt.use_ddp = False 145 | 146 | if opt.phase != 'test': 147 | # set continue_train automatically 148 | if opt.pretrained_name is None: 149 | model_dir = os.path.join(opt.checkpoints_dir, opt.name) 150 | else: 151 | model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) 152 | if os.path.isdir(model_dir): 153 | model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] 154 | if os.path.isdir(model_dir) and len(model_pths) != 0: 155 | opt.continue_train= True 156 | 157 | # update the latest epoch count 158 | if opt.continue_train: 159 | if opt.epoch == 'latest': 160 | epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] 161 | if len(epoch_counts) != 0: 162 | opt.epoch_count = max(epoch_counts) + 1 163 | else: 164 | opt.epoch_count = int(opt.epoch) + 1 165 | 166 | 167 | self.print_options(opt) 168 | self.opt = opt 169 | return self.opt 170 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the test options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | 6 | 7 | class TestOptions(BaseOptions): 8 | """This class includes test options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) # define shared options 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') 18 | 19 | # Dropout and Batchnorm has different behavior during training and test. 20 | self.isTrain = False 21 | return parser 22 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the training options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | from util import util 6 | 7 | class TrainOptions(BaseOptions): 8 | """This class includes training options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) 15 | # dataset parameters 16 | # for train 17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root') 18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') 19 | parser.add_argument('--batch_size', type=int, default=32) 20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') 21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') 25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') 26 | 27 | # for val 28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') 29 | parser.add_argument('--batch_size_val', type=int, default=32) 30 | 31 | 32 | # visualization parameters 33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') 34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 35 | 36 | # network saving and loading parameters 37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') 40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') 45 | 46 | # training parameters 47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') 48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') 51 | 52 | self.isTrain = True 53 | return parser 54 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """This script is the test script for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | from options.test_options import TestOptions 6 | from data import create_dataset 7 | from models import create_model 8 | from util.visualizer import MyVisualizer 9 | from util.preprocess import align_img 10 | from PIL import Image 11 | import numpy as np 12 | from util.load_mats import load_lm3d 13 | import torch 14 | from data.flist_dataset import default_flist_reader 15 | from scipy.io import loadmat, savemat 16 | 17 | def get_data_path(root='examples'): 18 | 19 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')] 20 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path] 21 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path] 22 | 23 | return im_path, lm_path 24 | 25 | def read_data(im_path, lm_path, lm3d_std, to_tensor=True): 26 | # to RGB 27 | im = Image.open(im_path).convert('RGB') 28 | W,H = im.size 29 | lm = np.loadtxt(lm_path).astype(np.float32) 30 | lm = lm.reshape([-1, 2]) 31 | lm[:, -1] = H - 1 - lm[:, -1] 32 | _, im, lm, _ = align_img(im, lm, lm3d_std) 33 | if to_tensor: 34 | im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) 35 | lm = torch.tensor(lm).unsqueeze(0) 36 | return im, lm 37 | 38 | def main(rank, opt, name='examples'): 39 | device = torch.device(rank) 40 | torch.cuda.set_device(device) 41 | model = create_model(opt) 42 | model.setup(opt) 43 | model.device = device 44 | model.parallelize() 45 | model.eval() 46 | visualizer = MyVisualizer(opt) 47 | 48 | im_path, lm_path = get_data_path(name) 49 | lm3d_std = load_lm3d(opt.bfm_folder) 50 | 51 | for i in range(len(im_path)): 52 | print(i, im_path[i]) 53 | img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','') 54 | if not os.path.isfile(lm_path[i]): 55 | print("%s is not found !!!"%lm_path[i]) 56 | continue 57 | im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std) 58 | data = { 59 | 'imgs': im_tensor, 60 | 'lms': lm_tensor 61 | } 62 | model.set_input(data) # unpack data from data loader 63 | model.test() # run inference 64 | visuals = model.get_current_visuals() # get image results 65 | visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1], 66 | save_results=True, count=i, name=img_name, add_image=False) 67 | 68 | model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.obj')) # save reconstruction meshes 69 | model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.mat')) # save predicted coefficients 70 | 71 | if __name__ == '__main__': 72 | opt = TestOptions().parse() # get test options 73 | main(0, opt,opt.img_folder) 74 | 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """This script is the training script for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | import time 6 | import numpy as np 7 | import torch 8 | from options.train_options import TrainOptions 9 | from data import create_dataset 10 | from models import create_model 11 | from util.visualizer import MyVisualizer 12 | from util.util import genvalconf 13 | import torch.multiprocessing as mp 14 | import torch.distributed as dist 15 | 16 | 17 | def setup(rank, world_size, port): 18 | os.environ['MASTER_ADDR'] = 'localhost' 19 | os.environ['MASTER_PORT'] = port 20 | 21 | # initialize the process group 22 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 23 | 24 | def cleanup(): 25 | dist.destroy_process_group() 26 | 27 | def main(rank, world_size, train_opt): 28 | val_opt = genvalconf(train_opt, isTrain=False) 29 | 30 | device = torch.device(rank) 31 | torch.cuda.set_device(device) 32 | use_ddp = train_opt.use_ddp 33 | 34 | if use_ddp: 35 | setup(rank, world_size, train_opt.ddp_port) 36 | 37 | train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank) 38 | train_dataset_batches, val_dataset_batches = \ 39 | len(train_dataset) // train_opt.batch_size, len(val_dataset) // val_opt.batch_size 40 | 41 | model = create_model(train_opt) # create a model given train_opt.model and other options 42 | model.setup(train_opt) 43 | model.device = device 44 | model.parallelize() 45 | 46 | if rank == 0: 47 | print('The batch number of training images = %d\n, \ 48 | the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches)) 49 | model.print_networks(train_opt.verbose) 50 | visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots 51 | 52 | total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations 53 | t_data = 0 54 | t_val = 0 55 | optimize_time = 0.1 56 | batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size 57 | 58 | if use_ddp: 59 | dist.barrier() 60 | 61 | times = [] 62 | for epoch in range(train_opt.epoch_count, train_opt.n_epochs + 1): # outer loop for different epochs; we save the model by , + 63 | epoch_start_time = time.time() # timer for entire epoch 64 | iter_data_time = time.time() # timer for train_data loading per iteration 65 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 66 | 67 | train_dataset.set_epoch(epoch) 68 | for i, train_data in enumerate(train_dataset): # inner loop within one epoch 69 | iter_start_time = time.time() # timer for computation per iteration 70 | if total_iters % train_opt.print_freq == 0: 71 | t_data = iter_start_time - iter_data_time 72 | total_iters += batch_size 73 | epoch_iter += batch_size 74 | 75 | torch.cuda.synchronize() 76 | optimize_start_time = time.time() 77 | 78 | model.set_input(train_data) # unpack train_data from dataset and apply preprocessing 79 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 80 | 81 | torch.cuda.synchronize() 82 | optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time 83 | 84 | if use_ddp: 85 | dist.barrier() 86 | 87 | if rank == 0 and (total_iters == batch_size or total_iters % train_opt.display_freq == 0): # display images on visdom and save images to a HTML file 88 | model.compute_visuals() 89 | visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch, 90 | save_results=True, 91 | add_image=train_opt.add_image) 92 | # (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0) 93 | 94 | if rank == 0 and (total_iters == batch_size or total_iters % train_opt.print_freq == 0): # print training losses and save logging information to the disk 95 | losses = model.get_current_losses() 96 | visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data) 97 | visualizer.plot_current_losses(total_iters, losses) 98 | 99 | if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0: 100 | with torch.no_grad(): 101 | torch.cuda.synchronize() 102 | val_start_time = time.time() 103 | losses_avg = {} 104 | model.eval() 105 | for j, val_data in enumerate(val_dataset): 106 | model.set_input(val_data) 107 | model.optimize_parameters(isTrain=False) 108 | if rank == 0 and j < train_opt.vis_batch_nums: 109 | model.compute_visuals() 110 | visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch, 111 | dataset='val', save_results=True, count=j * val_opt.batch_size, 112 | add_image=train_opt.add_image) 113 | 114 | if j < train_opt.eval_batch_nums: 115 | losses = model.get_current_losses() 116 | for key, value in losses.items(): 117 | losses_avg[key] = losses_avg.get(key, 0) + value 118 | 119 | for key, value in losses_avg.items(): 120 | losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches) 121 | 122 | torch.cuda.synchronize() 123 | eval_time = time.time() - val_start_time 124 | 125 | if rank == 0: 126 | visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results 127 | visualizer.plot_current_losses(total_iters, losses_avg, dataset='val') 128 | model.train() 129 | 130 | if use_ddp: 131 | dist.barrier() 132 | 133 | if rank == 0 and (total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0): # cache our latest model every iterations 134 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 135 | print(train_opt.name) # it's useful to occasionally show the experiment name on console 136 | save_suffix = 'iter_%d' % total_iters if train_opt.save_by_iter else 'latest' 137 | model.save_networks(save_suffix) 138 | 139 | if use_ddp: 140 | dist.barrier() 141 | 142 | iter_data_time = time.time() 143 | 144 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.n_epochs, time.time() - epoch_start_time)) 145 | model.update_learning_rate() # update learning rates at the end of every epoch. 146 | 147 | if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every epochs 148 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 149 | model.save_networks('latest') 150 | model.save_networks(epoch) 151 | 152 | if use_ddp: 153 | dist.barrier() 154 | 155 | if __name__ == '__main__': 156 | 157 | import warnings 158 | warnings.filterwarnings("ignore") 159 | 160 | train_opt = TrainOptions().parse() # get training options 161 | world_size = train_opt.world_size 162 | 163 | if train_opt.use_ddp: 164 | mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True) 165 | else: 166 | main(0, world_size, train_opt) 167 | -------------------------------------------------------------------------------- /util/BBRegressorParam_r.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/util/BBRegressorParam_r.mat -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from util import * 3 | -------------------------------------------------------------------------------- /util/detect_lm68.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from scipy.io import loadmat 5 | import tensorflow as tf 6 | from util.preprocess import align_for_lm 7 | from shutil import move 8 | 9 | mean_face = np.loadtxt('util/test_mean_face.txt') 10 | mean_face = mean_face.reshape([68, 2]) 11 | 12 | def save_label(labels, save_path): 13 | np.savetxt(save_path, labels) 14 | 15 | def draw_landmarks(img, landmark, save_name): 16 | landmark = landmark 17 | lm_img = np.zeros([img.shape[0], img.shape[1], 3]) 18 | lm_img[:] = img.astype(np.float32) 19 | landmark = np.round(landmark).astype(np.int32) 20 | 21 | for i in range(len(landmark)): 22 | for j in range(-1, 1): 23 | for k in range(-1, 1): 24 | if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ 25 | img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ 26 | landmark[i, 0]+k > 0 and \ 27 | landmark[i, 0]+k < img.shape[1]: 28 | lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, 29 | :] = np.array([0, 0, 255]) 30 | lm_img = lm_img.astype(np.uint8) 31 | 32 | cv2.imwrite(save_name, lm_img) 33 | 34 | 35 | def load_data(img_name, txt_name): 36 | return cv2.imread(img_name), np.loadtxt(txt_name) 37 | 38 | # create tensorflow graph for landmark detector 39 | def load_lm_graph(graph_filename): 40 | with tf.gfile.GFile(graph_filename, 'rb') as f: 41 | graph_def = tf.GraphDef() 42 | graph_def.ParseFromString(f.read()) 43 | 44 | with tf.Graph().as_default() as graph: 45 | tf.import_graph_def(graph_def, name='net') 46 | img_224 = graph.get_tensor_by_name('net/input_imgs:0') 47 | output_lm = graph.get_tensor_by_name('net/lm:0') 48 | lm_sess = tf.Session(graph=graph) 49 | 50 | return lm_sess,img_224,output_lm 51 | 52 | # landmark detection 53 | def detect_68p(img_path,sess,input_op,output_op): 54 | print('detecting landmarks......') 55 | names = [i for i in sorted(os.listdir( 56 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] 57 | vis_path = os.path.join(img_path, 'vis') 58 | remove_path = os.path.join(img_path, 'remove') 59 | save_path = os.path.join(img_path, 'landmarks') 60 | if not os.path.isdir(vis_path): 61 | os.makedirs(vis_path) 62 | if not os.path.isdir(remove_path): 63 | os.makedirs(remove_path) 64 | if not os.path.isdir(save_path): 65 | os.makedirs(save_path) 66 | 67 | for i in range(0, len(names)): 68 | name = names[i] 69 | print('%05d' % (i), ' ', name) 70 | full_image_name = os.path.join(img_path, name) 71 | txt_name = '.'.join(name.split('.')[:-1]) + '.txt' 72 | full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image 73 | 74 | # if an image does not have detected 5 facial landmarks, remove it from the training list 75 | if not os.path.isfile(full_txt_name): 76 | move(full_image_name, os.path.join(remove_path, name)) 77 | continue 78 | 79 | # load data 80 | img, five_points = load_data(full_image_name, full_txt_name) 81 | input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection 82 | 83 | # if the alignment fails, remove corresponding image from the training list 84 | if scale == 0: 85 | move(full_txt_name, os.path.join( 86 | remove_path, txt_name)) 87 | move(full_image_name, os.path.join(remove_path, name)) 88 | continue 89 | 90 | # detect landmarks 91 | input_img = np.reshape( 92 | input_img, [1, 224, 224, 3]).astype(np.float32) 93 | landmark = sess.run( 94 | output_op, feed_dict={input_op: input_img}) 95 | 96 | # transform back to original image coordinate 97 | landmark = landmark.reshape([68, 2]) + mean_face 98 | landmark[:, 1] = 223 - landmark[:, 1] 99 | landmark = landmark / scale 100 | landmark[:, 0] = landmark[:, 0] + bbox[0] 101 | landmark[:, 1] = landmark[:, 1] + bbox[1] 102 | landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] 103 | 104 | if i % 100 == 0: 105 | draw_landmarks(img, landmark, os.path.join(vis_path, name)) 106 | save_label(landmark, os.path.join(save_path, txt_name)) 107 | -------------------------------------------------------------------------------- /util/generate_list.py: -------------------------------------------------------------------------------- 1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | 6 | # save path to training data 7 | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): 8 | save_path = os.path.join(save_folder, mode) 9 | if not os.path.isdir(save_path): 10 | os.makedirs(save_path) 11 | with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: 12 | fd.writelines([i + '\n' for i in lms_list]) 13 | 14 | with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: 15 | fd.writelines([i + '\n' for i in imgs_list]) 16 | 17 | with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: 18 | fd.writelines([i + '\n' for i in msks_list]) 19 | 20 | # check if the path is valid 21 | def check_list(rlms_list, rimgs_list, rmsks_list): 22 | lms_list, imgs_list, msks_list = [], [], [] 23 | for i in range(len(rlms_list)): 24 | flag = 'false' 25 | lm_path = rlms_list[i] 26 | im_path = rimgs_list[i] 27 | msk_path = rmsks_list[i] 28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): 29 | flag = 'true' 30 | lms_list.append(rlms_list[i]) 31 | imgs_list.append(rimgs_list[i]) 32 | msks_list.append(rmsks_list[i]) 33 | print(i, rlms_list[i], flag) 34 | return lms_list, imgs_list, msks_list 35 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/load_mats.py: -------------------------------------------------------------------------------- 1 | """This script is to load 3D face model for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from scipy.io import loadmat, savemat 7 | from array import array 8 | import os.path as osp 9 | 10 | # load expression basis 11 | def LoadExpBasis(bfm_folder='BFM'): 12 | n_vertex = 53215 13 | Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') 14 | exp_dim = array('i') 15 | exp_dim.fromfile(Expbin, 1) 16 | expMU = array('f') 17 | expPC = array('f') 18 | expMU.fromfile(Expbin, 3*n_vertex) 19 | expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) 20 | Expbin.close() 21 | 22 | expPC = np.array(expPC) 23 | expPC = np.reshape(expPC, [exp_dim[0], -1]) 24 | expPC = np.transpose(expPC) 25 | 26 | expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) 27 | 28 | return expPC, expEV 29 | 30 | 31 | # transfer original BFM09 to our face model 32 | def transferBFM09(bfm_folder='BFM'): 33 | print('Transfer BFM09 to BFM_model_front......') 34 | original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) 35 | shapePC = original_BFM['shapePC'] # shape basis 36 | shapeEV = original_BFM['shapeEV'] # corresponding eigen value 37 | shapeMU = original_BFM['shapeMU'] # mean face 38 | texPC = original_BFM['texPC'] # texture basis 39 | texEV = original_BFM['texEV'] # eigen value 40 | texMU = original_BFM['texMU'] # mean texture 41 | 42 | expPC, expEV = LoadExpBasis() 43 | 44 | # transfer BFM09 to our face model 45 | 46 | idBase = shapePC*np.reshape(shapeEV, [-1, 199]) 47 | idBase = idBase/1e5 # unify the scale to decimeter 48 | idBase = idBase[:, :80] # use only first 80 basis 49 | 50 | exBase = expPC*np.reshape(expEV, [-1, 79]) 51 | exBase = exBase/1e5 # unify the scale to decimeter 52 | exBase = exBase[:, :64] # use only first 64 basis 53 | 54 | texBase = texPC*np.reshape(texEV, [-1, 199]) 55 | texBase = texBase[:, :80] # use only first 80 basis 56 | 57 | # our face model is cropped along face landmarks and contains only 35709 vertex. 58 | # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. 59 | # thus we select corresponding vertex to get our face model. 60 | 61 | index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) 62 | index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) 63 | 64 | index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) 65 | index_shape = index_shape['trimIndex'].astype( 66 | np.int32) - 1 # starts from 0 (to 53490) 67 | index_shape = index_shape[index_exp] 68 | 69 | idBase = np.reshape(idBase, [-1, 3, 80]) 70 | idBase = idBase[index_shape, :, :] 71 | idBase = np.reshape(idBase, [-1, 80]) 72 | 73 | texBase = np.reshape(texBase, [-1, 3, 80]) 74 | texBase = texBase[index_shape, :, :] 75 | texBase = np.reshape(texBase, [-1, 80]) 76 | 77 | exBase = np.reshape(exBase, [-1, 3, 64]) 78 | exBase = exBase[index_exp, :, :] 79 | exBase = np.reshape(exBase, [-1, 64]) 80 | 81 | meanshape = np.reshape(shapeMU, [-1, 3])/1e5 82 | meanshape = meanshape[index_shape, :] 83 | meanshape = np.reshape(meanshape, [1, -1]) 84 | 85 | meantex = np.reshape(texMU, [-1, 3]) 86 | meantex = meantex[index_shape, :] 87 | meantex = np.reshape(meantex, [1, -1]) 88 | 89 | # other info contains triangles, region used for computing photometric loss, 90 | # region used for skin texture regularization, and 68 landmarks index etc. 91 | other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) 92 | frontmask2_idx = other_info['frontmask2_idx'] 93 | skinmask = other_info['skinmask'] 94 | keypoints = other_info['keypoints'] 95 | point_buf = other_info['point_buf'] 96 | tri = other_info['tri'] 97 | tri_mask2 = other_info['tri_mask2'] 98 | 99 | # save our face model 100 | savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, 101 | 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) 102 | 103 | 104 | # load landmarks for standard face, which is used for image preprocessing 105 | def load_lm3d(bfm_folder): 106 | 107 | Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) 108 | Lm3D = Lm3D['lm'] 109 | 110 | # calculate 5 facial landmarks using 68 landmarks 111 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 112 | Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( 113 | Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) 114 | Lm3D = Lm3D[[1, 2, 0, 3, 4], :] 115 | 116 | return Lm3D 117 | 118 | -------------------------------------------------------------------------------- /util/nvdiffrast.py: -------------------------------------------------------------------------------- 1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch 2 | Attention, antialiasing step is missing in current version. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import kornia 8 | from kornia.geometry.camera import pixel2cam 9 | import numpy as np 10 | from typing import List 11 | import nvdiffrast.torch as dr 12 | from scipy.io import loadmat 13 | from torch import nn 14 | 15 | def ndc_projection(x=0.1, n=1.0, f=50.0): 16 | return np.array([[n/x, 0, 0, 0], 17 | [ 0, n/-x, 0, 0], 18 | [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], 19 | [ 0, 0, -1, 0]]).astype(np.float32) 20 | 21 | class MeshRenderer(nn.Module): 22 | def __init__(self, 23 | rasterize_fov, 24 | znear=0.1, 25 | zfar=10, 26 | rasterize_size=224, 27 | use_opengl=True): 28 | super(MeshRenderer, self).__init__() 29 | 30 | x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear 31 | self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( 32 | torch.diag(torch.tensor([1., -1, -1, 1]))) 33 | self.rasterize_size = rasterize_size 34 | self.use_opengl = use_opengl 35 | self.ctx = None 36 | 37 | def forward(self, vertex, tri, feat=None): 38 | """ 39 | Return: 40 | mask -- torch.tensor, size (B, 1, H, W) 41 | depth -- torch.tensor, size (B, 1, H, W) 42 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None 43 | 44 | Parameters: 45 | vertex -- torch.tensor, size (B, N, 3) 46 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles 47 | feat(optional) -- torch.tensor, size (B, C), features 48 | """ 49 | device = vertex.device 50 | rsize = int(self.rasterize_size) 51 | ndc_proj = self.ndc_proj.to(device) 52 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v 53 | if vertex.shape[-1] == 3: 54 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) 55 | vertex[..., 1] = -vertex[..., 1] 56 | 57 | 58 | vertex_ndc = vertex @ ndc_proj.t() 59 | if self.ctx is None: 60 | if self.use_opengl: 61 | self.ctx = dr.RasterizeGLContext(device=device) 62 | ctx_str = "opengl" 63 | else: 64 | self.ctx = dr.RasterizeCudaContext(device=device) 65 | ctx_str = "cuda" 66 | print("create %s ctx on device cuda:%d"%(ctx_str, device.index)) 67 | 68 | ranges = None 69 | if isinstance(tri, List) or len(tri.shape) == 3: 70 | vum = vertex_ndc.shape[1] 71 | fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) 72 | fstartidx = torch.cumsum(fnum, dim=0) - fnum 73 | ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() 74 | for i in range(tri.shape[0]): 75 | tri[i] = tri[i] + i*vum 76 | vertex_ndc = torch.cat(vertex_ndc, dim=0) 77 | tri = torch.cat(tri, dim=0) 78 | 79 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] 80 | tri = tri.type(torch.int32).contiguous() 81 | rast_out, _ = dr.rasterize(self.ctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges) 82 | 83 | depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri) 84 | depth = depth.permute(0, 3, 1, 2) 85 | mask = (rast_out[..., 3] > 0).float().unsqueeze(1) 86 | depth = mask * depth 87 | 88 | 89 | image = None 90 | if feat is not None: 91 | image, _ = dr.interpolate(feat, rast_out, tri) 92 | image = image.permute(0, 3, 1, 2) 93 | image = mask * image 94 | 95 | return mask, depth, image 96 | 97 | -------------------------------------------------------------------------------- /util/preprocess.py: -------------------------------------------------------------------------------- 1 | """This script contains the image preprocessing code for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | from scipy.io import loadmat 6 | 7 | try: 8 | from PIL.Image import Resampling 9 | RESAMPLING_METHOD = Resampling.BICUBIC 10 | except ImportError: 11 | from PIL.Image import BICUBIC 12 | RESAMPLING_METHOD = BICUBIC 13 | 14 | import cv2 15 | import os 16 | from skimage import transform as trans 17 | import torch 18 | import warnings 19 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 20 | warnings.filterwarnings("ignore", category=FutureWarning) 21 | 22 | 23 | # calculating least square problem for image alignment 24 | def POS(xp, x): 25 | npts = xp.shape[1] 26 | 27 | A = np.zeros([2*npts, 8]) 28 | 29 | A[0:2*npts-1:2, 0:3] = x.transpose() 30 | A[0:2*npts-1:2, 3] = 1 31 | 32 | A[1:2*npts:2, 4:7] = x.transpose() 33 | A[1:2*npts:2, 7] = 1 34 | 35 | b = np.reshape(xp.transpose(), [2*npts, 1]) 36 | 37 | k, _, _, _ = np.linalg.lstsq(A, b) 38 | 39 | R1 = k[0:3] 40 | R2 = k[4:7] 41 | sTx = k[3] 42 | sTy = k[7] 43 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 44 | t = np.stack([sTx, sTy], axis=0) 45 | 46 | return t, s 47 | 48 | # bounding box for 68 landmark detection 49 | def BBRegression(points, params): 50 | 51 | w1 = params['W1'] 52 | b1 = params['B1'] 53 | w2 = params['W2'] 54 | b2 = params['B2'] 55 | data = points.copy() 56 | data = data.reshape([5, 2]) 57 | data_mean = np.mean(data, axis=0) 58 | x_mean = data_mean[0] 59 | y_mean = data_mean[1] 60 | data[:, 0] = data[:, 0] - x_mean 61 | data[:, 1] = data[:, 1] - y_mean 62 | 63 | rms = np.sqrt(np.sum(data ** 2)/5) 64 | data = data / rms 65 | data = data.reshape([1, 10]) 66 | data = np.transpose(data) 67 | inputs = np.matmul(w1, data) + b1 68 | inputs = 2 / (1 + np.exp(-2 * inputs)) - 1 69 | inputs = np.matmul(w2, inputs) + b2 70 | inputs = np.transpose(inputs) 71 | x = inputs[:, 0] * rms + x_mean 72 | y = inputs[:, 1] * rms + y_mean 73 | w = 224/inputs[:, 2] * rms 74 | rects = [x, y, w, w] 75 | return np.array(rects).reshape([4]) 76 | 77 | # utils for landmark detection 78 | def img_padding(img, box): 79 | success = True 80 | bbox = box.copy() 81 | res = np.zeros([2*img.shape[0], 2*img.shape[1], 3]) 82 | res[img.shape[0] // 2: img.shape[0] + img.shape[0] // 83 | 2, img.shape[1] // 2: img.shape[1] + img.shape[1]//2] = img 84 | 85 | bbox[0] = bbox[0] + img.shape[1] // 2 86 | bbox[1] = bbox[1] + img.shape[0] // 2 87 | if bbox[0] < 0 or bbox[1] < 0: 88 | success = False 89 | return res, bbox, success 90 | 91 | # utils for landmark detection 92 | def crop(img, bbox): 93 | padded_img, padded_bbox, flag = img_padding(img, bbox) 94 | if flag: 95 | crop_img = padded_img[padded_bbox[1]: padded_bbox[1] + 96 | padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]] 97 | crop_img = cv2.resize(crop_img.astype(np.uint8), 98 | (224, 224), interpolation=cv2.INTER_CUBIC) 99 | scale = 224 / padded_bbox[3] 100 | return crop_img, scale 101 | else: 102 | return padded_img, 0 103 | 104 | # utils for landmark detection 105 | def scale_trans(img, lm, t, s): 106 | imgw = img.shape[1] 107 | imgh = img.shape[0] 108 | M_s = np.array([[1, 0, -t[0] + imgw//2 + 0.5], [0, 1, -imgh//2 + t[1]]], 109 | dtype=np.float32) 110 | img = cv2.warpAffine(img, M_s, (imgw, imgh)) 111 | w = int(imgw / s * 100) 112 | h = int(imgh / s * 100) 113 | img = cv2.resize(img, (w, h)) 114 | lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] - 115 | t[1] + imgh // 2], axis=1) / s * 100 116 | 117 | left = w//2 - 112 118 | up = h//2 - 112 119 | bbox = [left, up, 224, 224] 120 | cropped_img, scale2 = crop(img, bbox) 121 | assert(scale2!=0) 122 | t1 = np.array([bbox[0], bbox[1]]) 123 | 124 | # back to raw img s * crop + s * t1 + t2 125 | t1 = np.array([w//2 - 112, h//2 - 112]) 126 | scale = s / 100 127 | t2 = np.array([t[0] - imgw/2, t[1] - imgh / 2]) 128 | inv = (scale/scale2, scale * t1 + t2.reshape([2])) 129 | return cropped_img, inv 130 | 131 | # utils for landmark detection 132 | def align_for_lm(img, five_points): 133 | five_points = np.array(five_points).reshape([1, 10]) 134 | params = loadmat('util/BBRegressorParam_r.mat') 135 | bbox = BBRegression(five_points, params) 136 | assert(bbox[2] != 0) 137 | bbox = np.round(bbox).astype(np.int32) 138 | crop_img, scale = crop(img, bbox) 139 | return crop_img, scale, bbox 140 | 141 | 142 | # resize and crop images for face reconstruction 143 | def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): 144 | w0, h0 = img.size 145 | w = (w0*s).astype(np.int32) 146 | h = (h0*s).astype(np.int32) 147 | left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) 148 | right = left + target_size 149 | up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) 150 | below = up + target_size 151 | 152 | img = img.resize((w, h), resample=RESAMPLING_METHOD) 153 | img = img.crop((left, up, right, below)) 154 | 155 | if mask is not None: 156 | mask = mask.resize((w, h), resample=RESAMPLING_METHOD) 157 | mask = mask.crop((left, up, right, below)) 158 | 159 | lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - 160 | t[1] + h0/2], axis=1)*s 161 | lm = lm - np.reshape( 162 | np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) 163 | 164 | return img, lm, mask 165 | 166 | # utils for face reconstruction 167 | def extract_5p(lm): 168 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 169 | lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( 170 | lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) 171 | lm5p = lm5p[[1, 2, 0, 3, 4], :] 172 | return lm5p 173 | 174 | # utils for face reconstruction 175 | def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): 176 | """ 177 | Return: 178 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty) 179 | img_new --PIL.Image (target_size, target_size, 3) 180 | lm_new --numpy.array (68, 2), y direction is opposite to v direction 181 | mask_new --PIL.Image (target_size, target_size) 182 | 183 | Parameters: 184 | img --PIL.Image (raw_H, raw_W, 3) 185 | lm --numpy.array (68, 2), y direction is opposite to v direction 186 | lm3D --numpy.array (5, 3) 187 | mask --PIL.Image (raw_H, raw_W, 3) 188 | """ 189 | 190 | w0, h0 = img.size 191 | if lm.shape[0] != 5: 192 | lm5p = extract_5p(lm) 193 | else: 194 | lm5p = lm 195 | 196 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face 197 | t, s = POS(lm5p.transpose(), lm3D.transpose()) 198 | s = rescale_factor/s 199 | 200 | # processing the image 201 | img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) 202 | trans_params = np.array([w0, h0, s, t[0], t[1]]) 203 | 204 | return trans_params, img_new, lm_new, mask_new 205 | 206 | # utils for face recognition model 207 | def estimate_norm(lm_68p, H): 208 | # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68 209 | """ 210 | Return: 211 | trans_m --numpy.array (2, 3) 212 | Parameters: 213 | lm --numpy.array (68, 2), y direction is opposite to v direction 214 | H --int/float , image height 215 | """ 216 | lm = extract_5p(lm_68p) 217 | lm[:, -1] = H - 1 - lm[:, -1] 218 | tform = trans.SimilarityTransform() 219 | src = np.array( 220 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 221 | [41.5493, 92.3655], [70.7299, 92.2041]], 222 | dtype=np.float32) 223 | tform.estimate(lm, src) 224 | M = tform.params 225 | if np.linalg.det(M) == 0: 226 | M = np.eye(3) 227 | 228 | return M[0:2, :] 229 | 230 | def estimate_norm_torch(lm_68p, H): 231 | lm_68p_ = lm_68p.detach().cpu().numpy() 232 | M = [] 233 | for i in range(lm_68p_.shape[0]): 234 | M.append(estimate_norm(lm_68p_[i], H)) 235 | M = torch.tensor(np.array(M), dtype=torch.float32).to(lm_68p.device) 236 | return M 237 | -------------------------------------------------------------------------------- /util/skin_mask.py: -------------------------------------------------------------------------------- 1 | """This script is to generate skin attention mask for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import math 5 | import numpy as np 6 | import os 7 | import cv2 8 | 9 | class GMM: 10 | def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): 11 | self.dim = dim # feature dimension 12 | self.num = num # number of Gaussian components 13 | self.w = w # weights of Gaussian components (a list of scalars) 14 | self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) 15 | self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) 16 | self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) 17 | self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) 18 | 19 | self.factor = [0]*num 20 | for i in range(self.num): 21 | self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 22 | 23 | def likelihood(self, data): 24 | assert(data.shape[1] == self.dim) 25 | N = data.shape[0] 26 | lh = np.zeros(N) 27 | 28 | for i in range(self.num): 29 | data_ = data - self.mu[i] 30 | 31 | tmp = np.matmul(data_,self.cov_inv[i]) * data_ 32 | tmp = np.sum(tmp,axis=1) 33 | power = -0.5 * tmp 34 | 35 | p = np.array([math.exp(power[j]) for j in range(N)]) 36 | p = p/self.factor[i] 37 | lh += p*self.w[i] 38 | 39 | return lh 40 | 41 | 42 | def _rgb2ycbcr(rgb): 43 | m = np.array([[65.481, 128.553, 24.966], 44 | [-37.797, -74.203, 112], 45 | [112, -93.786, -18.214]]) 46 | shape = rgb.shape 47 | rgb = rgb.reshape((shape[0] * shape[1], 3)) 48 | ycbcr = np.dot(rgb, m.transpose() / 255.) 49 | ycbcr[:, 0] += 16. 50 | ycbcr[:, 1:] += 128. 51 | return ycbcr.reshape(shape) 52 | 53 | 54 | def _bgr2ycbcr(bgr): 55 | rgb = bgr[..., ::-1] 56 | return _rgb2ycbcr(rgb) 57 | 58 | 59 | gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] 60 | gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), 61 | np.array([150.19858, 105.18467, 155.51428]), 62 | np.array([183.92976, 107.62468, 152.71820]), 63 | np.array([114.90524, 113.59782, 151.38217])] 64 | gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] 65 | gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), 66 | np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), 67 | np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), 68 | np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] 69 | 70 | gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) 71 | 72 | gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] 73 | gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), 74 | np.array([110.91392, 125.52969, 130.19237]), 75 | np.array([129.75864, 129.96107, 126.96808]), 76 | np.array([112.29587, 128.85121, 129.05431])] 77 | gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] 78 | gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), 79 | np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), 80 | np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), 81 | np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] 82 | 83 | gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) 84 | 85 | prior_skin = 0.8 86 | prior_nonskin = 1 - prior_skin 87 | 88 | 89 | # calculate skin attention mask 90 | def skinmask(imbgr): 91 | im = _bgr2ycbcr(imbgr) 92 | 93 | data = im.reshape((-1,3)) 94 | 95 | lh_skin = gmm_skin.likelihood(data) 96 | lh_nonskin = gmm_nonskin.likelihood(data) 97 | 98 | tmp1 = prior_skin * lh_skin 99 | tmp2 = prior_nonskin * lh_nonskin 100 | post_skin = tmp1 / (tmp1+tmp2) # posterior probability 101 | 102 | post_skin = post_skin.reshape((im.shape[0],im.shape[1])) 103 | 104 | post_skin = np.round(post_skin*255) 105 | post_skin = post_skin.astype(np.uint8) 106 | post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 107 | 108 | return post_skin 109 | 110 | 111 | def get_skin_mask(img_path): 112 | print('generating skin masks......') 113 | names = [i for i in sorted(os.listdir( 114 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] 115 | save_path = os.path.join(img_path, 'mask') 116 | if not os.path.isdir(save_path): 117 | os.makedirs(save_path) 118 | 119 | for i in range(0, len(names)): 120 | name = names[i] 121 | print('%05d' % (i), ' ', name) 122 | full_image_name = os.path.join(img_path, name) 123 | img = cv2.imread(full_image_name).astype(np.float32) 124 | skin_img = skinmask(img) 125 | cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) 126 | -------------------------------------------------------------------------------- /util/test_mean_face.txt: -------------------------------------------------------------------------------- 1 | -5.228591537475585938e+01 2 | 2.078247070312500000e-01 3 | -5.064269638061523438e+01 4 | -1.315765380859375000e+01 5 | -4.952939224243164062e+01 6 | -2.592591094970703125e+01 7 | -4.793047332763671875e+01 8 | -3.832135772705078125e+01 9 | -4.512159729003906250e+01 10 | -5.059623336791992188e+01 11 | -3.917720794677734375e+01 12 | -6.043736648559570312e+01 13 | -2.929953765869140625e+01 14 | -6.861183166503906250e+01 15 | -1.719801330566406250e+01 16 | -7.572736358642578125e+01 17 | -1.961936950683593750e+00 18 | -7.862001037597656250e+01 19 | 1.467941284179687500e+01 20 | -7.607844543457031250e+01 21 | 2.744073486328125000e+01 22 | -6.915261840820312500e+01 23 | 3.855677795410156250e+01 24 | -5.950350570678710938e+01 25 | 4.478240966796875000e+01 26 | -4.867547225952148438e+01 27 | 4.714337158203125000e+01 28 | -3.800830078125000000e+01 29 | 4.940315246582031250e+01 30 | -2.496297454833984375e+01 31 | 5.117234802246093750e+01 32 | -1.241538238525390625e+01 33 | 5.190507507324218750e+01 34 | 8.244247436523437500e-01 35 | -4.150688934326171875e+01 36 | 2.386329650878906250e+01 37 | -3.570307159423828125e+01 38 | 3.017010498046875000e+01 39 | -2.790358734130859375e+01 40 | 3.212951660156250000e+01 41 | -1.941773223876953125e+01 42 | 3.156523132324218750e+01 43 | -1.138106536865234375e+01 44 | 2.841992187500000000e+01 45 | 5.993263244628906250e+00 46 | 2.895182800292968750e+01 47 | 1.343590545654296875e+01 48 | 3.189880371093750000e+01 49 | 2.203153991699218750e+01 50 | 3.302221679687500000e+01 51 | 2.992478942871093750e+01 52 | 3.099150085449218750e+01 53 | 3.628388977050781250e+01 54 | 2.765748596191406250e+01 55 | -1.933914184570312500e+00 56 | 1.405374145507812500e+01 57 | -2.153038024902343750e+00 58 | 5.772636413574218750e+00 59 | -2.270050048828125000e+00 60 | -2.121643066406250000e+00 61 | -2.218330383300781250e+00 62 | -1.068978118896484375e+01 63 | -1.187252044677734375e+01 64 | -1.997912597656250000e+01 65 | -6.879402160644531250e+00 66 | -2.143579864501953125e+01 67 | -1.227821350097656250e+00 68 | -2.193494415283203125e+01 69 | 4.623237609863281250e+00 70 | -2.152721405029296875e+01 71 | 9.721397399902343750e+00 72 | -1.953671264648437500e+01 73 | -3.648714447021484375e+01 74 | 9.811126708984375000e+00 75 | -3.130242919921875000e+01 76 | 1.422447967529296875e+01 77 | -2.212834930419921875e+01 78 | 1.493019866943359375e+01 79 | -1.500880432128906250e+01 80 | 1.073588562011718750e+01 81 | -2.095037078857421875e+01 82 | 9.054298400878906250e+00 83 | -3.050099182128906250e+01 84 | 8.704177856445312500e+00 85 | 1.173237609863281250e+01 86 | 1.054329681396484375e+01 87 | 1.856353759765625000e+01 88 | 1.535009765625000000e+01 89 | 2.893331909179687500e+01 90 | 1.451992797851562500e+01 91 | 3.452944946289062500e+01 92 | 1.065280151367187500e+01 93 | 2.875990295410156250e+01 94 | 8.654792785644531250e+00 95 | 1.942100524902343750e+01 96 | 9.422447204589843750e+00 97 | -2.204488372802734375e+01 98 | -3.983994293212890625e+01 99 | -1.324458312988281250e+01 100 | -3.467377471923828125e+01 101 | -6.749649047851562500e+00 102 | -3.092894744873046875e+01 103 | -9.183349609375000000e-01 104 | -3.196458435058593750e+01 105 | 4.220649719238281250e+00 106 | -3.090406036376953125e+01 107 | 1.089889526367187500e+01 108 | -3.497008514404296875e+01 109 | 1.874589538574218750e+01 110 | -4.065438079833984375e+01 111 | 1.124106597900390625e+01 112 | -4.438417816162109375e+01 113 | 5.181709289550781250e+00 114 | -4.649170684814453125e+01 115 | -1.158607482910156250e+00 116 | -4.680406951904296875e+01 117 | -7.918922424316406250e+00 118 | -4.671575164794921875e+01 119 | -1.452505493164062500e+01 120 | -4.416526031494140625e+01 121 | -2.005007171630859375e+01 122 | -3.997841644287109375e+01 123 | -1.054919433593750000e+01 124 | -3.849683380126953125e+01 125 | -1.051826477050781250e+00 126 | -3.794863128662109375e+01 127 | 6.412681579589843750e+00 128 | -3.804645538330078125e+01 129 | 1.627674865722656250e+01 130 | -4.039697265625000000e+01 131 | 6.373878479003906250e+00 132 | -4.087213897705078125e+01 133 | -8.551712036132812500e-01 134 | -4.157129669189453125e+01 135 | -1.014953613281250000e+01 136 | -4.128469085693359375e+01 137 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This script contains basic utilities for Deep3DFaceRecon_pytorch 2 | """ 3 | from __future__ import print_function 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | try: 8 | from PIL.Image import Resampling 9 | RESAMPLING_METHOD = Resampling.BICUBIC 10 | except ImportError: 11 | from PIL.Image import BICUBIC 12 | RESAMPLING_METHOD = BICUBIC 13 | import os 14 | import importlib 15 | import argparse 16 | from argparse import Namespace 17 | import torchvision 18 | 19 | 20 | def str2bool(v): 21 | if isinstance(v, bool): 22 | return v 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise argparse.ArgumentTypeError('Boolean value expected.') 29 | 30 | 31 | def copyconf(default_opt, **kwargs): 32 | conf = Namespace(**vars(default_opt)) 33 | for key in kwargs: 34 | setattr(conf, key, kwargs[key]) 35 | return conf 36 | 37 | def genvalconf(train_opt, **kwargs): 38 | conf = Namespace(**vars(train_opt)) 39 | attr_dict = train_opt.__dict__ 40 | for key, value in attr_dict.items(): 41 | if 'val' in key and key.split('_')[0] in attr_dict: 42 | setattr(conf, key.split('_')[0], value) 43 | 44 | for key in kwargs: 45 | setattr(conf, key, kwargs[key]) 46 | 47 | return conf 48 | 49 | def find_class_in_module(target_cls_name, module): 50 | target_cls_name = target_cls_name.replace('_', '').lower() 51 | clslib = importlib.import_module(module) 52 | cls = None 53 | for name, clsobj in clslib.__dict__.items(): 54 | if name.lower() == target_cls_name: 55 | cls = clsobj 56 | 57 | assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) 58 | 59 | return cls 60 | 61 | 62 | def tensor2im(input_image, imtype=np.uint8): 63 | """"Converts a Tensor array into a numpy image array. 64 | 65 | Parameters: 66 | input_image (tensor) -- the input image tensor array, range(0, 1) 67 | imtype (type) -- the desired type of the converted numpy array 68 | """ 69 | if not isinstance(input_image, np.ndarray): 70 | if isinstance(input_image, torch.Tensor): # get the data from a variable 71 | image_tensor = input_image.data 72 | else: 73 | return input_image 74 | image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array 75 | if image_numpy.shape[0] == 1: # grayscale to RGB 76 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 77 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling 78 | else: # if it is a numpy array, do nothing 79 | image_numpy = input_image 80 | return image_numpy.astype(imtype) 81 | 82 | 83 | def diagnose_network(net, name='network'): 84 | """Calculate and print the mean of average absolute(gradients) 85 | 86 | Parameters: 87 | net (torch network) -- Torch network 88 | name (str) -- the name of the network 89 | """ 90 | mean = 0.0 91 | count = 0 92 | for param in net.parameters(): 93 | if param.grad is not None: 94 | mean += torch.mean(torch.abs(param.grad.data)) 95 | count += 1 96 | if count > 0: 97 | mean = mean / count 98 | print(name) 99 | print(mean) 100 | 101 | 102 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 103 | """Save a numpy image to the disk 104 | 105 | Parameters: 106 | image_numpy (numpy array) -- input numpy array 107 | image_path (str) -- the path of the image 108 | """ 109 | 110 | image_pil = Image.fromarray(image_numpy) 111 | h, w, _ = image_numpy.shape 112 | 113 | if aspect_ratio is None: 114 | pass 115 | elif aspect_ratio > 1.0: 116 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), RESAMPLING_METHOD) 117 | elif aspect_ratio < 1.0: 118 | image_pil = image_pil.resize((int(h / aspect_ratio), w), RESAMPLING_METHOD) 119 | image_pil.save(image_path) 120 | 121 | 122 | def print_numpy(x, val=True, shp=False): 123 | """Print the mean, min, max, median, std, and size of a numpy array 124 | 125 | Parameters: 126 | val (bool) -- if print the values of the numpy array 127 | shp (bool) -- if print the shape of the numpy array 128 | """ 129 | x = x.astype(np.float64) 130 | if shp: 131 | print('shape,', x.shape) 132 | if val: 133 | x = x.flatten() 134 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 135 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 136 | 137 | 138 | def mkdirs(paths): 139 | """create empty directories if they don't exist 140 | 141 | Parameters: 142 | paths (str list) -- a list of directory paths 143 | """ 144 | if isinstance(paths, list) and not isinstance(paths, str): 145 | for path in paths: 146 | mkdir(path) 147 | else: 148 | mkdir(paths) 149 | 150 | 151 | def mkdir(path): 152 | """create a single empty directory if it didn't exist 153 | 154 | Parameters: 155 | path (str) -- a single directory path 156 | """ 157 | if not os.path.exists(path): 158 | os.makedirs(path) 159 | 160 | 161 | def correct_resize_label(t, size): 162 | device = t.device 163 | t = t.detach().cpu() 164 | resized = [] 165 | for i in range(t.size(0)): 166 | one_t = t[i, :1] 167 | one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) 168 | one_np = one_np[:, :, 0] 169 | one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) 170 | resized_t = torch.from_numpy(np.array(one_image)).long() 171 | resized.append(resized_t) 172 | return torch.stack(resized, dim=0).to(device) 173 | 174 | 175 | def correct_resize(t, size, mode=RESAMPLING_METHOD): 176 | device = t.device 177 | t = t.detach().cpu() 178 | resized = [] 179 | for i in range(t.size(0)): 180 | one_t = t[i:i + 1] 181 | one_image = Image.fromarray(tensor2im(one_t)).resize(size, RESAMPLING_METHOD) 182 | resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 183 | resized.append(resized_t) 184 | return torch.stack(resized, dim=0).to(device) 185 | 186 | def draw_landmarks(img, landmark, color='r', step=2): 187 | """ 188 | Return: 189 | img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) 190 | 191 | 192 | Parameters: 193 | img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) 194 | landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction 195 | color -- str, 'r' or 'b' (red or blue) 196 | """ 197 | if color =='r': 198 | c = np.array([255., 0, 0]) 199 | else: 200 | c = np.array([0, 0, 255.]) 201 | 202 | _, H, W, _ = img.shape 203 | img, landmark = img.copy(), landmark.copy() 204 | landmark[..., 1] = H - 1 - landmark[..., 1] 205 | landmark = np.round(landmark).astype(np.int32) 206 | for i in range(landmark.shape[1]): 207 | x, y = landmark[:, i, 0], landmark[:, i, 1] 208 | for j in range(-step, step): 209 | for k in range(-step, step): 210 | u = np.clip(x + j, 0, W - 1) 211 | v = np.clip(y + k, 0, H - 1) 212 | for m in range(landmark.shape[0]): 213 | img[m, v[m], u[m]] = c 214 | return img 215 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | """This script defines the visualizer for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | import ntpath 8 | import time 9 | from . import util, html 10 | from subprocess import Popen, PIPE 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 14 | """Save images to the disk. 15 | 16 | Parameters: 17 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 18 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 19 | image_path (str) -- the string is used to create image paths 20 | aspect_ratio (float) -- the aspect ratio of saved images 21 | width (int) -- the images will be resized to width x width 22 | 23 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 24 | """ 25 | image_dir = webpage.get_image_dir() 26 | short_path = ntpath.basename(image_path[0]) 27 | name = os.path.splitext(short_path)[0] 28 | 29 | webpage.add_header(name) 30 | ims, txts, links = [], [], [] 31 | 32 | for label, im_data in visuals.items(): 33 | im = util.tensor2im(im_data) 34 | image_name = '%s/%s.png' % (label, name) 35 | os.makedirs(os.path.join(image_dir, label), exist_ok=True) 36 | save_path = os.path.join(image_dir, image_name) 37 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 38 | ims.append(image_name) 39 | txts.append(label) 40 | links.append(image_name) 41 | webpage.add_images(ims, txts, links, width=width) 42 | 43 | 44 | class Visualizer(): 45 | """This class includes several functions that can display/save images and print/save logging information. 46 | 47 | It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 48 | """ 49 | 50 | def __init__(self, opt): 51 | """Initialize the Visualizer class 52 | 53 | Parameters: 54 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 55 | Step 1: Cache the training/test options 56 | Step 2: create a tensorboard writer 57 | Step 3: create an HTML object for saveing HTML filters 58 | Step 4: create a logging file to store training losses 59 | """ 60 | self.opt = opt # cache the option 61 | self.use_html = opt.isTrain and not opt.no_html 62 | self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) 63 | self.win_size = opt.display_winsize 64 | self.name = opt.name 65 | self.saved = False 66 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 67 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 68 | self.img_dir = os.path.join(self.web_dir, 'images') 69 | print('create web directory %s...' % self.web_dir) 70 | util.mkdirs([self.web_dir, self.img_dir]) 71 | # create a logging file to store training losses 72 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 73 | with open(self.log_name, "a") as log_file: 74 | now = time.strftime("%c") 75 | log_file.write('================ Training Loss (%s) ================\n' % now) 76 | 77 | def reset(self): 78 | """Reset the self.saved status""" 79 | self.saved = False 80 | 81 | 82 | def display_current_results(self, visuals, total_iters, epoch, save_result): 83 | """Display current results on tensorboad; save current results to an HTML file. 84 | 85 | Parameters: 86 | visuals (OrderedDict) - - dictionary of images to display or save 87 | total_iters (int) -- total iterations 88 | epoch (int) - - the current epoch 89 | save_result (bool) - - if save the current results to an HTML file 90 | """ 91 | for label, image in visuals.items(): 92 | self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') 93 | 94 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 95 | self.saved = True 96 | # save images to the disk 97 | for label, image in visuals.items(): 98 | image_numpy = util.tensor2im(image) 99 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 100 | util.save_image(image_numpy, img_path) 101 | 102 | # update website 103 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) 104 | for n in range(epoch, 0, -1): 105 | webpage.add_header('epoch [%d]' % n) 106 | ims, txts, links = [], [], [] 107 | 108 | for label, image_numpy in visuals.items(): 109 | image_numpy = util.tensor2im(image) 110 | img_path = 'epoch%.3d_%s.png' % (n, label) 111 | ims.append(img_path) 112 | txts.append(label) 113 | links.append(img_path) 114 | webpage.add_images(ims, txts, links, width=self.win_size) 115 | webpage.save() 116 | 117 | def plot_current_losses(self, total_iters, losses): 118 | # G_loss_collection = {} 119 | # D_loss_collection = {} 120 | # for name, value in losses.items(): 121 | # if 'G' in name or 'NCE' in name or 'idt' in name: 122 | # G_loss_collection[name] = value 123 | # else: 124 | # D_loss_collection[name] = value 125 | # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) 126 | # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) 127 | for name, value in losses.items(): 128 | self.writer.add_scalar(name, value, total_iters) 129 | 130 | # losses: same format as |losses| of plot_current_losses 131 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 132 | """print current losses on console; also save the losses to the disk 133 | 134 | Parameters: 135 | epoch (int) -- current epoch 136 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 137 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 138 | t_comp (float) -- computational time per data point (normalized by batch_size) 139 | t_data (float) -- data loading time per data point (normalized by batch_size) 140 | """ 141 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 142 | for k, v in losses.items(): 143 | message += '%s: %.3f ' % (k, v) 144 | 145 | print(message) # print the message 146 | with open(self.log_name, "a") as log_file: 147 | log_file.write('%s\n' % message) # save the message 148 | 149 | 150 | class MyVisualizer: 151 | def __init__(self, opt): 152 | """Initialize the Visualizer class 153 | 154 | Parameters: 155 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 156 | Step 1: Cache the training/test options 157 | Step 2: create a tensorboard writer 158 | Step 3: create an HTML object for saveing HTML filters 159 | Step 4: create a logging file to store training losses 160 | """ 161 | self.opt = opt # cache the optio 162 | self.name = opt.name 163 | self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') 164 | 165 | if opt.phase != 'test': 166 | self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) 167 | # create a logging file to store training losses 168 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 169 | with open(self.log_name, "a") as log_file: 170 | now = time.strftime("%c") 171 | log_file.write('================ Training Loss (%s) ================\n' % now) 172 | 173 | 174 | def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, 175 | add_image=True): 176 | """Display current results on tensorboad; save current results to an HTML file. 177 | 178 | Parameters: 179 | visuals (OrderedDict) - - dictionary of images to display or save 180 | total_iters (int) -- total iterations 181 | epoch (int) - - the current epoch 182 | dataset (str) - - 'train' or 'val' or 'test' 183 | """ 184 | # if (not add_image) and (not save_results): return 185 | 186 | for label, image in visuals.items(): 187 | for i in range(image.shape[0]): 188 | image_numpy = util.tensor2im(image[i]) 189 | if add_image: 190 | self.writer.add_image(label + '%s_%02d'%(dataset, i + count), 191 | image_numpy, total_iters, dataformats='HWC') 192 | 193 | if save_results: 194 | save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) 195 | if not os.path.isdir(save_path): 196 | os.makedirs(save_path) 197 | 198 | if name is not None: 199 | img_path = os.path.join(save_path, '%s.png' % name) 200 | else: 201 | img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) 202 | util.save_image(image_numpy, img_path) 203 | 204 | 205 | def plot_current_losses(self, total_iters, losses, dataset='train'): 206 | for name, value in losses.items(): 207 | self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) 208 | 209 | # losses: same format as |losses| of plot_current_losses 210 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): 211 | """print current losses on console; also save the losses to the disk 212 | 213 | Parameters: 214 | epoch (int) -- current epoch 215 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 216 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 217 | t_comp (float) -- computational time per data point (normalized by batch_size) 218 | t_data (float) -- data loading time per data point (normalized by batch_size) 219 | """ 220 | message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( 221 | dataset, epoch, iters, t_comp, t_data) 222 | for k, v in losses.items(): 223 | message += '%s: %.3f ' % (k, v) 224 | 225 | print(message) # print the message 226 | with open(self.log_name, "a") as log_file: 227 | log_file.write('%s\n' % message) # save the message 228 | --------------------------------------------------------------------------------