├── BFM
├── .gitkeep
├── BFM_exp_idx.mat
├── BFM_front_idx.mat
├── facemodel_info.mat
├── select_vertex_id.mat
├── similarity_Lm3D_all.mat
└── std_exp.txt
├── LICENSE
├── README.md
├── data
├── __init__.py
├── base_dataset.py
├── flist_dataset.py
├── image_folder.py
└── template_dataset.py
├── data_preparation.py
├── datasets
└── examples
│ ├── 000002.jpg
│ ├── 000006.jpg
│ ├── 000007.jpg
│ ├── 000031.jpg
│ ├── 000033.jpg
│ ├── 000037.jpg
│ ├── 000050.jpg
│ ├── 000055.jpg
│ ├── 000114.jpg
│ ├── 000125.jpg
│ ├── 000126.jpg
│ ├── 015259.jpg
│ ├── 015270.jpg
│ ├── 015309.jpg
│ ├── 015310.jpg
│ ├── 015316.jpg
│ ├── 015384.jpg
│ ├── detections
│ ├── 000002.txt
│ ├── 000006.txt
│ ├── 000007.txt
│ ├── 000031.txt
│ ├── 000033.txt
│ ├── 000037.txt
│ ├── 000050.txt
│ ├── 000055.txt
│ ├── 000114.txt
│ ├── 000125.txt
│ ├── 000126.txt
│ ├── 015259.txt
│ ├── 015270.txt
│ ├── 015309.txt
│ ├── 015310.txt
│ ├── 015316.txt
│ ├── 015384.txt
│ ├── vd006.txt
│ ├── vd025.txt
│ ├── vd026.txt
│ ├── vd034.txt
│ ├── vd051.txt
│ ├── vd070.txt
│ ├── vd092.txt
│ └── vd102.txt
│ ├── vd006.png
│ ├── vd025.png
│ ├── vd026.png
│ ├── vd034.png
│ ├── vd051.png
│ ├── vd070.png
│ ├── vd092.png
│ └── vd102.png
├── environment.yml
├── images
├── 20230425_compare.png
├── compare.png
└── example.gif
├── models
├── __init__.py
├── base_model.py
├── bfm.py
├── facerecon_model.py
├── losses.py
├── networks.py
└── template_model.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── test.py
├── train.py
└── util
├── BBRegressorParam_r.mat
├── __init__.py
├── detect_lm68.py
├── generate_list.py
├── html.py
├── load_mats.py
├── nvdiffrast.py
├── preprocess.py
├── skin_mask.py
├── test_mean_face.txt
├── util.py
└── visualizer.py
/BFM/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/.gitkeep
--------------------------------------------------------------------------------
/BFM/BFM_exp_idx.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/BFM_exp_idx.mat
--------------------------------------------------------------------------------
/BFM/BFM_front_idx.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/BFM_front_idx.mat
--------------------------------------------------------------------------------
/BFM/facemodel_info.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/facemodel_info.mat
--------------------------------------------------------------------------------
/BFM/select_vertex_id.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/select_vertex_id.mat
--------------------------------------------------------------------------------
/BFM/similarity_Lm3D_all.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/BFM/similarity_Lm3D_all.mat
--------------------------------------------------------------------------------
/BFM/std_exp.txt:
--------------------------------------------------------------------------------
1 | 453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Sicheng Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set —— PyTorch implementation ##
2 |
3 |
4 |
5 |
6 |
7 | This is an unofficial official pytorch implementation of the following paper:
8 |
9 | Y. Deng, J. Yang, S. Xu, D. Chen, Y. Jia, and X. Tong, [Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set](https://arxiv.org/abs/1903.08527), IEEE Computer Vision and Pattern Recognition Workshop (CVPRW) on Analysis and Modeling of Faces and Gestures (AMFG), 2019. (**_Best Paper Award!_**)
10 |
11 | The method enforces a hybrid-level weakly-supervised training for CNN-based 3D face reconstruction. It is fast, accurate, and robust to pose and occlussions. It achieves state-of-the-art performance on multiple datasets such as FaceWarehouse, MICC Florence and NoW Challenge.
12 |
13 |
14 | For the original tensorflow implementation, check this [repo](https://github.com/microsoft/Deep3DFaceReconstruction).
15 |
16 | This implementation is written by S. Xu.
17 | ## 04/25/2023 Update
18 | We updated a new model to improve the results on "closed eye" images. We collected ~2K facial images with closed eyes and included them in the training data. The updated model has similar reconstruction accuracy as the previous one on the benchmarks, but has better results for faces with closed eyes (see below). Here's the [link (google drive)](https://drive.google.com/drive/folders/1grs8J4vu7gOhEClyKjWU-SNxfonGue5F?usp=share_link) to the new model.
19 | ### ● Reconstruction accuracy
20 |
21 | |Method|FaceWareHouse|MICC Florence
22 | |:----:|:-----------:|:-----------:|
23 | |Deep3DFace_PyTorch_20230425|1.60±0.44|1.54±0.49|
24 |
25 | ### ● Visual quality
26 |
27 |
28 |
29 |
30 | ## Performance
31 |
32 | ### ● Reconstruction accuracy
33 |
34 | The pytorch implementation achieves lower shape reconstruction error (9% improvement) compare to the [original tensorflow implementation](https://github.com/microsoft/Deep3DFaceReconstruction). Quantitative evaluation (average shape errors in mm) on several benchmarks is as follows:
35 |
36 | |Method|FaceWareHouse|MICC Florence | NoW Challenge |
37 | |:----:|:-----------:|:-----------:|:-----------:|
38 | |Deep3DFace Tensorflow | 1.81±0.50 | 1.67±0.50 | 1.54±1.29 |
39 | |**Deep3DFace PyTorch** |**1.64±0.50**|**1.53±0.45**| **1.41±1.21** |
40 |
41 | The comparison result with state-of-the-art public 3D face reconstruction methods on the NoW face benchmark is as follows:
42 | |Rank|Method|Median(mm) | Mean(mm) | Std(mm) |
43 | |:----:|:-----------:|:-----------:|:-----------:|:-----------:|
44 | | 1. | [DECA\[Feng et al., SIGGRAPH 2021\]](https://github.com/YadiraF/DECA)|1.09|1.38|1.18|
45 | | **2.** | **Deep3DFace PyTorch**|**1.11**|**1.41**|**1.21**|
46 | | 3. | [RingNet [Sanyal et al., CVPR 2019]](https://github.com/soubhiksanyal/RingNet) | 1.21 | 1.53 | 1.31 |
47 | | 4. | [Deep3DFace [Deng et al., CVPRW 2019]](https://github.com/microsoft/Deep3DFaceReconstruction) | 1.23 | 1.54 | 1.29 |
48 | | 5. | [3DDFA-V2 [Guo et al., ECCV 2020]](https://github.com/cleardusk/3DDFA_V2) | 1.23 | 1.57 | 1.39 |
49 | | 6. | [MGCNet [Shang et al., ECCV 2020]](https://github.com/jiaxiangshang/MGCNet) | 1.31 | 1.87 | 2.63 |
50 | | 7. | [PRNet [Feng et al., ECCV 2018]](https://github.com/YadiraF/PRNet) | 1.50 | 1.98 | 1.88 |
51 | | 8. | [3DMM-CNN [Tran et al., CVPR 2017]](https://github.com/anhttran/3dmm_cnn) | 1.84 | 2.33 | 2.05 |
52 |
53 | For more details about the evaluation, check [Now Challenge](https://ringnet.is.tue.mpg.de/challenge.html) website.
54 |
55 | **_A recent benchmark [REALY](https://www.realy3dface.com/) indicates that our method still has the SOTA performance! You can check their paper and website for more details._**
56 |
57 | ### ● Visual quality
58 | The pytorch implementation achieves better visual consistency with the input images compare to the original tensorflow version.
59 |
60 |
61 |
62 |
63 |
64 | ### ● Speed
65 | The training speed is on par with the original tensorflow implementation. For more information, see [here](https://github.com/sicxu/Deep3DFaceRecon_pytorch#train-the-face-reconstruction-network).
66 |
67 | ## Major changes
68 |
69 | ### ● Differentiable renderer
70 |
71 | We use [Nvdiffrast](https://nvlabs.github.io/nvdiffrast/) which is a pytorch library that provides high-performance primitive operations for rasterization-based differentiable rendering. The original tensorflow implementation used [tf_mesh_renderer](https://github.com/google/tf_mesh_renderer) instead.
72 |
73 | ### ● Face recognition model
74 |
75 | We use [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch), a state-of-the-art face recognition model, for perceptual loss computation. By contrast, the original tensorflow implementation used [Facenet](https://github.com/davidsandberg/facenet).
76 |
77 | ### ● Training configuration
78 |
79 | Data augmentation is used in the training process which contains random image shifting, scaling, rotation, and flipping. We also enlarge the training batchsize from 5 to 32 to stablize the training process.
80 |
81 | ### ● Training data
82 |
83 | We use an extra high quality face image dataset [FFHQ](https://github.com/NVlabs/ffhq-dataset) to increase the diversity of training data.
84 |
85 | ## Requirements
86 | **This implementation is only tested under Ubuntu environment with Nvidia GPUs and CUDA installed.** But it should also work on Windows with proper lib configures.
87 |
88 | ## Installation
89 | 1. Clone the repository and set up a conda environment with all dependencies as follows:
90 | ```
91 | git clone https://github.com/sicxu/Deep3DFaceRecon_pytorch.git
92 | cd Deep3DFaceRecon_pytorch
93 | conda env create -f environment.yml
94 | source activate deep3d_pytorch
95 | ```
96 |
97 | 2. Install Nvdiffrast library:
98 | ```
99 | git clone -b 0.3.0 https://github.com/NVlabs/nvdiffrast
100 | cd nvdiffrast # ./Deep3DFaceRecon_pytorch/nvdiffrast
101 | pip install .
102 | ```
103 |
104 | 3. Install Arcface Pytorch:
105 | ```
106 | cd .. # ./Deep3DFaceRecon_pytorch
107 | git clone https://github.com/deepinsight/insightface.git
108 | cp -r ./insightface/recognition/arcface_torch ./models/
109 | ```
110 | ## Inference with a pre-trained model
111 |
112 | ### Prepare prerequisite models
113 | 1. Our method uses [Basel Face Model 2009 (BFM09)](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-0&id=basel_face_model) to represent 3d faces. Get access to BFM09 using this [link](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). After getting the access, download "01_MorphableModel.mat". In addition, we use an Expression Basis provided by [Guo et al.](https://github.com/Juyong/3DFace). Download the Expression Basis (Exp_Pca.bin) using this [link (google drive)](https://drive.google.com/file/d/1bw5Xf8C12pWmcMhNEu6PtsYVZkVucEN6/view?usp=sharing). Organize all files into the following structure:
114 | ```
115 | Deep3DFaceRecon_pytorch
116 | │
117 | └─── BFM
118 | │
119 | └─── 01_MorphableModel.mat
120 | │
121 | └─── Exp_Pca.bin
122 | |
123 | └─── ...
124 | ```
125 | 2. We provide a model trained on a combination of [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html),
126 | [LFW](http://vis-www.cs.umass.edu/lfw/), [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm),
127 | [IJB-A](https://www.nist.gov/programs-projects/face-challenges), [LS3D-W](https://www.adrianbulat.com/face-alignment), and [FFHQ](https://github.com/NVlabs/ffhq-dataset) datasets. Download the pre-trained model using this [link (google drive)](https://drive.google.com/drive/folders/1liaIxn9smpudjjqMaWWRpP0mXRW_qRPP?usp=sharing) and organize the directory into the following structure:
128 | ```
129 | Deep3DFaceRecon_pytorch
130 | │
131 | └─── checkpoints
132 | │
133 | └───
134 | │
135 | └─── epoch_20.pth
136 |
137 | ```
138 |
139 | ### Test with custom images
140 | To reconstruct 3d faces from test images, organize the test image folder as follows:
141 | ```
142 | Deep3DFaceRecon_pytorch
143 | │
144 | └───
145 | │
146 | └─── *.jpg/*.png
147 | |
148 | └─── detections
149 | |
150 | └─── *.txt
151 | ```
152 | The \*.jpg/\*.png files are test images. The \*.txt files are detected 5 facial landmarks with a shape of 5x2, and have the same name as the corresponding images. Check [./datasets/examples](datasets/examples) for a reference.
153 |
154 | Then, run the test script:
155 | ```
156 | # get reconstruction results of your custom images
157 | python test.py --name= --epoch=20 --img_folder=
158 |
159 | # get reconstruction results of example images
160 | python test.py --name= --epoch=20 --img_folder=./datasets/examples
161 | ```
162 | **_Following [#108](https://github.com/sicxu/Deep3DFaceRecon_pytorch/issues/108), if you don't have OpenGL environment, you can simply add "--use_opengl False" to use CUDA context. Make sure you have updated the nvdiffrast to the latest version._**
163 |
164 | Results will be saved into ./checkpoints//results/, which contain the following files:
165 | | \*.png | A combination of cropped input image, reconstructed image, and visualization of projected landmarks.
166 | |:----|:-----------|
167 | | \*.obj | Reconstructed 3d face mesh with predicted color (texture+illumination) in the world coordinate space. Best viewed in Meshlab. |
168 | | \*.mat | Predicted 257-dimensional coefficients and 68 projected 2d facial landmarks. Best viewed in Matlab.
169 |
170 | ## Training a model from scratch
171 | ### Prepare prerequisite models
172 | 1. We rely on [Arcface](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch) to extract identity features for loss computation. Download the pre-trained model from Arcface using this [link](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch#ms1mv3). By default, we use the resnet50 backbone ([ms1mv3_arcface_r50_fp16](https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215583&cid=4A83B6B633B029CC)), organize the download files into the following structure:
173 | ```
174 | Deep3DFaceRecon_pytorch
175 | │
176 | └─── checkpoints
177 | │
178 | └─── recog_model
179 | │
180 | └─── ms1mv3_arcface_r50_fp16
181 | |
182 | └─── backbone.pth
183 | ```
184 | 2. We initialize R-Net using the weights trained on [ImageNet](https://image-net.org/). Download the weights provided by PyTorch using this [link](https://download.pytorch.org/models/resnet50-0676ba61.pth), and organize the file as the following structure:
185 | ```
186 | Deep3DFaceRecon_pytorch
187 | │
188 | └─── checkpoints
189 | │
190 | └─── init_model
191 | │
192 | └─── resnet50-0676ba61.pth
193 | ```
194 | 3. We provide a landmark detector (tensorflow model) to extract 68 facial landmarks for loss computation. The detector is trained on [300WLP](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm), [LFW](http://vis-www.cs.umass.edu/lfw/), and [LS3D-W](https://www.adrianbulat.com/face-alignment) datasets. Download the trained model using this [link (google drive)](https://drive.google.com/file/d/1Jl1yy2v7lIJLTRVIpgg2wvxYITI8Dkmw/view?usp=sharing) and organize the file as follows:
195 | ```
196 | Deep3DFaceRecon_pytorch
197 | │
198 | └─── checkpoints
199 | │
200 | └─── lm_model
201 | │
202 | └─── 68lm_detector.pb
203 | ```
204 | ### Data preparation
205 | 1. To train a model with custom images,5 facial landmarks of each image are needed in advance for an image pre-alignment process. We recommend using [dlib](http://dlib.net/) or [MTCNN](https://github.com/ipazc/mtcnn) to detect these landmarks. Then, organize all files into the following structure:
206 | ```
207 | Deep3DFaceRecon_pytorch
208 | │
209 | └─── datasets
210 | │
211 | └───
212 | │
213 | └─── *.png/*.jpg
214 | |
215 | └─── detections
216 | |
217 | └─── *.txt
218 | ```
219 | The \*.txt files contain 5 facial landmarks with a shape of 5x2, and should have the same name with their corresponding images.
220 |
221 | 2. Generate 68 landmarks and skin attention mask for images using the following script:
222 | ```
223 | # preprocess training images
224 | python data_preparation.py --img_folder
225 |
226 | # alternatively, you can preprocess multiple image folders simultaneously
227 | python data_preparation.py --img_folder
228 |
229 | # preprocess validation images
230 | python data_preparation.py --img_folder --mode=val
231 | ```
232 | The script will generate files of landmarks and skin masks, and save them into ./datasets/. In addition, it also generates a file containing the path of all training data into ./datalist which will then be used in the training script.
233 |
234 | ### Train the face reconstruction network
235 | Run the following script to train a face reconstruction model using the pre-processed data:
236 | ```
237 | # train with single GPU
238 | python train.py --name= --gpu_ids=0
239 |
240 | # train with multiple GPUs
241 | python train.py --name= --gpu_ids=0,1
242 |
243 | # train with other custom settings
244 | python train.py --name= --gpu_ids=0 --batch_size=32 --n_epochs=20
245 | ```
246 | Training logs and model parameters will be saved into ./checkpoints/.
247 |
248 | By default, the script uses a batchsize of 32 and will train the model with 20 epochs. For reference, the pre-trained model in this repo is trained with the default setting on a image collection of 300k images. A single iteration takes 0.8~0.9s on a single Tesla M40 GPU. The total training process takes around two days.
249 |
250 | To use a trained model, see [Inference](https://github.com/sicxu/Deep3DFaceRecon_pytorch#inference-with-a-pre-trained-model) section.
251 | ## Contact
252 | If you have any questions, please contact the paper authors.
253 |
254 | ## Citation
255 |
256 | Please cite the following paper if this model helps your research:
257 |
258 | @inproceedings{deng2019accurate,
259 | title={Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set},
260 | author={Yu Deng and Jiaolong Yang and Sicheng Xu and Dong Chen and Yunde Jia and Xin Tong},
261 | booktitle={IEEE Computer Vision and Pattern Recognition Workshops},
262 | year={2019}
263 | }
264 | ##
265 | The face images on this page are from the public [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset released by MMLab, CUHK.
266 |
267 | Part of the code in this implementation takes [CUT](https://github.com/taesungp/contrastive-unpaired-translation) as a reference.
268 |
269 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import numpy as np
14 | import importlib
15 | import torch.utils.data
16 | from data.base_dataset import BaseDataset
17 |
18 |
19 | def find_dataset_using_name(dataset_name):
20 | """Import the module "data/[dataset_name]_dataset.py".
21 |
22 | In the file, the class called DatasetNameDataset() will
23 | be instantiated. It has to be a subclass of BaseDataset,
24 | and it is case-insensitive.
25 | """
26 | dataset_filename = "data." + dataset_name + "_dataset"
27 | datasetlib = importlib.import_module(dataset_filename)
28 |
29 | dataset = None
30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31 | for name, cls in datasetlib.__dict__.items():
32 | if name.lower() == target_dataset_name.lower() \
33 | and issubclass(cls, BaseDataset):
34 | dataset = cls
35 |
36 | if dataset is None:
37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38 |
39 | return dataset
40 |
41 |
42 | def get_option_setter(dataset_name):
43 | """Return the static method of the dataset class."""
44 | dataset_class = find_dataset_using_name(dataset_name)
45 | return dataset_class.modify_commandline_options
46 |
47 |
48 | def create_dataset(opt, rank=0):
49 | """Create a dataset given the option.
50 |
51 | This function wraps the class CustomDatasetDataLoader.
52 | This is the main interface between this package and 'train.py'/'test.py'
53 |
54 | Example:
55 | >>> from data import create_dataset
56 | >>> dataset = create_dataset(opt)
57 | """
58 | data_loader = CustomDatasetDataLoader(opt, rank=rank)
59 | dataset = data_loader.load_data()
60 | return dataset
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt, rank=0):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | self.sampler = None
75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76 | if opt.use_ddp and opt.isTrain:
77 | world_size = opt.world_size
78 | self.sampler = torch.utils.data.distributed.DistributedSampler(
79 | self.dataset,
80 | num_replicas=world_size,
81 | rank=rank,
82 | shuffle=not opt.serial_batches
83 | )
84 | self.dataloader = torch.utils.data.DataLoader(
85 | self.dataset,
86 | sampler=self.sampler,
87 | num_workers=int(opt.num_threads / world_size),
88 | batch_size=int(opt.batch_size / world_size),
89 | drop_last=True)
90 | else:
91 | self.dataloader = torch.utils.data.DataLoader(
92 | self.dataset,
93 | batch_size=opt.batch_size,
94 | shuffle=(not opt.serial_batches) and opt.isTrain,
95 | num_workers=int(opt.num_threads),
96 | drop_last=True
97 | )
98 |
99 | def set_epoch(self, epoch):
100 | self.dataset.current_epoch = epoch
101 | if self.sampler is not None:
102 | self.sampler.set_epoch(epoch)
103 |
104 | def load_data(self):
105 | return self
106 |
107 | def __len__(self):
108 | """Return the number of data in the dataset"""
109 | return min(len(self.dataset), self.opt.max_dataset_size)
110 |
111 | def __iter__(self):
112 | """Return a batch of data"""
113 | for i, data in enumerate(self.dataloader):
114 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
115 | break
116 | yield data
117 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | try:
10 | from PIL.Image import Resampling
11 | RESAMPLING_METHOD = Resampling.BICUBIC
12 | except ImportError:
13 | from PIL.Image import BICUBIC
14 | RESAMPLING_METHOD = BICUBIC
15 | import torchvision.transforms as transforms
16 | from abc import ABC, abstractmethod
17 |
18 |
19 | class BaseDataset(data.Dataset, ABC):
20 | """This class is an abstract base class (ABC) for datasets.
21 |
22 | To create a subclass, you need to implement the following four functions:
23 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
24 | -- <__len__>: return the size of dataset.
25 | -- <__getitem__>: get a data point.
26 | -- : (optionally) add dataset-specific options and set default options.
27 | """
28 |
29 | def __init__(self, opt):
30 | """Initialize the class; save the options in the class
31 |
32 | Parameters:
33 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
34 | """
35 | self.opt = opt
36 | # self.root = opt.dataroot
37 | self.current_epoch = 0
38 |
39 | @staticmethod
40 | def modify_commandline_options(parser, is_train):
41 | """Add new dataset-specific options, and rewrite default values for existing options.
42 |
43 | Parameters:
44 | parser -- original option parser
45 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
46 |
47 | Returns:
48 | the modified parser.
49 | """
50 | return parser
51 |
52 | @abstractmethod
53 | def __len__(self):
54 | """Return the total number of images in the dataset."""
55 | return 0
56 |
57 | @abstractmethod
58 | def __getitem__(self, index):
59 | """Return a data point and its metadata information.
60 |
61 | Parameters:
62 | index - - a random integer for data indexing
63 |
64 | Returns:
65 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
66 | """
67 | pass
68 |
69 |
70 | def get_transform(grayscale=False):
71 | transform_list = []
72 | if grayscale:
73 | transform_list.append(transforms.Grayscale(1))
74 | transform_list += [transforms.ToTensor()]
75 | return transforms.Compose(transform_list)
76 |
77 | def get_affine_mat(opt, size):
78 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
79 | w, h = size
80 |
81 | if 'shift' in opt.preprocess:
82 | shift_pixs = int(opt.shift_pixs)
83 | shift_x = random.randint(-shift_pixs, shift_pixs)
84 | shift_y = random.randint(-shift_pixs, shift_pixs)
85 | if 'scale' in opt.preprocess:
86 | scale = 1 + opt.scale_delta * (2 * random.random() - 1)
87 | if 'rot' in opt.preprocess:
88 | rot_angle = opt.rot_angle * (2 * random.random() - 1)
89 | rot_rad = -rot_angle * np.pi/180
90 | if 'flip' in opt.preprocess:
91 | flip = random.random() > 0.5
92 |
93 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
94 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
95 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
96 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
97 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
98 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
99 |
100 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
101 | affine_inv = np.linalg.inv(affine)
102 | return affine, affine_inv, flip
103 |
104 | def apply_img_affine(img, affine_inv, method=RESAMPLING_METHOD):
105 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=RESAMPLING_METHOD)
106 |
107 | def apply_lm_affine(landmark, affine, flip, size):
108 | _, h = size
109 | lm = landmark.copy()
110 | lm[:, 1] = h - 1 - lm[:, 1]
111 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
112 | lm = lm @ np.transpose(affine)
113 | lm[:, :2] = lm[:, :2] / lm[:, 2:]
114 | lm = lm[:, :2]
115 | lm[:, 1] = h - 1 - lm[:, 1]
116 | if flip:
117 | lm_ = lm.copy()
118 | lm_[:17] = lm[16::-1]
119 | lm_[17:22] = lm[26:21:-1]
120 | lm_[22:27] = lm[21:16:-1]
121 | lm_[31:36] = lm[35:30:-1]
122 | lm_[36:40] = lm[45:41:-1]
123 | lm_[40:42] = lm[47:45:-1]
124 | lm_[42:46] = lm[39:35:-1]
125 | lm_[46:48] = lm[41:39:-1]
126 | lm_[48:55] = lm[54:47:-1]
127 | lm_[55:60] = lm[59:54:-1]
128 | lm_[60:65] = lm[64:59:-1]
129 | lm_[65:68] = lm[67:64:-1]
130 | lm = lm_
131 | return lm
132 |
--------------------------------------------------------------------------------
/data/flist_dataset.py:
--------------------------------------------------------------------------------
1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os.path
5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6 | from data.image_folder import make_dataset
7 | from PIL import Image
8 | import random
9 | import util.util as util
10 | import numpy as np
11 | import json
12 | import torch
13 | from scipy.io import loadmat, savemat
14 | import pickle
15 | from util.preprocess import align_img, estimate_norm
16 | from util.load_mats import load_lm3d
17 |
18 |
19 | def default_flist_reader(flist):
20 | """
21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22 | """
23 | imlist = []
24 | with open(flist, 'r') as rf:
25 | for line in rf.readlines():
26 | impath = line.strip()
27 | imlist.append(impath)
28 |
29 | return imlist
30 |
31 | def jason_flist_reader(flist):
32 | with open(flist, 'r') as fp:
33 | info = json.load(fp)
34 | return info
35 |
36 | def parse_label(label):
37 | return torch.tensor(np.array(label).astype(np.float32))
38 |
39 |
40 | class FlistDataset(BaseDataset):
41 | """
42 | It requires one directories to host training images '/path/to/data/train'
43 | You can train the model with the dataset flag '--dataroot /path/to/data'.
44 | """
45 |
46 | def __init__(self, opt):
47 | """Initialize this dataset class.
48 |
49 | Parameters:
50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51 | """
52 | BaseDataset.__init__(self, opt)
53 |
54 | self.lm3d_std = load_lm3d(opt.bfm_folder)
55 |
56 | msk_names = default_flist_reader(opt.flist)
57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58 |
59 | self.size = len(self.msk_paths)
60 | self.opt = opt
61 |
62 | self.name = 'train' if opt.isTrain else 'val'
63 | if '_' in opt.flist:
64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65 |
66 |
67 | def __getitem__(self, index):
68 | """Return a data point and its metadata information.
69 |
70 | Parameters:
71 | index (int) -- a random integer for data indexing
72 |
73 | Returns a dictionary that contains A, B, A_paths and B_paths
74 | img (tensor) -- an image in the input domain
75 | msk (tensor) -- its corresponding attention mask
76 | lm (tensor) -- its corresponding 3d landmarks
77 | im_paths (str) -- image paths
78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented
79 | """
80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81 | img_path = msk_path.replace('mask/', '')
82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83 |
84 | raw_img = Image.open(img_path).convert('RGB')
85 | raw_msk = Image.open(msk_path).convert('RGB')
86 | raw_lm = np.loadtxt(lm_path).astype(np.float32)
87 |
88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89 |
90 | aug_flag = self.opt.use_aug and self.opt.isTrain
91 | if aug_flag:
92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93 |
94 | _, H = img.size
95 | M = estimate_norm(lm, H)
96 | transform = get_transform()
97 | img_tensor = transform(img)
98 | msk_tensor = transform(msk)[:1, ...]
99 | lm_tensor = parse_label(lm)
100 | M_tensor = parse_label(M)
101 |
102 |
103 | return {'imgs': img_tensor,
104 | 'lms': lm_tensor,
105 | 'msks': msk_tensor,
106 | 'M': M_tensor,
107 | 'im_paths': img_path,
108 | 'aug_flag': aug_flag,
109 | 'dataset': self.name}
110 |
111 | def _augmentation(self, img, lm, opt, msk=None):
112 | affine, affine_inv, flip = get_affine_mat(opt, img.size)
113 | img = apply_img_affine(img, affine_inv)
114 | lm = apply_lm_affine(lm, affine, flip, img.size)
115 | if msk is not None:
116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117 | return img, lm, msk
118 |
119 |
120 |
121 |
122 | def __len__(self):
123 | """Return the total number of images in the dataset.
124 | """
125 | return self.size
126 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 | import numpy as np
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | '.tif', '.TIF', '.tiff', '.TIFF',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir, max_dataset_size=float("inf")):
25 | images = []
26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 | return images[:min(max_dataset_size, len(images))]
34 |
35 |
36 | def default_loader(path):
37 | return Image.open(path).convert('RGB')
38 |
39 |
40 | class ImageFolder(data.Dataset):
41 |
42 | def __init__(self, root, transform=None, return_paths=False,
43 | loader=default_loader):
44 | imgs = make_dataset(root)
45 | if len(imgs) == 0:
46 | raise(RuntimeError("Found 0 images in: " + root + "\n"
47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/data/template_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | from data.base_dataset import BaseDataset, get_transform
15 | # from data.image_folder import make_dataset
16 | # from PIL import Image
17 |
18 |
19 | class TemplateDataset(BaseDataset):
20 | """A template dataset class for you to implement custom datasets."""
21 | @staticmethod
22 | def modify_commandline_options(parser, is_train):
23 | """Add new dataset-specific options, and rewrite default values for existing options.
24 |
25 | Parameters:
26 | parser -- original option parser
27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28 |
29 | Returns:
30 | the modified parser.
31 | """
32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34 | return parser
35 |
36 | def __init__(self, opt):
37 | """Initialize this dataset class.
38 |
39 | Parameters:
40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41 |
42 | A few things can be done here.
43 | - save the options (have been done in BaseDataset)
44 | - get image paths and meta information of the dataset.
45 | - define the image transformation.
46 | """
47 | # save the option and dataset root
48 | BaseDataset.__init__(self, opt)
49 | # get the image paths of your dataset;
50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51 | # define the default transform function. You can use ; You can also define your custom transform function
52 | self.transform = get_transform(opt)
53 |
54 | def __getitem__(self, index):
55 | """Return a data point and its metadata information.
56 |
57 | Parameters:
58 | index -- a random integer for data indexing
59 |
60 | Returns:
61 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
62 |
63 | Step 1: get a random image path: e.g., path = self.image_paths[index]
64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66 | Step 4: return a data point as a dictionary.
67 | """
68 | path = 'temp' # needs to be a string
69 | data_A = None # needs to be a tensor
70 | data_B = None # needs to be a tensor
71 | return {'data_A': data_A, 'data_B': data_B, 'path': path}
72 |
73 | def __len__(self):
74 | """Return the total number of images."""
75 | return len(self.image_paths)
76 |
--------------------------------------------------------------------------------
/data_preparation.py:
--------------------------------------------------------------------------------
1 | """This script is the data preparation script for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 | import numpy as np
6 | import argparse
7 | from util.detect_lm68 import detect_68p,load_lm_graph
8 | from util.skin_mask import get_skin_mask
9 | from util.generate_list import check_list, write_list
10 | import warnings
11 | warnings.filterwarnings("ignore")
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
15 | parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images')
16 | parser.add_argument('--mode', type=str, default='train', help='train or val')
17 | opt = parser.parse_args()
18 |
19 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
20 |
21 | def data_prepare(folder_list,mode):
22 |
23 | lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
24 |
25 | for img_folder in folder_list:
26 | detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
27 | get_skin_mask(img_folder) # generate skin attention mask for images
28 |
29 | # create files that record path to all training data
30 | msks_list = []
31 | for img_folder in folder_list:
32 | path = os.path.join(img_folder, 'mask')
33 | msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
34 | 'png' in i or 'jpeg' in i or 'PNG' in i]
35 |
36 | imgs_list = [i.replace('mask/', '') for i in msks_list]
37 | lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
38 | lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]
39 |
40 | lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
41 | write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
42 |
43 | if __name__ == '__main__':
44 | print('Datasets:',opt.img_folder)
45 | data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
46 |
--------------------------------------------------------------------------------
/datasets/examples/000002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000002.jpg
--------------------------------------------------------------------------------
/datasets/examples/000006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000006.jpg
--------------------------------------------------------------------------------
/datasets/examples/000007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000007.jpg
--------------------------------------------------------------------------------
/datasets/examples/000031.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000031.jpg
--------------------------------------------------------------------------------
/datasets/examples/000033.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000033.jpg
--------------------------------------------------------------------------------
/datasets/examples/000037.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000037.jpg
--------------------------------------------------------------------------------
/datasets/examples/000050.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000050.jpg
--------------------------------------------------------------------------------
/datasets/examples/000055.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000055.jpg
--------------------------------------------------------------------------------
/datasets/examples/000114.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000114.jpg
--------------------------------------------------------------------------------
/datasets/examples/000125.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000125.jpg
--------------------------------------------------------------------------------
/datasets/examples/000126.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/000126.jpg
--------------------------------------------------------------------------------
/datasets/examples/015259.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015259.jpg
--------------------------------------------------------------------------------
/datasets/examples/015270.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015270.jpg
--------------------------------------------------------------------------------
/datasets/examples/015309.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015309.jpg
--------------------------------------------------------------------------------
/datasets/examples/015310.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015310.jpg
--------------------------------------------------------------------------------
/datasets/examples/015316.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015316.jpg
--------------------------------------------------------------------------------
/datasets/examples/015384.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/015384.jpg
--------------------------------------------------------------------------------
/datasets/examples/detections/000002.txt:
--------------------------------------------------------------------------------
1 | 142.84 207.18
2 | 222.02 203.9
3 | 159.24 253.57
4 | 146.59 290.93
5 | 227.52 284.74
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000006.txt:
--------------------------------------------------------------------------------
1 | 199.93 158.28
2 | 255.34 166.54
3 | 236.08 198.92
4 | 198.83 229.24
5 | 245.23 234.52
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000007.txt:
--------------------------------------------------------------------------------
1 | 129.36 198.28
2 | 204.47 191.47
3 | 164.42 240.51
4 | 140.74 277.77
5 | 205.4 270.9
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000031.txt:
--------------------------------------------------------------------------------
1 | 151.23 240.71
2 | 274.05 235.52
3 | 217.37 305.99
4 | 158.03 346.06
5 | 272.17 341.09
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000033.txt:
--------------------------------------------------------------------------------
1 | 119.09 94.291
2 | 158.31 96.472
3 | 136.76 121.4
4 | 119.33 134.49
5 | 154.66 136.68
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000037.txt:
--------------------------------------------------------------------------------
1 | 147.37 159.39
2 | 196.94 163.26
3 | 190.68 194.36
4 | 153.72 228.44
5 | 193.94 229.7
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000050.txt:
--------------------------------------------------------------------------------
1 | 150.4 94.799
2 | 205.14 102.07
3 | 179.54 131.16
4 | 144.45 147.42
5 | 193.39 154.14
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000055.txt:
--------------------------------------------------------------------------------
1 | 114.26 193.42
2 | 205.8 190.27
3 | 154.15 244.02
4 | 124.69 295.22
5 | 200.88 292.69
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000114.txt:
--------------------------------------------------------------------------------
1 | 217.52 152.95
2 | 281.48 147.14
3 | 253.02 196.03
4 | 225.79 221.6
5 | 288.25 214.44
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000125.txt:
--------------------------------------------------------------------------------
1 | 90.928 99.858
2 | 146.87 100.33
3 | 114.22 130.36
4 | 91.579 153.32
5 | 143.63 153.56
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/000126.txt:
--------------------------------------------------------------------------------
1 | 307.56 166.54
2 | 387.06 159.62
3 | 335.52 222.26
4 | 319.3 248.85
5 | 397.71 239.14
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/015259.txt:
--------------------------------------------------------------------------------
1 | 226.38 193.65
2 | 319.12 208.97
3 | 279.99 245.88
4 | 213.79 290.55
5 | 303.03 302.1
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/015270.txt:
--------------------------------------------------------------------------------
1 | 208.4 410.08
2 | 364.41 388.68
3 | 291.6 503.57
4 | 244.82 572.86
5 | 383.18 553.49
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/015309.txt:
--------------------------------------------------------------------------------
1 | 284.61 496.57
2 | 562.77 550.78
3 | 395.85 712.84
4 | 238.92 786.8
5 | 495.61 827.22
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/015310.txt:
--------------------------------------------------------------------------------
1 | 153.95 153.43
2 | 211.13 161.54
3 | 197.28 190.26
4 | 150.82 215.98
5 | 202.32 223.12
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/015316.txt:
--------------------------------------------------------------------------------
1 | 481.31 396.88
2 | 667.75 392.43
3 | 557.81 440.55
4 | 490.44 586.28
5 | 640.56 583.2
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/015384.txt:
--------------------------------------------------------------------------------
1 | 191.79 143.97
2 | 271.86 151.23
3 | 191.25 210.29
4 | 187.82 257.12
5 | 258.82 261.96
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd006.txt:
--------------------------------------------------------------------------------
1 | 123.12 117.58
2 | 176.59 122.09
3 | 126.99 144.68
4 | 117.61 183.43
5 | 163.94 186.41
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd025.txt:
--------------------------------------------------------------------------------
1 | 180.12 116.13
2 | 263.18 98.397
3 | 230.48 154.72
4 | 201.37 199.01
5 | 279.18 182.56
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd026.txt:
--------------------------------------------------------------------------------
1 | 171.27 263.54
2 | 286.58 263.88
3 | 203.35 333.02
4 | 170.6 389.42
5 | 281.73 386.84
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd034.txt:
--------------------------------------------------------------------------------
1 | 136.01 167.83
2 | 195.25 151.71
3 | 152.89 191.45
4 | 149.85 235.5
5 | 201.16 222.8
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd051.txt:
--------------------------------------------------------------------------------
1 | 161.92 292.04
2 | 254.21 283.81
3 | 212.75 342.06
4 | 170.78 387.28
5 | 254.6 379.82
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd070.txt:
--------------------------------------------------------------------------------
1 | 276.53 290.35
2 | 383.38 294.75
3 | 314.48 354.66
4 | 275.08 407.72
5 | 364.94 411.48
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd092.txt:
--------------------------------------------------------------------------------
1 | 108.59 149.07
2 | 157.35 143.85
3 | 134.4 173.2
4 | 117.88 200.79
5 | 159.56 196.36
6 |
--------------------------------------------------------------------------------
/datasets/examples/detections/vd102.txt:
--------------------------------------------------------------------------------
1 | 121.62 225.96
2 | 186.73 223.07
3 | 162.99 269.82
4 | 132.12 302.62
5 | 186.42 299.21
6 |
--------------------------------------------------------------------------------
/datasets/examples/vd006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd006.png
--------------------------------------------------------------------------------
/datasets/examples/vd025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd025.png
--------------------------------------------------------------------------------
/datasets/examples/vd026.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd026.png
--------------------------------------------------------------------------------
/datasets/examples/vd034.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd034.png
--------------------------------------------------------------------------------
/datasets/examples/vd051.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd051.png
--------------------------------------------------------------------------------
/datasets/examples/vd070.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd070.png
--------------------------------------------------------------------------------
/datasets/examples/vd092.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd092.png
--------------------------------------------------------------------------------
/datasets/examples/vd102.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/datasets/examples/vd102.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: deep3d_pytorch
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.6
8 | - pytorch=1.6.0
9 | - torchvision=0.7.0
10 | - numpy=1.18.1
11 | - scikit-image=0.16.2
12 | - scipy=1.4.1
13 | - pillow=6.2.1
14 | - pip=20.0.2
15 | - ipython=7.13.0
16 | - yaml=0.1.7
17 | - pip:
18 | - matplotlib==2.2.5
19 | - opencv-python==3.4.9.33
20 | - tensorboard==1.15.0
21 | - tensorflow==1.15.0
22 | - kornia==0.5.5
23 | - dominate==2.6.0
24 | - trimesh==3.9.20
--------------------------------------------------------------------------------
/images/20230425_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/images/20230425_compare.png
--------------------------------------------------------------------------------
/images/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/images/compare.png
--------------------------------------------------------------------------------
/images/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/images/example.gif
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | """This script defines the base network model for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 | import numpy as np
6 | import torch
7 | from collections import OrderedDict
8 | from abc import ABC, abstractmethod
9 | from . import networks
10 |
11 |
12 | class BaseModel(ABC):
13 | """This class is an abstract base class (ABC) for models.
14 | To create a subclass, you need to implement the following five functions:
15 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
16 | -- : unpack data from dataset and apply preprocessing.
17 | -- : produce intermediate results.
18 | -- : calculate losses, gradients, and update network weights.
19 | -- : (optionally) add model-specific options and set default options.
20 | """
21 |
22 | def __init__(self, opt):
23 | """Initialize the BaseModel class.
24 |
25 | Parameters:
26 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
27 |
28 | When creating your custom class, you need to implement your own initialization.
29 | In this fucntion, you should first call
30 | Then, you need to define four lists:
31 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
32 | -- self.model_names (str list): specify the images that you want to display and save.
33 | -- self.visual_names (str list): define networks used in our training.
34 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
35 | """
36 | self.opt = opt
37 | self.isTrain = opt.isTrain
38 | self.device = torch.device('cpu')
39 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
40 | self.loss_names = []
41 | self.model_names = []
42 | self.visual_names = []
43 | self.parallel_names = []
44 | self.optimizers = []
45 | self.image_paths = []
46 | self.metric = 0 # used for learning rate policy 'plateau'
47 |
48 | @staticmethod
49 | def dict_grad_hook_factory(add_func=lambda x: x):
50 | saved_dict = dict()
51 |
52 | def hook_gen(name):
53 | def grad_hook(grad):
54 | saved_vals = add_func(grad)
55 | saved_dict[name] = saved_vals
56 | return grad_hook
57 | return hook_gen, saved_dict
58 |
59 | @staticmethod
60 | def modify_commandline_options(parser, is_train):
61 | """Add new model-specific options, and rewrite default values for existing options.
62 |
63 | Parameters:
64 | parser -- original option parser
65 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
66 |
67 | Returns:
68 | the modified parser.
69 | """
70 | return parser
71 |
72 | @abstractmethod
73 | def set_input(self, input):
74 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
75 |
76 | Parameters:
77 | input (dict): includes the data itself and its metadata information.
78 | """
79 | pass
80 |
81 | @abstractmethod
82 | def forward(self):
83 | """Run forward pass; called by both functions and ."""
84 | pass
85 |
86 | @abstractmethod
87 | def optimize_parameters(self):
88 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
89 | pass
90 |
91 | def setup(self, opt):
92 | """Load and print networks; create schedulers
93 |
94 | Parameters:
95 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
96 | """
97 | if self.isTrain:
98 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
99 |
100 | if not self.isTrain or opt.continue_train:
101 | load_suffix = opt.epoch
102 | self.load_networks(load_suffix)
103 |
104 |
105 | # self.print_networks(opt.verbose)
106 |
107 | def parallelize(self, convert_sync_batchnorm=True):
108 | if not self.opt.use_ddp:
109 | for name in self.parallel_names:
110 | if isinstance(name, str):
111 | module = getattr(self, name)
112 | setattr(self, name, module.to(self.device))
113 | else:
114 | for name in self.model_names:
115 | if isinstance(name, str):
116 | module = getattr(self, name)
117 | if convert_sync_batchnorm:
118 | module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
119 | setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
120 | device_ids=[self.device.index],
121 | find_unused_parameters=True, broadcast_buffers=True))
122 |
123 | # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
124 | for name in self.parallel_names:
125 | if isinstance(name, str) and name not in self.model_names:
126 | module = getattr(self, name)
127 | setattr(self, name, module.to(self.device))
128 |
129 | # put state_dict of optimizer to gpu device
130 | if self.opt.phase != 'test':
131 | if self.opt.continue_train:
132 | for optim in self.optimizers:
133 | for state in optim.state.values():
134 | for k, v in state.items():
135 | if isinstance(v, torch.Tensor):
136 | state[k] = v.to(self.device)
137 |
138 | def data_dependent_initialize(self, data):
139 | pass
140 |
141 | def train(self):
142 | """Make models train mode"""
143 | for name in self.model_names:
144 | if isinstance(name, str):
145 | net = getattr(self, name)
146 | net.train()
147 |
148 | def eval(self):
149 | """Make models eval mode"""
150 | for name in self.model_names:
151 | if isinstance(name, str):
152 | net = getattr(self, name)
153 | net.eval()
154 |
155 | def test(self):
156 | """Forward function used in test time.
157 |
158 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
159 | It also calls to produce additional visualization results
160 | """
161 | with torch.no_grad():
162 | self.forward()
163 | self.compute_visuals()
164 |
165 | def compute_visuals(self):
166 | """Calculate additional output images for visdom and HTML visualization"""
167 | pass
168 |
169 | def get_image_paths(self, name='A'):
170 | """ Return image paths that are used to load current data"""
171 | return self.image_paths if name =='A' else self.image_paths_B
172 |
173 | def update_learning_rate(self):
174 | """Update learning rates for all the networks; called at the end of every epoch"""
175 | for scheduler in self.schedulers:
176 | if self.opt.lr_policy == 'plateau':
177 | scheduler.step(self.metric)
178 | else:
179 | scheduler.step()
180 |
181 | lr = self.optimizers[0].param_groups[0]['lr']
182 | print('learning rate = %.7f' % lr)
183 |
184 | def get_current_visuals(self):
185 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
186 | visual_ret = OrderedDict()
187 | for name in self.visual_names:
188 | if isinstance(name, str):
189 | visual_ret[name] = getattr(self, name)[:, :3, ...]
190 | return visual_ret
191 |
192 | def get_current_losses(self):
193 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
194 | errors_ret = OrderedDict()
195 | for name in self.loss_names:
196 | if isinstance(name, str):
197 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
198 | return errors_ret
199 |
200 | def save_networks(self, epoch):
201 | """Save all the networks to the disk.
202 |
203 | Parameters:
204 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
205 | """
206 | if not os.path.isdir(self.save_dir):
207 | os.makedirs(self.save_dir)
208 |
209 | save_filename = 'epoch_%s.pth' % (epoch)
210 | save_path = os.path.join(self.save_dir, save_filename)
211 |
212 | save_dict = {}
213 | for name in self.model_names:
214 | if isinstance(name, str):
215 | net = getattr(self, name)
216 | if isinstance(net, torch.nn.DataParallel) or isinstance(net,
217 | torch.nn.parallel.DistributedDataParallel):
218 | net = net.module
219 | save_dict[name] = net.state_dict()
220 |
221 |
222 | for i, optim in enumerate(self.optimizers):
223 | save_dict['opt_%02d'%i] = optim.state_dict()
224 |
225 | for i, sched in enumerate(self.schedulers):
226 | save_dict['sched_%02d'%i] = sched.state_dict()
227 |
228 | torch.save(save_dict, save_path)
229 |
230 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
231 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
232 | key = keys[i]
233 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
234 | if module.__class__.__name__.startswith('InstanceNorm') and \
235 | (key == 'running_mean' or key == 'running_var'):
236 | if getattr(module, key) is None:
237 | state_dict.pop('.'.join(keys))
238 | if module.__class__.__name__.startswith('InstanceNorm') and \
239 | (key == 'num_batches_tracked'):
240 | state_dict.pop('.'.join(keys))
241 | else:
242 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
243 |
244 | def load_networks(self, epoch):
245 | """Load all the networks from the disk.
246 |
247 | Parameters:
248 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
249 | """
250 | if self.opt.isTrain and self.opt.pretrained_name is not None:
251 | load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
252 | else:
253 | load_dir = self.save_dir
254 | load_filename = 'epoch_%s.pth' % (epoch)
255 | load_path = os.path.join(load_dir, load_filename)
256 | state_dict = torch.load(load_path, map_location=self.device)
257 | print('loading the model from %s' % load_path)
258 |
259 | for name in self.model_names:
260 | if isinstance(name, str):
261 | net = getattr(self, name)
262 | if isinstance(net, torch.nn.DataParallel):
263 | net = net.module
264 | net.load_state_dict(state_dict[name])
265 |
266 | if self.opt.phase != 'test':
267 | if self.opt.continue_train:
268 | print('loading the optim from %s' % load_path)
269 | for i, optim in enumerate(self.optimizers):
270 | optim.load_state_dict(state_dict['opt_%02d'%i])
271 |
272 | try:
273 | print('loading the sched from %s' % load_path)
274 | for i, sched in enumerate(self.schedulers):
275 | sched.load_state_dict(state_dict['sched_%02d'%i])
276 | except:
277 | print('Failed to load schedulers, set schedulers according to epoch count manually')
278 | for i, sched in enumerate(self.schedulers):
279 | sched.last_epoch = self.opt.epoch_count - 1
280 |
281 |
282 |
283 |
284 | def print_networks(self, verbose):
285 | """Print the total number of parameters in the network and (if verbose) network architecture
286 |
287 | Parameters:
288 | verbose (bool) -- if verbose: print the network architecture
289 | """
290 | print('---------- Networks initialized -------------')
291 | for name in self.model_names:
292 | if isinstance(name, str):
293 | net = getattr(self, name)
294 | num_params = 0
295 | for param in net.parameters():
296 | num_params += param.numel()
297 | if verbose:
298 | print(net)
299 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
300 | print('-----------------------------------------------')
301 |
302 | def set_requires_grad(self, nets, requires_grad=False):
303 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
304 | Parameters:
305 | nets (network list) -- a list of networks
306 | requires_grad (bool) -- whether the networks require gradients or not
307 | """
308 | if not isinstance(nets, list):
309 | nets = [nets]
310 | for net in nets:
311 | if net is not None:
312 | for param in net.parameters():
313 | param.requires_grad = requires_grad
314 |
315 | def generate_visuals_for_evaluation(self, data, mode):
316 | return {}
317 |
--------------------------------------------------------------------------------
/models/bfm.py:
--------------------------------------------------------------------------------
1 | """This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from scipy.io import loadmat
8 | from util.load_mats import transferBFM09
9 | import os
10 |
11 | def perspective_projection(focal, center):
12 | # return p.T (N, 3) @ (3, 3)
13 | return np.array([
14 | focal, 0, center,
15 | 0, focal, center,
16 | 0, 0, 1
17 | ]).reshape([3, 3]).astype(np.float32).transpose()
18 |
19 | class SH:
20 | def __init__(self):
21 | self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
22 | self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
23 |
24 |
25 |
26 | class ParametricFaceModel:
27 | def __init__(self,
28 | bfm_folder='./BFM',
29 | recenter=True,
30 | camera_distance=10.,
31 | init_lit=np.array([
32 | 0.8, 0, 0, 0, 0, 0, 0, 0, 0
33 | ]),
34 | focal=1015.,
35 | center=112.,
36 | is_train=True,
37 | default_name='BFM_model_front.mat'):
38 |
39 | if not os.path.isfile(os.path.join(bfm_folder, default_name)):
40 | transferBFM09(bfm_folder)
41 | model = loadmat(os.path.join(bfm_folder, default_name))
42 | # mean face shape. [3*N,1]
43 | self.mean_shape = model['meanshape'].astype(np.float32)
44 | # identity basis. [3*N,80]
45 | self.id_base = model['idBase'].astype(np.float32)
46 | # expression basis. [3*N,64]
47 | self.exp_base = model['exBase'].astype(np.float32)
48 | # mean face texture. [3*N,1] (0-255)
49 | self.mean_tex = model['meantex'].astype(np.float32)
50 | # texture basis. [3*N,80]
51 | self.tex_base = model['texBase'].astype(np.float32)
52 | # face indices for each vertex that lies in. starts from 0. [N,8]
53 | self.point_buf = model['point_buf'].astype(np.int64) - 1
54 | # vertex indices for each face. starts from 0. [F,3]
55 | self.face_buf = model['tri'].astype(np.int64) - 1
56 | # vertex indices for 68 landmarks. starts from 0. [68,1]
57 | self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
58 |
59 | if is_train:
60 | # vertex indices for small face region to compute photometric error. starts from 0.
61 | self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
62 | # vertex indices for each face from small face region. starts from 0. [f,3]
63 | self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
64 | # vertex indices for pre-defined skin region to compute reflectance loss
65 | self.skin_mask = np.squeeze(model['skinmask'])
66 |
67 | if recenter:
68 | mean_shape = self.mean_shape.reshape([-1, 3])
69 | mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
70 | self.mean_shape = mean_shape.reshape([-1, 1])
71 |
72 | self.persc_proj = perspective_projection(focal, center)
73 | self.device = 'cpu'
74 | self.camera_distance = camera_distance
75 | self.SH = SH()
76 | self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
77 |
78 |
79 | def to(self, device):
80 | self.device = device
81 | for key, value in self.__dict__.items():
82 | if type(value).__module__ == np.__name__:
83 | setattr(self, key, torch.tensor(value).to(device))
84 |
85 |
86 | def compute_shape(self, id_coeff, exp_coeff):
87 | """
88 | Return:
89 | face_shape -- torch.tensor, size (B, N, 3)
90 |
91 | Parameters:
92 | id_coeff -- torch.tensor, size (B, 80), identity coeffs
93 | exp_coeff -- torch.tensor, size (B, 64), expression coeffs
94 | """
95 | batch_size = id_coeff.shape[0]
96 | id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
97 | exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
98 | face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
99 | return face_shape.reshape([batch_size, -1, 3])
100 |
101 |
102 | def compute_texture(self, tex_coeff, normalize=True):
103 | """
104 | Return:
105 | face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
106 |
107 | Parameters:
108 | tex_coeff -- torch.tensor, size (B, 80)
109 | """
110 | batch_size = tex_coeff.shape[0]
111 | face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
112 | if normalize:
113 | face_texture = face_texture / 255.
114 | return face_texture.reshape([batch_size, -1, 3])
115 |
116 |
117 | def compute_norm(self, face_shape):
118 | """
119 | Return:
120 | vertex_norm -- torch.tensor, size (B, N, 3)
121 |
122 | Parameters:
123 | face_shape -- torch.tensor, size (B, N, 3)
124 | """
125 |
126 | v1 = face_shape[:, self.face_buf[:, 0]]
127 | v2 = face_shape[:, self.face_buf[:, 1]]
128 | v3 = face_shape[:, self.face_buf[:, 2]]
129 | e1 = v1 - v2
130 | e2 = v2 - v3
131 | face_norm = torch.cross(e1, e2, dim=-1)
132 | face_norm = F.normalize(face_norm, dim=-1, p=2)
133 | face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
134 |
135 | vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
136 | vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
137 | return vertex_norm
138 |
139 |
140 | def compute_color(self, face_texture, face_norm, gamma):
141 | """
142 | Return:
143 | face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
144 |
145 | Parameters:
146 | face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
147 | face_norm -- torch.tensor, size (B, N, 3), rotated face normal
148 | gamma -- torch.tensor, size (B, 27), SH coeffs
149 | """
150 | batch_size = gamma.shape[0]
151 | v_num = face_texture.shape[1]
152 | a, c = self.SH.a, self.SH.c
153 | gamma = gamma.reshape([batch_size, 3, 9])
154 | gamma = gamma + self.init_lit
155 | gamma = gamma.permute(0, 2, 1)
156 | Y = torch.cat([
157 | a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
158 | -a[1] * c[1] * face_norm[..., 1:2],
159 | a[1] * c[1] * face_norm[..., 2:],
160 | -a[1] * c[1] * face_norm[..., :1],
161 | a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
162 | -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
163 | 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
164 | -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
165 | 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
166 | ], dim=-1)
167 | r = Y @ gamma[..., :1]
168 | g = Y @ gamma[..., 1:2]
169 | b = Y @ gamma[..., 2:]
170 | face_color = torch.cat([r, g, b], dim=-1) * face_texture
171 | return face_color
172 |
173 |
174 | def compute_rotation(self, angles):
175 | """
176 | Return:
177 | rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
178 |
179 | Parameters:
180 | angles -- torch.tensor, size (B, 3), radian
181 | """
182 |
183 | batch_size = angles.shape[0]
184 | ones = torch.ones([batch_size, 1]).to(self.device)
185 | zeros = torch.zeros([batch_size, 1]).to(self.device)
186 | x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
187 |
188 | rot_x = torch.cat([
189 | ones, zeros, zeros,
190 | zeros, torch.cos(x), -torch.sin(x),
191 | zeros, torch.sin(x), torch.cos(x)
192 | ], dim=1).reshape([batch_size, 3, 3])
193 |
194 | rot_y = torch.cat([
195 | torch.cos(y), zeros, torch.sin(y),
196 | zeros, ones, zeros,
197 | -torch.sin(y), zeros, torch.cos(y)
198 | ], dim=1).reshape([batch_size, 3, 3])
199 |
200 | rot_z = torch.cat([
201 | torch.cos(z), -torch.sin(z), zeros,
202 | torch.sin(z), torch.cos(z), zeros,
203 | zeros, zeros, ones
204 | ], dim=1).reshape([batch_size, 3, 3])
205 |
206 | rot = rot_z @ rot_y @ rot_x
207 | return rot.permute(0, 2, 1)
208 |
209 |
210 | def to_camera(self, face_shape):
211 | face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
212 | return face_shape
213 |
214 | def to_image(self, face_shape):
215 | """
216 | Return:
217 | face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
218 |
219 | Parameters:
220 | face_shape -- torch.tensor, size (B, N, 3)
221 | """
222 | # to image_plane
223 | face_proj = face_shape @ self.persc_proj
224 | face_proj = face_proj[..., :2] / face_proj[..., 2:]
225 |
226 | return face_proj
227 |
228 |
229 | def transform(self, face_shape, rot, trans):
230 | """
231 | Return:
232 | face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
233 |
234 | Parameters:
235 | face_shape -- torch.tensor, size (B, N, 3)
236 | rot -- torch.tensor, size (B, 3, 3)
237 | trans -- torch.tensor, size (B, 3)
238 | """
239 | return face_shape @ rot + trans.unsqueeze(1)
240 |
241 |
242 | def get_landmarks(self, face_proj):
243 | """
244 | Return:
245 | face_lms -- torch.tensor, size (B, 68, 2)
246 |
247 | Parameters:
248 | face_proj -- torch.tensor, size (B, N, 2)
249 | """
250 | return face_proj[:, self.keypoints]
251 |
252 | def split_coeff(self, coeffs):
253 | """
254 | Return:
255 | coeffs_dict -- a dict of torch.tensors
256 |
257 | Parameters:
258 | coeffs -- torch.tensor, size (B, 256)
259 | """
260 | id_coeffs = coeffs[:, :80]
261 | exp_coeffs = coeffs[:, 80: 144]
262 | tex_coeffs = coeffs[:, 144: 224]
263 | angles = coeffs[:, 224: 227]
264 | gammas = coeffs[:, 227: 254]
265 | translations = coeffs[:, 254:]
266 | return {
267 | 'id': id_coeffs,
268 | 'exp': exp_coeffs,
269 | 'tex': tex_coeffs,
270 | 'angle': angles,
271 | 'gamma': gammas,
272 | 'trans': translations
273 | }
274 | def compute_for_render(self, coeffs):
275 | """
276 | Return:
277 | face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
278 | face_color -- torch.tensor, size (B, N, 3), in RGB order
279 | landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
280 | Parameters:
281 | coeffs -- torch.tensor, size (B, 257)
282 | """
283 | coef_dict = self.split_coeff(coeffs)
284 | face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
285 | rotation = self.compute_rotation(coef_dict['angle'])
286 |
287 |
288 | face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
289 | face_vertex = self.to_camera(face_shape_transformed)
290 |
291 | face_proj = self.to_image(face_vertex)
292 | landmark = self.get_landmarks(face_proj)
293 |
294 | face_texture = self.compute_texture(coef_dict['tex'])
295 | face_norm = self.compute_norm(face_shape)
296 | face_norm_roted = face_norm @ rotation
297 | face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
298 |
299 | return face_vertex, face_texture, face_color, landmark
300 |
--------------------------------------------------------------------------------
/models/facerecon_model.py:
--------------------------------------------------------------------------------
1 | """This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | import torch
6 | from .base_model import BaseModel
7 | from . import networks
8 | from .bfm import ParametricFaceModel
9 | from .losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
10 | from util import util
11 | from util.nvdiffrast import MeshRenderer
12 | from util.preprocess import estimate_norm_torch
13 |
14 | import trimesh
15 | from scipy.io import savemat
16 |
17 | class FaceReconModel(BaseModel):
18 |
19 | @staticmethod
20 | def modify_commandline_options(parser, is_train=True):
21 | """ Configures options specific for CUT model
22 | """
23 | # net structure and parameters
24 | parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
25 | parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth')
26 | parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
27 | parser.add_argument('--bfm_folder', type=str, default='BFM')
28 | parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
29 |
30 | # renderer parameters
31 | parser.add_argument('--focal', type=float, default=1015.)
32 | parser.add_argument('--center', type=float, default=112.)
33 | parser.add_argument('--camera_d', type=float, default=10.)
34 | parser.add_argument('--z_near', type=float, default=5.)
35 | parser.add_argument('--z_far', type=float, default=15.)
36 | parser.add_argument('--use_opengl', type=util.str2bool, nargs='?', const=True, default=True, help='use opengl context or not')
37 |
38 | if is_train:
39 | # training parameters
40 | parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
41 | parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
42 | parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
43 | parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
44 |
45 |
46 | # augmentation parameters
47 | parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
48 | parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
49 | parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
50 |
51 | # loss weights
52 | parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
53 | parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
54 | parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
55 | parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
56 | parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
57 | parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
58 | parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
59 | parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
60 | parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
61 |
62 |
63 |
64 | opt, _ = parser.parse_known_args()
65 | parser.set_defaults(
66 | focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
67 | )
68 | if is_train:
69 | parser.set_defaults(
70 | use_crop_face=True, use_predef_M=False
71 | )
72 | return parser
73 |
74 | def __init__(self, opt):
75 | """Initialize this model class.
76 |
77 | Parameters:
78 | opt -- training/test options
79 |
80 | A few things can be done here.
81 | - (required) call the initialization function of BaseModel
82 | - define loss function, visualization images, model names, and optimizers
83 | """
84 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel
85 |
86 | self.visual_names = ['output_vis']
87 | self.model_names = ['net_recon']
88 | self.parallel_names = self.model_names + ['renderer']
89 |
90 | self.net_recon = networks.define_net_recon(
91 | net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path
92 | )
93 |
94 | self.facemodel = ParametricFaceModel(
95 | bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
96 | is_train=self.isTrain, default_name=opt.bfm_model
97 | )
98 |
99 | fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
100 | self.renderer = MeshRenderer(
101 | rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center), use_opengl=opt.use_opengl
102 | )
103 |
104 | if self.isTrain:
105 | self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
106 |
107 | self.net_recog = networks.define_net_recog(
108 | net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
109 | )
110 | # loss func name: (compute_%s_loss) % loss_name
111 | self.compute_feat_loss = perceptual_loss
112 | self.comupte_color_loss = photo_loss
113 | self.compute_lm_loss = landmark_loss
114 | self.compute_reg_loss = reg_loss
115 | self.compute_reflc_loss = reflectance_loss
116 |
117 | self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
118 | self.optimizers = [self.optimizer]
119 | self.parallel_names += ['net_recog']
120 | # Our program will automatically call to define schedulers, load networks, and print networks
121 |
122 | def set_input(self, input):
123 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
124 |
125 | Parameters:
126 | input: a dictionary that contains the data itself and its metadata information.
127 | """
128 | self.input_img = input['imgs'].to(self.device)
129 | self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
130 | self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
131 | self.trans_m = input['M'].to(self.device) if 'M' in input else None
132 | self.image_paths = input['im_paths'] if 'im_paths' in input else None
133 |
134 | def forward(self):
135 | output_coeff = self.net_recon(self.input_img)
136 | self.facemodel.to(self.device)
137 | self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
138 | self.facemodel.compute_for_render(output_coeff)
139 | self.pred_mask, _, self.pred_face = self.renderer(
140 | self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
141 |
142 | self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
143 |
144 |
145 | def compute_losses(self):
146 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
147 |
148 | assert self.net_recog.training == False
149 | trans_m = self.trans_m
150 | if not self.opt.use_predef_M:
151 | trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
152 |
153 | pred_feat = self.net_recog(self.pred_face, trans_m)
154 | gt_feat = self.net_recog(self.input_img, self.trans_m)
155 | self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
156 |
157 | face_mask = self.pred_mask
158 | if self.opt.use_crop_face:
159 | face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
160 |
161 | face_mask = face_mask.detach()
162 | self.loss_color = self.opt.w_color * self.comupte_color_loss(
163 | self.pred_face, self.input_img, self.atten_mask * face_mask)
164 |
165 | loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
166 | self.loss_reg = self.opt.w_reg * loss_reg
167 | self.loss_gamma = self.opt.w_gamma * loss_gamma
168 |
169 | self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
170 |
171 | self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
172 |
173 | self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
174 | + self.loss_lm + self.loss_reflc
175 |
176 |
177 | def optimize_parameters(self, isTrain=True):
178 | self.forward()
179 | self.compute_losses()
180 | """Update network weights; it will be called in every training iteration."""
181 | if isTrain:
182 | self.optimizer.zero_grad()
183 | self.loss_all.backward()
184 | self.optimizer.step()
185 |
186 | def compute_visuals(self):
187 | with torch.no_grad():
188 | input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
189 | output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
190 | output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
191 |
192 | if self.gt_lm is not None:
193 | gt_lm_numpy = self.gt_lm.cpu().numpy()
194 | pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
195 | output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
196 | output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
197 |
198 | output_vis_numpy = np.concatenate((input_img_numpy,
199 | output_vis_numpy_raw, output_vis_numpy), axis=-2)
200 | else:
201 | output_vis_numpy = np.concatenate((input_img_numpy,
202 | output_vis_numpy_raw), axis=-2)
203 |
204 | self.output_vis = torch.tensor(
205 | output_vis_numpy / 255., dtype=torch.float32
206 | ).permute(0, 3, 1, 2).to(self.device)
207 |
208 | def save_mesh(self, name):
209 |
210 | recon_shape = self.pred_vertex # get reconstructed shape
211 | recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
212 | recon_shape = recon_shape.cpu().numpy()[0]
213 | recon_color = self.pred_color
214 | recon_color = recon_color.cpu().numpy()[0]
215 | tri = self.facemodel.face_buf.cpu().numpy()
216 | mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8), process=False)
217 | mesh.export(name)
218 |
219 | def save_coeff(self,name):
220 |
221 | pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
222 | pred_lm = self.pred_lm.cpu().numpy()
223 | pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
224 | pred_coeffs['lm68'] = pred_lm
225 | savemat(name,pred_coeffs)
226 |
227 |
228 |
229 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from kornia.geometry import warp_affine
5 | import torch.nn.functional as F
6 |
7 | def resize_n_crop(image, M, dsize=112):
8 | # image: (b, c, h, w)
9 | # M : (b, 2, 3)
10 | return warp_affine(image, M, dsize=(dsize, dsize))
11 |
12 | ### perceptual level loss
13 | class PerceptualLoss(nn.Module):
14 | def __init__(self, recog_net, input_size=112):
15 | super(PerceptualLoss, self).__init__()
16 | self.recog_net = recog_net
17 | self.preprocess = lambda x: 2 * x - 1
18 | self.input_size=input_size
19 | def forward(imageA, imageB, M):
20 | """
21 | 1 - cosine distance
22 | Parameters:
23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
24 | imageB --same as imageA
25 | """
26 |
27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
29 |
30 | # freeze bn
31 | self.recog_net.eval()
32 |
33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
36 | # assert torch.sum((cosine_d > 1).float()) == 0
37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0]
38 |
39 | def perceptual_loss(id_featureA, id_featureB):
40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
41 | # assert torch.sum((cosine_d > 1).float()) == 0
42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0]
43 |
44 | ### image level loss
45 | def photo_loss(imageA, imageB, mask, eps=1e-6):
46 | """
47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
48 | Parameters:
49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
50 | imageB --same as imageA
51 | """
52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
54 | return loss
55 |
56 | def landmark_loss(predict_lm, gt_lm, weight=None):
57 | """
58 | weighted mse loss
59 | Parameters:
60 | predict_lm --torch.tensor (B, 68, 2)
61 | gt_lm --torch.tensor (B, 68, 2)
62 | weight --numpy.array (1, 68)
63 | """
64 | if not weight:
65 | weight = np.ones([68])
66 | weight[28:31] = 20
67 | weight[-8:] = 20
68 | weight = np.expand_dims(weight, 0)
69 | weight = torch.tensor(weight).to(predict_lm.device)
70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
72 | return loss
73 |
74 |
75 | ### regulization
76 | def reg_loss(coeffs_dict, opt=None):
77 | """
78 | l2 norm without the sqrt, from yu's implementation (mse)
79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
80 | Parameters:
81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
82 |
83 | """
84 | # coefficient regularization to ensure plausible 3d faces
85 | if opt:
86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
87 | else:
88 | w_id, w_exp, w_tex = 1, 1, 1, 1
89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2)
92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0]
93 |
94 | # gamma regularization to ensure a nearly-monochromatic light
95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
98 |
99 | return creg_loss, gamma_loss
100 |
101 | def reflectance_loss(texture, mask):
102 | """
103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
104 | Parameters:
105 | texture --torch.tensor, (B, N, 3)
106 | mask --torch.tensor, (N), 1 or 0
107 |
108 | """
109 | mask = mask.reshape([1, mask.shape[0], 1])
110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
112 | return loss
113 |
114 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | """This script defines deep neural networks for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from torch.nn import init
8 | import functools
9 | from torch.optim import lr_scheduler
10 | import torch
11 | from torch import Tensor
12 | import torch.nn as nn
13 | try:
14 | from torch.hub import load_state_dict_from_url
15 | except ImportError:
16 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
17 | from typing import Type, Any, Callable, Union, List, Optional
18 | from .arcface_torch.backbones import get_model
19 | from kornia.geometry import warp_affine
20 |
21 | def resize_n_crop(image, M, dsize=112):
22 | # image: (b, c, h, w)
23 | # M : (b, 2, 3)
24 | return warp_affine(image, M, dsize=(dsize, dsize))
25 |
26 | def filter_state_dict(state_dict, remove_name='fc'):
27 | new_state_dict = {}
28 | for key in state_dict:
29 | if remove_name in key:
30 | continue
31 | new_state_dict[key] = state_dict[key]
32 | return new_state_dict
33 |
34 | def get_scheduler(optimizer, opt):
35 | """Return a learning rate scheduler
36 |
37 | Parameters:
38 | optimizer -- the optimizer of the network
39 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
40 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
41 |
42 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
43 | See https://pytorch.org/docs/stable/optim.html for more details.
44 | """
45 | if opt.lr_policy == 'linear':
46 | def lambda_rule(epoch):
47 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
48 | return lr_l
49 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
50 | elif opt.lr_policy == 'step':
51 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
52 | elif opt.lr_policy == 'plateau':
53 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
54 | elif opt.lr_policy == 'cosine':
55 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
56 | else:
57 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
58 | return scheduler
59 |
60 |
61 | def define_net_recon(net_recon, use_last_fc=False, init_path=None):
62 | return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
63 |
64 | def define_net_recog(net_recog, pretrained_path=None):
65 | net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
66 | net.eval()
67 | return net
68 |
69 | class ReconNetWrapper(nn.Module):
70 | fc_dim=257
71 | def __init__(self, net_recon, use_last_fc=False, init_path=None):
72 | super(ReconNetWrapper, self).__init__()
73 | self.use_last_fc = use_last_fc
74 | if net_recon not in func_dict:
75 | return NotImplementedError('network [%s] is not implemented', net_recon)
76 | func, last_dim = func_dict[net_recon]
77 | backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
78 | if init_path and os.path.isfile(init_path):
79 | state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
80 | backbone.load_state_dict(state_dict)
81 | print("loading init net_recon %s from %s" %(net_recon, init_path))
82 | self.backbone = backbone
83 | if not use_last_fc:
84 | self.final_layers = nn.ModuleList([
85 | conv1x1(last_dim, 80, bias=True), # id layer
86 | conv1x1(last_dim, 64, bias=True), # exp layer
87 | conv1x1(last_dim, 80, bias=True), # tex layer
88 | conv1x1(last_dim, 3, bias=True), # angle layer
89 | conv1x1(last_dim, 27, bias=True), # gamma layer
90 | conv1x1(last_dim, 2, bias=True), # tx, ty
91 | conv1x1(last_dim, 1, bias=True) # tz
92 | ])
93 | for m in self.final_layers:
94 | nn.init.constant_(m.weight, 0.)
95 | nn.init.constant_(m.bias, 0.)
96 |
97 | def forward(self, x):
98 | x = self.backbone(x)
99 | if not self.use_last_fc:
100 | output = []
101 | for layer in self.final_layers:
102 | output.append(layer(x))
103 | x = torch.flatten(torch.cat(output, dim=1), 1)
104 | return x
105 |
106 |
107 | class RecogNetWrapper(nn.Module):
108 | def __init__(self, net_recog, pretrained_path=None, input_size=112):
109 | super(RecogNetWrapper, self).__init__()
110 | net = get_model(name=net_recog, fp16=False)
111 | if pretrained_path:
112 | state_dict = torch.load(pretrained_path, map_location='cpu')
113 | net.load_state_dict(state_dict)
114 | print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
115 | for param in net.parameters():
116 | param.requires_grad = False
117 | self.net = net
118 | self.preprocess = lambda x: 2 * x - 1
119 | self.input_size=input_size
120 |
121 | def forward(self, image, M):
122 | image = self.preprocess(resize_n_crop(image, M, self.input_size))
123 | id_feature = F.normalize(self.net(image), dim=-1, p=2)
124 | return id_feature
125 |
126 |
127 | # adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
128 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
129 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
130 | 'wide_resnet50_2', 'wide_resnet101_2']
131 |
132 |
133 | model_urls = {
134 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
135 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
136 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
137 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
138 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
139 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
140 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
141 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
142 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
143 | }
144 |
145 |
146 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
147 | """3x3 convolution with padding"""
148 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
149 | padding=dilation, groups=groups, bias=False, dilation=dilation)
150 |
151 |
152 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
153 | """1x1 convolution"""
154 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
155 |
156 |
157 | class BasicBlock(nn.Module):
158 | expansion: int = 1
159 |
160 | def __init__(
161 | self,
162 | inplanes: int,
163 | planes: int,
164 | stride: int = 1,
165 | downsample: Optional[nn.Module] = None,
166 | groups: int = 1,
167 | base_width: int = 64,
168 | dilation: int = 1,
169 | norm_layer: Optional[Callable[..., nn.Module]] = None
170 | ) -> None:
171 | super(BasicBlock, self).__init__()
172 | if norm_layer is None:
173 | norm_layer = nn.BatchNorm2d
174 | if groups != 1 or base_width != 64:
175 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
176 | if dilation > 1:
177 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
178 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
179 | self.conv1 = conv3x3(inplanes, planes, stride)
180 | self.bn1 = norm_layer(planes)
181 | self.relu = nn.ReLU(inplace=True)
182 | self.conv2 = conv3x3(planes, planes)
183 | self.bn2 = norm_layer(planes)
184 | self.downsample = downsample
185 | self.stride = stride
186 |
187 | def forward(self, x: Tensor) -> Tensor:
188 | identity = x
189 |
190 | out = self.conv1(x)
191 | out = self.bn1(out)
192 | out = self.relu(out)
193 |
194 | out = self.conv2(out)
195 | out = self.bn2(out)
196 |
197 | if self.downsample is not None:
198 | identity = self.downsample(x)
199 |
200 | out += identity
201 | out = self.relu(out)
202 |
203 | return out
204 |
205 |
206 | class Bottleneck(nn.Module):
207 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
208 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
209 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
210 | # This variant is also known as ResNet V1.5 and improves accuracy according to
211 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
212 |
213 | expansion: int = 4
214 |
215 | def __init__(
216 | self,
217 | inplanes: int,
218 | planes: int,
219 | stride: int = 1,
220 | downsample: Optional[nn.Module] = None,
221 | groups: int = 1,
222 | base_width: int = 64,
223 | dilation: int = 1,
224 | norm_layer: Optional[Callable[..., nn.Module]] = None
225 | ) -> None:
226 | super(Bottleneck, self).__init__()
227 | if norm_layer is None:
228 | norm_layer = nn.BatchNorm2d
229 | width = int(planes * (base_width / 64.)) * groups
230 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
231 | self.conv1 = conv1x1(inplanes, width)
232 | self.bn1 = norm_layer(width)
233 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
234 | self.bn2 = norm_layer(width)
235 | self.conv3 = conv1x1(width, planes * self.expansion)
236 | self.bn3 = norm_layer(planes * self.expansion)
237 | self.relu = nn.ReLU(inplace=True)
238 | self.downsample = downsample
239 | self.stride = stride
240 |
241 | def forward(self, x: Tensor) -> Tensor:
242 | identity = x
243 |
244 | out = self.conv1(x)
245 | out = self.bn1(out)
246 | out = self.relu(out)
247 |
248 | out = self.conv2(out)
249 | out = self.bn2(out)
250 | out = self.relu(out)
251 |
252 | out = self.conv3(out)
253 | out = self.bn3(out)
254 |
255 | if self.downsample is not None:
256 | identity = self.downsample(x)
257 |
258 | out += identity
259 | out = self.relu(out)
260 |
261 | return out
262 |
263 |
264 | class ResNet(nn.Module):
265 |
266 | def __init__(
267 | self,
268 | block: Type[Union[BasicBlock, Bottleneck]],
269 | layers: List[int],
270 | num_classes: int = 1000,
271 | zero_init_residual: bool = False,
272 | use_last_fc: bool = False,
273 | groups: int = 1,
274 | width_per_group: int = 64,
275 | replace_stride_with_dilation: Optional[List[bool]] = None,
276 | norm_layer: Optional[Callable[..., nn.Module]] = None
277 | ) -> None:
278 | super(ResNet, self).__init__()
279 | if norm_layer is None:
280 | norm_layer = nn.BatchNorm2d
281 | self._norm_layer = norm_layer
282 |
283 | self.inplanes = 64
284 | self.dilation = 1
285 | if replace_stride_with_dilation is None:
286 | # each element in the tuple indicates if we should replace
287 | # the 2x2 stride with a dilated convolution instead
288 | replace_stride_with_dilation = [False, False, False]
289 | if len(replace_stride_with_dilation) != 3:
290 | raise ValueError("replace_stride_with_dilation should be None "
291 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
292 | self.use_last_fc = use_last_fc
293 | self.groups = groups
294 | self.base_width = width_per_group
295 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
296 | bias=False)
297 | self.bn1 = norm_layer(self.inplanes)
298 | self.relu = nn.ReLU(inplace=True)
299 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
300 | self.layer1 = self._make_layer(block, 64, layers[0])
301 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
302 | dilate=replace_stride_with_dilation[0])
303 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
304 | dilate=replace_stride_with_dilation[1])
305 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
306 | dilate=replace_stride_with_dilation[2])
307 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
308 |
309 | if self.use_last_fc:
310 | self.fc = nn.Linear(512 * block.expansion, num_classes)
311 |
312 | for m in self.modules():
313 | if isinstance(m, nn.Conv2d):
314 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
315 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
316 | nn.init.constant_(m.weight, 1)
317 | nn.init.constant_(m.bias, 0)
318 |
319 |
320 |
321 | # Zero-initialize the last BN in each residual branch,
322 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
323 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
324 | if zero_init_residual:
325 | for m in self.modules():
326 | if isinstance(m, Bottleneck):
327 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
328 | elif isinstance(m, BasicBlock):
329 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
330 |
331 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
332 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
333 | norm_layer = self._norm_layer
334 | downsample = None
335 | previous_dilation = self.dilation
336 | if dilate:
337 | self.dilation *= stride
338 | stride = 1
339 | if stride != 1 or self.inplanes != planes * block.expansion:
340 | downsample = nn.Sequential(
341 | conv1x1(self.inplanes, planes * block.expansion, stride),
342 | norm_layer(planes * block.expansion),
343 | )
344 |
345 | layers = []
346 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
347 | self.base_width, previous_dilation, norm_layer))
348 | self.inplanes = planes * block.expansion
349 | for _ in range(1, blocks):
350 | layers.append(block(self.inplanes, planes, groups=self.groups,
351 | base_width=self.base_width, dilation=self.dilation,
352 | norm_layer=norm_layer))
353 |
354 | return nn.Sequential(*layers)
355 |
356 | def _forward_impl(self, x: Tensor) -> Tensor:
357 | # See note [TorchScript super()]
358 | x = self.conv1(x)
359 | x = self.bn1(x)
360 | x = self.relu(x)
361 | x = self.maxpool(x)
362 |
363 | x = self.layer1(x)
364 | x = self.layer2(x)
365 | x = self.layer3(x)
366 | x = self.layer4(x)
367 |
368 | x = self.avgpool(x)
369 | if self.use_last_fc:
370 | x = torch.flatten(x, 1)
371 | x = self.fc(x)
372 | return x
373 |
374 | def forward(self, x: Tensor) -> Tensor:
375 | return self._forward_impl(x)
376 |
377 |
378 | def _resnet(
379 | arch: str,
380 | block: Type[Union[BasicBlock, Bottleneck]],
381 | layers: List[int],
382 | pretrained: bool,
383 | progress: bool,
384 | **kwargs: Any
385 | ) -> ResNet:
386 | model = ResNet(block, layers, **kwargs)
387 | if pretrained:
388 | state_dict = load_state_dict_from_url(model_urls[arch],
389 | progress=progress)
390 | model.load_state_dict(state_dict)
391 | return model
392 |
393 |
394 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
395 | r"""ResNet-18 model from
396 | `"Deep Residual Learning for Image Recognition" `_.
397 |
398 | Args:
399 | pretrained (bool): If True, returns a model pre-trained on ImageNet
400 | progress (bool): If True, displays a progress bar of the download to stderr
401 | """
402 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
403 | **kwargs)
404 |
405 |
406 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
407 | r"""ResNet-34 model from
408 | `"Deep Residual Learning for Image Recognition" `_.
409 |
410 | Args:
411 | pretrained (bool): If True, returns a model pre-trained on ImageNet
412 | progress (bool): If True, displays a progress bar of the download to stderr
413 | """
414 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
415 | **kwargs)
416 |
417 |
418 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
419 | r"""ResNet-50 model from
420 | `"Deep Residual Learning for Image Recognition" `_.
421 |
422 | Args:
423 | pretrained (bool): If True, returns a model pre-trained on ImageNet
424 | progress (bool): If True, displays a progress bar of the download to stderr
425 | """
426 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
427 | **kwargs)
428 |
429 |
430 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
431 | r"""ResNet-101 model from
432 | `"Deep Residual Learning for Image Recognition" `_.
433 |
434 | Args:
435 | pretrained (bool): If True, returns a model pre-trained on ImageNet
436 | progress (bool): If True, displays a progress bar of the download to stderr
437 | """
438 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
439 | **kwargs)
440 |
441 |
442 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
443 | r"""ResNet-152 model from
444 | `"Deep Residual Learning for Image Recognition" `_.
445 |
446 | Args:
447 | pretrained (bool): If True, returns a model pre-trained on ImageNet
448 | progress (bool): If True, displays a progress bar of the download to stderr
449 | """
450 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
451 | **kwargs)
452 |
453 |
454 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
455 | r"""ResNeXt-50 32x4d model from
456 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
457 |
458 | Args:
459 | pretrained (bool): If True, returns a model pre-trained on ImageNet
460 | progress (bool): If True, displays a progress bar of the download to stderr
461 | """
462 | kwargs['groups'] = 32
463 | kwargs['width_per_group'] = 4
464 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
465 | pretrained, progress, **kwargs)
466 |
467 |
468 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
469 | r"""ResNeXt-101 32x8d model from
470 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
471 |
472 | Args:
473 | pretrained (bool): If True, returns a model pre-trained on ImageNet
474 | progress (bool): If True, displays a progress bar of the download to stderr
475 | """
476 | kwargs['groups'] = 32
477 | kwargs['width_per_group'] = 8
478 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
479 | pretrained, progress, **kwargs)
480 |
481 |
482 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
483 | r"""Wide ResNet-50-2 model from
484 | `"Wide Residual Networks" `_.
485 |
486 | The model is the same as ResNet except for the bottleneck number of channels
487 | which is twice larger in every block. The number of channels in outer 1x1
488 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
489 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
490 |
491 | Args:
492 | pretrained (bool): If True, returns a model pre-trained on ImageNet
493 | progress (bool): If True, displays a progress bar of the download to stderr
494 | """
495 | kwargs['width_per_group'] = 64 * 2
496 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
497 | pretrained, progress, **kwargs)
498 |
499 |
500 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
501 | r"""Wide ResNet-101-2 model from
502 | `"Wide Residual Networks" `_.
503 |
504 | The model is the same as ResNet except for the bottleneck number of channels
505 | which is twice larger in every block. The number of channels in outer 1x1
506 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
507 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
508 |
509 | Args:
510 | pretrained (bool): If True, returns a model pre-trained on ImageNet
511 | progress (bool): If True, displays a progress bar of the download to stderr
512 | """
513 | kwargs['width_per_group'] = 64 * 2
514 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
515 | pretrained, progress, **kwargs)
516 |
517 |
518 | func_dict = {
519 | 'resnet18': (resnet18, 512),
520 | 'resnet50': (resnet50, 2048)
521 | }
522 |
--------------------------------------------------------------------------------
/models/template_model.py:
--------------------------------------------------------------------------------
1 | """Model class template
2 |
3 | This module provides a template for users to implement custom models.
4 | You can specify '--model template' to use this model.
5 | The class name should be consistent with both the filename and its model option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | It implements a simple image-to-image translation baseline based on regression loss.
9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
10 | min_ ||netG(data_A) - data_B||_1
11 | You need to implement the following functions:
12 | : Add model-specific options and rewrite default values for existing options.
13 | <__init__>: Initialize this model class.
14 | : Unpack input data and perform data pre-processing.
15 | : Run forward pass. This will be called by both and .
16 | : Update network weights; it will be called in every training iteration.
17 | """
18 | import numpy as np
19 | import torch
20 | from .base_model import BaseModel
21 | from . import networks
22 |
23 |
24 | class TemplateModel(BaseModel):
25 | @staticmethod
26 | def modify_commandline_options(parser, is_train=True):
27 | """Add new model-specific options and rewrite default values for existing options.
28 |
29 | Parameters:
30 | parser -- the option parser
31 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
32 |
33 | Returns:
34 | the modified parser.
35 | """
36 | parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
37 | if is_train:
38 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
39 |
40 | return parser
41 |
42 | def __init__(self, opt):
43 | """Initialize this model class.
44 |
45 | Parameters:
46 | opt -- training/test options
47 |
48 | A few things can be done here.
49 | - (required) call the initialization function of BaseModel
50 | - define loss function, visualization images, model names, and optimizers
51 | """
52 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel
53 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
54 | self.loss_names = ['loss_G']
55 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
56 | self.visual_names = ['data_A', 'data_B', 'output']
57 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
58 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
59 | self.model_names = ['G']
60 | # define networks; you can use opt.isTrain to specify different behaviors for training and test.
61 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
62 | if self.isTrain: # only defined during training time
63 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
64 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
65 | self.criterionLoss = torch.nn.L1Loss()
66 | # define and initialize optimizers. You can define one optimizer for each network.
67 | # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
68 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
69 | self.optimizers = [self.optimizer]
70 |
71 | # Our program will automatically call to define schedulers, load networks, and print networks
72 |
73 | def set_input(self, input):
74 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
75 |
76 | Parameters:
77 | input: a dictionary that contains the data itself and its metadata information.
78 | """
79 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
80 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
81 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
82 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
83 |
84 | def forward(self):
85 | """Run forward pass. This will be called by both functions and ."""
86 | self.output = self.netG(self.data_A) # generate output image given the input data_A
87 |
88 | def backward(self):
89 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
90 | # caculate the intermediate results if necessary; here self.output has been computed during function
91 | # calculate loss given the input and intermediate results
92 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
93 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
94 |
95 | def optimize_parameters(self):
96 | """Update network weights; it will be called in every training iteration."""
97 | self.forward() # first call forward to calculate intermediate results
98 | self.optimizer.zero_grad() # clear network G's existing gradients
99 | self.backward() # calculate gradients for network G
100 | self.optimizer.step() # update gradients for network G
101 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | """This script contains base options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import argparse
5 | import os
6 | from util import util
7 | import numpy as np
8 | import torch
9 | import models
10 | import data
11 |
12 |
13 | class BaseOptions():
14 | """This class defines options used during both training and test time.
15 |
16 | It also implements several helper functions such as parsing, printing, and saving the options.
17 | It also gathers additional options defined in functions in both dataset class and model class.
18 | """
19 |
20 | def __init__(self, cmd_line=None):
21 | """Reset the class; indicates the class hasn't been initailized"""
22 | self.initialized = False
23 | self.cmd_line = None
24 | if cmd_line is not None:
25 | self.cmd_line = cmd_line.split()
26 |
27 | def initialize(self, parser):
28 | """Define the common options that are used in both training and test."""
29 | # basic parameters
30 | parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
31 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
32 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
33 | parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
34 | parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
35 | parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
36 | parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
37 | parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
38 | parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
39 | parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
40 |
41 | # model parameters
42 | parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
43 |
44 | # additional parameters
45 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
46 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
47 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
48 |
49 | self.initialized = True
50 | return parser
51 |
52 | def gather_options(self):
53 | """Initialize our parser with basic options(only once).
54 | Add additional model-specific and dataset-specific options.
55 | These options are defined in the function
56 | in model and dataset classes.
57 | """
58 | if not self.initialized: # check if it has been initialized
59 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
60 | parser = self.initialize(parser)
61 |
62 | # get the basic options
63 | if self.cmd_line is None:
64 | opt, _ = parser.parse_known_args()
65 | else:
66 | opt, _ = parser.parse_known_args(self.cmd_line)
67 |
68 | # set cuda visible devices
69 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
70 |
71 | # modify model-related parser options
72 | model_name = opt.model
73 | model_option_setter = models.get_option_setter(model_name)
74 | parser = model_option_setter(parser, self.isTrain)
75 | if self.cmd_line is None:
76 | opt, _ = parser.parse_known_args() # parse again with new defaults
77 | else:
78 | opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
79 |
80 | # modify dataset-related parser options
81 | if opt.dataset_mode:
82 | dataset_name = opt.dataset_mode
83 | dataset_option_setter = data.get_option_setter(dataset_name)
84 | parser = dataset_option_setter(parser, self.isTrain)
85 |
86 | # save and return the parser
87 | self.parser = parser
88 | if self.cmd_line is None:
89 | return parser.parse_args()
90 | else:
91 | return parser.parse_args(self.cmd_line)
92 |
93 | def print_options(self, opt):
94 | """Print and save options
95 |
96 | It will print both current options and default values(if different).
97 | It will save options into a text file / [checkpoints_dir] / opt.txt
98 | """
99 | message = ''
100 | message += '----------------- Options ---------------\n'
101 | for k, v in sorted(vars(opt).items()):
102 | comment = ''
103 | default = self.parser.get_default(k)
104 | if v != default:
105 | comment = '\t[default: %s]' % str(default)
106 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
107 | message += '----------------- End -------------------'
108 | print(message)
109 |
110 | # save to the disk
111 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
112 | util.mkdirs(expr_dir)
113 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
114 | try:
115 | with open(file_name, 'wt') as opt_file:
116 | opt_file.write(message)
117 | opt_file.write('\n')
118 | except PermissionError as error:
119 | print("permission error {}".format(error))
120 | pass
121 |
122 | def parse(self):
123 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
124 | opt = self.gather_options()
125 | opt.isTrain = self.isTrain # train or test
126 |
127 | # process opt.suffix
128 | if opt.suffix:
129 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
130 | opt.name = opt.name + suffix
131 |
132 |
133 | # set gpu ids
134 | str_ids = opt.gpu_ids.split(',')
135 | gpu_ids = []
136 | for str_id in str_ids:
137 | id = int(str_id)
138 | if id >= 0:
139 | gpu_ids.append(id)
140 | opt.world_size = len(gpu_ids)
141 | # if len(opt.gpu_ids) > 0:
142 | # torch.cuda.set_device(gpu_ids[0])
143 | if opt.world_size == 1:
144 | opt.use_ddp = False
145 |
146 | if opt.phase != 'test':
147 | # set continue_train automatically
148 | if opt.pretrained_name is None:
149 | model_dir = os.path.join(opt.checkpoints_dir, opt.name)
150 | else:
151 | model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
152 | if os.path.isdir(model_dir):
153 | model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
154 | if os.path.isdir(model_dir) and len(model_pths) != 0:
155 | opt.continue_train= True
156 |
157 | # update the latest epoch count
158 | if opt.continue_train:
159 | if opt.epoch == 'latest':
160 | epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
161 | if len(epoch_counts) != 0:
162 | opt.epoch_count = max(epoch_counts) + 1
163 | else:
164 | opt.epoch_count = int(opt.epoch) + 1
165 |
166 |
167 | self.print_options(opt)
168 | self.opt = opt
169 | return self.opt
170 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | """This script contains the test options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | from .base_options import BaseOptions
5 |
6 |
7 | class TestOptions(BaseOptions):
8 | """This class includes test options.
9 |
10 | It also includes shared options defined in BaseOptions.
11 | """
12 |
13 | def initialize(self, parser):
14 | parser = BaseOptions.initialize(self, parser) # define shared options
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
18 |
19 | # Dropout and Batchnorm has different behavior during training and test.
20 | self.isTrain = False
21 | return parser
22 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | """This script contains the training options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | from .base_options import BaseOptions
5 | from util import util
6 |
7 | class TrainOptions(BaseOptions):
8 | """This class includes training options.
9 |
10 | It also includes shared options defined in BaseOptions.
11 | """
12 |
13 | def initialize(self, parser):
14 | parser = BaseOptions.initialize(self, parser)
15 | # dataset parameters
16 | # for train
17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root')
18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
19 | parser.add_argument('--batch_size', type=int, default=32)
20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
26 |
27 | # for val
28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
29 | parser.add_argument('--batch_size_val', type=int, default=32)
30 |
31 |
32 | # visualization parameters
33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
35 |
36 | # network saving and loading parameters
37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
45 |
46 | # training parameters
47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
51 |
52 | self.isTrain = True
53 | return parser
54 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """This script is the test script for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 | from options.test_options import TestOptions
6 | from data import create_dataset
7 | from models import create_model
8 | from util.visualizer import MyVisualizer
9 | from util.preprocess import align_img
10 | from PIL import Image
11 | import numpy as np
12 | from util.load_mats import load_lm3d
13 | import torch
14 | from data.flist_dataset import default_flist_reader
15 | from scipy.io import loadmat, savemat
16 |
17 | def get_data_path(root='examples'):
18 |
19 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')]
20 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path]
21 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path]
22 |
23 | return im_path, lm_path
24 |
25 | def read_data(im_path, lm_path, lm3d_std, to_tensor=True):
26 | # to RGB
27 | im = Image.open(im_path).convert('RGB')
28 | W,H = im.size
29 | lm = np.loadtxt(lm_path).astype(np.float32)
30 | lm = lm.reshape([-1, 2])
31 | lm[:, -1] = H - 1 - lm[:, -1]
32 | _, im, lm, _ = align_img(im, lm, lm3d_std)
33 | if to_tensor:
34 | im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
35 | lm = torch.tensor(lm).unsqueeze(0)
36 | return im, lm
37 |
38 | def main(rank, opt, name='examples'):
39 | device = torch.device(rank)
40 | torch.cuda.set_device(device)
41 | model = create_model(opt)
42 | model.setup(opt)
43 | model.device = device
44 | model.parallelize()
45 | model.eval()
46 | visualizer = MyVisualizer(opt)
47 |
48 | im_path, lm_path = get_data_path(name)
49 | lm3d_std = load_lm3d(opt.bfm_folder)
50 |
51 | for i in range(len(im_path)):
52 | print(i, im_path[i])
53 | img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','')
54 | if not os.path.isfile(lm_path[i]):
55 | print("%s is not found !!!"%lm_path[i])
56 | continue
57 | im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std)
58 | data = {
59 | 'imgs': im_tensor,
60 | 'lms': lm_tensor
61 | }
62 | model.set_input(data) # unpack data from data loader
63 | model.test() # run inference
64 | visuals = model.get_current_visuals() # get image results
65 | visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1],
66 | save_results=True, count=i, name=img_name, add_image=False)
67 |
68 | model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.obj')) # save reconstruction meshes
69 | model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.mat')) # save predicted coefficients
70 |
71 | if __name__ == '__main__':
72 | opt = TestOptions().parse() # get test options
73 | main(0, opt,opt.img_folder)
74 |
75 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """This script is the training script for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 | import time
6 | import numpy as np
7 | import torch
8 | from options.train_options import TrainOptions
9 | from data import create_dataset
10 | from models import create_model
11 | from util.visualizer import MyVisualizer
12 | from util.util import genvalconf
13 | import torch.multiprocessing as mp
14 | import torch.distributed as dist
15 |
16 |
17 | def setup(rank, world_size, port):
18 | os.environ['MASTER_ADDR'] = 'localhost'
19 | os.environ['MASTER_PORT'] = port
20 |
21 | # initialize the process group
22 | dist.init_process_group("gloo", rank=rank, world_size=world_size)
23 |
24 | def cleanup():
25 | dist.destroy_process_group()
26 |
27 | def main(rank, world_size, train_opt):
28 | val_opt = genvalconf(train_opt, isTrain=False)
29 |
30 | device = torch.device(rank)
31 | torch.cuda.set_device(device)
32 | use_ddp = train_opt.use_ddp
33 |
34 | if use_ddp:
35 | setup(rank, world_size, train_opt.ddp_port)
36 |
37 | train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank)
38 | train_dataset_batches, val_dataset_batches = \
39 | len(train_dataset) // train_opt.batch_size, len(val_dataset) // val_opt.batch_size
40 |
41 | model = create_model(train_opt) # create a model given train_opt.model and other options
42 | model.setup(train_opt)
43 | model.device = device
44 | model.parallelize()
45 |
46 | if rank == 0:
47 | print('The batch number of training images = %d\n, \
48 | the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches))
49 | model.print_networks(train_opt.verbose)
50 | visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots
51 |
52 | total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations
53 | t_data = 0
54 | t_val = 0
55 | optimize_time = 0.1
56 | batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size
57 |
58 | if use_ddp:
59 | dist.barrier()
60 |
61 | times = []
62 | for epoch in range(train_opt.epoch_count, train_opt.n_epochs + 1): # outer loop for different epochs; we save the model by , +
63 | epoch_start_time = time.time() # timer for entire epoch
64 | iter_data_time = time.time() # timer for train_data loading per iteration
65 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
66 |
67 | train_dataset.set_epoch(epoch)
68 | for i, train_data in enumerate(train_dataset): # inner loop within one epoch
69 | iter_start_time = time.time() # timer for computation per iteration
70 | if total_iters % train_opt.print_freq == 0:
71 | t_data = iter_start_time - iter_data_time
72 | total_iters += batch_size
73 | epoch_iter += batch_size
74 |
75 | torch.cuda.synchronize()
76 | optimize_start_time = time.time()
77 |
78 | model.set_input(train_data) # unpack train_data from dataset and apply preprocessing
79 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights
80 |
81 | torch.cuda.synchronize()
82 | optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
83 |
84 | if use_ddp:
85 | dist.barrier()
86 |
87 | if rank == 0 and (total_iters == batch_size or total_iters % train_opt.display_freq == 0): # display images on visdom and save images to a HTML file
88 | model.compute_visuals()
89 | visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
90 | save_results=True,
91 | add_image=train_opt.add_image)
92 | # (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0)
93 |
94 | if rank == 0 and (total_iters == batch_size or total_iters % train_opt.print_freq == 0): # print training losses and save logging information to the disk
95 | losses = model.get_current_losses()
96 | visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
97 | visualizer.plot_current_losses(total_iters, losses)
98 |
99 | if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0:
100 | with torch.no_grad():
101 | torch.cuda.synchronize()
102 | val_start_time = time.time()
103 | losses_avg = {}
104 | model.eval()
105 | for j, val_data in enumerate(val_dataset):
106 | model.set_input(val_data)
107 | model.optimize_parameters(isTrain=False)
108 | if rank == 0 and j < train_opt.vis_batch_nums:
109 | model.compute_visuals()
110 | visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
111 | dataset='val', save_results=True, count=j * val_opt.batch_size,
112 | add_image=train_opt.add_image)
113 |
114 | if j < train_opt.eval_batch_nums:
115 | losses = model.get_current_losses()
116 | for key, value in losses.items():
117 | losses_avg[key] = losses_avg.get(key, 0) + value
118 |
119 | for key, value in losses_avg.items():
120 | losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches)
121 |
122 | torch.cuda.synchronize()
123 | eval_time = time.time() - val_start_time
124 |
125 | if rank == 0:
126 | visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results
127 | visualizer.plot_current_losses(total_iters, losses_avg, dataset='val')
128 | model.train()
129 |
130 | if use_ddp:
131 | dist.barrier()
132 |
133 | if rank == 0 and (total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0): # cache our latest model every iterations
134 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
135 | print(train_opt.name) # it's useful to occasionally show the experiment name on console
136 | save_suffix = 'iter_%d' % total_iters if train_opt.save_by_iter else 'latest'
137 | model.save_networks(save_suffix)
138 |
139 | if use_ddp:
140 | dist.barrier()
141 |
142 | iter_data_time = time.time()
143 |
144 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.n_epochs, time.time() - epoch_start_time))
145 | model.update_learning_rate() # update learning rates at the end of every epoch.
146 |
147 | if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every epochs
148 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
149 | model.save_networks('latest')
150 | model.save_networks(epoch)
151 |
152 | if use_ddp:
153 | dist.barrier()
154 |
155 | if __name__ == '__main__':
156 |
157 | import warnings
158 | warnings.filterwarnings("ignore")
159 |
160 | train_opt = TrainOptions().parse() # get training options
161 | world_size = train_opt.world_size
162 |
163 | if train_opt.use_ddp:
164 | mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True)
165 | else:
166 | main(0, world_size, train_opt)
167 |
--------------------------------------------------------------------------------
/util/BBRegressorParam_r.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sicxu/Deep3DFaceRecon_pytorch/9167c7136aebef0c54d3ac74f4b0396222a491c1/util/BBRegressorParam_r.mat
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 | from util import *
3 |
--------------------------------------------------------------------------------
/util/detect_lm68.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from scipy.io import loadmat
5 | import tensorflow as tf
6 | from util.preprocess import align_for_lm
7 | from shutil import move
8 |
9 | mean_face = np.loadtxt('util/test_mean_face.txt')
10 | mean_face = mean_face.reshape([68, 2])
11 |
12 | def save_label(labels, save_path):
13 | np.savetxt(save_path, labels)
14 |
15 | def draw_landmarks(img, landmark, save_name):
16 | landmark = landmark
17 | lm_img = np.zeros([img.shape[0], img.shape[1], 3])
18 | lm_img[:] = img.astype(np.float32)
19 | landmark = np.round(landmark).astype(np.int32)
20 |
21 | for i in range(len(landmark)):
22 | for j in range(-1, 1):
23 | for k in range(-1, 1):
24 | if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
25 | img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
26 | landmark[i, 0]+k > 0 and \
27 | landmark[i, 0]+k < img.shape[1]:
28 | lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
29 | :] = np.array([0, 0, 255])
30 | lm_img = lm_img.astype(np.uint8)
31 |
32 | cv2.imwrite(save_name, lm_img)
33 |
34 |
35 | def load_data(img_name, txt_name):
36 | return cv2.imread(img_name), np.loadtxt(txt_name)
37 |
38 | # create tensorflow graph for landmark detector
39 | def load_lm_graph(graph_filename):
40 | with tf.gfile.GFile(graph_filename, 'rb') as f:
41 | graph_def = tf.GraphDef()
42 | graph_def.ParseFromString(f.read())
43 |
44 | with tf.Graph().as_default() as graph:
45 | tf.import_graph_def(graph_def, name='net')
46 | img_224 = graph.get_tensor_by_name('net/input_imgs:0')
47 | output_lm = graph.get_tensor_by_name('net/lm:0')
48 | lm_sess = tf.Session(graph=graph)
49 |
50 | return lm_sess,img_224,output_lm
51 |
52 | # landmark detection
53 | def detect_68p(img_path,sess,input_op,output_op):
54 | print('detecting landmarks......')
55 | names = [i for i in sorted(os.listdir(
56 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
57 | vis_path = os.path.join(img_path, 'vis')
58 | remove_path = os.path.join(img_path, 'remove')
59 | save_path = os.path.join(img_path, 'landmarks')
60 | if not os.path.isdir(vis_path):
61 | os.makedirs(vis_path)
62 | if not os.path.isdir(remove_path):
63 | os.makedirs(remove_path)
64 | if not os.path.isdir(save_path):
65 | os.makedirs(save_path)
66 |
67 | for i in range(0, len(names)):
68 | name = names[i]
69 | print('%05d' % (i), ' ', name)
70 | full_image_name = os.path.join(img_path, name)
71 | txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
72 | full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
73 |
74 | # if an image does not have detected 5 facial landmarks, remove it from the training list
75 | if not os.path.isfile(full_txt_name):
76 | move(full_image_name, os.path.join(remove_path, name))
77 | continue
78 |
79 | # load data
80 | img, five_points = load_data(full_image_name, full_txt_name)
81 | input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
82 |
83 | # if the alignment fails, remove corresponding image from the training list
84 | if scale == 0:
85 | move(full_txt_name, os.path.join(
86 | remove_path, txt_name))
87 | move(full_image_name, os.path.join(remove_path, name))
88 | continue
89 |
90 | # detect landmarks
91 | input_img = np.reshape(
92 | input_img, [1, 224, 224, 3]).astype(np.float32)
93 | landmark = sess.run(
94 | output_op, feed_dict={input_op: input_img})
95 |
96 | # transform back to original image coordinate
97 | landmark = landmark.reshape([68, 2]) + mean_face
98 | landmark[:, 1] = 223 - landmark[:, 1]
99 | landmark = landmark / scale
100 | landmark[:, 0] = landmark[:, 0] + bbox[0]
101 | landmark[:, 1] = landmark[:, 1] + bbox[1]
102 | landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
103 |
104 | if i % 100 == 0:
105 | draw_landmarks(img, landmark, os.path.join(vis_path, name))
106 | save_label(landmark, os.path.join(save_path, txt_name))
107 |
--------------------------------------------------------------------------------
/util/generate_list.py:
--------------------------------------------------------------------------------
1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 |
6 | # save path to training data
7 | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
8 | save_path = os.path.join(save_folder, mode)
9 | if not os.path.isdir(save_path):
10 | os.makedirs(save_path)
11 | with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
12 | fd.writelines([i + '\n' for i in lms_list])
13 |
14 | with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
15 | fd.writelines([i + '\n' for i in imgs_list])
16 |
17 | with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
18 | fd.writelines([i + '\n' for i in msks_list])
19 |
20 | # check if the path is valid
21 | def check_list(rlms_list, rimgs_list, rmsks_list):
22 | lms_list, imgs_list, msks_list = [], [], []
23 | for i in range(len(rlms_list)):
24 | flag = 'false'
25 | lm_path = rlms_list[i]
26 | im_path = rimgs_list[i]
27 | msk_path = rmsks_list[i]
28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
29 | flag = 'true'
30 | lms_list.append(rlms_list[i])
31 | imgs_list.append(rimgs_list[i])
32 | msks_list.append(rmsks_list[i])
33 | print(i, rlms_list[i], flag)
34 | return lms_list, imgs_list, msks_list
35 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/util/load_mats.py:
--------------------------------------------------------------------------------
1 | """This script is to load 3D face model for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | from PIL import Image
6 | from scipy.io import loadmat, savemat
7 | from array import array
8 | import os.path as osp
9 |
10 | # load expression basis
11 | def LoadExpBasis(bfm_folder='BFM'):
12 | n_vertex = 53215
13 | Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
14 | exp_dim = array('i')
15 | exp_dim.fromfile(Expbin, 1)
16 | expMU = array('f')
17 | expPC = array('f')
18 | expMU.fromfile(Expbin, 3*n_vertex)
19 | expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
20 | Expbin.close()
21 |
22 | expPC = np.array(expPC)
23 | expPC = np.reshape(expPC, [exp_dim[0], -1])
24 | expPC = np.transpose(expPC)
25 |
26 | expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
27 |
28 | return expPC, expEV
29 |
30 |
31 | # transfer original BFM09 to our face model
32 | def transferBFM09(bfm_folder='BFM'):
33 | print('Transfer BFM09 to BFM_model_front......')
34 | original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
35 | shapePC = original_BFM['shapePC'] # shape basis
36 | shapeEV = original_BFM['shapeEV'] # corresponding eigen value
37 | shapeMU = original_BFM['shapeMU'] # mean face
38 | texPC = original_BFM['texPC'] # texture basis
39 | texEV = original_BFM['texEV'] # eigen value
40 | texMU = original_BFM['texMU'] # mean texture
41 |
42 | expPC, expEV = LoadExpBasis()
43 |
44 | # transfer BFM09 to our face model
45 |
46 | idBase = shapePC*np.reshape(shapeEV, [-1, 199])
47 | idBase = idBase/1e5 # unify the scale to decimeter
48 | idBase = idBase[:, :80] # use only first 80 basis
49 |
50 | exBase = expPC*np.reshape(expEV, [-1, 79])
51 | exBase = exBase/1e5 # unify the scale to decimeter
52 | exBase = exBase[:, :64] # use only first 64 basis
53 |
54 | texBase = texPC*np.reshape(texEV, [-1, 199])
55 | texBase = texBase[:, :80] # use only first 80 basis
56 |
57 | # our face model is cropped along face landmarks and contains only 35709 vertex.
58 | # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
59 | # thus we select corresponding vertex to get our face model.
60 |
61 | index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
62 | index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
63 |
64 | index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
65 | index_shape = index_shape['trimIndex'].astype(
66 | np.int32) - 1 # starts from 0 (to 53490)
67 | index_shape = index_shape[index_exp]
68 |
69 | idBase = np.reshape(idBase, [-1, 3, 80])
70 | idBase = idBase[index_shape, :, :]
71 | idBase = np.reshape(idBase, [-1, 80])
72 |
73 | texBase = np.reshape(texBase, [-1, 3, 80])
74 | texBase = texBase[index_shape, :, :]
75 | texBase = np.reshape(texBase, [-1, 80])
76 |
77 | exBase = np.reshape(exBase, [-1, 3, 64])
78 | exBase = exBase[index_exp, :, :]
79 | exBase = np.reshape(exBase, [-1, 64])
80 |
81 | meanshape = np.reshape(shapeMU, [-1, 3])/1e5
82 | meanshape = meanshape[index_shape, :]
83 | meanshape = np.reshape(meanshape, [1, -1])
84 |
85 | meantex = np.reshape(texMU, [-1, 3])
86 | meantex = meantex[index_shape, :]
87 | meantex = np.reshape(meantex, [1, -1])
88 |
89 | # other info contains triangles, region used for computing photometric loss,
90 | # region used for skin texture regularization, and 68 landmarks index etc.
91 | other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
92 | frontmask2_idx = other_info['frontmask2_idx']
93 | skinmask = other_info['skinmask']
94 | keypoints = other_info['keypoints']
95 | point_buf = other_info['point_buf']
96 | tri = other_info['tri']
97 | tri_mask2 = other_info['tri_mask2']
98 |
99 | # save our face model
100 | savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
101 | 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
102 |
103 |
104 | # load landmarks for standard face, which is used for image preprocessing
105 | def load_lm3d(bfm_folder):
106 |
107 | Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
108 | Lm3D = Lm3D['lm']
109 |
110 | # calculate 5 facial landmarks using 68 landmarks
111 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
112 | Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
113 | Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
114 | Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
115 |
116 | return Lm3D
117 |
118 |
--------------------------------------------------------------------------------
/util/nvdiffrast.py:
--------------------------------------------------------------------------------
1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch
2 | Attention, antialiasing step is missing in current version.
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import kornia
8 | from kornia.geometry.camera import pixel2cam
9 | import numpy as np
10 | from typing import List
11 | import nvdiffrast.torch as dr
12 | from scipy.io import loadmat
13 | from torch import nn
14 |
15 | def ndc_projection(x=0.1, n=1.0, f=50.0):
16 | return np.array([[n/x, 0, 0, 0],
17 | [ 0, n/-x, 0, 0],
18 | [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
19 | [ 0, 0, -1, 0]]).astype(np.float32)
20 |
21 | class MeshRenderer(nn.Module):
22 | def __init__(self,
23 | rasterize_fov,
24 | znear=0.1,
25 | zfar=10,
26 | rasterize_size=224,
27 | use_opengl=True):
28 | super(MeshRenderer, self).__init__()
29 |
30 | x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
31 | self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
32 | torch.diag(torch.tensor([1., -1, -1, 1])))
33 | self.rasterize_size = rasterize_size
34 | self.use_opengl = use_opengl
35 | self.ctx = None
36 |
37 | def forward(self, vertex, tri, feat=None):
38 | """
39 | Return:
40 | mask -- torch.tensor, size (B, 1, H, W)
41 | depth -- torch.tensor, size (B, 1, H, W)
42 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
43 |
44 | Parameters:
45 | vertex -- torch.tensor, size (B, N, 3)
46 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
47 | feat(optional) -- torch.tensor, size (B, C), features
48 | """
49 | device = vertex.device
50 | rsize = int(self.rasterize_size)
51 | ndc_proj = self.ndc_proj.to(device)
52 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
53 | if vertex.shape[-1] == 3:
54 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
55 | vertex[..., 1] = -vertex[..., 1]
56 |
57 |
58 | vertex_ndc = vertex @ ndc_proj.t()
59 | if self.ctx is None:
60 | if self.use_opengl:
61 | self.ctx = dr.RasterizeGLContext(device=device)
62 | ctx_str = "opengl"
63 | else:
64 | self.ctx = dr.RasterizeCudaContext(device=device)
65 | ctx_str = "cuda"
66 | print("create %s ctx on device cuda:%d"%(ctx_str, device.index))
67 |
68 | ranges = None
69 | if isinstance(tri, List) or len(tri.shape) == 3:
70 | vum = vertex_ndc.shape[1]
71 | fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
72 | fstartidx = torch.cumsum(fnum, dim=0) - fnum
73 | ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
74 | for i in range(tri.shape[0]):
75 | tri[i] = tri[i] + i*vum
76 | vertex_ndc = torch.cat(vertex_ndc, dim=0)
77 | tri = torch.cat(tri, dim=0)
78 |
79 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
80 | tri = tri.type(torch.int32).contiguous()
81 | rast_out, _ = dr.rasterize(self.ctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)
82 |
83 | depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri)
84 | depth = depth.permute(0, 3, 1, 2)
85 | mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
86 | depth = mask * depth
87 |
88 |
89 | image = None
90 | if feat is not None:
91 | image, _ = dr.interpolate(feat, rast_out, tri)
92 | image = image.permute(0, 3, 1, 2)
93 | image = mask * image
94 |
95 | return mask, depth, image
96 |
97 |
--------------------------------------------------------------------------------
/util/preprocess.py:
--------------------------------------------------------------------------------
1 | """This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | from scipy.io import loadmat
6 |
7 | try:
8 | from PIL.Image import Resampling
9 | RESAMPLING_METHOD = Resampling.BICUBIC
10 | except ImportError:
11 | from PIL.Image import BICUBIC
12 | RESAMPLING_METHOD = BICUBIC
13 |
14 | import cv2
15 | import os
16 | from skimage import transform as trans
17 | import torch
18 | import warnings
19 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
20 | warnings.filterwarnings("ignore", category=FutureWarning)
21 |
22 |
23 | # calculating least square problem for image alignment
24 | def POS(xp, x):
25 | npts = xp.shape[1]
26 |
27 | A = np.zeros([2*npts, 8])
28 |
29 | A[0:2*npts-1:2, 0:3] = x.transpose()
30 | A[0:2*npts-1:2, 3] = 1
31 |
32 | A[1:2*npts:2, 4:7] = x.transpose()
33 | A[1:2*npts:2, 7] = 1
34 |
35 | b = np.reshape(xp.transpose(), [2*npts, 1])
36 |
37 | k, _, _, _ = np.linalg.lstsq(A, b)
38 |
39 | R1 = k[0:3]
40 | R2 = k[4:7]
41 | sTx = k[3]
42 | sTy = k[7]
43 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
44 | t = np.stack([sTx, sTy], axis=0)
45 |
46 | return t, s
47 |
48 | # bounding box for 68 landmark detection
49 | def BBRegression(points, params):
50 |
51 | w1 = params['W1']
52 | b1 = params['B1']
53 | w2 = params['W2']
54 | b2 = params['B2']
55 | data = points.copy()
56 | data = data.reshape([5, 2])
57 | data_mean = np.mean(data, axis=0)
58 | x_mean = data_mean[0]
59 | y_mean = data_mean[1]
60 | data[:, 0] = data[:, 0] - x_mean
61 | data[:, 1] = data[:, 1] - y_mean
62 |
63 | rms = np.sqrt(np.sum(data ** 2)/5)
64 | data = data / rms
65 | data = data.reshape([1, 10])
66 | data = np.transpose(data)
67 | inputs = np.matmul(w1, data) + b1
68 | inputs = 2 / (1 + np.exp(-2 * inputs)) - 1
69 | inputs = np.matmul(w2, inputs) + b2
70 | inputs = np.transpose(inputs)
71 | x = inputs[:, 0] * rms + x_mean
72 | y = inputs[:, 1] * rms + y_mean
73 | w = 224/inputs[:, 2] * rms
74 | rects = [x, y, w, w]
75 | return np.array(rects).reshape([4])
76 |
77 | # utils for landmark detection
78 | def img_padding(img, box):
79 | success = True
80 | bbox = box.copy()
81 | res = np.zeros([2*img.shape[0], 2*img.shape[1], 3])
82 | res[img.shape[0] // 2: img.shape[0] + img.shape[0] //
83 | 2, img.shape[1] // 2: img.shape[1] + img.shape[1]//2] = img
84 |
85 | bbox[0] = bbox[0] + img.shape[1] // 2
86 | bbox[1] = bbox[1] + img.shape[0] // 2
87 | if bbox[0] < 0 or bbox[1] < 0:
88 | success = False
89 | return res, bbox, success
90 |
91 | # utils for landmark detection
92 | def crop(img, bbox):
93 | padded_img, padded_bbox, flag = img_padding(img, bbox)
94 | if flag:
95 | crop_img = padded_img[padded_bbox[1]: padded_bbox[1] +
96 | padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]]
97 | crop_img = cv2.resize(crop_img.astype(np.uint8),
98 | (224, 224), interpolation=cv2.INTER_CUBIC)
99 | scale = 224 / padded_bbox[3]
100 | return crop_img, scale
101 | else:
102 | return padded_img, 0
103 |
104 | # utils for landmark detection
105 | def scale_trans(img, lm, t, s):
106 | imgw = img.shape[1]
107 | imgh = img.shape[0]
108 | M_s = np.array([[1, 0, -t[0] + imgw//2 + 0.5], [0, 1, -imgh//2 + t[1]]],
109 | dtype=np.float32)
110 | img = cv2.warpAffine(img, M_s, (imgw, imgh))
111 | w = int(imgw / s * 100)
112 | h = int(imgh / s * 100)
113 | img = cv2.resize(img, (w, h))
114 | lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] -
115 | t[1] + imgh // 2], axis=1) / s * 100
116 |
117 | left = w//2 - 112
118 | up = h//2 - 112
119 | bbox = [left, up, 224, 224]
120 | cropped_img, scale2 = crop(img, bbox)
121 | assert(scale2!=0)
122 | t1 = np.array([bbox[0], bbox[1]])
123 |
124 | # back to raw img s * crop + s * t1 + t2
125 | t1 = np.array([w//2 - 112, h//2 - 112])
126 | scale = s / 100
127 | t2 = np.array([t[0] - imgw/2, t[1] - imgh / 2])
128 | inv = (scale/scale2, scale * t1 + t2.reshape([2]))
129 | return cropped_img, inv
130 |
131 | # utils for landmark detection
132 | def align_for_lm(img, five_points):
133 | five_points = np.array(five_points).reshape([1, 10])
134 | params = loadmat('util/BBRegressorParam_r.mat')
135 | bbox = BBRegression(five_points, params)
136 | assert(bbox[2] != 0)
137 | bbox = np.round(bbox).astype(np.int32)
138 | crop_img, scale = crop(img, bbox)
139 | return crop_img, scale, bbox
140 |
141 |
142 | # resize and crop images for face reconstruction
143 | def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
144 | w0, h0 = img.size
145 | w = (w0*s).astype(np.int32)
146 | h = (h0*s).astype(np.int32)
147 | left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
148 | right = left + target_size
149 | up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
150 | below = up + target_size
151 |
152 | img = img.resize((w, h), resample=RESAMPLING_METHOD)
153 | img = img.crop((left, up, right, below))
154 |
155 | if mask is not None:
156 | mask = mask.resize((w, h), resample=RESAMPLING_METHOD)
157 | mask = mask.crop((left, up, right, below))
158 |
159 | lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
160 | t[1] + h0/2], axis=1)*s
161 | lm = lm - np.reshape(
162 | np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
163 |
164 | return img, lm, mask
165 |
166 | # utils for face reconstruction
167 | def extract_5p(lm):
168 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
169 | lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
170 | lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
171 | lm5p = lm5p[[1, 2, 0, 3, 4], :]
172 | return lm5p
173 |
174 | # utils for face reconstruction
175 | def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
176 | """
177 | Return:
178 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
179 | img_new --PIL.Image (target_size, target_size, 3)
180 | lm_new --numpy.array (68, 2), y direction is opposite to v direction
181 | mask_new --PIL.Image (target_size, target_size)
182 |
183 | Parameters:
184 | img --PIL.Image (raw_H, raw_W, 3)
185 | lm --numpy.array (68, 2), y direction is opposite to v direction
186 | lm3D --numpy.array (5, 3)
187 | mask --PIL.Image (raw_H, raw_W, 3)
188 | """
189 |
190 | w0, h0 = img.size
191 | if lm.shape[0] != 5:
192 | lm5p = extract_5p(lm)
193 | else:
194 | lm5p = lm
195 |
196 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
197 | t, s = POS(lm5p.transpose(), lm3D.transpose())
198 | s = rescale_factor/s
199 |
200 | # processing the image
201 | img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
202 | trans_params = np.array([w0, h0, s, t[0], t[1]])
203 |
204 | return trans_params, img_new, lm_new, mask_new
205 |
206 | # utils for face recognition model
207 | def estimate_norm(lm_68p, H):
208 | # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68
209 | """
210 | Return:
211 | trans_m --numpy.array (2, 3)
212 | Parameters:
213 | lm --numpy.array (68, 2), y direction is opposite to v direction
214 | H --int/float , image height
215 | """
216 | lm = extract_5p(lm_68p)
217 | lm[:, -1] = H - 1 - lm[:, -1]
218 | tform = trans.SimilarityTransform()
219 | src = np.array(
220 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
221 | [41.5493, 92.3655], [70.7299, 92.2041]],
222 | dtype=np.float32)
223 | tform.estimate(lm, src)
224 | M = tform.params
225 | if np.linalg.det(M) == 0:
226 | M = np.eye(3)
227 |
228 | return M[0:2, :]
229 |
230 | def estimate_norm_torch(lm_68p, H):
231 | lm_68p_ = lm_68p.detach().cpu().numpy()
232 | M = []
233 | for i in range(lm_68p_.shape[0]):
234 | M.append(estimate_norm(lm_68p_[i], H))
235 | M = torch.tensor(np.array(M), dtype=torch.float32).to(lm_68p.device)
236 | return M
237 |
--------------------------------------------------------------------------------
/util/skin_mask.py:
--------------------------------------------------------------------------------
1 | """This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import math
5 | import numpy as np
6 | import os
7 | import cv2
8 |
9 | class GMM:
10 | def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
11 | self.dim = dim # feature dimension
12 | self.num = num # number of Gaussian components
13 | self.w = w # weights of Gaussian components (a list of scalars)
14 | self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
15 | self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
16 | self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
17 | self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
18 |
19 | self.factor = [0]*num
20 | for i in range(self.num):
21 | self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
22 |
23 | def likelihood(self, data):
24 | assert(data.shape[1] == self.dim)
25 | N = data.shape[0]
26 | lh = np.zeros(N)
27 |
28 | for i in range(self.num):
29 | data_ = data - self.mu[i]
30 |
31 | tmp = np.matmul(data_,self.cov_inv[i]) * data_
32 | tmp = np.sum(tmp,axis=1)
33 | power = -0.5 * tmp
34 |
35 | p = np.array([math.exp(power[j]) for j in range(N)])
36 | p = p/self.factor[i]
37 | lh += p*self.w[i]
38 |
39 | return lh
40 |
41 |
42 | def _rgb2ycbcr(rgb):
43 | m = np.array([[65.481, 128.553, 24.966],
44 | [-37.797, -74.203, 112],
45 | [112, -93.786, -18.214]])
46 | shape = rgb.shape
47 | rgb = rgb.reshape((shape[0] * shape[1], 3))
48 | ycbcr = np.dot(rgb, m.transpose() / 255.)
49 | ycbcr[:, 0] += 16.
50 | ycbcr[:, 1:] += 128.
51 | return ycbcr.reshape(shape)
52 |
53 |
54 | def _bgr2ycbcr(bgr):
55 | rgb = bgr[..., ::-1]
56 | return _rgb2ycbcr(rgb)
57 |
58 |
59 | gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
60 | gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
61 | np.array([150.19858, 105.18467, 155.51428]),
62 | np.array([183.92976, 107.62468, 152.71820]),
63 | np.array([114.90524, 113.59782, 151.38217])]
64 | gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
65 | gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
66 | np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
67 | np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
68 | np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
69 |
70 | gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
71 |
72 | gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
73 | gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
74 | np.array([110.91392, 125.52969, 130.19237]),
75 | np.array([129.75864, 129.96107, 126.96808]),
76 | np.array([112.29587, 128.85121, 129.05431])]
77 | gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
78 | gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
79 | np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
80 | np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
81 | np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
82 |
83 | gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
84 |
85 | prior_skin = 0.8
86 | prior_nonskin = 1 - prior_skin
87 |
88 |
89 | # calculate skin attention mask
90 | def skinmask(imbgr):
91 | im = _bgr2ycbcr(imbgr)
92 |
93 | data = im.reshape((-1,3))
94 |
95 | lh_skin = gmm_skin.likelihood(data)
96 | lh_nonskin = gmm_nonskin.likelihood(data)
97 |
98 | tmp1 = prior_skin * lh_skin
99 | tmp2 = prior_nonskin * lh_nonskin
100 | post_skin = tmp1 / (tmp1+tmp2) # posterior probability
101 |
102 | post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
103 |
104 | post_skin = np.round(post_skin*255)
105 | post_skin = post_skin.astype(np.uint8)
106 | post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
107 |
108 | return post_skin
109 |
110 |
111 | def get_skin_mask(img_path):
112 | print('generating skin masks......')
113 | names = [i for i in sorted(os.listdir(
114 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
115 | save_path = os.path.join(img_path, 'mask')
116 | if not os.path.isdir(save_path):
117 | os.makedirs(save_path)
118 |
119 | for i in range(0, len(names)):
120 | name = names[i]
121 | print('%05d' % (i), ' ', name)
122 | full_image_name = os.path.join(img_path, name)
123 | img = cv2.imread(full_image_name).astype(np.float32)
124 | skin_img = skinmask(img)
125 | cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
126 |
--------------------------------------------------------------------------------
/util/test_mean_face.txt:
--------------------------------------------------------------------------------
1 | -5.228591537475585938e+01
2 | 2.078247070312500000e-01
3 | -5.064269638061523438e+01
4 | -1.315765380859375000e+01
5 | -4.952939224243164062e+01
6 | -2.592591094970703125e+01
7 | -4.793047332763671875e+01
8 | -3.832135772705078125e+01
9 | -4.512159729003906250e+01
10 | -5.059623336791992188e+01
11 | -3.917720794677734375e+01
12 | -6.043736648559570312e+01
13 | -2.929953765869140625e+01
14 | -6.861183166503906250e+01
15 | -1.719801330566406250e+01
16 | -7.572736358642578125e+01
17 | -1.961936950683593750e+00
18 | -7.862001037597656250e+01
19 | 1.467941284179687500e+01
20 | -7.607844543457031250e+01
21 | 2.744073486328125000e+01
22 | -6.915261840820312500e+01
23 | 3.855677795410156250e+01
24 | -5.950350570678710938e+01
25 | 4.478240966796875000e+01
26 | -4.867547225952148438e+01
27 | 4.714337158203125000e+01
28 | -3.800830078125000000e+01
29 | 4.940315246582031250e+01
30 | -2.496297454833984375e+01
31 | 5.117234802246093750e+01
32 | -1.241538238525390625e+01
33 | 5.190507507324218750e+01
34 | 8.244247436523437500e-01
35 | -4.150688934326171875e+01
36 | 2.386329650878906250e+01
37 | -3.570307159423828125e+01
38 | 3.017010498046875000e+01
39 | -2.790358734130859375e+01
40 | 3.212951660156250000e+01
41 | -1.941773223876953125e+01
42 | 3.156523132324218750e+01
43 | -1.138106536865234375e+01
44 | 2.841992187500000000e+01
45 | 5.993263244628906250e+00
46 | 2.895182800292968750e+01
47 | 1.343590545654296875e+01
48 | 3.189880371093750000e+01
49 | 2.203153991699218750e+01
50 | 3.302221679687500000e+01
51 | 2.992478942871093750e+01
52 | 3.099150085449218750e+01
53 | 3.628388977050781250e+01
54 | 2.765748596191406250e+01
55 | -1.933914184570312500e+00
56 | 1.405374145507812500e+01
57 | -2.153038024902343750e+00
58 | 5.772636413574218750e+00
59 | -2.270050048828125000e+00
60 | -2.121643066406250000e+00
61 | -2.218330383300781250e+00
62 | -1.068978118896484375e+01
63 | -1.187252044677734375e+01
64 | -1.997912597656250000e+01
65 | -6.879402160644531250e+00
66 | -2.143579864501953125e+01
67 | -1.227821350097656250e+00
68 | -2.193494415283203125e+01
69 | 4.623237609863281250e+00
70 | -2.152721405029296875e+01
71 | 9.721397399902343750e+00
72 | -1.953671264648437500e+01
73 | -3.648714447021484375e+01
74 | 9.811126708984375000e+00
75 | -3.130242919921875000e+01
76 | 1.422447967529296875e+01
77 | -2.212834930419921875e+01
78 | 1.493019866943359375e+01
79 | -1.500880432128906250e+01
80 | 1.073588562011718750e+01
81 | -2.095037078857421875e+01
82 | 9.054298400878906250e+00
83 | -3.050099182128906250e+01
84 | 8.704177856445312500e+00
85 | 1.173237609863281250e+01
86 | 1.054329681396484375e+01
87 | 1.856353759765625000e+01
88 | 1.535009765625000000e+01
89 | 2.893331909179687500e+01
90 | 1.451992797851562500e+01
91 | 3.452944946289062500e+01
92 | 1.065280151367187500e+01
93 | 2.875990295410156250e+01
94 | 8.654792785644531250e+00
95 | 1.942100524902343750e+01
96 | 9.422447204589843750e+00
97 | -2.204488372802734375e+01
98 | -3.983994293212890625e+01
99 | -1.324458312988281250e+01
100 | -3.467377471923828125e+01
101 | -6.749649047851562500e+00
102 | -3.092894744873046875e+01
103 | -9.183349609375000000e-01
104 | -3.196458435058593750e+01
105 | 4.220649719238281250e+00
106 | -3.090406036376953125e+01
107 | 1.089889526367187500e+01
108 | -3.497008514404296875e+01
109 | 1.874589538574218750e+01
110 | -4.065438079833984375e+01
111 | 1.124106597900390625e+01
112 | -4.438417816162109375e+01
113 | 5.181709289550781250e+00
114 | -4.649170684814453125e+01
115 | -1.158607482910156250e+00
116 | -4.680406951904296875e+01
117 | -7.918922424316406250e+00
118 | -4.671575164794921875e+01
119 | -1.452505493164062500e+01
120 | -4.416526031494140625e+01
121 | -2.005007171630859375e+01
122 | -3.997841644287109375e+01
123 | -1.054919433593750000e+01
124 | -3.849683380126953125e+01
125 | -1.051826477050781250e+00
126 | -3.794863128662109375e+01
127 | 6.412681579589843750e+00
128 | -3.804645538330078125e+01
129 | 1.627674865722656250e+01
130 | -4.039697265625000000e+01
131 | 6.373878479003906250e+00
132 | -4.087213897705078125e+01
133 | -8.551712036132812500e-01
134 | -4.157129669189453125e+01
135 | -1.014953613281250000e+01
136 | -4.128469085693359375e+01
137 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """This script contains basic utilities for Deep3DFaceRecon_pytorch
2 | """
3 | from __future__ import print_function
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | try:
8 | from PIL.Image import Resampling
9 | RESAMPLING_METHOD = Resampling.BICUBIC
10 | except ImportError:
11 | from PIL.Image import BICUBIC
12 | RESAMPLING_METHOD = BICUBIC
13 | import os
14 | import importlib
15 | import argparse
16 | from argparse import Namespace
17 | import torchvision
18 |
19 |
20 | def str2bool(v):
21 | if isinstance(v, bool):
22 | return v
23 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
24 | return True
25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
26 | return False
27 | else:
28 | raise argparse.ArgumentTypeError('Boolean value expected.')
29 |
30 |
31 | def copyconf(default_opt, **kwargs):
32 | conf = Namespace(**vars(default_opt))
33 | for key in kwargs:
34 | setattr(conf, key, kwargs[key])
35 | return conf
36 |
37 | def genvalconf(train_opt, **kwargs):
38 | conf = Namespace(**vars(train_opt))
39 | attr_dict = train_opt.__dict__
40 | for key, value in attr_dict.items():
41 | if 'val' in key and key.split('_')[0] in attr_dict:
42 | setattr(conf, key.split('_')[0], value)
43 |
44 | for key in kwargs:
45 | setattr(conf, key, kwargs[key])
46 |
47 | return conf
48 |
49 | def find_class_in_module(target_cls_name, module):
50 | target_cls_name = target_cls_name.replace('_', '').lower()
51 | clslib = importlib.import_module(module)
52 | cls = None
53 | for name, clsobj in clslib.__dict__.items():
54 | if name.lower() == target_cls_name:
55 | cls = clsobj
56 |
57 | assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
58 |
59 | return cls
60 |
61 |
62 | def tensor2im(input_image, imtype=np.uint8):
63 | """"Converts a Tensor array into a numpy image array.
64 |
65 | Parameters:
66 | input_image (tensor) -- the input image tensor array, range(0, 1)
67 | imtype (type) -- the desired type of the converted numpy array
68 | """
69 | if not isinstance(input_image, np.ndarray):
70 | if isinstance(input_image, torch.Tensor): # get the data from a variable
71 | image_tensor = input_image.data
72 | else:
73 | return input_image
74 | image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
75 | if image_numpy.shape[0] == 1: # grayscale to RGB
76 | image_numpy = np.tile(image_numpy, (3, 1, 1))
77 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
78 | else: # if it is a numpy array, do nothing
79 | image_numpy = input_image
80 | return image_numpy.astype(imtype)
81 |
82 |
83 | def diagnose_network(net, name='network'):
84 | """Calculate and print the mean of average absolute(gradients)
85 |
86 | Parameters:
87 | net (torch network) -- Torch network
88 | name (str) -- the name of the network
89 | """
90 | mean = 0.0
91 | count = 0
92 | for param in net.parameters():
93 | if param.grad is not None:
94 | mean += torch.mean(torch.abs(param.grad.data))
95 | count += 1
96 | if count > 0:
97 | mean = mean / count
98 | print(name)
99 | print(mean)
100 |
101 |
102 | def save_image(image_numpy, image_path, aspect_ratio=1.0):
103 | """Save a numpy image to the disk
104 |
105 | Parameters:
106 | image_numpy (numpy array) -- input numpy array
107 | image_path (str) -- the path of the image
108 | """
109 |
110 | image_pil = Image.fromarray(image_numpy)
111 | h, w, _ = image_numpy.shape
112 |
113 | if aspect_ratio is None:
114 | pass
115 | elif aspect_ratio > 1.0:
116 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), RESAMPLING_METHOD)
117 | elif aspect_ratio < 1.0:
118 | image_pil = image_pil.resize((int(h / aspect_ratio), w), RESAMPLING_METHOD)
119 | image_pil.save(image_path)
120 |
121 |
122 | def print_numpy(x, val=True, shp=False):
123 | """Print the mean, min, max, median, std, and size of a numpy array
124 |
125 | Parameters:
126 | val (bool) -- if print the values of the numpy array
127 | shp (bool) -- if print the shape of the numpy array
128 | """
129 | x = x.astype(np.float64)
130 | if shp:
131 | print('shape,', x.shape)
132 | if val:
133 | x = x.flatten()
134 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
135 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
136 |
137 |
138 | def mkdirs(paths):
139 | """create empty directories if they don't exist
140 |
141 | Parameters:
142 | paths (str list) -- a list of directory paths
143 | """
144 | if isinstance(paths, list) and not isinstance(paths, str):
145 | for path in paths:
146 | mkdir(path)
147 | else:
148 | mkdir(paths)
149 |
150 |
151 | def mkdir(path):
152 | """create a single empty directory if it didn't exist
153 |
154 | Parameters:
155 | path (str) -- a single directory path
156 | """
157 | if not os.path.exists(path):
158 | os.makedirs(path)
159 |
160 |
161 | def correct_resize_label(t, size):
162 | device = t.device
163 | t = t.detach().cpu()
164 | resized = []
165 | for i in range(t.size(0)):
166 | one_t = t[i, :1]
167 | one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
168 | one_np = one_np[:, :, 0]
169 | one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
170 | resized_t = torch.from_numpy(np.array(one_image)).long()
171 | resized.append(resized_t)
172 | return torch.stack(resized, dim=0).to(device)
173 |
174 |
175 | def correct_resize(t, size, mode=RESAMPLING_METHOD):
176 | device = t.device
177 | t = t.detach().cpu()
178 | resized = []
179 | for i in range(t.size(0)):
180 | one_t = t[i:i + 1]
181 | one_image = Image.fromarray(tensor2im(one_t)).resize(size, RESAMPLING_METHOD)
182 | resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
183 | resized.append(resized_t)
184 | return torch.stack(resized, dim=0).to(device)
185 |
186 | def draw_landmarks(img, landmark, color='r', step=2):
187 | """
188 | Return:
189 | img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
190 |
191 |
192 | Parameters:
193 | img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
194 | landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
195 | color -- str, 'r' or 'b' (red or blue)
196 | """
197 | if color =='r':
198 | c = np.array([255., 0, 0])
199 | else:
200 | c = np.array([0, 0, 255.])
201 |
202 | _, H, W, _ = img.shape
203 | img, landmark = img.copy(), landmark.copy()
204 | landmark[..., 1] = H - 1 - landmark[..., 1]
205 | landmark = np.round(landmark).astype(np.int32)
206 | for i in range(landmark.shape[1]):
207 | x, y = landmark[:, i, 0], landmark[:, i, 1]
208 | for j in range(-step, step):
209 | for k in range(-step, step):
210 | u = np.clip(x + j, 0, W - 1)
211 | v = np.clip(y + k, 0, H - 1)
212 | for m in range(landmark.shape[0]):
213 | img[m, v[m], u[m]] = c
214 | return img
215 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | """This script defines the visualizer for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | import os
6 | import sys
7 | import ntpath
8 | import time
9 | from . import util, html
10 | from subprocess import Popen, PIPE
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
14 | """Save images to the disk.
15 |
16 | Parameters:
17 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
18 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
19 | image_path (str) -- the string is used to create image paths
20 | aspect_ratio (float) -- the aspect ratio of saved images
21 | width (int) -- the images will be resized to width x width
22 |
23 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
24 | """
25 | image_dir = webpage.get_image_dir()
26 | short_path = ntpath.basename(image_path[0])
27 | name = os.path.splitext(short_path)[0]
28 |
29 | webpage.add_header(name)
30 | ims, txts, links = [], [], []
31 |
32 | for label, im_data in visuals.items():
33 | im = util.tensor2im(im_data)
34 | image_name = '%s/%s.png' % (label, name)
35 | os.makedirs(os.path.join(image_dir, label), exist_ok=True)
36 | save_path = os.path.join(image_dir, image_name)
37 | util.save_image(im, save_path, aspect_ratio=aspect_ratio)
38 | ims.append(image_name)
39 | txts.append(label)
40 | links.append(image_name)
41 | webpage.add_images(ims, txts, links, width=width)
42 |
43 |
44 | class Visualizer():
45 | """This class includes several functions that can display/save images and print/save logging information.
46 |
47 | It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
48 | """
49 |
50 | def __init__(self, opt):
51 | """Initialize the Visualizer class
52 |
53 | Parameters:
54 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
55 | Step 1: Cache the training/test options
56 | Step 2: create a tensorboard writer
57 | Step 3: create an HTML object for saveing HTML filters
58 | Step 4: create a logging file to store training losses
59 | """
60 | self.opt = opt # cache the option
61 | self.use_html = opt.isTrain and not opt.no_html
62 | self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
63 | self.win_size = opt.display_winsize
64 | self.name = opt.name
65 | self.saved = False
66 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
67 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
68 | self.img_dir = os.path.join(self.web_dir, 'images')
69 | print('create web directory %s...' % self.web_dir)
70 | util.mkdirs([self.web_dir, self.img_dir])
71 | # create a logging file to store training losses
72 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
73 | with open(self.log_name, "a") as log_file:
74 | now = time.strftime("%c")
75 | log_file.write('================ Training Loss (%s) ================\n' % now)
76 |
77 | def reset(self):
78 | """Reset the self.saved status"""
79 | self.saved = False
80 |
81 |
82 | def display_current_results(self, visuals, total_iters, epoch, save_result):
83 | """Display current results on tensorboad; save current results to an HTML file.
84 |
85 | Parameters:
86 | visuals (OrderedDict) - - dictionary of images to display or save
87 | total_iters (int) -- total iterations
88 | epoch (int) - - the current epoch
89 | save_result (bool) - - if save the current results to an HTML file
90 | """
91 | for label, image in visuals.items():
92 | self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
93 |
94 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
95 | self.saved = True
96 | # save images to the disk
97 | for label, image in visuals.items():
98 | image_numpy = util.tensor2im(image)
99 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
100 | util.save_image(image_numpy, img_path)
101 |
102 | # update website
103 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
104 | for n in range(epoch, 0, -1):
105 | webpage.add_header('epoch [%d]' % n)
106 | ims, txts, links = [], [], []
107 |
108 | for label, image_numpy in visuals.items():
109 | image_numpy = util.tensor2im(image)
110 | img_path = 'epoch%.3d_%s.png' % (n, label)
111 | ims.append(img_path)
112 | txts.append(label)
113 | links.append(img_path)
114 | webpage.add_images(ims, txts, links, width=self.win_size)
115 | webpage.save()
116 |
117 | def plot_current_losses(self, total_iters, losses):
118 | # G_loss_collection = {}
119 | # D_loss_collection = {}
120 | # for name, value in losses.items():
121 | # if 'G' in name or 'NCE' in name or 'idt' in name:
122 | # G_loss_collection[name] = value
123 | # else:
124 | # D_loss_collection[name] = value
125 | # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
126 | # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
127 | for name, value in losses.items():
128 | self.writer.add_scalar(name, value, total_iters)
129 |
130 | # losses: same format as |losses| of plot_current_losses
131 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
132 | """print current losses on console; also save the losses to the disk
133 |
134 | Parameters:
135 | epoch (int) -- current epoch
136 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
137 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
138 | t_comp (float) -- computational time per data point (normalized by batch_size)
139 | t_data (float) -- data loading time per data point (normalized by batch_size)
140 | """
141 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
142 | for k, v in losses.items():
143 | message += '%s: %.3f ' % (k, v)
144 |
145 | print(message) # print the message
146 | with open(self.log_name, "a") as log_file:
147 | log_file.write('%s\n' % message) # save the message
148 |
149 |
150 | class MyVisualizer:
151 | def __init__(self, opt):
152 | """Initialize the Visualizer class
153 |
154 | Parameters:
155 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
156 | Step 1: Cache the training/test options
157 | Step 2: create a tensorboard writer
158 | Step 3: create an HTML object for saveing HTML filters
159 | Step 4: create a logging file to store training losses
160 | """
161 | self.opt = opt # cache the optio
162 | self.name = opt.name
163 | self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
164 |
165 | if opt.phase != 'test':
166 | self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
167 | # create a logging file to store training losses
168 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
169 | with open(self.log_name, "a") as log_file:
170 | now = time.strftime("%c")
171 | log_file.write('================ Training Loss (%s) ================\n' % now)
172 |
173 |
174 | def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
175 | add_image=True):
176 | """Display current results on tensorboad; save current results to an HTML file.
177 |
178 | Parameters:
179 | visuals (OrderedDict) - - dictionary of images to display or save
180 | total_iters (int) -- total iterations
181 | epoch (int) - - the current epoch
182 | dataset (str) - - 'train' or 'val' or 'test'
183 | """
184 | # if (not add_image) and (not save_results): return
185 |
186 | for label, image in visuals.items():
187 | for i in range(image.shape[0]):
188 | image_numpy = util.tensor2im(image[i])
189 | if add_image:
190 | self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
191 | image_numpy, total_iters, dataformats='HWC')
192 |
193 | if save_results:
194 | save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
195 | if not os.path.isdir(save_path):
196 | os.makedirs(save_path)
197 |
198 | if name is not None:
199 | img_path = os.path.join(save_path, '%s.png' % name)
200 | else:
201 | img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
202 | util.save_image(image_numpy, img_path)
203 |
204 |
205 | def plot_current_losses(self, total_iters, losses, dataset='train'):
206 | for name, value in losses.items():
207 | self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
208 |
209 | # losses: same format as |losses| of plot_current_losses
210 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
211 | """print current losses on console; also save the losses to the disk
212 |
213 | Parameters:
214 | epoch (int) -- current epoch
215 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
216 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
217 | t_comp (float) -- computational time per data point (normalized by batch_size)
218 | t_data (float) -- data loading time per data point (normalized by batch_size)
219 | """
220 | message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
221 | dataset, epoch, iters, t_comp, t_data)
222 | for k, v in losses.items():
223 | message += '%s: %.3f ' % (k, v)
224 |
225 | print(message) # print the message
226 | with open(self.log_name, "a") as log_file:
227 | log_file.write('%s\n' % message) # save the message
228 |
--------------------------------------------------------------------------------