├── .gitignore
├── LICENSE
├── README.md
├── assets
└── demo.png
├── decalib
├── __init__.py
├── datasets
│ ├── aflw2000.py
│ ├── build_datasets.py
│ ├── datasets.py
│ ├── detectors.py
│ ├── ethnicity.py
│ ├── now.py
│ ├── train_datasets.py
│ ├── vggface.py
│ └── vox.py
├── deca.py
├── models
│ ├── FLAME.py
│ ├── decoders.py
│ ├── encoders.py
│ ├── frnet.py
│ ├── lbs.py
│ └── resnet.py
├── trainer.py
└── utils
│ ├── config.py
│ ├── lossfunc.py
│ ├── rasterizer
│ ├── INSTALL.md
│ ├── __init__.py
│ ├── setup.py
│ ├── standard_rasterize_cuda.cpp
│ └── standard_rasterize_cuda_kernel.cu
│ ├── renderer.py
│ ├── rotation_converter.py
│ ├── tensor_cropper.py
│ ├── trainer.py
│ └── util.py
├── inference_rigface.py
├── requirements.txt
├── rigface
└── models
│ ├── attention_ID.py
│ ├── attention_denoising.py
│ ├── pipelineRigFace.py
│ ├── transformer_ID_2d.py
│ ├── transformer_denoising_2d.py
│ ├── unet_ID_2d_blocks.py
│ ├── unet_ID_2d_condition.py
│ ├── unet_denoising_2d_blocks.py
│ └── unet_denoising_2d_condition.py
└── utils
├── compute_renders.py
├── data_utils.py
├── datasets_faceswap.py
├── make_bgs.py
├── model.py
├── preprocess.py
├── resnet.py
├── save_exp_coeffs.py
└── test_images
├── id1
├── bg_pose+exp.png
├── exp_pose+exp.npy
├── render.png
├── render_exp+pose.png
├── render_exp.png
├── sor.png
└── tar.png
└── id2
├── bg_light.png
├── exp_light.npy
├── render_light.png
├── sor.png
└── tar.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Towards Consistent and Controllable Image Synthesis for Face Editing
4 |
5 |
6 | Mengting Wei, Tuomas Varanka, Yante Li, Xingxun Jiang, Huai-Qian Khor, Guoying Zhao
7 |
8 |
9 | University of Oulu
10 |
11 |
12 | ### [Paper](https://arxiv.org/abs/2502.02465)
13 |
14 |
15 |
16 | ## :mega: Updates
17 |
18 | [6/2/2025] Inference code and pre-trained models are released.
19 |
20 |
21 | ## Introduction
22 |
23 | We present RigFace, an efficient approach to edit the expression, pose and lighting with consistent
24 | identity and other attributes from a given image.
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## Installation
32 |
33 | To deploy and run RigFace, run the following scripts:
34 | ```
35 | conda create -n rigface python=3.9
36 | conda activate rigface
37 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
38 | pip install -r requirements.txt
39 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
40 | pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu116_pyt1130/download.html
41 | ```
42 |
43 | ## Download models and data
44 |
45 | - Download utils files for pre-processing from [Huggingface](https://huggingface.co/mengtingwei/rigface/tree/main).
46 | Put them in the `utils` directory.
47 | - Download pre-trained weights of our model from the same website.
48 | - Make a new directory `data` inside the `utils`.
49 | - Download pre_trained DECA model `deca_model.tar` from [here](https://github.com/yfeng95/DECA).
50 | Put it in the `data`.
51 | - Download `generic_model.pkl` from [FLAME2020](https://flame.is.tue.mpg.de/download.php).
52 | Also put it under `data`.
53 | - Download `FLAME_texture.npz` from [FLAME texture space](https://flame.is.tue.mpg.de/download.php).
54 | Put it in `data`.
55 | - Download other files in [DECA page](https://github.com/yfeng95/DECA/tree/master/data). Put all the files under `data`.
56 |
57 | ```
58 | ...
59 | pre_trained
60 | -unet_denoise
61 | ...
62 | -unet_id
63 | ...
64 | utils
65 | -checkpoints
66 | ...
67 | -third_party
68 | ...
69 | -third_party_files
70 | ...
71 | -data
72 | ...
73 | ...
74 | ```
75 |
76 | ## Test with our examples
77 |
78 |
79 |
80 | ```
81 | python inference_rigface.py --id_path utils/test_images/id1/sor.png --bg_path utils/test_images/id1/bg_pose+exp.png --exp_path utils/test_images/id3/exp_pose+exp.npy --render_path utils/test_images/id3/render_pose+exp.png --save_path ./res
82 | ```
83 |
84 | You will find the edited image under `res` directory.
85 |
86 |
87 | ## Test your own data
88 |
89 | 1. First, you need to ensure that both the source and target images have a resolution of 512x512.
90 |
91 | ```
92 | cd utils
93 | python preprocess.py --img_path --save_path
94 | ```
95 |
96 | 2. Parse background using the resized source and target images. We provide an example here.
97 |
98 | ```
99 | python make_bgs.py --sor_path ./test_images/id1/sor.png --tar_path ./test_images/id1/tar.png --modes pose+exp
100 | ```
101 |
102 | If you want to edit only one mode, just provide the single mode (for example '--modes exp') here.
103 |
104 | 3. Compute the expression coefficients. For the case of editing lighting or pose,
105 | the coefficients will be computed from the source image, and for the case of expression,
106 | they will be computed from the target image.
107 |
108 | ```
109 | python save_exp_coeffs.py --sor_path ./test_images/id1/sor.png --tar_path ./test_images/id1/tar.png --mode pose+exp
110 | ```
111 |
112 | 4. Compute the rendering according to the edit modes.
113 |
114 | ```
115 | python compute_renders.py --sor_path ./test_images/id1/sor.png --tar_path ./test_images/id1/tar.png --mode pose+exp
116 | ```
117 |
118 | 5. Then do the inference using all the conditions generated.
119 |
120 | ## Acknowledgements
121 |
122 | This project is built on source codes shared by [DECA](https://github.com/yfeng95/DECA),
123 | [Deep3DRecon](https://github.com/sicxu/Deep3DFaceRecon_pytorch),
124 | [faceParsing](https://github.com/zllrunning/face-parsing.PyTorch) and
125 | [ControlNet](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py).
126 |
--------------------------------------------------------------------------------
/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/assets/demo.png
--------------------------------------------------------------------------------
/decalib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/decalib/__init__.py
--------------------------------------------------------------------------------
/decalib/datasets/aflw2000.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 | import scipy.io
12 |
13 | class AFLW2000(Dataset):
14 | def __init__(self, testpath='/ps/scratch/yfeng/Data/AFLW2000/GT', crop_size=224):
15 | '''
16 | data class for loading AFLW2000 dataset
17 | make sure each image has corresponding mat file, which provides cropping infromation
18 | '''
19 | if os.path.isdir(testpath):
20 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png')
21 | elif isinstance(testpath, list):
22 | self.imagepath_list = testpath
23 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png']):
24 | self.imagepath_list = [testpath]
25 | else:
26 | print('please check the input path')
27 | exit()
28 | print('total {} images'.format(len(self.imagepath_list)))
29 | self.imagepath_list = sorted(self.imagepath_list)
30 | self.crop_size = crop_size
31 | self.scale = 1.6
32 | self.resolution_inp = crop_size
33 |
34 | def __len__(self):
35 | return len(self.imagepath_list)
36 |
37 | def __getitem__(self, index):
38 | imagepath = self.imagepath_list[index]
39 | imagename = imagepath.split('/')[-1].split('.')[0]
40 | image = imread(imagepath)[:,:,:3]
41 | kpt = scipy.io.loadmat(imagepath.replace('jpg', 'mat'))['pt3d_68'].T
42 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
43 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
44 |
45 | h, w, _ = image.shape
46 | old_size = (right - left + bottom - top)/2
47 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
48 | size = int(old_size*self.scale)
49 |
50 | # crop image
51 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
52 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]])
53 | tform = estimate_transform('similarity', src_pts, DST_PTS)
54 |
55 | image = image/255.
56 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp))
57 | dst_image = dst_image.transpose(2,0,1)
58 | return {'image': torch.tensor(dst_image).float(),
59 | 'imagename': imagename,
60 | # 'tform': tform,
61 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(),
62 | }
--------------------------------------------------------------------------------
/decalib/datasets/build_datasets.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | from torch.utils.data import Dataset, ConcatDataset
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import cv2
7 | import scipy
8 | from skimage.io import imread, imsave
9 | from skimage.transform import estimate_transform, warp, resize, rescale
10 | from glob import glob
11 |
12 | from .vggface import VGGFace2Dataset
13 | from .ethnicity import EthnicityDataset
14 | from .aflw2000 import AFLW2000
15 | from .now import NoWDataset
16 | from .vox import VoxelDataset
17 |
18 | def build_train(config, is_train=True):
19 | data_list = []
20 | if 'vox2' in config.training_data:
21 | data_list.append(VoxelDataset(dataname='vox2', K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
22 | if 'vggface2' in config.training_data:
23 | data_list.append(VGGFace2Dataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
24 | if 'vggface2hq' in config.training_data:
25 | data_list.append(VGGFace2HQDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
26 | if 'ethnicity' in config.training_data:
27 | data_list.append(EthnicityDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
28 | if 'coco' in config.training_data:
29 | data_list.append(COCODataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale))
30 | if 'celebahq' in config.training_data:
31 | data_list.append(CelebAHQDataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale))
32 | dataset = ConcatDataset(data_list)
33 |
34 | return dataset
35 |
36 | def build_val(config, is_train=True):
37 | data_list = []
38 | if 'vggface2' in config.eval_data:
39 | data_list.append(VGGFace2Dataset(isEval=True, K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
40 | if 'now' in config.eval_data:
41 | data_list.append(NoWDataset())
42 | if 'aflw2000' in config.eval_data:
43 | data_list.append(AFLW2000())
44 | dataset = ConcatDataset(data_list)
45 |
46 | return dataset
47 |
--------------------------------------------------------------------------------
/decalib/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import os, sys
17 | import torch
18 | from torch.utils.data import Dataset, DataLoader
19 | import torchvision.transforms as transforms
20 | import numpy as np
21 | import cv2
22 | import scipy
23 | from skimage.io import imread, imsave
24 | from skimage.transform import estimate_transform, warp, resize, rescale
25 | from glob import glob
26 | import scipy.io
27 | from decalib.datasets import datasets
28 | from torchvision.utils import save_image
29 |
30 | from . import detectors
31 |
32 | def video2sequence(video_path, sample_step=10):
33 | videofolder = os.path.splitext(video_path)[0]
34 | os.makedirs(videofolder, exist_ok=True)
35 | video_name = os.path.splitext(os.path.split(video_path)[-1])[0]
36 | vidcap = cv2.VideoCapture(video_path)
37 | success,image = vidcap.read()
38 | count = 0
39 | imagepath_list = []
40 | while success:
41 | # if count%sample_step == 0:
42 | imagepath = os.path.join(videofolder, f'{video_name}_frame{count:04d}.jpg')
43 | cv2.imwrite(imagepath, image) # save frame as JPEG file
44 | success,image = vidcap.read()
45 | count += 1
46 | imagepath_list.append(imagepath)
47 | print('video frames are stored in {}'.format(videofolder))
48 | return imagepath_list
49 |
50 | class TestData(Dataset):
51 | # 传递过来的有iscrop=True, size=256, sort=True
52 | def __init__(self, testpath, iscrop=True, crop_size=224, scale=1.25, face_detector='fan',
53 | sample_step=10, size=256, sort=False):
54 |
55 | if isinstance(testpath, list):
56 | self.imagepath_list = testpath
57 | elif os.path.isdir(testpath):
58 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png') + glob(testpath + '/*.bmp')
59 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png', 'bmp']):
60 | self.imagepath_list = [testpath]
61 | elif os.path.isfile(testpath) and (testpath[-3:] in ['mp4', 'csv', 'vid', 'ebm']):
62 | self.imagepath_list = video2sequence(testpath, sample_step)
63 |
64 | if sort:
65 | self.imagepath_list = sorted(self.imagepath_list)
66 | self.crop_size = crop_size
67 | self.scale = scale
68 | self.iscrop = iscrop
69 | self.resolution_inp = crop_size
70 | self.size = size
71 | # 使用的是face alignment的关键点检测工具
72 | if face_detector == 'fan':
73 | self.face_detector = detectors.FAN()
74 | else:
75 | print(f'please check the detector: {face_detector}')
76 | exit()
77 |
78 | def __len__(self):
79 | return len(self.imagepath_list)
80 |
81 | def bbox2point(self, left, right, top, bottom, type='bbox'):
82 |
83 | if type=='kpt68':
84 | old_size = (right - left + bottom - top) / 2 * 1.1
85 | # 人脸中心点的位置
86 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
87 | elif type=='bbox':
88 | old_size = (right - left + bottom - top)/2
89 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.12])
90 | else:
91 | raise NotImplementedError
92 | return old_size, center
93 |
94 | def get_image(self, image):
95 | h, w, _ = image.shape
96 | bbox, bbox_type = self.face_detector.run(image)
97 | if len(bbox) < 4:
98 | print('no face detected! run original image')
99 | left = 0; right = h-1; top=0; bottom=w-1
100 | else:
101 | left = bbox[0]; right=bbox[2]
102 | top = bbox[1]; bottom=bbox[3]
103 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type)
104 | size = int(old_size*self.scale)
105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
106 |
107 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]])
108 | tform = estimate_transform('similarity', src_pts, DST_PTS)
109 |
110 | image = image / 255.
111 |
112 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp))
113 | dst_image = dst_image.transpose(2,0,1)
114 | return {'image': torch.tensor(dst_image).float(),
115 | 'tform': torch.tensor(tform.params).float(),
116 | 'original_image': torch.tensor(image.transpose(2,0,1)).float(),
117 | }
118 |
119 |
120 | def __getitem__(self, index):
121 |
122 | imagepath = self.imagepath_list[index]
123 | imagename = os.path.splitext(os.path.split(imagepath)[-1])[0]
124 | im = imread(imagepath)
125 |
126 | if self.size is not None: # size = 256
127 | im = (resize(im, (self.size, self.size), anti_aliasing=True) * 255.).astype(np.uint8)
128 |
129 | # (256, 256, 3)
130 | image = np.array(im)
131 |
132 | if len(image.shape) == 2:
133 | image = image[:, :, None].repeat(1,1,3)
134 | if len(image.shape) == 3 and image.shape[2] > 3:
135 | image = image[:, :, :3]
136 |
137 | h, w, _ = image.shape
138 | if self.iscrop: # true
139 | # provide kpt as txt file, or mat file (for AFLW2000)
140 | # 检查是否存在landmark的文件,不存在则自己检测
141 | kpt_matpath = os.path.splitext(imagepath)[0]+'.mat'
142 | kpt_txtpath = os.path.splitext(imagepath)[0]+'.txt'
143 | if os.path.exists(kpt_matpath):
144 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T
145 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0])
146 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
147 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68')
148 | elif os.path.exists(kpt_txtpath):
149 | kpt = np.loadtxt(kpt_txtpath)
150 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0])
151 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
152 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68')
153 | else:
154 | bbox, bbox_type = self.face_detector.run(image)
155 | if len(bbox) < 4:
156 | print('no face detected! run original image')
157 | left = 0; right = h-1; top=0; bottom=w-1
158 | else:
159 | left = bbox[0]; right=bbox[2]
160 | top = bbox[1]; bottom=bbox[3]
161 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type)
162 | size = int(old_size * self.scale)
163 | src_pts = np.array([[center[0] - size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
164 | else:
165 | src_pts = np.array([[0, 0], [0, h-1], [w-1, 0]])
166 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]])
167 | # self.resolution_inp = 224,目标图像大小
168 | DST_PTS = np.array([[0, 0], [0, self.resolution_inp - 1], [self.resolution_inp - 1, 0]])
169 | # 计算源图像变换到目标图像需要经过怎样的矩阵转换
170 | tform = estimate_transform('similarity', src_pts, DST_PTS)
171 |
172 | image = image / 255. # 0-1区间
173 |
174 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp))
175 | dst_image = dst_image.transpose(2,0,1)
176 |
177 | return {'image': torch.tensor(dst_image).float(), # 只对面部区域进行了保留
178 | 'imagename': imagename,
179 | 'tform': torch.tensor(tform.params).float(),
180 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(),
181 | }
182 |
183 |
184 |
185 |
186 |
187 | if __name__ == '__main__':
188 | testdata_source = datasets.TestData(
189 | source, iscrop=True, size=512, sort=True
190 | )
--------------------------------------------------------------------------------
/decalib/datasets/detectors.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import numpy as np
17 | import torch
18 |
19 | class FAN(object):
20 | def __init__(self):
21 | import face_alignment
22 | self.model = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
23 |
24 | def run(self, image):
25 | '''
26 | image: 0-255, uint8, rgb, [h, w, 3]
27 | return: detected box list
28 | '''
29 | out = self.model.get_landmarks(image)
30 | if out is None:
31 | return [0], 'kpt68'
32 | else:
33 | kpt = out[0].squeeze()
34 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0])
35 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
36 | bbox = [left, top, right, bottom]
37 | return bbox, 'kpt68'
38 |
39 | class MTCNN(object):
40 | def __init__(self, device = 'cpu'):
41 | '''
42 | https://github.com/timesler/facenet-pytorch/blob/master/examples/infer.ipynb
43 | '''
44 | from facenet_pytorch import MTCNN as mtcnn
45 | self.device = device
46 | self.model = mtcnn(keep_all=True)
47 | def run(self, input):
48 | '''
49 | image: 0-255, uint8, rgb, [h, w, 3]
50 | return: detected box
51 | '''
52 | out = self.model.detect(input[None,...])
53 | if out[0][0] is None:
54 | return [0]
55 | else:
56 | bbox = out[0][0].squeeze()
57 | return bbox, 'bbox'
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/decalib/datasets/ethnicity.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class EthnicityDataset(Dataset):
13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False):
14 | '''
15 | K must be less than 6
16 | '''
17 | self.K = K
18 | self.image_size = image_size
19 | self.imagefolder = '/ps/scratch/face2d3d/train'
20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/'
21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/'
22 | # hq:
23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_and_race_per_7000_african_asian_2d_train_list_max_normal_100_ring_5_1_serial.npy'
25 | self.data_lines = np.load(datafile).astype('str')
26 |
27 | self.isTemporal = isTemporal
28 | self.scale = scale #[scale_min, scale_max]
29 | self.trans_scale = trans_scale #[dx, dy]
30 | self.isSingle = isSingle
31 | if isSingle:
32 | self.K = 1
33 |
34 | def __len__(self):
35 | return len(self.data_lines)
36 |
37 | def __getitem__(self, idx):
38 | images_list = []; kpt_list = []; mask_list = []
39 | for i in range(self.K):
40 | name = self.data_lines[idx, i]
41 | if name[0]=='n':
42 | self.imagefolder = '/ps/scratch/face2d3d/train/'
43 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/'
44 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/'
45 | elif name[0]=='A':
46 | self.imagefolder = '/ps/scratch/face2d3d/race_per_7000/'
47 | self.kptfolder = '/ps/scratch/face2d3d/race_per_7000_annotated_torch7_new/'
48 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/race7000_seg/test_crop_size_400_batch/'
49 |
50 | image_path = os.path.join(self.imagefolder, name + '.jpg')
51 | seg_path = os.path.join(self.segfolder, name + '.npy')
52 | kpt_path = os.path.join(self.kptfolder, name + '.npy')
53 |
54 | image = imread(image_path)/255.
55 | kpt = np.load(kpt_path)[:,:2]
56 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1])
57 |
58 | ### crop information
59 | tform = self.crop(image, kpt)
60 | ## crop
61 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
62 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size))
63 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
64 |
65 | # normalized kpt
66 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
67 |
68 | images_list.append(cropped_image.transpose(2,0,1))
69 | kpt_list.append(cropped_kpt)
70 | mask_list.append(cropped_mask)
71 |
72 | ###
73 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
74 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
75 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3
76 |
77 | if self.isSingle:
78 | images_array = images_array.squeeze()
79 | kpt_array = kpt_array.squeeze()
80 | mask_array = mask_array.squeeze()
81 |
82 | data_dict = {
83 | 'image': images_array,
84 | 'landmark': kpt_array,
85 | 'mask': mask_array
86 | }
87 |
88 | return data_dict
89 |
90 | def crop(self, image, kpt):
91 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
92 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
93 |
94 | h, w, _ = image.shape
95 | old_size = (right - left + bottom - top)/2
96 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
97 | # translate center
98 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
99 | center = center + trans_scale*old_size # 0.5
100 |
101 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
102 | size = int(old_size*scale)
103 |
104 | # crop image
105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
106 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
107 | tform = estimate_transform('similarity', src_pts, DST_PTS)
108 |
109 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
110 | # # change kpt accordingly
111 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
112 | return tform
113 |
114 | def load_mask(self, maskpath, h, w):
115 | # print(maskpath)
116 | if os.path.isfile(maskpath):
117 | vis_parsing_anno = np.load(maskpath)
118 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
119 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
120 | mask = np.zeros_like(vis_parsing_anno)
121 | # for i in range(1, 16):
122 | mask[vis_parsing_anno>0.5] = 1.
123 | else:
124 | mask = np.ones((h, w))
125 | return mask
126 |
127 |
--------------------------------------------------------------------------------
/decalib/datasets/now.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class NoWDataset(Dataset):
13 | def __init__(self, ring_elements=6, crop_size=224, scale=1.6):
14 | folder = '/ps/scratch/yfeng/other-github/now_evaluation/data/NoW_Dataset'
15 | self.data_path = os.path.join(folder, 'imagepathsvalidation.txt')
16 | with open(self.data_path) as f:
17 | self.data_lines = f.readlines()
18 |
19 | self.imagefolder = os.path.join(folder, 'final_release_version', 'iphone_pictures')
20 | self.bbxfolder = os.path.join(folder, 'final_release_version', 'detected_face')
21 |
22 | # self.data_path = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/test_image_paths_ring_6_elements.npy'
23 | # self.imagepath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/iphone_pictures/'
24 | # self.bbxpath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/detected_face/'
25 | self.crop_size = crop_size
26 | self.scale = scale
27 |
28 | def __len__(self):
29 | return len(self.data_lines)
30 |
31 | def __getitem__(self, index):
32 | imagepath = os.path.join(self.imagefolder, self.data_lines[index].strip()) #+ '.jpg'
33 | bbx_path = os.path.join(self.bbxfolder, self.data_lines[index].strip().replace('.jpg', '.npy'))
34 | bbx_data = np.load(bbx_path, allow_pickle=True, encoding='latin1').item()
35 | # box = np.array([[bbx_data['left'], bbx_data['top']], [bbx_data['right'], bbx_data['bottom']]]).astype('float32')
36 | left = bbx_data['left']; right = bbx_data['right']
37 | top = bbx_data['top']; bottom = bbx_data['bottom']
38 |
39 | imagename = imagepath.split('/')[-1].split('.')[0]
40 | image = imread(imagepath)[:,:,:3]
41 |
42 | h, w, _ = image.shape
43 | old_size = (right - left + bottom - top)/2
44 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
45 | size = int(old_size*self.scale)
46 |
47 | # crop image
48 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
49 | DST_PTS = np.array([[0,0], [0,self.crop_size - 1], [self.crop_size - 1, 0]])
50 | tform = estimate_transform('similarity', src_pts, DST_PTS)
51 |
52 | image = image/255.
53 | dst_image = warp(image, tform.inverse, output_shape=(self.crop_size, self.crop_size))
54 | dst_image = dst_image.transpose(2,0,1)
55 | return {'image': torch.tensor(dst_image).float(),
56 | 'imagename': self.data_lines[index].strip().replace('.jpg', ''),
57 | # 'tform': tform,
58 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(),
59 | }
--------------------------------------------------------------------------------
/decalib/datasets/vggface.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class VGGFace2Dataset(Dataset):
13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False):
14 | '''
15 | K must be less than 6
16 | '''
17 | self.K = K
18 | self.image_size = image_size
19 | self.imagefolder = '/ps/scratch/face2d3d/train'
20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7'
21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch'
22 | # hq:
23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_train_list_max_normal_100_ring_5_1_serial.npy'
25 | if isEval:
26 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_val_list_max_normal_100_ring_5_1_serial.npy'
27 | self.data_lines = np.load(datafile).astype('str')
28 |
29 | self.isTemporal = isTemporal
30 | self.scale = scale #[scale_min, scale_max]
31 | self.trans_scale = trans_scale #[dx, dy]
32 | self.isSingle = isSingle
33 | if isSingle:
34 | self.K = 1
35 |
36 | def __len__(self):
37 | return len(self.data_lines)
38 |
39 | def __getitem__(self, idx):
40 | images_list = []; kpt_list = []; mask_list = []
41 |
42 | random_ind = np.random.permutation(5)[:self.K]
43 | for i in random_ind:
44 | name = self.data_lines[idx, i]
45 | image_path = os.path.join(self.imagefolder, name + '.jpg')
46 | seg_path = os.path.join(self.segfolder, name + '.npy')
47 | kpt_path = os.path.join(self.kptfolder, name + '.npy')
48 |
49 | image = imread(image_path)/255.
50 | kpt = np.load(kpt_path)[:,:2]
51 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1])
52 |
53 | ### crop information
54 | tform = self.crop(image, kpt)
55 | ## crop
56 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
57 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size))
58 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
59 |
60 | # normalized kpt
61 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
62 |
63 | images_list.append(cropped_image.transpose(2,0,1))
64 | kpt_list.append(cropped_kpt)
65 | mask_list.append(cropped_mask)
66 |
67 | ###
68 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
69 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
70 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3
71 |
72 | if self.isSingle:
73 | images_array = images_array.squeeze()
74 | kpt_array = kpt_array.squeeze()
75 | mask_array = mask_array.squeeze()
76 |
77 | data_dict = {
78 | 'image': images_array,
79 | 'landmark': kpt_array,
80 | 'mask': mask_array
81 | }
82 |
83 | return data_dict
84 |
85 | def crop(self, image, kpt):
86 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
87 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
88 |
89 | h, w, _ = image.shape
90 | old_size = (right - left + bottom - top)/2
91 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
92 | # translate center
93 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
94 | center = center + trans_scale*old_size # 0.5
95 |
96 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
97 | size = int(old_size*scale)
98 |
99 | # crop image
100 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
101 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
102 | tform = estimate_transform('similarity', src_pts, DST_PTS)
103 |
104 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
105 | # # change kpt accordingly
106 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
107 | return tform
108 |
109 | def load_mask(self, maskpath, h, w):
110 | # print(maskpath)
111 | if os.path.isfile(maskpath):
112 | vis_parsing_anno = np.load(maskpath)
113 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
114 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
115 | mask = np.zeros_like(vis_parsing_anno)
116 | # for i in range(1, 16):
117 | mask[vis_parsing_anno>0.5] = 1.
118 | else:
119 | mask = np.ones((h, w))
120 | return mask
121 |
122 |
123 |
124 | class VGGFace2HQDataset(Dataset):
125 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False):
126 | '''
127 | K must be less than 6
128 | '''
129 | self.K = K
130 | self.image_size = image_size
131 | self.imagefolder = '/ps/scratch/face2d3d/train'
132 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7'
133 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch'
134 | # hq:
135 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
136 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
137 | self.data_lines = np.load(datafile).astype('str')
138 |
139 | self.isTemporal = isTemporal
140 | self.scale = scale #[scale_min, scale_max]
141 | self.trans_scale = trans_scale #[dx, dy]
142 | self.isSingle = isSingle
143 | if isSingle:
144 | self.K = 1
145 |
146 | def __len__(self):
147 | return len(self.data_lines)
148 |
149 | def __getitem__(self, idx):
150 | images_list = []; kpt_list = []; mask_list = []
151 |
152 | for i in range(self.K):
153 | name = self.data_lines[idx, i]
154 | image_path = os.path.join(self.imagefolder, name + '.jpg')
155 | seg_path = os.path.join(self.segfolder, name + '.npy')
156 | kpt_path = os.path.join(self.kptfolder, name + '.npy')
157 |
158 | image = imread(image_path)/255.
159 | kpt = np.load(kpt_path)[:,:2]
160 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1])
161 |
162 | ### crop information
163 | tform = self.crop(image, kpt)
164 | ## crop
165 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
166 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size))
167 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
168 |
169 | # normalized kpt
170 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
171 |
172 | images_list.append(cropped_image.transpose(2,0,1))
173 | kpt_list.append(cropped_kpt)
174 | mask_list.append(cropped_mask)
175 |
176 | ###
177 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
178 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
179 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3
180 |
181 | if self.isSingle:
182 | images_array = images_array.squeeze()
183 | kpt_array = kpt_array.squeeze()
184 | mask_array = mask_array.squeeze()
185 |
186 | data_dict = {
187 | 'image': images_array,
188 | 'landmark': kpt_array,
189 | 'mask': mask_array
190 | }
191 |
192 | return data_dict
193 |
194 | def crop(self, image, kpt):
195 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
196 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
197 |
198 | h, w, _ = image.shape
199 | old_size = (right - left + bottom - top)/2
200 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
201 | # translate center
202 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
203 | center = center + trans_scale*old_size # 0.5
204 |
205 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
206 | size = int(old_size*scale)
207 |
208 | # crop image
209 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
210 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
211 | tform = estimate_transform('similarity', src_pts, DST_PTS)
212 |
213 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
214 | # # change kpt accordingly
215 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
216 | return tform
217 |
218 | def load_mask(self, maskpath, h, w):
219 | # print(maskpath)
220 | if os.path.isfile(maskpath):
221 | vis_parsing_anno = np.load(maskpath)
222 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
223 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
224 | mask = np.zeros_like(vis_parsing_anno)
225 | # for i in range(1, 16):
226 | mask[vis_parsing_anno>0.5] = 1.
227 | else:
228 | mask = np.ones((h, w))
229 | return mask
--------------------------------------------------------------------------------
/decalib/datasets/vox.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class VoxelDataset(Dataset):
13 | def __init__(self, K, image_size, scale, trans_scale = 0, dataname='vox2', n_train=100000, isTemporal=False, isEval=False, isSingle=False):
14 | self.K = K
15 | self.image_size = image_size
16 | if dataname == 'vox1':
17 | self.kpt_suffix = '.txt'
18 | self.imagefolder = '/ps/project/face2d3d/VoxCeleb/vox1/dev/images_cropped'
19 | self.kptfolder = '/ps/scratch/yfeng/Data/VoxCeleb/vox1/landmark_2d'
20 |
21 | self.face_dict = {}
22 | for person_id in sorted(os.listdir(self.kptfolder)):
23 | for video_id in os.listdir(os.path.join(self.kptfolder, person_id)):
24 | for face_id in os.listdir(os.path.join(self.kptfolder, person_id, video_id)):
25 | if 'txt' in face_id:
26 | continue
27 | key = person_id + '/' + video_id + '/' + face_id
28 | # if key not in self.face_dict.keys():
29 | # self.face_dict[key] = []
30 | name_list = os.listdir(os.path.join(self.kptfolder, person_id, video_id, face_id))
31 | name_list = [name.split['.'][0] for name in name_list]
32 | if len(name_list)0.5] = 1.
162 | else:
163 | mask = np.ones((h, w))
164 | return mask
165 |
--------------------------------------------------------------------------------
/decalib/deca.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import os, sys
17 | import torch
18 | import torchvision
19 | import torch.nn.functional as F
20 | import torch.nn as nn
21 |
22 | import numpy as np
23 | from time import time
24 | from skimage.io import imread
25 | import cv2
26 | import pickle
27 | from .utils.renderer import SRenderY, set_rasterizer
28 | from .models.encoders import ResnetEncoder
29 | from .models.FLAME import FLAME, FLAMETex
30 | from .models.decoders import Generator
31 | from .utils import util
32 | from .utils.rotation_converter import batch_euler2axis
33 | from .utils.tensor_cropper import transform_points
34 | from .datasets import datasets
35 | from .utils.config import cfg
36 | torch.backends.cudnn.benchmark = True
37 |
38 | class DECA(nn.Module):
39 | def __init__(self, config=None, device='cuda'):
40 | super(DECA, self).__init__()
41 | if config is None:
42 | self.cfg = cfg
43 | else:
44 | self.cfg = config
45 | self.device = device
46 | self.image_size = self.cfg.dataset.image_size
47 | self.uv_size = self.cfg.model.uv_size
48 |
49 | self._create_model(self.cfg.model)
50 | self._setup_renderer(self.cfg.model)
51 |
52 | def _setup_renderer(self, model_cfg):
53 | set_rasterizer(self.cfg.rasterizer_type)
54 | self.render = SRenderY(self.image_size, obj_filename=model_cfg.topology_path, uv_size=model_cfg.uv_size, rasterizer_type=self.cfg.rasterizer_type).to(self.device)
55 | # face mask for rendering details
56 | mask = imread(model_cfg.face_eye_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous()
57 | self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device)
58 | mask = imread(model_cfg.face_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous()
59 | self.uv_face_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device)
60 | # displacement correction
61 | fixed_dis = np.load(model_cfg.fixed_displacement_path)
62 | self.fixed_uv_dis = torch.tensor(fixed_dis).float().to(self.device)
63 | # mean texture
64 | mean_texture = imread(model_cfg.mean_tex_path).astype(np.float32)/255.; mean_texture = torch.from_numpy(mean_texture.transpose(2,0,1))[None,:,:,:].contiguous()
65 | self.mean_texture = F.interpolate(mean_texture, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device)
66 | # dense mesh template, for save detail mesh
67 | self.dense_template = np.load(model_cfg.dense_template_path, allow_pickle=True, encoding='latin1').item()
68 |
69 | def _create_model(self, model_cfg):
70 | # set up parameters
71 | self.n_param = model_cfg.n_shape+model_cfg.n_tex+model_cfg.n_exp+model_cfg.n_pose+model_cfg.n_cam+model_cfg.n_light
72 | self.n_detail = model_cfg.n_detail
73 | self.n_cond = model_cfg.n_exp + 3 # exp + jaw pose
74 | self.num_list = [model_cfg.n_shape, model_cfg.n_tex, model_cfg.n_exp, model_cfg.n_pose, model_cfg.n_cam, model_cfg.n_light]
75 | self.param_dict = {i:model_cfg.get('n_' + i) for i in model_cfg.param_list}
76 |
77 | # encoders
78 | self.E_flame = ResnetEncoder(outsize=self.n_param).to(self.device)
79 | self.E_detail = ResnetEncoder(outsize=self.n_detail).to(self.device)
80 | # decoders
81 | self.flame = FLAME(model_cfg).to(self.device)
82 | if model_cfg.use_tex:
83 | self.flametex = FLAMETex(model_cfg).to(self.device)
84 | self.D_detail = Generator(latent_dim=self.n_detail+self.n_cond, out_channels=1, out_scale=model_cfg.max_z, sample_mode = 'bilinear').to(self.device)
85 | # resume model
86 | model_path = self.cfg.pretrained_modelpath
87 | if os.path.exists(model_path):
88 | print(f'trained model found. load {model_path}')
89 | checkpoint = torch.load(model_path)
90 | self.checkpoint = checkpoint
91 | util.copy_state_dict(self.E_flame.state_dict(), checkpoint['E_flame'])
92 | util.copy_state_dict(self.E_detail.state_dict(), checkpoint['E_detail'])
93 | util.copy_state_dict(self.D_detail.state_dict(), checkpoint['D_detail'])
94 | else:
95 | print(f'please check model path: {model_path}')
96 | # eval mode
97 | self.E_flame.eval()
98 | self.E_detail.eval()
99 | self.D_detail.eval()
100 |
101 | def decompose_code(self, code, num_dict):
102 | ''' Convert a flattened parameter vector to a dictionary of parameters
103 | code_dict.keys() = ['shape', 'tex', 'exp', 'pose', 'cam', 'light']
104 | '''
105 | code_dict = {}
106 | start = 0
107 | for key in num_dict:
108 | end = start+int(num_dict[key])
109 | code_dict[key] = code[:, start:end]
110 | start = end
111 | if key == 'light':
112 | code_dict[key] = code_dict[key].reshape(code_dict[key].shape[0], 9, 3)
113 | return code_dict
114 |
115 | def displacement2normal(self, uv_z, coarse_verts, coarse_normals):
116 | ''' Convert displacement map into detail normal map
117 | '''
118 | batch_size = uv_z.shape[0]
119 | uv_coarse_vertices = self.render.world2uv(coarse_verts).detach()
120 | uv_coarse_normals = self.render.world2uv(coarse_normals).detach()
121 |
122 | uv_z = uv_z*self.uv_face_eye_mask
123 | uv_detail_vertices = uv_coarse_vertices + uv_z*uv_coarse_normals + self.fixed_uv_dis[None,None,:,:]*uv_coarse_normals.detach()
124 | dense_vertices = uv_detail_vertices.permute(0,2,3,1).reshape([batch_size, -1, 3])
125 | uv_detail_normals = util.vertex_normals(dense_vertices, self.render.dense_faces.expand(batch_size, -1, -1))
126 | uv_detail_normals = uv_detail_normals.reshape([batch_size, uv_coarse_vertices.shape[2], uv_coarse_vertices.shape[3], 3]).permute(0,3,1,2)
127 | uv_detail_normals = uv_detail_normals*self.uv_face_eye_mask + uv_coarse_normals*(1.-self.uv_face_eye_mask)
128 | return uv_detail_normals
129 |
130 | def visofp(self, normals):
131 | ''' visibility of keypoints, based on the normal direction
132 | '''
133 | normals68 = self.flame.seletec_3d68(normals)
134 | vis68 = (normals68[:,:,2:] < 0.1).float()
135 | return vis68
136 |
137 | # @torch.no_grad()
138 | def encode(self, images, use_detail=True):
139 | if use_detail:
140 | # use_detail is for training detail model, need to set coarse model as eval mode
141 | with torch.no_grad():
142 | parameters = self.E_flame(images)
143 | else:
144 | parameters = self.E_flame(images)
145 | codedict = self.decompose_code(parameters, self.param_dict)
146 | codedict['images'] = images
147 | if use_detail:
148 | detailcode = self.E_detail(images)
149 | codedict['detail'] = detailcode
150 | if self.cfg.model.jaw_type == 'euler':
151 | posecode = codedict['pose']
152 | euler_jaw_pose = posecode[:,3:].clone() # x for yaw (open mouth), y for pitch (left ang right), z for roll
153 | posecode[:,3:] = batch_euler2axis(euler_jaw_pose)
154 | codedict['pose'] = posecode
155 | codedict['euler_jaw_pose'] = euler_jaw_pose
156 | return codedict
157 |
158 | # @torch.no_grad()
159 | def decode(self, codedict, rendering=True, iddict=None, vis_lmk=True, return_vis=True, use_detail=True,
160 | render_orig=False, original_image=None, tform=None, add_light=True, th=0,
161 | align_ffhq=False, return_ffhq_center=False, ffhq_center=None,
162 | light_type='point', render_norm=False):
163 | # 仍然是裁减之后的图像
164 | images = codedict['images']
165 | batch_size = images.shape[0]
166 |
167 | ## decode
168 | verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape'],
169 | expression_params=codedict['exp'],
170 | pose_params=codedict['pose'])
171 |
172 | if (align_ffhq and ffhq_center is not None) or return_ffhq_center:
173 | lm_eye_left = landmarks2d[:, 36:42] # left-clockwise
174 | lm_eye_right = landmarks2d[:, 42:48] # left-clockwise
175 | lm_mouth_outer = landmarks2d[:, 48:60] # left-clockwise
176 |
177 | eye_left = torch.mean(lm_eye_left, dim=1)
178 | eye_right = torch.mean(lm_eye_right, dim=1)
179 | eye_avg = (eye_left + eye_right) * 0.5
180 | # eye_to_eye = eye_right - eye_left
181 | mouth_left = lm_mouth_outer[:, 0]
182 | mouth_right = lm_mouth_outer[:, 6]
183 | mouth_avg = (mouth_left + mouth_right) * 0.5
184 | eye_to_mouth = mouth_avg - eye_avg
185 |
186 | center = eye_avg + eye_to_mouth * 0.1
187 |
188 | if return_ffhq_center:
189 | return center
190 |
191 | if align_ffhq:
192 | delta = ffhq_center - center
193 | verts = verts + delta
194 |
195 | if self.cfg.model.use_tex:
196 | albedo = self.flametex(codedict['tex'])
197 | else:
198 | albedo = torch.zeros([batch_size, 3, self.uv_size, self.uv_size], device=images.device)
199 | landmarks3d_world = landmarks3d.clone()
200 |
201 | ## projection
202 | landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:,:,:2]; landmarks2d[:,:,1:] = -landmarks2d[:,:,1:]#; landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2
203 | landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam']); landmarks3d[:,:,1:] = -landmarks3d[:,:,1:] #; landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2
204 |
205 | trans_verts = util.batch_orth_proj(verts, codedict['cam']); trans_verts[:,:,1:] = -trans_verts[:,:,1:]
206 | opdict = {
207 | 'verts': verts,
208 | 'trans_verts': trans_verts,
209 | 'landmarks2d': landmarks2d,
210 | 'landmarks3d': landmarks3d,
211 | 'landmarks3d_world': landmarks3d_world,
212 | }
213 |
214 | ## rendering
215 | if return_vis and render_orig and original_image is not None and tform is not None:
216 | points_scale = [self.image_size, self.image_size]
217 | _, _, h, w = original_image.shape
218 | # import ipdb; ipdb.set_trace()
219 | trans_verts = transform_points(trans_verts, tform, points_scale, [h, w])
220 | landmarks2d = transform_points(landmarks2d, tform, points_scale, [h, w])
221 | landmarks3d = transform_points(landmarks3d, tform, points_scale, [h, w])
222 | images = original_image
223 | else:
224 | h, w = self.image_size, self.image_size
225 |
226 | if rendering:
227 | ops = self.render(verts, trans_verts, albedo, codedict['light'], h=h, w=w,
228 | add_light=add_light, th=th, light_type=light_type, render_norm=render_norm)
229 | ## output
230 | opdict['grid'] = ops['grid']
231 | opdict['rendered_images'] = ops['images']
232 | opdict['alpha_images'] = ops['alpha_images']
233 | opdict['normal_images'] = ops['normal_images']
234 | opdict['albedo_images'] = ops['albedo_images']
235 |
236 | if self.cfg.model.use_tex:
237 | opdict['albedo'] = albedo
238 |
239 | return opdict, _
240 |
241 | def visualize(self, visdict, size=224, dim=2):
242 | '''
243 | image range should be [0,1]
244 | dim: 2 for horizontal. 1 for vertical
245 | '''
246 | assert dim == 1 or dim==2
247 | grids = {}
248 | for key in visdict:
249 | _,_,h,w = visdict[key].shape
250 | if dim == 2:
251 | new_h = size; new_w = int(w*size/h)
252 | elif dim == 1:
253 | new_h = int(h*size/w); new_w = size
254 | grids[key] = torchvision.utils.make_grid(F.interpolate(visdict[key], [new_h, new_w]).detach().cpu())
255 | grid = torch.cat(list(grids.values()), dim)
256 | grid_image = (grid.numpy().transpose(1,2,0).copy()*255)[:,:,[2,1,0]]
257 | grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8)
258 | return grid_image
259 |
260 | def save_obj(self, filename, opdict):
261 | '''
262 | vertices: [nv, 3], tensor
263 | texture: [3, h, w], tensor
264 | '''
265 | i = 0
266 | vertices = opdict['verts'][i].cpu().numpy()
267 | faces = self.render.faces[0].cpu().numpy()
268 | texture = util.tensor2image(opdict['uv_texture_gt'][i])
269 | uvcoords = self.render.raw_uvcoords[0].cpu().numpy()
270 | uvfaces = self.render.uvfaces[0].cpu().numpy()
271 | # save coarse mesh, with texture and normal map
272 | normal_map = util.tensor2image(opdict['uv_detail_normals'][i]*0.5 + 0.5)
273 | util.write_obj(filename, vertices, faces,
274 | texture=texture,
275 | uvcoords=uvcoords,
276 | uvfaces=uvfaces,
277 | normal_map=normal_map)
278 | # upsample mesh, save detailed mesh
279 | texture = texture[:,:,[2,1,0]]
280 | normals = opdict['normals'][i].cpu().numpy()
281 | displacement_map = opdict['displacement_map'][i].cpu().numpy().squeeze()
282 | dense_vertices, dense_colors, dense_faces = util.upsample_mesh(vertices, normals, faces, displacement_map, texture, self.dense_template)
283 | util.write_obj(filename.replace('.obj', '_detail.obj'),
284 | dense_vertices,
285 | dense_faces,
286 | colors = dense_colors,
287 | inverse_face_order=True)
288 |
289 | def run(self, imagepath, iscrop=True):
290 | ''' An api for running deca given an image path
291 | '''
292 | testdata = datasets.TestData(imagepath)
293 | images = testdata[0]['image'].to(self.device)[None,...]
294 | codedict = self.encode(images)
295 | opdict, visdict = self.decode(codedict)
296 | return codedict, opdict, visdict
297 |
298 | def model_dict(self):
299 | return {
300 | 'E_flame': self.E_flame.state_dict(),
301 | 'E_detail': self.E_detail.state_dict(),
302 | 'D_detail': self.D_detail.state_dict()
303 | }
304 |
--------------------------------------------------------------------------------
/decalib/models/FLAME.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch
17 | import torch.nn as nn
18 | import numpy as np
19 | import pickle
20 | import torch.nn.functional as F
21 |
22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler
23 |
24 | def to_tensor(array, dtype=torch.float32):
25 | if 'torch.tensor' not in str(type(array)):
26 | return torch.tensor(array, dtype=dtype)
27 | def to_np(array, dtype=np.float32):
28 | if 'scipy.sparse' in str(type(array)):
29 | array = array.todense()
30 | return np.array(array, dtype=dtype)
31 |
32 | class Struct(object):
33 | def __init__(self, **kwargs):
34 | for key, val in kwargs.items():
35 | setattr(self, key, val)
36 |
37 | class FLAME(nn.Module):
38 | """
39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py
40 | Given flame parameters this class generates a differentiable FLAME function
41 | which outputs the a mesh and 2D/3D facial landmarks
42 | """
43 | def __init__(self, config):
44 | super(FLAME, self).__init__()
45 | print("creating the FLAME Decoder")
46 | with open(config.flame_model_path, 'rb') as f:
47 | ss = pickle.load(f, encoding='latin1')
48 | flame_model = Struct(**ss)
49 |
50 | self.dtype = torch.float32
51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
52 | # The vertices of the template model
53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
54 | # The shape components and expression
55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2)
57 | self.register_buffer('shapedirs', shapedirs)
58 | # The pose components
59 | num_pose_basis = flame_model.posedirs.shape[-1]
60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
62 | #
63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype))
64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1
65 | self.register_buffer('parents', parents)
66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype))
67 |
68 | # Fixing Eyeball and neck rotation
69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False)
70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose,
71 | requires_grad=False))
72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False)
73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose,
74 | requires_grad=False))
75 |
76 | # Static and Dynamic Landmark embeddings for FLAME
77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1')
78 | lmk_embeddings = lmk_embeddings[()]
79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long())
80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype))
81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long())
82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype))
83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long())
84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype))
85 |
86 | neck_kin_chain = []; NECK_IDX=1
87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
88 | while curr_idx != -1:
89 | neck_kin_chain.append(curr_idx)
90 | curr_idx = self.parents[curr_idx]
91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain))
92 |
93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx,
94 | dynamic_lmk_b_coords,
95 | neck_kin_chain, dtype=torch.float32):
96 | """
97 | Selects the face contour depending on the reletive position of the head
98 | Input:
99 | vertices: N X num_of_vertices X 3
100 | pose: N X full pose
101 | dynamic_lmk_faces_idx: The list of contour face indexes
102 | dynamic_lmk_b_coords: The list of contour barycentric weights
103 | neck_kin_chain: The tree to consider for the relative rotation
104 | dtype: Data type
105 | return:
106 | The contour face indexes and the corresponding barycentric weights
107 | """
108 |
109 | batch_size = pose.shape[0]
110 |
111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
112 | neck_kin_chain)
113 | rot_mats = batch_rodrigues(
114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
115 |
116 | rel_rot_mat = torch.eye(3, device=pose.device,
117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
118 | for idx in range(len(neck_kin_chain)):
119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
120 |
121 | y_rot_angle = torch.round(
122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
123 | max=39)).to(dtype=torch.long)
124 |
125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
128 | y_rot_angle = (neg_mask * neg_vals +
129 | (1 - neg_mask) * y_rot_angle)
130 |
131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
132 | 0, y_rot_angle)
133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
134 | 0, y_rot_angle)
135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
136 |
137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords):
138 | """
139 | Calculates landmarks by barycentric interpolation
140 | Input:
141 | vertices: torch.tensor NxVx3, dtype = torch.float32
142 | The tensor of input vertices
143 | faces: torch.tensor (N*F)x3, dtype = torch.long
144 | The faces of the mesh
145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long
146 | The tensor with the indices of the faces used to calculate the
147 | landmarks.
148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32
149 | The tensor of barycentric coordinates that are used to interpolate
150 | the landmarks
151 |
152 | Returns:
153 | landmarks: torch.tensor NxLx3, dtype = torch.float32
154 | The coordinates of the landmarks for each mesh in the batch
155 | """
156 | # Extract the indices of the vertices for each face
157 | # NxLx3
158 | batch_size, num_verts = vertices.shape[:dd2]
159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1)
161 |
162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to(
163 | device=vertices.device) * num_verts
164 |
165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces]
166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
167 | return landmarks
168 |
169 | def seletec_3d68(self, vertices):
170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
173 | return landmarks3d
174 |
175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None):
176 | """
177 | Input:
178 | shape_params: N X number of shape parameters
179 | expression_params: N X number of expression parameters
180 | pose_params: N X number of pose parameters (6)
181 | return:d
182 | vertices: N X V X 3
183 | landmarks: N X number of landmarks X 3
184 | """
185 | batch_size = shape_params.shape[0]
186 | if pose_params is None:
187 | pose_params = self.eye_pose.expand(batch_size, -1)
188 | if eye_pose_params is None:
189 | eye_pose_params = self.eye_pose.expand(batch_size, -1)
190 | betas = torch.cat([shape_params, expression_params], dim=1)
191 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1)
192 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
193 |
194 | vertices, _ = lbs(betas, full_pose, template_vertices,
195 | self.shapedirs, self.posedirs,
196 | self.J_regressor, self.parents,
197 | self.lbs_weights, dtype=self.dtype)
198 |
199 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
200 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
201 |
202 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
203 | full_pose, self.dynamic_lmk_faces_idx,
204 | self.dynamic_lmk_bary_coords,
205 | self.neck_kin_chain, dtype=self.dtype)
206 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
207 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
208 |
209 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
210 | lmk_faces_idx,
211 | lmk_bary_coords)
212 | bz = vertices.shape[0]
213 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
214 | self.full_lmk_faces_idx.repeat(bz, 1),
215 | self.full_lmk_bary_coords.repeat(bz, 1, 1))
216 | return vertices, landmarks2d, landmarks3d
217 |
218 | class FLAMETex(nn.Module):
219 | """
220 | FLAME texture:
221 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64
222 | FLAME texture converted from BFM:
223 | https://github.com/TimoBolkart/BFM_to_FLAME
224 | """
225 | def __init__(self, config):
226 | super(FLAMETex, self).__init__()
227 | if config.tex_type == 'BFM':
228 | mu_key = 'MU'
229 | pc_key = 'PC'
230 | n_pc = 199
231 | tex_path = config.tex_path
232 | tex_space = np.load(tex_path)
233 | texture_mean = tex_space[mu_key].reshape(1, -1)
234 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)
235 |
236 | elif config.tex_type == 'FLAME':
237 | mu_key = 'mean'
238 | pc_key = 'tex_dir'
239 | n_pc = 200
240 | tex_path = config.tex_path # config.flame_tex_path
241 | tex_space = np.load(tex_path)
242 | texture_mean = tex_space[mu_key].reshape(1, -1)/255.
243 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255.
244 | else:
245 | print('texture type ', config.tex_type, 'not exist!')
246 | raise NotImplementedError
247 |
248 | n_tex = config.n_tex
249 | num_components = texture_basis.shape[1]
250 | texture_mean = torch.from_numpy(texture_mean).float()[None,...]
251 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...]
252 | self.register_buffer('texture_mean', texture_mean)
253 | self.register_buffer('texture_basis', texture_basis)
254 |
255 | def forward(self, texcode):
256 | '''
257 | texcode: [batchsize, n_tex]
258 | texture: [bz, 3, 256, 256], range: 0-1
259 | '''
260 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1)
261 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2)
262 | texture = F.interpolate(texture, [256, 256])
263 | texture = texture[:,[2,1,0], :,:]
264 | return texture
265 |
--------------------------------------------------------------------------------
/decalib/models/decoders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | class Generator(nn.Module):
20 | def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'):
21 | super(Generator, self).__init__()
22 | self.out_scale = out_scale
23 |
24 | self.init_size = 32 // 4 # Initial size before upsampling
25 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
26 | self.conv_blocks = nn.Sequential(
27 | nn.BatchNorm2d(128),
28 | nn.Upsample(scale_factor=2, mode=sample_mode), #16
29 | nn.Conv2d(128, 128, 3, stride=1, padding=1),
30 | nn.BatchNorm2d(128, 0.8),
31 | nn.LeakyReLU(0.2, inplace=True),
32 | nn.Upsample(scale_factor=2, mode=sample_mode), #32
33 | nn.Conv2d(128, 64, 3, stride=1, padding=1),
34 | nn.BatchNorm2d(64, 0.8),
35 | nn.LeakyReLU(0.2, inplace=True),
36 | nn.Upsample(scale_factor=2, mode=sample_mode), #64
37 | nn.Conv2d(64, 64, 3, stride=1, padding=1),
38 | nn.BatchNorm2d(64, 0.8),
39 | nn.LeakyReLU(0.2, inplace=True),
40 | nn.Upsample(scale_factor=2, mode=sample_mode), #128
41 | nn.Conv2d(64, 32, 3, stride=1, padding=1),
42 | nn.BatchNorm2d(32, 0.8),
43 | nn.LeakyReLU(0.2, inplace=True),
44 | nn.Upsample(scale_factor=2, mode=sample_mode), #256
45 | nn.Conv2d(32, 16, 3, stride=1, padding=1),
46 | nn.BatchNorm2d(16, 0.8),
47 | nn.LeakyReLU(0.2, inplace=True),
48 | nn.Conv2d(16, out_channels, 3, stride=1, padding=1),
49 | nn.Tanh(),
50 | )
51 |
52 | def forward(self, noise):
53 | out = self.l1(noise)
54 | out = out.view(out.shape[0], 128, self.init_size, self.init_size)
55 | img = self.conv_blocks(out)
56 | return img*self.out_scale
--------------------------------------------------------------------------------
/decalib/models/encoders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import numpy as np
17 | import torch.nn as nn
18 | import torch
19 | import torch.nn.functional as F
20 | from . import resnet
21 |
22 | class ResnetEncoder(nn.Module):
23 | def __init__(self, outsize, last_op=None):
24 | super(ResnetEncoder, self).__init__()
25 | feature_size = 2048
26 | self.encoder = resnet.load_ResNet50Model() #out: 2048
27 | ### regressor
28 | self.layers = nn.Sequential(
29 | nn.Linear(feature_size, 1024),
30 | nn.ReLU(),
31 | nn.Linear(1024, outsize)
32 | )
33 | self.last_op = last_op
34 |
35 | def forward(self, inputs):
36 | features = self.encoder(inputs)
37 | parameters = self.layers(features)
38 | if self.last_op:
39 | parameters = self.last_op(parameters)
40 | return parameters
41 |
--------------------------------------------------------------------------------
/decalib/models/frnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | # from pro_gan_pytorch.PRO_GAN import ProGAN, Generator, Discriminator
5 | import torch.nn.functional as F
6 | import cv2
7 | from torch.autograd import Variable
8 | import math
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = conv3x3(inplanes, planes, stride)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.conv2 = conv3x3(planes, planes)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.downsample = downsample
26 | self.stride = stride
27 |
28 | def forward(self, x):
29 | residual = x
30 |
31 | out = self.conv1(x)
32 | out = self.bn1(out)
33 | out = self.relu(out)
34 |
35 | out = self.conv2(out)
36 | out = self.bn2(out)
37 |
38 | if self.downsample is not None:
39 | residual = self.downsample(x)
40 |
41 | out += residual
42 | out = self.relu(out)
43 |
44 | return out
45 |
46 |
47 | class Bottleneck(nn.Module):
48 | expansion = 4
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super(Bottleneck, self).__init__()
52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
55 | self.bn2 = nn.BatchNorm2d(planes)
56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
57 | self.bn3 = nn.BatchNorm2d(planes * 4)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 |
87 | def __init__(self, block, layers, num_classes=1000, include_top=True):
88 | self.inplanes = 64
89 | super(ResNet, self).__init__()
90 | self.include_top = include_top
91 |
92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
96 |
97 | self.layer1 = self._make_layer(block, 64, layers[0])
98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
101 | self.avgpool = nn.AvgPool2d(7, stride=1)
102 | self.fc = nn.Linear(512 * block.expansion, num_classes)
103 |
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107 | m.weight.data.normal_(0, math.sqrt(2. / n))
108 | elif isinstance(m, nn.BatchNorm2d):
109 | m.weight.data.fill_(1)
110 | m.bias.data.zero_()
111 |
112 | def _make_layer(self, block, planes, blocks, stride=1):
113 | downsample = None
114 | if stride != 1 or self.inplanes != planes * block.expansion:
115 | downsample = nn.Sequential(
116 | nn.Conv2d(self.inplanes, planes * block.expansion,
117 | kernel_size=1, stride=stride, bias=False),
118 | nn.BatchNorm2d(planes * block.expansion),
119 | )
120 |
121 | layers = []
122 | layers.append(block(self.inplanes, planes, stride, downsample))
123 | self.inplanes = planes * block.expansion
124 | for i in range(1, blocks):
125 | layers.append(block(self.inplanes, planes))
126 |
127 | return nn.Sequential(*layers)
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 | x = self.bn1(x)
132 | x = self.relu(x)
133 | x = self.maxpool(x)
134 |
135 | x = self.layer1(x)
136 | x = self.layer2(x)
137 | x = self.layer3(x)
138 | x = self.layer4(x)
139 |
140 | x = self.avgpool(x)
141 |
142 | if not self.include_top:
143 | return x
144 |
145 | x = x.view(x.size(0), -1)
146 | x = self.fc(x)
147 | return x
148 |
149 | def resnet50(**kwargs):
150 | """Constructs a ResNet-50 model.
151 | """
152 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
153 | return model
154 |
155 | import pickle
156 | def load_state_dict(model, fname):
157 | """
158 | Set parameters converted from Caffe models authors of VGGFace2 provide.
159 | See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
160 | Arguments:
161 | model: model
162 | fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
163 | """
164 | with open(fname, 'rb') as f:
165 | weights = pickle.load(f, encoding='latin1')
166 |
167 | own_state = model.state_dict()
168 | for name, param in weights.items():
169 | if name in own_state:
170 | try:
171 | own_state[name].copy_(torch.from_numpy(param))
172 | except Exception:
173 | raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
174 | 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
175 | else:
176 | raise KeyError('unexpected key "{}" in state_dict'.format(name))
177 |
178 |
--------------------------------------------------------------------------------
/decalib/models/lbs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from __future__ import absolute_import
18 | from __future__ import print_function
19 | from __future__ import division
20 |
21 | import numpy as np
22 |
23 | import torch
24 | import torch.nn.functional as F
25 |
26 | def rot_mat_to_euler(rot_mats):
27 | # Calculates rotation matrix to euler angles
28 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
29 |
30 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
31 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
32 | return torch.atan2(-rot_mats[:, 2, 0], sy)
33 |
34 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
35 | dynamic_lmk_b_coords,
36 | neck_kin_chain, dtype=torch.float32):
37 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks
38 |
39 |
40 | To do so, we first compute the rotation of the neck around the y-axis
41 | and then use a pre-computed look-up table to find the faces and the
42 | barycentric coordinates that will be used.
43 |
44 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
45 | for providing the original TensorFlow implementation and for the LUT.
46 |
47 | Parameters
48 | ----------
49 | vertices: torch.tensor BxVx3, dtype = torch.float32
50 | The tensor of input vertices
51 | pose: torch.tensor Bx(Jx3), dtype = torch.float32
52 | The current pose of the body model
53 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
54 | The look-up table from neck rotation to faces
55 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
56 | The look-up table from neck rotation to barycentric coordinates
57 | neck_kin_chain: list
58 | A python list that contains the indices of the joints that form the
59 | kinematic chain of the neck.
60 | dtype: torch.dtype, optional
61 |
62 | Returns
63 | -------
64 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
65 | A tensor of size BxL that contains the indices of the faces that
66 | will be used to compute the current dynamic landmarks.
67 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
68 | A tensor of size BxL that contains the indices of the faces that
69 | will be used to compute the current dynamic landmarks.
70 | '''
71 |
72 | batch_size = vertices.shape[0]
73 |
74 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
75 | neck_kin_chain)
76 | rot_mats = batch_rodrigues(
77 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
78 |
79 | rel_rot_mat = torch.eye(3, device=vertices.device,
80 | dtype=dtype).unsqueeze_(dim=0)
81 | for idx in range(len(neck_kin_chain)):
82 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
83 |
84 | y_rot_angle = torch.round(
85 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
86 | max=39)).to(dtype=torch.long)
87 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
88 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
89 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
90 | y_rot_angle = (neg_mask * neg_vals +
91 | (1 - neg_mask) * y_rot_angle)
92 |
93 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
94 | 0, y_rot_angle)
95 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
96 | 0, y_rot_angle)
97 |
98 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
99 |
100 |
101 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
102 | ''' Calculates landmarks by barycentric interpolation
103 |
104 | Parameters
105 | ----------
106 | vertices: torch.tensor BxVx3, dtype = torch.float32
107 | The tensor of input vertices
108 | faces: torch.tensor Fx3, dtype = torch.long
109 | The faces of the mesh
110 | lmk_faces_idx: torch.tensor L, dtype = torch.long
111 | The tensor with the indices of the faces used to calculate the
112 | landmarks.
113 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
114 | The tensor of barycentric coordinates that are used to interpolate
115 | the landmarks
116 |
117 | Returns
118 | -------
119 | landmarks: torch.tensor BxLx3, dtype = torch.float32
120 | The coordinates of the landmarks for each mesh in the batch
121 | '''
122 | # Extract the indices of the vertices for each face
123 | # BxLx3
124 | batch_size, num_verts = vertices.shape[:2]
125 | device = vertices.device
126 |
127 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
128 | batch_size, -1, 3)
129 |
130 | lmk_faces += torch.arange(
131 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
132 |
133 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
134 | batch_size, -1, 3, 3)
135 |
136 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
137 | return landmarks
138 |
139 |
140 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
141 | lbs_weights, pose2rot=True, dtype=torch.float32):
142 | ''' Performs Linear Blend Skinning with the given shape and pose parameters
143 |
144 | Parameters
145 | ----------
146 | betas : torch.tensor BxNB
147 | The tensor of shape parameters
148 | pose : torch.tensor Bx(J + 1) * 3
149 | The pose parameters in axis-angle format
150 | v_template torch.tensor BxVx3
151 | The template mesh that will be deformed
152 | shapedirs : torch.tensor 1xNB
153 | The tensor of PCA shape displacements
154 | posedirs : torch.tensor Px(V * 3)
155 | The pose PCA coefficients
156 | J_regressor : torch.tensor JxV
157 | The regressor array that is used to calculate the joints from
158 | the position of the vertices
159 | parents: torch.tensor J
160 | The array that describes the kinematic tree for the model
161 | lbs_weights: torch.tensor N x V x (J + 1)
162 | The linear blend skinning weights that represent how much the
163 | rotation matrix of each part affects each vertex
164 | pose2rot: bool, optional
165 | Flag on whether to convert the input pose tensor to rotation
166 | matrices. The default value is True. If False, then the pose tensor
167 | should already contain rotation matrices and have a size of
168 | Bx(J + 1)x9
169 | dtype: torch.dtype, optional
170 |
171 | Returns
172 | -------
173 | verts: torch.tensor BxVx3
174 | The vertices of the mesh after applying the shape and pose
175 | displacements.
176 | joints: torch.tensor BxJx3
177 | The joints of the model
178 | '''
179 |
180 | batch_size = max(betas.shape[0], pose.shape[0])
181 | device = betas.device
182 |
183 | # Add shape contribution
184 | v_shaped = v_template + blend_shapes(betas, shapedirs)
185 |
186 | # Get the joints
187 | # NxJx3 array
188 | J = vertices2joints(J_regressor, v_shaped)
189 |
190 | # 3. Add pose blend shapes
191 | # N x J x 3 x 3
192 | ident = torch.eye(3, dtype=dtype, device=device)
193 | if pose2rot:
194 | rot_mats = batch_rodrigues(
195 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
196 |
197 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
198 | # (N x P) x (P, V * 3) -> N x V x 3
199 | pose_offsets = torch.matmul(pose_feature, posedirs) \
200 | .view(batch_size, -1, 3)
201 | else:
202 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
203 | rot_mats = pose.view(batch_size, -1, 3, 3)
204 |
205 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
206 | posedirs).view(batch_size, -1, 3)
207 |
208 | v_posed = pose_offsets + v_shaped
209 | # 4. Get the global joint location
210 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
211 |
212 | # 5. Do skinning:
213 | # W is N x V x (J + 1)
214 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
215 | # (N x V x (J + 1)) x (N x (J + 1) x 16)
216 | num_joints = J_regressor.shape[0]
217 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
218 | .view(batch_size, -1, 4, 4)
219 |
220 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
221 | dtype=dtype, device=device)
222 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
223 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
224 |
225 | verts = v_homo[:, :, :3, 0]
226 |
227 | return verts, J_transformed
228 |
229 |
230 | def vertices2joints(J_regressor, vertices):
231 | ''' Calculates the 3D joint locations from the vertices
232 |
233 | Parameters
234 | ----------
235 | J_regressor : torch.tensor JxV
236 | The regressor array that is used to calculate the joints from the
237 | position of the vertices
238 | vertices : torch.tensor BxVx3
239 | The tensor of mesh vertices
240 |
241 | Returns
242 | -------
243 | torch.tensor BxJx3
244 | The location of the joints
245 | '''
246 |
247 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
248 |
249 |
250 | def blend_shapes(betas, shape_disps):
251 | ''' Calculates the per vertex displacement due to the blend shapes
252 |
253 |
254 | Parameters
255 | ----------
256 | betas : torch.tensor Bx(num_betas)
257 | Blend shape coefficients
258 | shape_disps: torch.tensor Vx3x(num_betas)
259 | Blend shapes
260 |
261 | Returns
262 | -------
263 | torch.tensor BxVx3
264 | The per-vertex displacement due to shape deformation
265 | '''
266 |
267 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
268 | # i.e. Multiply each shape displacement by its corresponding beta and
269 | # then sum them.
270 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
271 | return blend_shape
272 |
273 |
274 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
275 | ''' Calculates the rotation matrices for a batch of rotation vectors
276 | Parameters
277 | ----------
278 | rot_vecs: torch.tensor Nx3
279 | array of N axis-angle vectors
280 | Returns
281 | -------
282 | R: torch.tensor Nx3x3
283 | The rotation matrices for the given axis-angle parameters
284 | '''
285 |
286 | batch_size = rot_vecs.shape[0]
287 | device = rot_vecs.device
288 |
289 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
290 | rot_dir = rot_vecs / angle
291 |
292 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
293 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
294 |
295 | # Bx1 arrays
296 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
297 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
298 |
299 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
300 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
301 | .view((batch_size, 3, 3))
302 |
303 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
304 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
305 | return rot_mat
306 |
307 |
308 | def transform_mat(R, t):
309 | ''' Creates a batch of transformation matrices
310 | Args:
311 | - R: Bx3x3 array of a batch of rotation matrices
312 | - t: Bx3x1 array of a batch of translation vectors
313 | Returns:
314 | - T: Bx4x4 Transformation matrix
315 | '''
316 | # No padding left or right, only add an extra row
317 | return torch.cat([F.pad(R, [0, 0, 0, 1]),
318 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
319 |
320 |
321 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
322 | """
323 | Applies a batch of rigid transformations to the joints
324 |
325 | Parameters
326 | ----------
327 | rot_mats : torch.tensor BxNx3x3
328 | Tensor of rotation matrices
329 | joints : torch.tensor BxNx3
330 | Locations of joints
331 | parents : torch.tensor BxN
332 | The kinematic tree of each object
333 | dtype : torch.dtype, optional:
334 | The data type of the created tensors, the default is torch.float32
335 |
336 | Returns
337 | -------
338 | posed_joints : torch.tensor BxNx3
339 | The locations of the joints after applying the pose rotations
340 | rel_transforms : torch.tensor BxNx4x4
341 | The relative (with respect to the root joint) rigid transformations
342 | for all the joints
343 | """
344 |
345 | joints = torch.unsqueeze(joints, dim=-1)
346 |
347 | rel_joints = joints.clone()
348 | rel_joints[:, 1:] -= joints[:, parents[1:]]
349 |
350 | # transforms_mat = transform_mat(
351 | # rot_mats.view(-1, 3, 3),
352 | # rel_joints.view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4)
353 | transforms_mat = transform_mat(
354 | rot_mats.view(-1, 3, 3),
355 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
356 |
357 | transform_chain = [transforms_mat[:, 0]]
358 | for i in range(1, parents.shape[0]):
359 | # Subtract the joint location at the rest pose
360 | # No need for rotation, since it's identity when at rest
361 | curr_res = torch.matmul(transform_chain[parents[i]],
362 | transforms_mat[:, i])
363 | transform_chain.append(curr_res)
364 |
365 | transforms = torch.stack(transform_chain, dim=1)
366 |
367 | # The last column of the transformations contains the posed joints
368 | posed_joints = transforms[:, :, :3, 3]
369 |
370 | # The last column of the transformations contains the posed joints
371 | posed_joints = transforms[:, :, :3, 3]
372 |
373 | joints_homogen = F.pad(joints, [0, 0, 0, 1])
374 |
375 | rel_transforms = transforms - F.pad(
376 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
377 |
378 | return posed_joints, rel_transforms
--------------------------------------------------------------------------------
/decalib/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Soubhik Sanyal
3 | Copyright (c) 2019, Soubhik Sanyal
4 | All rights reserved.
5 | Loads different resnet models
6 | """
7 | '''
8 | file: Resnet.py
9 | date: 2018_05_02
10 | author: zhangxiong(1025679612@qq.com)
11 | mark: copied from pytorch source code
12 | '''
13 |
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch
17 | from torch.nn.parameter import Parameter
18 | import torch.optim as optim
19 | import numpy as np
20 | import math
21 | import torchvision
22 |
23 | class ResNet(nn.Module):
24 | def __init__(self, block, layers, num_classes=1000):
25 | self.inplanes = 64
26 | super(ResNet, self).__init__()
27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
28 | bias=False)
29 | self.bn1 = nn.BatchNorm2d(64)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
32 | self.layer1 = self._make_layer(block, 64, layers[0])
33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
36 | self.avgpool = nn.AvgPool2d(7, stride=1)
37 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
38 |
39 | for m in self.modules():
40 | if isinstance(m, nn.Conv2d):
41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
42 | m.weight.data.normal_(0, math.sqrt(2. / n))
43 | elif isinstance(m, nn.BatchNorm2d):
44 | m.weight.data.fill_(1)
45 | m.bias.data.zero_()
46 |
47 | def _make_layer(self, block, planes, blocks, stride=1):
48 | downsample = None
49 | if stride != 1 or self.inplanes != planes * block.expansion:
50 | downsample = nn.Sequential(
51 | nn.Conv2d(self.inplanes, planes * block.expansion,
52 | kernel_size=1, stride=stride, bias=False),
53 | nn.BatchNorm2d(planes * block.expansion),
54 | )
55 |
56 | layers = []
57 | layers.append(block(self.inplanes, planes, stride, downsample))
58 | self.inplanes = planes * block.expansion
59 | for i in range(1, blocks):
60 | layers.append(block(self.inplanes, planes))
61 |
62 | return nn.Sequential(*layers)
63 |
64 | def forward(self, x):
65 | x = self.conv1(x)
66 | x = self.bn1(x)
67 | x = self.relu(x)
68 | x = self.maxpool(x)
69 |
70 | x = self.layer1(x)
71 | x = self.layer2(x)
72 | x = self.layer3(x)
73 | x1 = self.layer4(x)
74 |
75 | x2 = self.avgpool(x1)
76 | x2 = x2.view(x2.size(0), -1)
77 | # x = self.fc(x)
78 | ## x2: [bz, 2048] for shape
79 | ## x1: [bz, 2048, 7, 7] for texture
80 | return x2
81 |
82 | class Bottleneck(nn.Module):
83 | expansion = 4
84 |
85 | def __init__(self, inplanes, planes, stride=1, downsample=None):
86 | super(Bottleneck, self).__init__()
87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
88 | self.bn1 = nn.BatchNorm2d(planes)
89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
90 | padding=1, bias=False)
91 | self.bn2 = nn.BatchNorm2d(planes)
92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
93 | self.bn3 = nn.BatchNorm2d(planes * 4)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.downsample = downsample
96 | self.stride = stride
97 |
98 | def forward(self, x):
99 | residual = x
100 |
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv2(out)
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | out = self.bn3(out)
111 |
112 | if self.downsample is not None:
113 | residual = self.downsample(x)
114 |
115 | out += residual
116 | out = self.relu(out)
117 |
118 | return out
119 |
120 | def conv3x3(in_planes, out_planes, stride=1):
121 | """3x3 convolution with padding"""
122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
123 | padding=1, bias=False)
124 |
125 | class BasicBlock(nn.Module):
126 | expansion = 1
127 |
128 | def __init__(self, inplanes, planes, stride=1, downsample=None):
129 | super(BasicBlock, self).__init__()
130 | self.conv1 = conv3x3(inplanes, planes, stride)
131 | self.bn1 = nn.BatchNorm2d(planes)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.conv2 = conv3x3(planes, planes)
134 | self.bn2 = nn.BatchNorm2d(planes)
135 | self.downsample = downsample
136 | self.stride = stride
137 |
138 | def forward(self, x):
139 | residual = x
140 |
141 | out = self.conv1(x)
142 | out = self.bn1(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv2(out)
146 | out = self.bn2(out)
147 |
148 | if self.downsample is not None:
149 | residual = self.downsample(x)
150 |
151 | out += residual
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 | def copy_parameter_from_resnet(model, resnet_dict):
157 | cur_state_dict = model.state_dict()
158 | # import ipdb; ipdb.set_trace()
159 | for name, param in list(resnet_dict.items())[0:None]:
160 | if name not in cur_state_dict:
161 | # print(name, ' not available in reconstructed resnet')
162 | continue
163 | if isinstance(param, Parameter):
164 | param = param.data
165 | try:
166 | cur_state_dict[name].copy_(param)
167 | except:
168 | # print(name, ' is inconsistent!')
169 | continue
170 | # print('copy resnet state dict finished!')
171 | # import ipdb; ipdb.set_trace()
172 |
173 | def load_ResNet50Model():
174 | model = ResNet(Bottleneck, [3, 4, 6, 3])
175 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = False).state_dict())
176 | return model
177 |
178 | def load_ResNet101Model():
179 | model = ResNet(Bottleneck, [3, 4, 23, 3])
180 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict())
181 | return model
182 |
183 | def load_ResNet152Model():
184 | model = ResNet(Bottleneck, [3, 8, 36, 3])
185 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict())
186 | return model
187 |
188 | # model.load_state_dict(checkpoint['model_state_dict'])
189 |
190 |
191 | ######## Unet
192 |
193 | class DoubleConv(nn.Module):
194 | """(convolution => [BN] => ReLU) * 2"""
195 |
196 | def __init__(self, in_channels, out_channels):
197 | super().__init__()
198 | self.double_conv = nn.Sequential(
199 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
200 | nn.BatchNorm2d(out_channels),
201 | nn.ReLU(inplace=True),
202 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
203 | nn.BatchNorm2d(out_channels),
204 | nn.ReLU(inplace=True)
205 | )
206 |
207 | def forward(self, x):
208 | return self.double_conv(x)
209 |
210 |
211 | class Down(nn.Module):
212 | """Downscaling with maxpool then double conv"""
213 |
214 | def __init__(self, in_channels, out_channels):
215 | super().__init__()
216 | self.maxpool_conv = nn.Sequential(
217 | nn.MaxPool2d(2),
218 | DoubleConv(in_channels, out_channels)
219 | )
220 |
221 | def forward(self, x):
222 | return self.maxpool_conv(x)
223 |
224 |
225 | class Up(nn.Module):
226 | """Upscaling then double conv"""
227 |
228 | def __init__(self, in_channels, out_channels, bilinear=True):
229 | super().__init__()
230 |
231 | # if bilinear, use the normal convolutions to reduce the number of channels
232 | if bilinear:
233 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
234 | else:
235 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
236 |
237 | self.conv = DoubleConv(in_channels, out_channels)
238 |
239 | def forward(self, x1, x2):
240 | x1 = self.up(x1)
241 | # input is CHW
242 | diffY = x2.size()[2] - x1.size()[2]
243 | diffX = x2.size()[3] - x1.size()[3]
244 |
245 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
246 | diffY // 2, diffY - diffY // 2])
247 | # if you have padding issues, see
248 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
249 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
250 | x = torch.cat([x2, x1], dim=1)
251 | return self.conv(x)
252 |
253 |
254 | class OutConv(nn.Module):
255 | def __init__(self, in_channels, out_channels):
256 | super(OutConv, self).__init__()
257 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
258 |
259 | def forward(self, x):
260 | return self.conv(x)
--------------------------------------------------------------------------------
/decalib/utils/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Default config for DECA
3 | '''
4 | from yacs.config import CfgNode as CN
5 | import argparse
6 | import yaml
7 | import os
8 |
9 | cfg = CN()
10 |
11 | abs_deca_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'utils'))
12 |
13 | cfg.deca_dir = abs_deca_dir
14 | cfg.device = 'cuda'
15 | cfg.device_id = '0'
16 |
17 | cfg.pretrained_modelpath = os.path.join(cfg.deca_dir, 'data', 'deca_model.tar')
18 | cfg.output_dir = ''
19 | cfg.rasterizer_type = 'pytorch3d'
20 | # ---------------------------------------------------------------------------- #
21 | # Options for Face model
22 | # ---------------------------------------------------------------------------- #
23 | cfg.model = CN()
24 | cfg.model.topology_path = os.path.join(cfg.deca_dir, 'data', 'head_template.obj')
25 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip
26 | cfg.model.dense_template_path = os.path.join(cfg.deca_dir, 'data', 'texture_data_256.npy')
27 | cfg.model.fixed_displacement_path = os.path.join(cfg.deca_dir, 'data', 'fixed_displacement_256.npy')
28 | cfg.model.flame_model_path = os.path.join(cfg.deca_dir, 'data', 'generic_model.pkl')
29 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.deca_dir, 'data', 'landmark_embedding.npy')
30 | cfg.model.face_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_mask.png')
31 | cfg.model.face_eye_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_eye_mask.png')
32 | cfg.model.mean_tex_path = os.path.join(cfg.deca_dir, 'data', 'mean_texture.jpg')
33 | cfg.model.tex_path = os.path.join(cfg.deca_dir, 'data', 'FLAME_albedo_from_BFM.npz')
34 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM
35 | cfg.model.uv_size = 256
36 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light']
37 | cfg.model.n_shape = 100
38 | cfg.model.n_tex = 50
39 | cfg.model.n_exp = 50
40 | cfg.model.n_cam = 3
41 | cfg.model.n_pose = 6
42 | cfg.model.n_light = 27
43 | cfg.model.use_tex = True
44 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler. Note that: aa is not stable in the beginning
45 | # face recognition model
46 | cfg.model.fr_model_path = os.path.join(cfg.deca_dir, 'data', 'resnet50_ft_weight.pkl')
47 |
48 | ## details
49 | cfg.model.n_detail = 128
50 | cfg.model.max_z = 0.01
51 |
52 | # ---------------------------------------------------------------------------- #
53 | # Options for Dataset
54 | # ---------------------------------------------------------------------------- #
55 | cfg.dataset = CN()
56 | cfg.dataset.training_data = ['vggface2', 'ethnicity']
57 | # cfg.dataset.training_data = ['ethnicity']
58 | cfg.dataset.eval_data = ['aflw2000']
59 | cfg.dataset.test_data = ['']
60 | cfg.dataset.batch_size = 2
61 | cfg.dataset.K = 4
62 | cfg.dataset.isSingle = False
63 | cfg.dataset.num_workers = 2
64 | cfg.dataset.image_size = 224
65 | cfg.dataset.scale_min = 1.4
66 | cfg.dataset.scale_max = 1.8
67 | cfg.dataset.trans_scale = 0.
68 |
69 | # ---------------------------------------------------------------------------- #
70 | # Options for training
71 | # ---------------------------------------------------------------------------- #
72 | cfg.train = CN()
73 | cfg.train.train_detail = False
74 | cfg.train.max_epochs = 500
75 | cfg.train.max_steps = 1000000
76 | cfg.train.lr = 1e-4
77 | cfg.train.log_dir = 'logs'
78 | cfg.train.log_steps = 10
79 | cfg.train.vis_dir = 'train_images'
80 | cfg.train.vis_steps = 200
81 | cfg.train.write_summary = True
82 | cfg.train.checkpoint_steps = 500
83 | cfg.train.val_steps = 500
84 | cfg.train.val_vis_dir = 'val_images'
85 | cfg.train.eval_steps = 5000
86 | cfg.train.resume = True
87 |
88 | # ---------------------------------------------------------------------------- #
89 | # Options for Losses
90 | # ---------------------------------------------------------------------------- #
91 | cfg.loss = CN()
92 | cfg.loss.lmk = 1.0
93 | cfg.loss.useWlmk = True
94 | cfg.loss.eyed = 1.0
95 | cfg.loss.lipd = 0.5
96 | cfg.loss.photo = 2.0
97 | cfg.loss.useSeg = True
98 | cfg.loss.id = 0.2
99 | cfg.loss.id_shape_only = True
100 | cfg.loss.reg_shape = 1e-04
101 | cfg.loss.reg_exp = 1e-04
102 | cfg.loss.reg_tex = 1e-04
103 | cfg.loss.reg_light = 1.
104 | cfg.loss.reg_jaw_pose = 0. #1.
105 | cfg.loss.use_gender_prior = False
106 | cfg.loss.shape_consistency = True
107 | # loss for detail
108 | cfg.loss.detail_consistency = True
109 | cfg.loss.useConstraint = True
110 | cfg.loss.mrf = 5e-2
111 | cfg.loss.photo_D = 2.
112 | cfg.loss.reg_sym = 0.005
113 | cfg.loss.reg_z = 0.005
114 | cfg.loss.reg_diff = 0.005
115 |
116 |
117 | def get_cfg_defaults():
118 | """Get a yacs CfgNode object with default values for my_project."""
119 | # Return a clone so that the defaults will not be altered
120 | # This is for the "local variable" use pattern
121 | return cfg.clone()
122 |
123 | def update_cfg(cfg, cfg_file):
124 | cfg.merge_from_file(cfg_file)
125 | return cfg.clone()
126 |
127 | def parse_args():
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--cfg', type=str, help='cfg file path')
130 | parser.add_argument('--mode', type=str, default = 'train', help='deca mode')
131 |
132 | args = parser.parse_args()
133 | print(args, end='\n\n')
134 |
135 | cfg = get_cfg_defaults()
136 | cfg.cfg_file = None
137 | cfg.mode = args.mode
138 | # import ipdb; ipdb.set_trace()
139 | if args.cfg is not None:
140 | cfg_file = args.cfg
141 | cfg = update_cfg(cfg, args.cfg)
142 | cfg.cfg_file = cfg_file
143 |
144 | return cfg
145 |
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/INSTALL.md:
--------------------------------------------------------------------------------
1 | ## Install
2 | from standard_rasterize_cuda import standard_rasterize
3 | # from .rasterizer.standard_rasterize_cuda import standard_rasterize
4 |
5 | in this folder, run
6 | ```python setup.py build_ext -i ```
7 |
8 | then remember to set --rasterizer_type=standard when runing demos :)
9 |
10 | ## Alg
11 | https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation
12 |
13 | ## Speed Comparison
14 | runtime for raterization only
15 | In PIXIE, number of faces in SMPLX: 20908
16 |
17 | for image size = 1024
18 | pytorch3d: 0.031s
19 | standard: 0.01s
20 |
21 | for image size = 224
22 | pytorch3d: 0.0035s
23 | standard: 0.0014s
24 |
25 | why standard rasterizer is faster than pytorch3d?
26 | Ref: https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
27 | pytorch3d: for each pixel in image space (each pixel is parallel in cuda), loop through the faces, check if this pixel is in the projection bounding box of the face, then sorting faces according to z, record the face id of closest K faces.
28 | standard rasterization: for each face in mesh (each face is parallel in cuda), loop through pixels in the projection bounding box (normally a very samll number), compare z, record face id of that pixel
29 |
30 |
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/decalib/utils/rasterizer/__init__.py
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/setup.py:
--------------------------------------------------------------------------------
1 | # To install, run
2 | # python setup.py build_ext -i
3 | # Ref: https://github.com/pytorch/pytorch/blob/11a40410e755b1fe74efe9eaa635e7ba5712846b/test/cpp_extensions/setup.py#L62
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 | import os
8 |
9 | # USE_NINJA = os.getenv('USE_NINJA') == '1'
10 | os.environ["CC"] = "gcc-7"
11 | os.environ["CXX"] = "gcc-7"
12 |
13 | USE_NINJA = os.getenv('USE_NINJA') == '1'
14 |
15 | setup(
16 | name='standard_rasterize_cuda',
17 | ext_modules=[
18 | CUDAExtension('standard_rasterize_cuda', [
19 | 'standard_rasterize_cuda.cpp',
20 | 'standard_rasterize_cuda_kernel.cu',
21 | ])
22 | ],
23 | cmdclass={'build_ext': BuildExtension.with_options(use_ninja=USE_NINJA)}
24 | )
25 |
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/standard_rasterize_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | std::vector forward_rasterize_cuda(
6 | at::Tensor face_vertices,
7 | at::Tensor depth_buffer,
8 | at::Tensor triangle_buffer,
9 | at::Tensor baryw_buffer,
10 | int h,
11 | int w);
12 |
13 | std::vector standard_rasterize(
14 | at::Tensor face_vertices,
15 | at::Tensor depth_buffer,
16 | at::Tensor triangle_buffer,
17 | at::Tensor baryw_buffer,
18 | int height, int width
19 | ) {
20 | return forward_rasterize_cuda(face_vertices, depth_buffer, triangle_buffer, baryw_buffer, height, width);
21 | }
22 |
23 | std::vector forward_rasterize_colors_cuda(
24 | at::Tensor face_vertices,
25 | at::Tensor face_colors,
26 | at::Tensor depth_buffer,
27 | at::Tensor triangle_buffer,
28 | at::Tensor images,
29 | int h,
30 | int w);
31 |
32 | std::vector standard_rasterize_colors(
33 | at::Tensor face_vertices,
34 | at::Tensor face_colors,
35 | at::Tensor depth_buffer,
36 | at::Tensor triangle_buffer,
37 | at::Tensor images,
38 | int height, int width
39 | ) {
40 | return forward_rasterize_colors_cuda(face_vertices, face_colors, depth_buffer, triangle_buffer, images, height, width);
41 | }
42 |
43 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
44 | m.def("standard_rasterize", &standard_rasterize, "RASTERIZE (CUDA)");
45 | m.def("standard_rasterize_colors", &standard_rasterize_colors, "RASTERIZE COLORS (CUDA)");
46 | }
47 |
48 | // TODO: backward
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/standard_rasterize_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | // Ref: https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/cuda/rasterize_cuda_kernel.cu
2 | // https://github.com/YadiraF/face3d/blob/master/face3d/mesh/cython/mesh_core.cpp
3 |
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | namespace{
10 | __device__ __forceinline__ float atomicMin(float* address, float val)
11 | {
12 | int* address_as_i = (int*) address;
13 | int old = *address_as_i, assumed;
14 | do {
15 | assumed = old;
16 | old = atomicCAS(address_as_i, assumed,
17 | __float_as_int(fminf(val, __int_as_float(assumed))));
18 | } while (assumed != old);
19 | return __int_as_float(old);
20 | }
21 | __device__ __forceinline__ double atomicMin(double* address, double val)
22 | {
23 | unsigned long long int* address_as_i = (unsigned long long int*) address;
24 | unsigned long long int old = *address_as_i, assumed;
25 | do {
26 | assumed = old;
27 | old = atomicCAS(address_as_i, assumed,
28 | __double_as_longlong(fminf(val, __longlong_as_double(assumed))));
29 | } while (assumed != old);
30 | return __longlong_as_double(old);
31 | }
32 |
33 | template
34 | __device__ __forceinline__ bool check_face_frontside(const scalar_t *face) {
35 | return (face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]);
36 | }
37 |
38 |
39 | template struct point
40 | {
41 | public:
42 | scalar_t x;
43 | scalar_t y;
44 |
45 | __host__ __device__ scalar_t dot(point p)
46 | {
47 | return this->x * p.x + this->y * p.y;
48 | };
49 |
50 | __host__ __device__ point operator-(point& p)
51 | {
52 | point np;
53 | np.x = this->x - p.x;
54 | np.y = this->y - p.y;
55 | return np;
56 | };
57 |
58 | __host__ __device__ point operator+(point& p)
59 | {
60 | point np;
61 | np.x = this->x + p.x;
62 | np.y = this->y + p.y;
63 | return np;
64 | };
65 |
66 | __host__ __device__ point operator*(scalar_t s)
67 | {
68 | point np;
69 | np.x = s * this->x;
70 | np.y = s * this->y;
71 | return np;
72 | };
73 | };
74 |
75 | template
76 | __device__ __forceinline__ bool check_pixel_inside(const scalar_t *w) {
77 | return w[0] <= 1 && w[0] >= 0 && w[1] <= 1 && w[1] >= 0 && w[2] <= 1 && w[2] >= 0;
78 | }
79 |
80 | template
81 | __device__ __forceinline__ void barycentric_weight(scalar_t *w, point p, point p0, point p1, point p2) {
82 |
83 | // vectors
84 | point v0, v1, v2;
85 | scalar_t s = p.dot(p);
86 | v0 = p2 - p0;
87 | v1 = p1 - p0;
88 | v2 = p - p0;
89 |
90 | // dot products
91 | scalar_t dot00 = v0.dot(v0); //v0.x * v0.x + v0.y * v0.y //np.dot(v0.T, v0)
92 | scalar_t dot01 = v0.dot(v1); //v0.x * v1.x + v0.y * v1.y //np.dot(v0.T, v1)
93 | scalar_t dot02 = v0.dot(v2); //v0.x * v2.x + v0.y * v2.y //np.dot(v0.T, v2)
94 | scalar_t dot11 = v1.dot(v1); //v1.x * v1.x + v1.y * v1.y //np.dot(v1.T, v1)
95 | scalar_t dot12 = v1.dot(v2); //v1.x * v2.x + v1.y * v2.y//np.dot(v1.T, v2)
96 |
97 | // barycentric coordinates
98 | scalar_t inverDeno;
99 | if(dot00*dot11 - dot01*dot01 == 0)
100 | inverDeno = 0;
101 | else
102 | inverDeno = 1/(dot00*dot11 - dot01*dot01);
103 |
104 | scalar_t u = (dot11*dot02 - dot01*dot12)*inverDeno;
105 | scalar_t v = (dot00*dot12 - dot01*dot02)*inverDeno;
106 |
107 | // weight
108 | w[0] = 1 - u - v;
109 | w[1] = v;
110 | w[2] = u;
111 | }
112 |
113 | // Ref: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/overview-rasterization-algorithm
114 | template
115 | __global__ void forward_rasterize_cuda_kernel(
116 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3]
117 | scalar_t* depth_buffer,
118 | int* triangle_buffer,
119 | scalar_t* baryw_buffer,
120 | int batch_size, int h, int w,
121 | int ntri) {
122 |
123 | const int i = blockIdx.x * blockDim.x + threadIdx.x;
124 | if (i >= batch_size * ntri) {
125 | return;
126 | }
127 | int bn = i/ntri;
128 | const scalar_t* face = &face_vertices[i * 9];
129 | scalar_t bw[3];
130 | point p0, p1, p2, p;
131 |
132 | p0.x = face[0]; p0.y=face[1];
133 | p1.x = face[3]; p1.y=face[4];
134 | p2.x = face[6]; p2.y=face[7];
135 |
136 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
137 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
138 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
139 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
140 |
141 | for(int y = y_min; y <= y_max; y++) //h
142 | {
143 | for(int x = x_min; x <= x_max; x++) //w
144 | {
145 | p.x = x; p.y = y;
146 | barycentric_weight(bw, p, p0, p1, p2);
147 | // if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face))
148 | if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0))
149 | {
150 | // perspective correct: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/perspective-correct-interpolation-vertex-attributes
151 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]);
152 | // printf("%f %f %f \n", (float)zp, (float)face[2], (float)bw[2]);
153 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp);
154 | if(depth_buffer[bn*h*w + y*w + x] == zp)
155 | {
156 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri);
157 | for(int k=0; k<3; k++){
158 | baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k];
159 | }
160 | }
161 | }
162 | }
163 | }
164 |
165 | }
166 |
167 | template
168 | __global__ void forward_rasterize_colors_cuda_kernel(
169 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3]
170 | const scalar_t* __restrict__ face_colors, //[bz, nf, 3, 3]
171 | scalar_t* depth_buffer,
172 | int* triangle_buffer,
173 | scalar_t* images,
174 | int batch_size, int h, int w,
175 | int ntri) {
176 | const int i = blockIdx.x * blockDim.x + threadIdx.x;
177 | if (i >= batch_size * ntri) {
178 | return;
179 | }
180 | int bn = i/ntri;
181 | const scalar_t* face = &face_vertices[i * 9];
182 | const scalar_t* color = &face_colors[i * 9];
183 | scalar_t bw[3];
184 | point p0, p1, p2, p;
185 |
186 | p0.x = face[0]; p0.y=face[1];
187 | p1.x = face[3]; p1.y=face[4];
188 | p2.x = face[6]; p2.y=face[7];
189 | scalar_t cl[3][3];
190 | for (int num = 0; num < 3; num++) {
191 | for (int dim = 0; dim < 3; dim++) {
192 | cl[num][dim] = color[3 * num + dim]; //[3p,3rgb]
193 | }
194 | }
195 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
196 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
197 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
198 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
199 |
200 | for(int y = y_min; y <= y_max; y++) //h
201 | {
202 | for(int x = x_min; x <= x_max; x++) //w
203 | {
204 | p.x = x; p.y = y;
205 | barycentric_weight(bw, p, p0, p1, p2);
206 | if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face))
207 | // if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0))
208 | {
209 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]);
210 |
211 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp);
212 | if(depth_buffer[bn*h*w + y*w + x] == zp)
213 | {
214 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri);
215 | for(int k=0; k<3; k++){
216 | // baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k];
217 | images[bn*h*w*3 + y*w*3 + x*3 + k] = bw[0]*cl[0][k] + bw[1]*cl[1][k] + bw[2]*cl[2][k];
218 | }
219 | // buffers[bn*h*w*2 + y*w*2 + x*2 + 1] = p_depth;
220 | }
221 | }
222 | }
223 | }
224 |
225 | }
226 |
227 | }
228 |
229 | std::vector forward_rasterize_cuda(
230 | at::Tensor face_vertices,
231 | at::Tensor depth_buffer,
232 | at::Tensor triangle_buffer,
233 | at::Tensor baryw_buffer,
234 | int h,
235 | int w){
236 |
237 | const auto batch_size = face_vertices.size(0);
238 | const auto ntri = face_vertices.size(1);
239 |
240 | // print(channel_size)
241 | const int threads = 512;
242 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1);
243 |
244 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda1", ([&] {
245 | forward_rasterize_cuda_kernel<<>>(
246 | face_vertices.data(),
247 | depth_buffer.data(),
248 | triangle_buffer.data(),
249 | baryw_buffer.data(),
250 | batch_size, h, w,
251 | ntri);
252 | }));
253 |
254 | // better to do it twice (or there will be balck spots in the rendering)
255 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda2", ([&] {
256 | forward_rasterize_cuda_kernel<<>>(
257 | face_vertices.data(),
258 | depth_buffer.data(),
259 | triangle_buffer.data(),
260 | baryw_buffer.data(),
261 | batch_size, h, w,
262 | ntri);
263 | }));
264 | cudaError_t err = cudaGetLastError();
265 | if (err != cudaSuccess)
266 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err));
267 |
268 | return {depth_buffer, triangle_buffer, baryw_buffer};
269 | }
270 |
271 |
272 | std::vector forward_rasterize_colors_cuda(
273 | at::Tensor face_vertices,
274 | at::Tensor face_colors,
275 | at::Tensor depth_buffer,
276 | at::Tensor triangle_buffer,
277 | at::Tensor images,
278 | int h,
279 | int w){
280 |
281 | const auto batch_size = face_vertices.size(0);
282 | const auto ntri = face_vertices.size(1);
283 |
284 | // print(channel_size)
285 | const int threads = 512;
286 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1);
287 | //initial
288 |
289 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] {
290 | forward_rasterize_colors_cuda_kernel<<>>(
291 | face_vertices.data(),
292 | face_colors.data(),
293 | depth_buffer.data(),
294 | triangle_buffer.data(),
295 | images.data(),
296 | batch_size, h, w,
297 | ntri);
298 | }));
299 | // better to do it twice
300 | // AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] {
301 | // forward_rasterize_colors_cuda_kernel<<>>(
302 | // face_vertices.data(),
303 | // face_colors.data(),
304 | // depth_buffer.data(),
305 | // triangle_buffer.data(),
306 | // images.data(),
307 | // batch_size, h, w,
308 | // ntri);
309 | // }));
310 | cudaError_t err = cudaGetLastError();
311 | if (err != cudaSuccess)
312 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err));
313 |
314 | return {depth_buffer, triangle_buffer, images};
315 | }
316 |
317 |
318 |
319 |
320 |
--------------------------------------------------------------------------------
/decalib/utils/rotation_converter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ''' Rotation Converter
4 | Repre: euler angle(3), angle axis(3), rotation matrix(3x3), quaternion(4)
5 | ref: https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html#
6 | "pi",
7 | "rad2deg",
8 | "deg2rad",
9 | # "angle_axis_to_rotation_matrix", batch_rodrigues
10 | "rotation_matrix_to_angle_axis",
11 | "rotation_matrix_to_quaternion",
12 | "quaternion_to_angle_axis",
13 | # "angle_axis_to_quaternion",
14 |
15 | euler2quat_conversion_sanity_batch
16 |
17 | ref: smplx/lbs
18 | batch_rodrigues: axis angle -> matrix
19 | #
20 | '''
21 | pi = torch.Tensor([3.14159265358979323846])
22 |
23 | def rad2deg(tensor):
24 | """Function that converts angles from radians to degrees.
25 |
26 | See :class:`~torchgeometry.RadToDeg` for details.
27 |
28 | Args:
29 | tensor (Tensor): Tensor of arbitrary shape.
30 |
31 | Returns:
32 | Tensor: Tensor with same shape as input.
33 |
34 | Example:
35 | >>> input = tgm.pi * torch.rand(1, 3, 3)
36 | >>> output = tgm.rad2deg(input)
37 | """
38 | if not torch.is_tensor(tensor):
39 | raise TypeError("Input type is not a torch.Tensor. Got {}"
40 | .format(type(tensor)))
41 |
42 | return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)
43 |
44 | def deg2rad(tensor):
45 | """Function that converts angles from degrees to radians.
46 |
47 | See :class:`~torchgeometry.DegToRad` for details.
48 |
49 | Args:
50 | tensor (Tensor): Tensor of arbitrary shape.
51 |
52 | Returns:
53 | Tensor: Tensor with same shape as input.
54 |
55 | Examples::
56 |
57 | >>> input = 360. * torch.rand(1, 3, 3)
58 | >>> output = tgm.deg2rad(input)
59 | """
60 | if not torch.is_tensor(tensor):
61 | raise TypeError("Input type is not a torch.Tensor. Got {}"
62 | .format(type(tensor)))
63 |
64 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
65 |
66 | ######### to quaternion
67 | def euler_to_quaternion(r):
68 | x = r[..., 0]
69 | y = r[..., 1]
70 | z = r[..., 2]
71 |
72 | z = z/2.0
73 | y = y/2.0
74 | x = x/2.0
75 | cz = torch.cos(z)
76 | sz = torch.sin(z)
77 | cy = torch.cos(y)
78 | sy = torch.sin(y)
79 | cx = torch.cos(x)
80 | sx = torch.sin(x)
81 | quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device)
82 | quaternion[..., 0] += cx*cy*cz - sx*sy*sz
83 | quaternion[..., 1] += cx*sy*sz + cy*cz*sx
84 | quaternion[..., 2] += cx*cz*sy - sx*cy*sz
85 | quaternion[..., 3] += cx*cy*sz + sx*cz*sy
86 | return quaternion
87 |
88 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
89 | """Convert 3x4 rotation matrix to 4d quaternion vector
90 |
91 | This algorithm is based on algorithm described in
92 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
93 |
94 | Args:
95 | rotation_matrix (Tensor): the rotation matrix to convert.
96 |
97 | Return:
98 | Tensor: the rotation in quaternion
99 |
100 | Shape:
101 | - Input: :math:`(N, 3, 4)`
102 | - Output: :math:`(N, 4)`
103 |
104 | Example:
105 | >>> input = torch.rand(4, 3, 4) # Nx3x4
106 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
107 | """
108 | if not torch.is_tensor(rotation_matrix):
109 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
110 | type(rotation_matrix)))
111 |
112 | if len(rotation_matrix.shape) > 3:
113 | raise ValueError(
114 | "Input size must be a three dimensional tensor. Got {}".format(
115 | rotation_matrix.shape))
116 | # if not rotation_matrix.shape[-2:] == (3, 4):
117 | # raise ValueError(
118 | # "Input size must be a N x 3 x 4 tensor. Got {}".format(
119 | # rotation_matrix.shape))
120 |
121 | rmat_t = torch.transpose(rotation_matrix, 1, 2)
122 |
123 | mask_d2 = rmat_t[:, 2, 2] < eps
124 |
125 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
126 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
127 |
128 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
129 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
130 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
131 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
132 | t0_rep = t0.repeat(4, 1).t()
133 |
134 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
135 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
136 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
137 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
138 | t1_rep = t1.repeat(4, 1).t()
139 |
140 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
141 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
142 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
143 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
144 | t2_rep = t2.repeat(4, 1).t()
145 |
146 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
147 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
148 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
149 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
150 | t3_rep = t3.repeat(4, 1).t()
151 |
152 | mask_c0 = mask_d2 * mask_d0_d1.float()
153 | mask_c1 = mask_d2 * (1 - mask_d0_d1.float())
154 | mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1
155 | mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float())
156 | mask_c0 = mask_c0.view(-1, 1).type_as(q0)
157 | mask_c1 = mask_c1.view(-1, 1).type_as(q1)
158 | mask_c2 = mask_c2.view(-1, 1).type_as(q2)
159 | mask_c3 = mask_c3.view(-1, 1).type_as(q3)
160 |
161 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
162 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
163 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
164 | q *= 0.5
165 | return q
166 |
167 | # def angle_axis_to_quaternion(theta):
168 | # batch_size = theta.shape[0]
169 | # l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
170 | # angle = torch.unsqueeze(l1norm, -1)
171 | # normalized = torch.div(theta, angle)
172 | # angle = angle * 0.5
173 | # v_cos = torch.cos(angle)
174 | # v_sin = torch.sin(angle)
175 | # quat = torch.cat([v_cos, v_sin * normalized], dim=1)
176 | # return quat
177 |
178 | def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor:
179 | """Convert an angle axis to a quaternion.
180 |
181 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
182 |
183 | Args:
184 | angle_axis (torch.Tensor): tensor with angle axis.
185 |
186 | Return:
187 | torch.Tensor: tensor with quaternion.
188 |
189 | Shape:
190 | - Input: :math:`(*, 3)` where `*` means, any number of dimensions
191 | - Output: :math:`(*, 4)`
192 |
193 | Example:
194 | >>> angle_axis = torch.rand(2, 4) # Nx4
195 | >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3
196 | """
197 | if not torch.is_tensor(angle_axis):
198 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
199 | type(angle_axis)))
200 |
201 | if not angle_axis.shape[-1] == 3:
202 | raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}"
203 | .format(angle_axis.shape))
204 | # unpack input and compute conversion
205 | a0: torch.Tensor = angle_axis[..., 0:1]
206 | a1: torch.Tensor = angle_axis[..., 1:2]
207 | a2: torch.Tensor = angle_axis[..., 2:3]
208 | theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2
209 |
210 | theta: torch.Tensor = torch.sqrt(theta_squared)
211 | half_theta: torch.Tensor = theta * 0.5
212 |
213 | mask: torch.Tensor = theta_squared > 0.0
214 | ones: torch.Tensor = torch.ones_like(half_theta)
215 |
216 | k_neg: torch.Tensor = 0.5 * ones
217 | k_pos: torch.Tensor = torch.sin(half_theta) / theta
218 | k: torch.Tensor = torch.where(mask, k_pos, k_neg)
219 | w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones)
220 |
221 | quaternion: torch.Tensor = torch.zeros_like(angle_axis)
222 | quaternion[..., 0:1] += a0 * k
223 | quaternion[..., 1:2] += a1 * k
224 | quaternion[..., 2:3] += a2 * k
225 | return torch.cat([w, quaternion], dim=-1)
226 |
227 | #### quaternion to
228 | def quaternion_to_rotation_matrix(quat):
229 | """Convert quaternion coefficients to rotation matrix.
230 | Args:
231 | quat: size = [B, 4] 4 <===>(w, x, y, z)
232 | Returns:
233 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
234 | """
235 | norm_quat = quat
236 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
237 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
238 |
239 | B = quat.size(0)
240 |
241 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
242 | wx, wy, wz = w * x, w * y, w * z
243 | xy, xz, yz = x * y, x * z, y * z
244 |
245 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
246 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
247 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
248 | return rotMat
249 |
250 | def quaternion_to_angle_axis(quaternion: torch.Tensor):
251 | """Convert quaternion vector to angle axis of rotation. TODO: CORRECT
252 |
253 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
254 |
255 | Args:
256 | quaternion (torch.Tensor): tensor with quaternions.
257 |
258 | Return:
259 | torch.Tensor: tensor with angle axis of rotation.
260 |
261 | Shape:
262 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions
263 | - Output: :math:`(*, 3)`
264 |
265 | Example:
266 | >>> quaternion = torch.rand(2, 4) # Nx4
267 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
268 | """
269 | if not torch.is_tensor(quaternion):
270 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
271 | type(quaternion)))
272 |
273 | if not quaternion.shape[-1] == 4:
274 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
275 | .format(quaternion.shape))
276 | # unpack input and compute conversion
277 | q1: torch.Tensor = quaternion[..., 1]
278 | q2: torch.Tensor = quaternion[..., 2]
279 | q3: torch.Tensor = quaternion[..., 3]
280 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
281 |
282 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
283 | cos_theta: torch.Tensor = quaternion[..., 0]
284 | two_theta: torch.Tensor = 2.0 * torch.where(
285 | cos_theta < 0.0,
286 | torch.atan2(-sin_theta, -cos_theta),
287 | torch.atan2(sin_theta, cos_theta))
288 |
289 | k_pos: torch.Tensor = two_theta / sin_theta
290 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device)
291 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
292 |
293 | angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3]
294 | angle_axis[..., 0] += q1 * k
295 | angle_axis[..., 1] += q2 * k
296 | angle_axis[..., 2] += q3 * k
297 | return angle_axis
298 |
299 | #### batch converter
300 | def batch_euler2axis(r):
301 | return quaternion_to_angle_axis(euler_to_quaternion(r))
302 |
303 | def batch_euler2matrix(r):
304 | return quaternion_to_rotation_matrix(euler_to_quaternion(r))
305 |
306 | def batch_matrix2euler(rot_mats):
307 | # Calculates rotation matrix to euler angles
308 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
309 | ### only y?
310 | # TODO:
311 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
312 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
313 | return torch.atan2(-rot_mats[:, 2, 0], sy)
314 |
315 | def batch_matrix2axis(rot_mats):
316 | return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats))
317 |
318 | def batch_axis2matrix(theta):
319 | # angle axis to rotation matrix
320 | # theta N x 3
321 | # return quat2mat(quat)
322 | # batch_rodrigues
323 | return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta))
324 |
325 | def batch_axis2euler(theta):
326 | return batch_matrix2euler(batch_axis2matrix(theta))
327 |
328 | def batch_axis2euler(r):
329 | return rot_mat_to_euler(batch_rodrigues(r))
330 |
331 |
332 | def batch_orth_proj(X, camera):
333 | '''
334 | X is N x num_pquaternion_to_angle_axisoints x 3
335 | '''
336 | camera = camera.clone().view(-1, 1, 3)
337 | X_trans = X[:, :, :2] + camera[:, :, 1:]
338 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2)
339 | Xn = (camera[:, :, 0:1] * X_trans)
340 | return Xn
341 |
342 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
343 | ''' same as batch_matrix2axis
344 | Calculates the rotation matrices for a batch of rotation vectors
345 | Parameters
346 | ----------
347 | rot_vecs: torch.tensor Nx3
348 | array of N axis-angle vectors
349 | Returns
350 | -------
351 | R: torch.tensor Nx3x3
352 | The rotation matrices for the given axis-angle parameters
353 | '''
354 |
355 | batch_size = rot_vecs.shape[0]
356 | device = rot_vecs.device
357 |
358 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
359 | rot_dir = rot_vecs / angle
360 |
361 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
362 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
363 |
364 | # Bx1 arrays
365 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
366 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
367 |
368 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
369 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
370 | .view((batch_size, 3, 3))
371 |
372 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
373 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
374 | return rot_mat
375 |
--------------------------------------------------------------------------------
/decalib/utils/tensor_cropper.py:
--------------------------------------------------------------------------------
1 | '''
2 | crop
3 | for torch tensor
4 | Given image, bbox(center, bboxsize)
5 | return: cropped image, tform(used for transform the keypoint accordingly)
6 | only support crop to squared images
7 | '''
8 | import torch
9 | from kornia.geometry.transform.imgwarp import (
10 | warp_perspective, get_perspective_transform, warp_affine
11 | )
12 |
13 | def points2bbox(points, points_scale=None):
14 | if points_scale:
15 | assert points_scale[0]==points_scale[1]
16 | points = points.clone()
17 | points[:,:,:2] = (points[:,:,:2]*0.5 + 0.5)*points_scale[0]
18 | min_coords, _ = torch.min(points, dim=1)
19 | xmin, ymin = min_coords[:, 0], min_coords[:, 1]
20 | max_coords, _ = torch.max(points, dim=1)
21 | xmax, ymax = max_coords[:, 0], max_coords[:, 1]
22 | center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5
23 |
24 | width = (xmax - xmin)
25 | height = (ymax - ymin)
26 | # Convert the bounding box to a square box
27 | size = torch.max(width, height).unsqueeze(-1)
28 | return center, size
29 |
30 | def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.):
31 | batch_size = center.shape[0]
32 | trans_scale = (torch.rand([batch_size, 2], device=center.device)*2. -1.) * trans_scale
33 | center = center + trans_scale*bbox_size # 0.5
34 | scale = torch.rand([batch_size,1], device=center.device) * (scale[1] - scale[0]) + scale[0]
35 | size = bbox_size*scale
36 | return center, size
37 |
38 | def crop_tensor(image, center, bbox_size, crop_size, interpolation = 'bilinear', align_corners=False):
39 | ''' for batch image
40 | Args:
41 | image (torch.Tensor): the reference tensor of shape BXHxWXC.
42 | center: [bz, 2]
43 | bboxsize: [bz, 1]
44 | crop_size;
45 | interpolation (str): Interpolation flag. Default: 'bilinear'.
46 | align_corners (bool): mode for grid_generation. Default: False. See
47 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
48 | Returns:
49 | cropped_image
50 | tform
51 | '''
52 | dtype = image.dtype
53 | device = image.device
54 | batch_size = image.shape[0]
55 | # points: top-left, top-right, bottom-right, bottom-left
56 | src_pts = torch.zeros([4,2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1).contiguous()
57 |
58 | src_pts[:, 0, :] = center - bbox_size*0.5 # / (self.crop_size - 1)
59 | src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
60 | src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
61 | src_pts[:, 2, :] = center + bbox_size * 0.5
62 | src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
63 | src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5
64 |
65 | DST_PTS = torch.tensor([[
66 | [0, 0],
67 | [crop_size - 1, 0],
68 | [crop_size - 1, crop_size - 1],
69 | [0, crop_size - 1],
70 | ]], dtype=dtype, device=device).expand(batch_size, -1, -1)
71 | # estimate transformation between points
72 | dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
73 | # simulate broadcasting
74 | # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)
75 |
76 | # warp images
77 | cropped_image = warp_affine(
78 | image, dst_trans_src[:, :2, :], (crop_size, crop_size),
79 | flags=interpolation, align_corners=align_corners)
80 |
81 | tform = torch.transpose(dst_trans_src, 2, 1)
82 | # tform = torch.inverse(dst_trans_src)
83 | return cropped_image, tform
84 |
85 | class Cropper(object):
86 | def __init__(self, crop_size, scale=[1,1], trans_scale = 0.):
87 | self.crop_size = crop_size
88 | self.scale = scale
89 | self.trans_scale = trans_scale
90 |
91 | def crop(self, image, points, points_scale=None):
92 | # points to bbox
93 | center, bbox_size = points2bbox(points.clone(), points_scale)
94 | # argument bbox. TODO: add rotation?
95 | center, bbox_size = augment_bbox(center, bbox_size, scale=self.scale, trans_scale=self.trans_scale)
96 | # crop
97 | cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size)
98 | return cropped_image, tform
99 |
100 | def transform_points(self, points, tform, points_scale=None, normalize = True):
101 | points_2d = points[:,:,:2]
102 |
103 | #'input points must use original range'
104 | if points_scale:
105 | assert points_scale[0]==points_scale[1]
106 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0]
107 |
108 | batch_size, n_points, _ = points.shape
109 | trans_points_2d = torch.bmm(
110 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1),
111 | tform
112 | )
113 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1)
114 | if normalize:
115 | trans_points[:,:,:2] = trans_points[:,:,:2]/self.crop_size*2 - 1
116 | return trans_points
117 |
118 | def transform_points(points, tform, points_scale=None, out_scale=None):
119 | points_2d = points[:,:,:2]
120 |
121 | #'input points must use original range'
122 | if points_scale:
123 | assert points_scale[0]==points_scale[1]
124 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0]
125 | # import ipdb; ipdb.set_trace()
126 |
127 | batch_size, n_points, _ = points.shape
128 | trans_points_2d = torch.bmm(
129 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1),
130 | tform
131 | )
132 | if out_scale: # h,w of output image size
133 | trans_points_2d[:,:,0] = trans_points_2d[:,:,0]/out_scale[1]*2 - 1
134 | trans_points_2d[:,:,1] = trans_points_2d[:,:,1]/out_scale[0]*2 - 1
135 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1)
136 | return trans_points
--------------------------------------------------------------------------------
/inference_rigface.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.checkpoint
8 | import torchvision.transforms as transforms
9 | from PIL import Image
10 | from diffusers import AutoencoderKL
11 | from diffusers import (
12 | UniPCMultistepScheduler,
13 | )
14 | from transformers import CLIPTextModel, CLIPTokenizer
15 |
16 | from rigface.models.pipelineRigFace import RigFacePipeline as RigFacePipelineInference
17 | from rigface.models.unet_ID_2d_condition import UNetID2DConditionModel
18 | from rigface.models.unet_denoising_2d_condition import UNetDenoise2DConditionModel
19 |
20 |
21 |
22 | def parse_args(input_args=None):
23 | parser = argparse.ArgumentParser(description="Inference script.")
24 |
25 | parser.add_argument(
26 | "--pretrained_model_name_or_path",
27 | type=str,
28 | default='stable-diffusion-v1-5/stable-diffusion-v1-5',
29 | required=False,
30 | help="Path to pretrained model or model identifier from huggingface.co/models.",
31 | )
32 |
33 | parser.add_argument(
34 | "--revision",
35 | type=str,
36 | default=None,
37 | required=False,
38 | help="Revision of pretrained model identifier from huggingface.co/models.",
39 | )
40 | parser.add_argument(
41 | "--variant",
42 | type=str,
43 | default=None,
44 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
45 | )
46 |
47 | parser.add_argument("--seed", type=int, default=424, help="A seed for reproducible training.")
48 |
49 |
50 | parser.add_argument(
51 | "--inference_steps",
52 | type=int,
53 | default=50,
54 | )
55 |
56 |
57 | parser.add_argument(
58 | "--vit_path",
59 | type=str,
60 | default="openai/clip-vit-large-patch14",
61 | )
62 |
63 | parser.add_argument(
64 | "--vton_unet_path",
65 | type=str,
66 | default='./pre_trained/unet_denoise/checkpoint-70000',
67 | )
68 |
69 | parser.add_argument(
70 | "--garm_unet_path",
71 | type=str,
72 | default='./pre_trained/unet_id/checkpoint-70000',
73 | )
74 | parser.add_argument(
75 | "--id_path",
76 | type=str,
77 | default='',
78 | )
79 | parser.add_argument(
80 | "--bg_path",
81 | type=str,
82 | default='',
83 | )
84 | parser.add_argument(
85 | "--exp_path",
86 | type=str,
87 | default='',
88 | )
89 | parser.add_argument(
90 | "--render_path",
91 | type=str,
92 | default='',
93 | )
94 | parser.add_argument(
95 | "--save_path",
96 | type=str,
97 | default='',
98 | )
99 |
100 | if input_args is not None:
101 | args = parser.parse_args(input_args)
102 | else:
103 | args = parser.parse_args()
104 |
105 | return args
106 |
107 |
108 |
109 | def make_data(args):
110 |
111 | transform = transforms.ToTensor()
112 |
113 | img_name = args.id_path
114 | bg_name = args.bg_path
115 | render_name = args.render_path
116 |
117 | source = Image.open(img_name)
118 | source = transform(source)
119 |
120 | bg = Image.open(bg_name)
121 | bg = transform(bg)
122 |
123 | render = Image.open(render_name)
124 | # render = render.resize((512, 512))
125 | render = transform(render)
126 |
127 | return source, bg, render
128 |
129 |
130 | def tokenize_captions(tokenizer, captions, max_length):
131 |
132 | inputs = tokenizer(
133 | captions,
134 | max_length=tokenizer.model_max_length,
135 | padding="max_length",
136 | truncation=True,
137 | return_tensors="pt"
138 | )
139 | return inputs.input_ids
140 |
141 |
142 | def main(args):
143 |
144 | device = 'cuda'
145 | vton_unet_path = args.vton_unet_path
146 | garm_unet_path = args.garm_unet_path
147 |
148 | vae = AutoencoderKL.from_pretrained(
149 | args.pretrained_model_name_or_path,
150 | subfolder="vae"
151 | ).to(device)
152 | text_encoder = CLIPTextModel.from_pretrained(
153 | args.pretrained_model_name_or_path,
154 | subfolder="text_encoder",
155 | ).to(device)
156 |
157 | tokenizer = CLIPTokenizer.from_pretrained(
158 | args.pretrained_model_name_or_path,
159 | subfolder="tokenizer",
160 | )
161 |
162 | unet_id = UNetID2DConditionModel.from_pretrained(
163 | garm_unet_path,
164 | # torch_dtype=torch.float16,
165 | use_safetensors=True,
166 | low_cpu_mem_usage=False,
167 | ignore_mismatched_sizes=True
168 | )
169 |
170 | unet_denoising = UNetDenoise2DConditionModel.from_pretrained(
171 | vton_unet_path,
172 | # torch_dtype=torch.float16,
173 | use_safetensors=True,
174 | low_cpu_mem_usage=False,
175 | ignore_mismatched_sizes=True
176 | )
177 |
178 | unet_denoising.requires_grad_(False)
179 | unet_id.requires_grad_(False)
180 | vae.requires_grad_(False)
181 | text_encoder.requires_grad_(False)
182 |
183 | weight_dtype = torch.float32
184 |
185 |
186 | pipeline = RigFacePipelineInference.from_pretrained(
187 | args.pretrained_model_name_or_path,
188 | vae=vae,
189 | text_encoder=text_encoder,
190 | tokenizer=tokenizer,
191 | unet_id=unet_id,
192 | unet_denoising=unet_denoising,
193 | safety_checker=None,
194 | revision=args.revision,
195 | variant=args.variant,
196 | torch_dtype=weight_dtype,
197 | ).to(device)
198 |
199 |
200 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
201 | pipeline.set_progress_bar_config(disable=True)
202 |
203 | if args.seed is None:
204 | generator = None
205 | else:
206 | generator = torch.Generator(device=device).manual_seed(args.seed)
207 |
208 | source, bg, rend = make_data(args)
209 | prompt = 'A close up of a person.'
210 | source = source.unsqueeze(0)
211 | bg = bg.unsqueeze(0)
212 | rend = rend.unsqueeze(0)
213 |
214 | prompt_embeds = text_encoder(tokenize_captions(tokenizer, [prompt], 2).to(device))[0]
215 |
216 |
217 | exp = np.load(args.exp_path)
218 |
219 | os.makedirs(args.save_path, exist_ok=True)
220 | tor_exp = torch.from_numpy(exp).unsqueeze(0)
221 |
222 | samples = pipeline(
223 | prompt_embeds=prompt_embeds,
224 | source=source,
225 | bg=bg,
226 | render=rend,
227 | exp=tor_exp,
228 | num_inference_steps=args.inference_steps,
229 | generator=generator,
230 | ).images[0]
231 | samples.save(os.path.join(args.save_path, f'out.png'))
232 |
233 | if __name__ == "__main__":
234 | args = parse_args()
235 |
236 | main(args)
237 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.3.0
2 | bitsandbytes==0.45.1
3 | diffusers==0.22.0
4 | einops==0.8.0
5 | face_alignment==1.4.1
6 | facenet_pytorch==2.6.0
7 | huggingface_hub==0.28.1
8 | insightface==0.7.3
9 | ipdb==0.13.13
10 | kornia==0.8.0
11 | loguru==0.7.3
12 | numpy==1.23.0
13 | omegaconf==2.3.0
14 | opencv_python==4.11.0.86
15 | opencv_python_headless==4.11.0.86
16 | packaging==24.2
17 | pandas==2.2.3
18 | Pillow==11.1.0
19 | preprocess==2.0.0
20 | PyYAML==6.0.2
21 | PyYAML==6.0.2
22 | safetensors==0.5.2
23 | scipy==1.15.1
24 | setuptools==75.8.0
25 | skimage==0.0
26 | torchfile==0.1.0
27 | tqdm==4.67.1
28 | transformers==4.42.4
29 | wandb==0.19.6
30 | xformers==0.0.29.post2
31 | yacs==0.1.8
32 |
--------------------------------------------------------------------------------
/utils/compute_renders.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 | import argparse
6 | import torch as th
7 | from torchvision.utils import save_image
8 |
9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".")))
11 |
12 |
13 | from decalib.deca import DECA
14 | from decalib.utils.config import cfg as deca_cfg
15 | from data_utils import get_image_dict
16 |
17 | # Build DECA
18 | deca_cfg.model.use_tex = True
19 | deca_cfg.model.tex_path = "./data/FLAME_texture.npz"
20 | deca_cfg.model.tex_type = "FLAME"
21 | deca = DECA(config=deca_cfg, device="cuda")
22 |
23 |
24 |
25 | def get_render(source, target, modes):
26 | src_dict = get_image_dict(source, 512, True)
27 | tar_dict = get_image_dict(target, 512, True)
28 | # ===================get DECA codes of the target image===============================
29 | tar_cropped = tar_dict["image"].unsqueeze(0).to("cuda")
30 | imgname = tar_dict["imagename"]
31 | with th.no_grad():
32 | tar_code = deca.encode(tar_cropped)
33 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda")
34 | # ===================get DECA codes of the source image===============================
35 | src_cropped = src_dict["image"].unsqueeze(0).to("cuda")
36 | with th.no_grad():
37 | src_code = deca.encode(src_cropped)
38 | # To align the face when the pose is changing
39 | src_ffhq_center = deca.decode(src_code, return_ffhq_center=True)
40 | tar_ffhq_center = deca.decode(tar_code, return_ffhq_center=True)
41 |
42 | src_tform = src_dict["tform"].unsqueeze(0)
43 | src_tform = th.inverse(src_tform).transpose(1, 2).to("cuda")
44 | src_code["tform"] = src_tform
45 |
46 | tar_tform = tar_dict["tform"].unsqueeze(0)
47 | tar_tform = th.inverse(tar_tform).transpose(1, 2).to("cuda")
48 | tar_code["tform"] = tar_tform
49 |
50 | src_image = src_dict["original_image"].unsqueeze(0).to("cuda") # 平均的参数
51 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda")
52 |
53 | # code 1 means source code, code 2 means target code
54 | code1, code2 = {}, {}
55 | for k in src_code:
56 | code1[k] = src_code[k].clone()
57 |
58 | for k in tar_code:
59 | code2[k] = tar_code[k].clone()
60 |
61 | mode_list = modes.split("+")
62 | # 应该是确定有pose参与,就转换目标为target
63 | if 'pose' in mode_list:
64 | if 'exp' not in mode_list:
65 | code2['exp'] = src_code['exp']
66 | code2['pose'][:, 3:] = src_code['pose'][:, 3:]
67 | if 'light' not in mode_list:
68 | code2['light'] = src_code['light']
69 | opdict, _ = deca.decode(
70 | code2,
71 | render_orig=True,
72 | original_image=tar_image,
73 | tform=tar_code["tform"],
74 | align_ffhq=True,
75 | ffhq_center=tar_ffhq_center,
76 | )
77 | else:
78 | if 'exp' in mode_list:
79 | code1['exp'] = tar_code['exp']
80 | code1['pose'][:, 3:] = tar_code['pose'][:, 3:]
81 | if 'light' not in mode_list:
82 | code1['light'] = tar_code['light']
83 | opdict, _ = deca.decode(
84 | code1,
85 | render_orig=True,
86 | original_image=src_image,
87 | tform=src_code["tform"],
88 | align_ffhq=True,
89 | ffhq_center=src_ffhq_center,
90 | )
91 |
92 | rendered = opdict["rendered_images"].detach()
93 | os.makedirs('results', exist_ok=True)
94 | save_image(rendered[0], f"./results/render_{modes}.png")
95 |
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument(
101 | "--sor_path",
102 | type=str,
103 | default='',
104 | required=False
105 | )
106 | parser.add_argument(
107 | "--tar_path",
108 | type=str,
109 | default='',
110 | required=False
111 | )
112 | parser.add_argument(
113 | "--modes",
114 | type=str,
115 | default='',
116 | required=False
117 | )
118 |
119 | args = parser.parse_args()
120 | get_render(args.sor_path, args.tar_path, args.modes)
121 | print('done')
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 |
6 | import numpy as np
7 | import scipy
8 | import scipy.io
9 | import torch
10 | from skimage.io import imread
11 | from skimage.transform import estimate_transform, warp, resize
12 |
13 | from decalib.datasets import detectors
14 |
15 | face_detector = detectors.FAN()
16 | scale = 1.3
17 | resolution_inp = 224
18 |
19 |
20 |
21 | def bbox2point(left, right, top, bottom, type='bbox'):
22 |
23 | if type =='kpt68':
24 | old_size = (right - left + bottom - top) / 2 * 1.1
25 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
26 | elif type =='bbox':
27 | old_size = (right - left + bottom - top ) /2
28 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size *0.12])
29 | else:
30 | raise NotImplementedError
31 | return old_size, center
32 |
33 |
34 |
35 |
36 |
37 | def get_image_dict(img_path, size, iscrop):
38 | img_name = img_path.split('/')[-1]
39 | im = imread(img_path)
40 | if size is not None: # size = 256
41 | im = (resize(im, (size, size), anti_aliasing=True) * 255.).astype(np.uint8)
42 | # (256, 256, 3)
43 | image = np.array(im)
44 | if len(image.shape) == 2:
45 | image = image[:, :, None].repeat(1, 1, 3)
46 | if len(image.shape) == 3 and image.shape[2] > 3:
47 | image = image[:, :, :3]
48 |
49 | h, w, _ = image.shape
50 | if iscrop: # true
51 | # provide kpt as txt file, or mat file (for AFLW2000)
52 | kpt_matpath = os.path.splitext(img_path)[0] + '.mat'
53 | kpt_txtpath = os.path.splitext(img_path)[0] + '.txt'
54 | if os.path.exists(kpt_matpath):
55 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T
56 | left = np.min(kpt[:, 0])
57 | right = np.max(kpt[:, 0])
58 | top = np.min(kpt[:, 1])
59 | bottom = np.max(kpt[:, 1])
60 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68')
61 | elif os.path.exists(kpt_txtpath):
62 | kpt = np.loadtxt(kpt_txtpath)
63 | left = np.min(kpt[:, 0])
64 | right = np.max(kpt[:, 0])
65 | top = np.min(kpt[:, 1])
66 | bottom = np.max(kpt[:, 1])
67 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68')
68 | else:
69 | bbox, bbox_type = face_detector.run(image)
70 | if len(bbox) < 4:
71 | print('no face detected! run original image')
72 | left = 0
73 | right = h - 1
74 | top = 0
75 | bottom = w - 1
76 | else:
77 | left = bbox[0]
78 | right = bbox[2]
79 | top = bbox[1]
80 | bottom = bbox[3]
81 | old_size, center = bbox2point(left, right, top, bottom, type=bbox_type)
82 | size = int(old_size * scale)
83 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
84 | [center[0] + size / 2, center[1] - size / 2]])
85 | else:
86 | src_pts = np.array([[0, 0], [0, h - 1], [w - 1, 0]])
87 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]])
88 | DST_PTS = np.array([[0, 0], [0, resolution_inp - 1], [resolution_inp - 1, 0]])
89 | tform = estimate_transform('similarity', src_pts, DST_PTS)
90 |
91 | image = image / 255.
92 |
93 | dst_image = warp(image, tform.inverse, output_shape=(resolution_inp, resolution_inp))
94 | dst_image = dst_image.transpose(2, 0, 1)
95 | return {'image': torch.tensor(dst_image).float(),
96 | 'imagename': img_name,
97 | 'tform': torch.tensor(tform.params).float(),
98 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(),
99 | }
100 |
--------------------------------------------------------------------------------
/utils/datasets_faceswap.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import cv2
4 | from PIL import ImageFile
5 |
6 |
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 | mean_face_lm5p_256 = np.array([
10 | [(30.2946+8)*2+16, 51.6963*2], # left eye pupil
11 | [(65.5318+8)*2+16, 51.5014*2], # right eye pupil
12 | [(48.0252+8)*2+16, 71.7366*2], # nose tip
13 | [(33.5493+8)*2+16, 92.3655*2], # left mouth corner
14 | [(62.7299+8)*2+16, 92.2041*2], # right mouth corner
15 | ], dtype=np.float32)
16 |
17 |
18 |
19 | mean_box_lm4p_512 = np.array([
20 | [80, 80],
21 | [80, 432],
22 | [432, 432],
23 | [432, 80],
24 | ], dtype=np.float32)
25 |
26 |
27 |
28 | def get_box_lm4p(pts):
29 | x1 = np.min(pts[:,0])
30 | x2 = np.max(pts[:,0])
31 | y1 = np.min(pts[:,1])
32 | y2 = np.max(pts[:,1])
33 |
34 | x_center = (x1+x2)*0.5
35 | y_center = (y1+y2)*0.5
36 | box_size = max(x2-x1, y2-y1)
37 |
38 | x1 = x_center-0.5*box_size
39 | x2 = x_center+0.5*box_size
40 | y1 = y_center-0.5*box_size
41 | y2 = y_center+0.5*box_size
42 |
43 | return np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32)
44 |
45 |
46 | def get_affine_transform(target_face_lm5p, mean_lm5p):
47 | mat_warp = np.zeros((2,3))
48 | A = np.zeros((4,4))
49 | B = np.zeros((4))
50 | for i in range(5):
51 | #sa[0][0] += a[i].x*a[i].x + a[i].y*a[i].y;
52 | A[0][0] += target_face_lm5p[i][0] * target_face_lm5p[i][0] + target_face_lm5p[i][1] * target_face_lm5p[i][1]
53 | #sa[0][2] += a[i].x;
54 | A[0][2] += target_face_lm5p[i][0]
55 | #sa[0][3] += a[i].y;
56 | A[0][3] += target_face_lm5p[i][1]
57 |
58 | #sb[0] += a[i].x*b[i].x + a[i].y*b[i].y;
59 | B[0] += target_face_lm5p[i][0] * mean_lm5p[i][0] + target_face_lm5p[i][1] * mean_lm5p[i][1]
60 | #sb[1] += a[i].x*b[i].y - a[i].y*b[i].x;
61 | B[1] += target_face_lm5p[i][0] * mean_lm5p[i][1] - target_face_lm5p[i][1] * mean_lm5p[i][0]
62 | #sb[2] += b[i].x;
63 | B[2] += mean_lm5p[i][0]
64 | #sb[3] += b[i].y;
65 | B[3] += mean_lm5p[i][1]
66 |
67 | #sa[1][1] = sa[0][0];
68 | A[1][1] = A[0][0]
69 | #sa[2][1] = sa[1][2] = -sa[0][3];
70 | A[2][1] = A[1][2] = -A[0][3]
71 | #sa[3][1] = sa[1][3] = sa[2][0] = sa[0][2];
72 | A[3][1] = A[1][3] = A[2][0] = A[0][2]
73 | #sa[2][2] = sa[3][3] = count;
74 | A[2][2] = A[3][3] = 5
75 | #sa[3][0] = sa[0][3];
76 | A[3][0] = A[0][3]
77 |
78 | _, mat23 = cv2.solve(A, B, flags=cv2.DECOMP_SVD)
79 | mat_warp[0][0] = mat23[0]
80 | mat_warp[1][1] = mat23[0]
81 | mat_warp[0][1] = -mat23[1]
82 | mat_warp[1][0] = mat23[1]
83 | mat_warp[0][2] = mat23[2]
84 | mat_warp[1][2] = mat23[3]
85 |
86 | return mat_warp
87 |
88 |
89 |
90 |
91 | def transformation_from_points(points1, points2):
92 | points1 = np.float64(np.matrix([[point[0], point[1]] for point in points1]))
93 | points2 = np.float64(np.matrix([[point[0], point[1]] for point in points2]))
94 |
95 | points1 = points1.astype(np.float64)
96 | points2 = points2.astype(np.float64)
97 | c1 = np.mean(points1, axis=0)
98 | c2 = np.mean(points2, axis=0)
99 | points1 -= c1
100 | points2 -= c2
101 | s1 = np.std(points1)
102 | s2 = np.std(points2)
103 | points1 /= s1
104 | points2 /= s2
105 | #points2 = np.array(points2)
106 | #write_pts('pt2.txt', points2)
107 | U, S, Vt = np.linalg.svd(points1.T * points2)
108 | R = (U * Vt).T
109 | return np.array(np.vstack([np.hstack(((s2 / s1) * R,c2.T - (s2 / s1) * R * c1.T)),np.matrix([0., 0., 1.])])[:2])
110 |
111 |
--------------------------------------------------------------------------------
/utils/make_bgs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 | from insightface.app import FaceAnalysis
10 | from torchvision.utils import save_image
11 |
12 | from model import BiSeNet
13 |
14 | device = 'cuda'
15 | checkpoint = './checkpoints'
16 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'),
17 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
18 | app.prepare(ctx_id=0, det_size=(640, 640))
19 |
20 |
21 | n_classes = 19
22 | net = BiSeNet(n_classes=n_classes)
23 | net.cuda()
24 | model_pth = './third_party_files/79999_iter.pth'
25 | net.load_state_dict(torch.load(model_pth))
26 | net.eval()
27 |
28 |
29 |
30 |
31 | def keep_background(im, parsing_anno, stride):
32 | # Colors for all 20 parts
33 | part_colors = [[0, 0, 0], [0, 0, 0], [0, 0, 0],
34 | [0, 0, 0], [0, 0, 0],
35 | [0, 0, 0], [0, 0, 0], [0, 0, 0],
36 | [0, 0, 0], [0, 0, 0],
37 | [0, 0, 0], [0, 0, 0], [0, 0, 0],
38 | [0, 0, 0], [0, 0, 0],
39 | [0, 0, 0], [0, 0, 0], [0, 0, 0],
40 | [0, 0, 0], [0, 0, 0], [0, 0, 0],
41 | [0, 0, 0], [0, 0, 0], [0, 0, 0]]
42 |
43 | im = np.array(im)
44 | vis_im = im.copy().astype(np.uint8)
45 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) # [1, 19, 512, 512]
46 |
47 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
48 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
49 |
50 | num_of_class = np.max(vis_parsing_anno)
51 |
52 |
53 | for pi in range(1, num_of_class + 1):
54 | # if pi == 8 or pi == 9 or pi == 14 or pi == 17 or pi == 18:
55 | # continue
56 | index = np.where(vis_parsing_anno == pi)
57 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
58 |
59 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
60 | tmp = vis_parsing_anno_color / 255
61 | mask_channel = 1 - tmp[:, :, 0]
62 | deep_gray = np.full_like(vis_im, (0, 0, 0), dtype=np.uint8)
63 | result_image = np.where(mask_channel[:, :, np.newaxis] == 1, 255, deep_gray) # [0-255]
64 | return result_image
65 |
66 |
67 | def deal_with_one_image(sorpth, tgtpth, modes):
68 |
69 | trans = transforms.ToTensor()
70 |
71 | to_tensor = transforms.Compose([
72 | transforms.Resize(512),
73 | transforms.ToTensor(),
74 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
75 | ])
76 |
77 | with torch.no_grad():
78 | tgtimg, sorimg = Image.open(tgtpth), Image.open(sorpth)
79 | tgtimage, sorimage = tgtimg.resize((512, 512), Image.BILINEAR), sorimg.resize((512, 512),
80 | Image.BILINEAR)
81 | # image = img
82 | tgtimg, sorimg = to_tensor(tgtimage), to_tensor(sorimage)
83 | tgtimg, sorimg = torch.unsqueeze(tgtimg, 0), torch.unsqueeze(sorimg, 0)
84 | tgtimg, sorimg = tgtimg.cuda(), sorimg.cuda()
85 | tgtout, sorout = net(tgtimg)[0], net(sorimg)[0] # [1, 19, 512, 512]
86 | tgtparsing, sorparsing = tgtout.squeeze(0).cpu().numpy().argmax(0), sorout.squeeze(
87 | 0).cpu().numpy().argmax(0)
88 |
89 | tgtbg, sorbg = keep_background(tgtimage, tgtparsing, stride=1), keep_background(sorimage, sorparsing,
90 | stride=1)
91 | tgtbg, sorbg = cv2.cvtColor(tgtbg, cv2.COLOR_RGB2BGR), cv2.cvtColor(sorbg, cv2.COLOR_RGB2BGR)
92 | mode_list = modes.split('+')
93 |
94 | if 'pose' in mode_list:
95 | logical_or = np.bitwise_or(tgtbg, sorbg)
96 | elif 'light' in mode_list or 'exp' in mode_list:
97 | logical_or = sorbg
98 | else:
99 | raise ValueError(f'Unknown mode: {modes}')
100 |
101 | tmp = logical_or / 255
102 | mask_channel = 1 - tmp[:, :, 0]
103 | tmp_sor = cv2.imread(sorpth)
104 | deep_gray = np.full_like(tmp_sor, (127.5, 127.5, 127.5), dtype=np.uint8)
105 |
106 | im_cv = cv2.cvtColor(tmp_sor, cv2.COLOR_RGB2BGR)
107 | result_image = np.where(mask_channel[:, :, np.newaxis] == 1, im_cv, deep_gray)
108 | result_image = trans(result_image)
109 | # # res = bg + im_pts70
110 | os.makedirs('results', exist_ok=True)
111 | save_image(result_image, f"./results/bg_{modes}.png")
112 |
113 |
114 |
115 |
116 | if __name__ == '__main__':
117 | parser = argparse.ArgumentParser()
118 | parser.add_argument(
119 | "--sor_path",
120 | type=str,
121 | default='',
122 | required=False
123 | )
124 | parser.add_argument(
125 | "--tar_path",
126 | type=str,
127 | default='',
128 | required=False
129 | )
130 | parser.add_argument(
131 | "--modes",
132 | type=str,
133 | default='',
134 | required=False
135 | )
136 |
137 | args = parser.parse_args()
138 | deal_with_one_image(args.sor_path, args.tar_path, args.modes)
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 |
10 | from resnet import Resnet18
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 |
14 | class ConvBNReLU(nn.Module):
15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16 | super(ConvBNReLU, self).__init__()
17 | self.conv = nn.Conv2d(in_chan,
18 | out_chan,
19 | kernel_size = ks,
20 | stride = stride,
21 | padding = padding,
22 | bias = False)
23 | self.bn = nn.BatchNorm2d(out_chan)
24 | self.init_weight()
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = F.relu(self.bn(x))
29 | return x
30 |
31 | def init_weight(self):
32 | for ly in self.children():
33 | if isinstance(ly, nn.Conv2d):
34 | nn.init.kaiming_normal_(ly.weight, a=1)
35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36 |
37 | class BiSeNetOutput(nn.Module):
38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39 | super(BiSeNetOutput, self).__init__()
40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42 | self.init_weight()
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | x = self.conv_out(x)
47 | return x
48 |
49 | def init_weight(self):
50 | for ly in self.children():
51 | if isinstance(ly, nn.Conv2d):
52 | nn.init.kaiming_normal_(ly.weight, a=1)
53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54 |
55 | def get_params(self):
56 | wd_params, nowd_params = [], []
57 | for name, module in self.named_modules():
58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59 | wd_params.append(module.weight)
60 | if not module.bias is None:
61 | nowd_params.append(module.bias)
62 | elif isinstance(module, nn.BatchNorm2d):
63 | nowd_params += list(module.parameters())
64 | return wd_params, nowd_params
65 |
66 |
67 | class AttentionRefinementModule(nn.Module):
68 | def __init__(self, in_chan, out_chan, *args, **kwargs):
69 | super(AttentionRefinementModule, self).__init__()
70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72 | self.bn_atten = nn.BatchNorm2d(out_chan)
73 | self.sigmoid_atten = nn.Sigmoid()
74 | self.init_weight()
75 |
76 | def forward(self, x):
77 | feat = self.conv(x)
78 | atten = F.avg_pool2d(feat, feat.size()[2:])
79 | atten = self.conv_atten(atten)
80 | atten = self.bn_atten(atten)
81 | atten = self.sigmoid_atten(atten)
82 | out = torch.mul(feat, atten)
83 | return out
84 |
85 | def init_weight(self):
86 | for ly in self.children():
87 | if isinstance(ly, nn.Conv2d):
88 | nn.init.kaiming_normal_(ly.weight, a=1)
89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90 |
91 |
92 | class ContextPath(nn.Module):
93 | def __init__(self, *args, **kwargs):
94 | super(ContextPath, self).__init__()
95 | self.resnet = Resnet18()
96 | self.arm16 = AttentionRefinementModule(256, 128)
97 | self.arm32 = AttentionRefinementModule(512, 128)
98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101 |
102 | self.init_weight()
103 |
104 | def forward(self, x):
105 | H0, W0 = x.size()[2:]
106 | feat8, feat16, feat32 = self.resnet(x)
107 | H8, W8 = feat8.size()[2:]
108 | H16, W16 = feat16.size()[2:]
109 | H32, W32 = feat32.size()[2:]
110 |
111 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
112 | avg = self.conv_avg(avg)
113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114 |
115 | feat32_arm = self.arm32(feat32)
116 | feat32_sum = feat32_arm + avg_up
117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118 | feat32_up = self.conv_head32(feat32_up)
119 |
120 | feat16_arm = self.arm16(feat16)
121 | feat16_sum = feat16_arm + feat32_up
122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123 | feat16_up = self.conv_head16(feat16_up)
124 |
125 | return feat8, feat16_up, feat32_up # x8, x8, x16
126 |
127 | def init_weight(self):
128 | for ly in self.children():
129 | if isinstance(ly, nn.Conv2d):
130 | nn.init.kaiming_normal_(ly.weight, a=1)
131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132 |
133 | def get_params(self):
134 | wd_params, nowd_params = [], []
135 | for name, module in self.named_modules():
136 | if isinstance(module, (nn.Linear, nn.Conv2d)):
137 | wd_params.append(module.weight)
138 | if not module.bias is None:
139 | nowd_params.append(module.bias)
140 | elif isinstance(module, nn.BatchNorm2d):
141 | nowd_params += list(module.parameters())
142 | return wd_params, nowd_params
143 |
144 |
145 | ### This is not used, since I replace this with the resnet feature with the same size
146 | class SpatialPath(nn.Module):
147 | def __init__(self, *args, **kwargs):
148 | super(SpatialPath, self).__init__()
149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153 | self.init_weight()
154 |
155 | def forward(self, x):
156 | feat = self.conv1(x)
157 | feat = self.conv2(feat)
158 | feat = self.conv3(feat)
159 | feat = self.conv_out(feat)
160 | return feat
161 |
162 | def init_weight(self):
163 | for ly in self.children():
164 | if isinstance(ly, nn.Conv2d):
165 | nn.init.kaiming_normal_(ly.weight, a=1)
166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167 |
168 | def get_params(self):
169 | wd_params, nowd_params = [], []
170 | for name, module in self.named_modules():
171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172 | wd_params.append(module.weight)
173 | if not module.bias is None:
174 | nowd_params.append(module.bias)
175 | elif isinstance(module, nn.BatchNorm2d):
176 | nowd_params += list(module.parameters())
177 | return wd_params, nowd_params
178 |
179 |
180 | class FeatureFusionModule(nn.Module):
181 | def __init__(self, in_chan, out_chan, *args, **kwargs):
182 | super(FeatureFusionModule, self).__init__()
183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184 | self.conv1 = nn.Conv2d(out_chan,
185 | out_chan//4,
186 | kernel_size = 1,
187 | stride = 1,
188 | padding = 0,
189 | bias = False)
190 | self.conv2 = nn.Conv2d(out_chan//4,
191 | out_chan,
192 | kernel_size = 1,
193 | stride = 1,
194 | padding = 0,
195 | bias = False)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.sigmoid = nn.Sigmoid()
198 | self.init_weight()
199 |
200 | def forward(self, fsp, fcp):
201 | fcat = torch.cat([fsp, fcp], dim=1)
202 | feat = self.convblk(fcat)
203 | atten = F.avg_pool2d(feat, feat.size()[2:])
204 | atten = self.conv1(atten)
205 | atten = self.relu(atten)
206 | atten = self.conv2(atten)
207 | atten = self.sigmoid(atten)
208 | feat_atten = torch.mul(feat, atten)
209 | feat_out = feat_atten + feat
210 | return feat_out
211 |
212 | def init_weight(self):
213 | for ly in self.children():
214 | if isinstance(ly, nn.Conv2d):
215 | nn.init.kaiming_normal_(ly.weight, a=1)
216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217 |
218 | def get_params(self):
219 | wd_params, nowd_params = [], []
220 | for name, module in self.named_modules():
221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222 | wd_params.append(module.weight)
223 | if not module.bias is None:
224 | nowd_params.append(module.bias)
225 | elif isinstance(module, nn.BatchNorm2d):
226 | nowd_params += list(module.parameters())
227 | return wd_params, nowd_params
228 |
229 |
230 | class BiSeNet(nn.Module):
231 | def __init__(self, n_classes, *args, **kwargs):
232 | super(BiSeNet, self).__init__()
233 | self.cp = ContextPath()
234 | ## here self.sp is deleted
235 | self.ffm = FeatureFusionModule(256, 256)
236 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239 | self.init_weight()
240 |
241 | def forward(self, x):
242 | H, W = x.size()[2:]
243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245 | feat_fuse = self.ffm(feat_sp, feat_cp8)
246 |
247 | feat_out = self.conv_out(feat_fuse)
248 | feat_out16 = self.conv_out16(feat_cp8)
249 | feat_out32 = self.conv_out32(feat_cp16)
250 |
251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254 | return feat_out, feat_out16, feat_out32
255 |
256 | def init_weight(self):
257 | for ly in self.children():
258 | if isinstance(ly, nn.Conv2d):
259 | nn.init.kaiming_normal_(ly.weight, a=1)
260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261 |
262 | def get_params(self):
263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264 | for name, child in self.named_children():
265 | child_wd_params, child_nowd_params = child.get_params()
266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267 | lr_mul_wd_params += child_wd_params
268 | lr_mul_nowd_params += child_nowd_params
269 | else:
270 | wd_params += child_wd_params
271 | nowd_params += child_nowd_params
272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273 |
274 |
275 | if __name__ == "__main__":
276 | net = BiSeNet(19)
277 | net.cuda()
278 | net.eval()
279 | in_ten = torch.randn(16, 3, 640, 480).cuda()
280 | out, out16, out32 = net(in_ten)
281 | print(out.shape)
282 |
283 | net.get_params()
284 |
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import cv2
4 | import numpy as np
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 | from insightface.app import FaceAnalysis
8 | from torchvision.utils import save_image
9 |
10 | import datasets_faceswap as datasets_faceswap
11 |
12 | pil2tensor = transforms.Compose([transforms.ToTensor(), transforms.Resize(512)])
13 |
14 | pil2tensor = transforms.ToTensor()
15 |
16 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'),
17 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
18 | app.prepare(ctx_id=0, det_size=(640, 640))
19 |
20 |
21 | def get_bbox(dets, crop_ratio):
22 | if crop_ratio > 0:
23 | bbox = dets[0:4]
24 | bbox_size = max(bbox[2] - bbox[0], bbox[2] - bbox[0])
25 | bbox_x = 0.5 * (bbox[2] + bbox[0])
26 | bbox_y = 0.5 * (bbox[3] + bbox[1])
27 | x1 = bbox_x - bbox_size * crop_ratio
28 | x2 = bbox_x + bbox_size * crop_ratio
29 | y1 = bbox_y - bbox_size * crop_ratio
30 | y2 = bbox_y + bbox_size * crop_ratio
31 | bbox_pts4 = np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32)
32 | else:
33 | # original box
34 | bbox = dets[0:4].reshape((2, 2))
35 | bbox_pts4 = datasets_faceswap.get_box_lm4p(bbox)
36 | return bbox_pts4
37 |
38 |
39 |
40 | def crop_one_image(args):
41 | cur_img_sor_path = args.img_path
42 | im_pil_sor = Image.open(cur_img_sor_path).convert("RGB")
43 | face_info_sor = app.get(cv2.cvtColor(np.array(im_pil_sor), cv2.COLOR_RGB2BGR))
44 | assert len(face_info_sor) >= 1, 'The input image must contain a face!'
45 | if len(face_info_sor) > 1:
46 | print('The input image contain more than one face, we will only use the maximum face')
47 | face_info_sor = \
48 | sorted(face_info_sor, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1]
49 | dets_sor= face_info_sor['bbox']
50 | bbox_pst_sor = get_bbox(dets_sor, crop_ratio=0.75)
51 |
52 | warp_mat_crop_sor = datasets_faceswap.transformation_from_points(bbox_pst_sor,
53 | datasets_faceswap.mean_box_lm4p_512)
54 | im_crop512_sor = cv2.warpAffine(np.array(im_pil_sor), warp_mat_crop_sor, (512, 512), flags=cv2.INTER_LINEAR)
55 |
56 | im_pil_sor = Image.fromarray(im_crop512_sor)
57 | im_pil_sor = pil2tensor(im_pil_sor)
58 | save_image(im_pil_sor, args.save_path)
59 |
60 |
61 |
62 | if __name__ == '__main__':
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument(
65 | "--img_path",
66 | type=str,
67 | default='',
68 | required=False
69 | )
70 | parser.add_argument(
71 | "--save_path",
72 | type=str,
73 | default='',
74 | required=False
75 | )
76 | args = parser.parse_args()
77 | crop_one_image(args)
--------------------------------------------------------------------------------
/utils/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/utils/save_exp_coeffs.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torchvision.transforms as transforms
9 | from PIL import Image
10 | from insightface.app import FaceAnalysis
11 |
12 | import datasets_faceswap as datasets_faceswap
13 | import third_party.d3dfr.bfm as bfm
14 | import third_party.model_resnet_d3dfr as model_resnet_d3dfr
15 |
16 | device = 'cpu'
17 | checkpoint = './checkpoints'
18 |
19 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'),
20 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
21 |
22 | app.prepare(ctx_id=0, det_size=(640, 640))
23 |
24 | pil2tensor = transforms.Compose([
25 | transforms.Resize((256, 256)),
26 | transforms.ToTensor(),
27 | transforms.Normalize(mean=0.5, std=0.5)])
28 |
29 |
30 | net_d3dfr = model_resnet_d3dfr.getd3dfr_res50(os.path.join(checkpoint, 'third_party/d3dfr_res50_nofc.pth')).eval().to(device)
31 |
32 | bfm_facemodel = bfm.BFM(focal=1015*256/224, image_size=256,
33 | bfm_model_path=os.path.join(checkpoint, 'third_party/BFM_model_front.mat')).to(device)
34 |
35 |
36 | def get_landmarks(image):
37 | face_info = app.get(cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR))
38 | if len(face_info) == 0:
39 | return 'error'
40 | face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1]
41 | pts5 = face_info['kps']
42 |
43 | warp_mat = datasets_faceswap.get_affine_transform(pts5, datasets_faceswap.mean_face_lm5p_256)
44 | drive_im_crop256 = cv2.warpAffine(np.array(image), warp_mat, (256, 256), flags=cv2.INTER_LINEAR)
45 |
46 | drive_im_crop256_pil = Image.fromarray(drive_im_crop256)
47 | image_tar_crop256 = pil2tensor(drive_im_crop256_pil).view(1, 3, 256, 256)
48 |
49 | gt_d3d_coeff = net_d3dfr(image_tar_crop256)
50 | # _, ex_coeff = bfm_facemodel.get_lm68(gt_d3d_coeff)
51 | id_coeff, ex_coeff, tex_coeff, angles, gamma, translation = bfm_facemodel.split_coeff_orderly(gt_d3d_coeff)
52 |
53 | return (ex_coeff, angles, gamma, translation)
54 |
55 |
56 |
57 | def main(sorpth, tarpth, modes):
58 |
59 | mode_list = modes.split('+')
60 | if 'exp' in mode_list:
61 | dstpth = tarpth
62 | elif 'light' in mode_list or 'pose' in mode_list:
63 | dstpth = sorpth
64 | else:
65 | raise ValueError('Unrecognized mode')
66 | with torch.no_grad():
67 | img = Image.open(dstpth)
68 | res = get_landmarks(img)
69 | if isinstance(res, str):
70 | print('cannot find face on ', dstpth)
71 | return
72 | ex_coeff, angles, gamma, translation = res
73 | os.makedirs('results', exist_ok=True)
74 | np.save(f"./results/exp_{modes}.npy", ex_coeff[0])
75 |
76 |
77 |
78 | if __name__ == '__main__':
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument(
81 | "--sor_path",
82 | type=str,
83 | default='',
84 | required=False
85 | )
86 | parser.add_argument(
87 | "--tar_path",
88 | type=str,
89 | default='',
90 | required=False
91 | )
92 | parser.add_argument(
93 | "--modes",
94 | type=str,
95 | default='',
96 | required=False
97 | )
98 |
99 | args = parser.parse_args()
100 | main(args.sor_path, args.tar_path, args.modes)
--------------------------------------------------------------------------------
/utils/test_images/id1/bg_pose+exp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/bg_pose+exp.png
--------------------------------------------------------------------------------
/utils/test_images/id1/exp_pose+exp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/exp_pose+exp.npy
--------------------------------------------------------------------------------
/utils/test_images/id1/render.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/render.png
--------------------------------------------------------------------------------
/utils/test_images/id1/render_exp+pose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/render_exp+pose.png
--------------------------------------------------------------------------------
/utils/test_images/id1/render_exp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/render_exp.png
--------------------------------------------------------------------------------
/utils/test_images/id1/sor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/sor.png
--------------------------------------------------------------------------------
/utils/test_images/id1/tar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/tar.png
--------------------------------------------------------------------------------
/utils/test_images/id2/bg_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/bg_light.png
--------------------------------------------------------------------------------
/utils/test_images/id2/exp_light.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/exp_light.npy
--------------------------------------------------------------------------------
/utils/test_images/id2/render_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/render_light.png
--------------------------------------------------------------------------------
/utils/test_images/id2/sor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/sor.png
--------------------------------------------------------------------------------
/utils/test_images/id2/tar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/tar.png
--------------------------------------------------------------------------------