├── .gitignore ├── LICENSE ├── README.md ├── assets └── demo.png ├── decalib ├── __init__.py ├── datasets │ ├── aflw2000.py │ ├── build_datasets.py │ ├── datasets.py │ ├── detectors.py │ ├── ethnicity.py │ ├── now.py │ ├── train_datasets.py │ ├── vggface.py │ └── vox.py ├── deca.py ├── models │ ├── FLAME.py │ ├── decoders.py │ ├── encoders.py │ ├── frnet.py │ ├── lbs.py │ └── resnet.py ├── trainer.py └── utils │ ├── config.py │ ├── lossfunc.py │ ├── rasterizer │ ├── INSTALL.md │ ├── __init__.py │ ├── setup.py │ ├── standard_rasterize_cuda.cpp │ └── standard_rasterize_cuda_kernel.cu │ ├── renderer.py │ ├── rotation_converter.py │ ├── tensor_cropper.py │ ├── trainer.py │ └── util.py ├── inference_rigface.py ├── requirements.txt ├── rigface └── models │ ├── attention_ID.py │ ├── attention_denoising.py │ ├── pipelineRigFace.py │ ├── transformer_ID_2d.py │ ├── transformer_denoising_2d.py │ ├── unet_ID_2d_blocks.py │ ├── unet_ID_2d_condition.py │ ├── unet_denoising_2d_blocks.py │ └── unet_denoising_2d_condition.py └── utils ├── compute_renders.py ├── data_utils.py ├── datasets_faceswap.py ├── make_bgs.py ├── model.py ├── preprocess.py ├── resnet.py ├── save_exp_coeffs.py └── test_images ├── id1 ├── bg_pose+exp.png ├── exp_pose+exp.npy ├── render.png ├── render_exp+pose.png ├── render_exp.png ├── sor.png └── tar.png └── id2 ├── bg_light.png ├── exp_light.npy ├── render_light.png ├── sor.png └── tar.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Towards Consistent and Controllable Image Synthesis for Face Editing 4 | 5 | 6 | Mengting Wei, Tuomas Varanka, Yante Li, Xingxun Jiang, Huai-Qian Khor, Guoying Zhao 7 | 8 | 9 | University of Oulu 10 | 11 | 12 | ### [Paper](https://arxiv.org/abs/2502.02465) 13 | 14 |
15 | 16 | ## :mega: Updates 17 | 18 | [6/2/2025] Inference code and pre-trained models are released. 19 | 20 | 21 | ## Introduction 22 | 23 | We present RigFace, an efficient approach to edit the expression, pose and lighting with consistent 24 | identity and other attributes from a given image. 25 | 26 |

27 | 28 |

29 | 30 | 31 | ## Installation 32 | 33 | To deploy and run RigFace, run the following scripts: 34 | ``` 35 | conda create -n rigface python=3.9 36 | conda activate rigface 37 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 38 | pip install -r requirements.txt 39 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 40 | pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu116_pyt1130/download.html 41 | ``` 42 | 43 | ## Download models and data 44 | 45 | - Download utils files for pre-processing from [Huggingface](https://huggingface.co/mengtingwei/rigface/tree/main). 46 | Put them in the `utils` directory. 47 | - Download pre-trained weights of our model from the same website. 48 | - Make a new directory `data` inside the `utils`. 49 | - Download pre_trained DECA model `deca_model.tar` from [here](https://github.com/yfeng95/DECA). 50 | Put it in the `data`. 51 | - Download `generic_model.pkl` from [FLAME2020](https://flame.is.tue.mpg.de/download.php). 52 | Also put it under `data`. 53 | - Download `FLAME_texture.npz` from [FLAME texture space](https://flame.is.tue.mpg.de/download.php). 54 | Put it in `data`. 55 | - Download other files in [DECA page](https://github.com/yfeng95/DECA/tree/master/data). Put all the files under `data`. 56 | 57 | ``` 58 | ... 59 | pre_trained 60 | -unet_denoise 61 | ... 62 | -unet_id 63 | ... 64 | utils 65 | -checkpoints 66 | ... 67 | -third_party 68 | ... 69 | -third_party_files 70 | ... 71 | -data 72 | ... 73 | ... 74 | ``` 75 | 76 | ## Test with our examples 77 | 78 | 79 | 80 | ``` 81 | python inference_rigface.py --id_path utils/test_images/id1/sor.png --bg_path utils/test_images/id1/bg_pose+exp.png --exp_path utils/test_images/id3/exp_pose+exp.npy --render_path utils/test_images/id3/render_pose+exp.png --save_path ./res 82 | ``` 83 | 84 | You will find the edited image under `res` directory. 85 | 86 | 87 | ## Test your own data 88 | 89 | 1. First, you need to ensure that both the source and target images have a resolution of 512x512. 90 | 91 | ``` 92 | cd utils 93 | python preprocess.py --img_path --save_path 94 | ``` 95 | 96 | 2. Parse background using the resized source and target images. We provide an example here. 97 | 98 | ``` 99 | python make_bgs.py --sor_path ./test_images/id1/sor.png --tar_path ./test_images/id1/tar.png --modes pose+exp 100 | ``` 101 | 102 | If you want to edit only one mode, just provide the single mode (for example '--modes exp') here. 103 | 104 | 3. Compute the expression coefficients. For the case of editing lighting or pose, 105 | the coefficients will be computed from the source image, and for the case of expression, 106 | they will be computed from the target image. 107 | 108 | ``` 109 | python save_exp_coeffs.py --sor_path ./test_images/id1/sor.png --tar_path ./test_images/id1/tar.png --mode pose+exp 110 | ``` 111 | 112 | 4. Compute the rendering according to the edit modes. 113 | 114 | ``` 115 | python compute_renders.py --sor_path ./test_images/id1/sor.png --tar_path ./test_images/id1/tar.png --mode pose+exp 116 | ``` 117 | 118 | 5. Then do the inference using all the conditions generated. 119 | 120 | ## Acknowledgements 121 | 122 | This project is built on source codes shared by [DECA](https://github.com/yfeng95/DECA), 123 | [Deep3DRecon](https://github.com/sicxu/Deep3DFaceRecon_pytorch), 124 | [faceParsing](https://github.com/zllrunning/face-parsing.PyTorch) and 125 | [ControlNet](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py). 126 | -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/assets/demo.png -------------------------------------------------------------------------------- /decalib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/decalib/__init__.py -------------------------------------------------------------------------------- /decalib/datasets/aflw2000.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | import scipy.io 12 | 13 | class AFLW2000(Dataset): 14 | def __init__(self, testpath='/ps/scratch/yfeng/Data/AFLW2000/GT', crop_size=224): 15 | ''' 16 | data class for loading AFLW2000 dataset 17 | make sure each image has corresponding mat file, which provides cropping infromation 18 | ''' 19 | if os.path.isdir(testpath): 20 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png') 21 | elif isinstance(testpath, list): 22 | self.imagepath_list = testpath 23 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png']): 24 | self.imagepath_list = [testpath] 25 | else: 26 | print('please check the input path') 27 | exit() 28 | print('total {} images'.format(len(self.imagepath_list))) 29 | self.imagepath_list = sorted(self.imagepath_list) 30 | self.crop_size = crop_size 31 | self.scale = 1.6 32 | self.resolution_inp = crop_size 33 | 34 | def __len__(self): 35 | return len(self.imagepath_list) 36 | 37 | def __getitem__(self, index): 38 | imagepath = self.imagepath_list[index] 39 | imagename = imagepath.split('/')[-1].split('.')[0] 40 | image = imread(imagepath)[:,:,:3] 41 | kpt = scipy.io.loadmat(imagepath.replace('jpg', 'mat'))['pt3d_68'].T 42 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 43 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 44 | 45 | h, w, _ = image.shape 46 | old_size = (right - left + bottom - top)/2 47 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 48 | size = int(old_size*self.scale) 49 | 50 | # crop image 51 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 52 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 53 | tform = estimate_transform('similarity', src_pts, DST_PTS) 54 | 55 | image = image/255. 56 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp)) 57 | dst_image = dst_image.transpose(2,0,1) 58 | return {'image': torch.tensor(dst_image).float(), 59 | 'imagename': imagename, 60 | # 'tform': tform, 61 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(), 62 | } -------------------------------------------------------------------------------- /decalib/datasets/build_datasets.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | from torch.utils.data import Dataset, ConcatDataset 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import cv2 7 | import scipy 8 | from skimage.io import imread, imsave 9 | from skimage.transform import estimate_transform, warp, resize, rescale 10 | from glob import glob 11 | 12 | from .vggface import VGGFace2Dataset 13 | from .ethnicity import EthnicityDataset 14 | from .aflw2000 import AFLW2000 15 | from .now import NoWDataset 16 | from .vox import VoxelDataset 17 | 18 | def build_train(config, is_train=True): 19 | data_list = [] 20 | if 'vox2' in config.training_data: 21 | data_list.append(VoxelDataset(dataname='vox2', K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 22 | if 'vggface2' in config.training_data: 23 | data_list.append(VGGFace2Dataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 24 | if 'vggface2hq' in config.training_data: 25 | data_list.append(VGGFace2HQDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 26 | if 'ethnicity' in config.training_data: 27 | data_list.append(EthnicityDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 28 | if 'coco' in config.training_data: 29 | data_list.append(COCODataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale)) 30 | if 'celebahq' in config.training_data: 31 | data_list.append(CelebAHQDataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale)) 32 | dataset = ConcatDataset(data_list) 33 | 34 | return dataset 35 | 36 | def build_val(config, is_train=True): 37 | data_list = [] 38 | if 'vggface2' in config.eval_data: 39 | data_list.append(VGGFace2Dataset(isEval=True, K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 40 | if 'now' in config.eval_data: 41 | data_list.append(NoWDataset()) 42 | if 'aflw2000' in config.eval_data: 43 | data_list.append(AFLW2000()) 44 | dataset = ConcatDataset(data_list) 45 | 46 | return dataset 47 | -------------------------------------------------------------------------------- /decalib/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import os, sys 17 | import torch 18 | from torch.utils.data import Dataset, DataLoader 19 | import torchvision.transforms as transforms 20 | import numpy as np 21 | import cv2 22 | import scipy 23 | from skimage.io import imread, imsave 24 | from skimage.transform import estimate_transform, warp, resize, rescale 25 | from glob import glob 26 | import scipy.io 27 | from decalib.datasets import datasets 28 | from torchvision.utils import save_image 29 | 30 | from . import detectors 31 | 32 | def video2sequence(video_path, sample_step=10): 33 | videofolder = os.path.splitext(video_path)[0] 34 | os.makedirs(videofolder, exist_ok=True) 35 | video_name = os.path.splitext(os.path.split(video_path)[-1])[0] 36 | vidcap = cv2.VideoCapture(video_path) 37 | success,image = vidcap.read() 38 | count = 0 39 | imagepath_list = [] 40 | while success: 41 | # if count%sample_step == 0: 42 | imagepath = os.path.join(videofolder, f'{video_name}_frame{count:04d}.jpg') 43 | cv2.imwrite(imagepath, image) # save frame as JPEG file 44 | success,image = vidcap.read() 45 | count += 1 46 | imagepath_list.append(imagepath) 47 | print('video frames are stored in {}'.format(videofolder)) 48 | return imagepath_list 49 | 50 | class TestData(Dataset): 51 | # 传递过来的有iscrop=True, size=256, sort=True 52 | def __init__(self, testpath, iscrop=True, crop_size=224, scale=1.25, face_detector='fan', 53 | sample_step=10, size=256, sort=False): 54 | 55 | if isinstance(testpath, list): 56 | self.imagepath_list = testpath 57 | elif os.path.isdir(testpath): 58 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png') + glob(testpath + '/*.bmp') 59 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png', 'bmp']): 60 | self.imagepath_list = [testpath] 61 | elif os.path.isfile(testpath) and (testpath[-3:] in ['mp4', 'csv', 'vid', 'ebm']): 62 | self.imagepath_list = video2sequence(testpath, sample_step) 63 | 64 | if sort: 65 | self.imagepath_list = sorted(self.imagepath_list) 66 | self.crop_size = crop_size 67 | self.scale = scale 68 | self.iscrop = iscrop 69 | self.resolution_inp = crop_size 70 | self.size = size 71 | # 使用的是face alignment的关键点检测工具 72 | if face_detector == 'fan': 73 | self.face_detector = detectors.FAN() 74 | else: 75 | print(f'please check the detector: {face_detector}') 76 | exit() 77 | 78 | def __len__(self): 79 | return len(self.imagepath_list) 80 | 81 | def bbox2point(self, left, right, top, bottom, type='bbox'): 82 | 83 | if type=='kpt68': 84 | old_size = (right - left + bottom - top) / 2 * 1.1 85 | # 人脸中心点的位置 86 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 87 | elif type=='bbox': 88 | old_size = (right - left + bottom - top)/2 89 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.12]) 90 | else: 91 | raise NotImplementedError 92 | return old_size, center 93 | 94 | def get_image(self, image): 95 | h, w, _ = image.shape 96 | bbox, bbox_type = self.face_detector.run(image) 97 | if len(bbox) < 4: 98 | print('no face detected! run original image') 99 | left = 0; right = h-1; top=0; bottom=w-1 100 | else: 101 | left = bbox[0]; right=bbox[2] 102 | top = bbox[1]; bottom=bbox[3] 103 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type) 104 | size = int(old_size*self.scale) 105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 106 | 107 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 108 | tform = estimate_transform('similarity', src_pts, DST_PTS) 109 | 110 | image = image / 255. 111 | 112 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp)) 113 | dst_image = dst_image.transpose(2,0,1) 114 | return {'image': torch.tensor(dst_image).float(), 115 | 'tform': torch.tensor(tform.params).float(), 116 | 'original_image': torch.tensor(image.transpose(2,0,1)).float(), 117 | } 118 | 119 | 120 | def __getitem__(self, index): 121 | 122 | imagepath = self.imagepath_list[index] 123 | imagename = os.path.splitext(os.path.split(imagepath)[-1])[0] 124 | im = imread(imagepath) 125 | 126 | if self.size is not None: # size = 256 127 | im = (resize(im, (self.size, self.size), anti_aliasing=True) * 255.).astype(np.uint8) 128 | 129 | # (256, 256, 3) 130 | image = np.array(im) 131 | 132 | if len(image.shape) == 2: 133 | image = image[:, :, None].repeat(1,1,3) 134 | if len(image.shape) == 3 and image.shape[2] > 3: 135 | image = image[:, :, :3] 136 | 137 | h, w, _ = image.shape 138 | if self.iscrop: # true 139 | # provide kpt as txt file, or mat file (for AFLW2000) 140 | # 检查是否存在landmark的文件,不存在则自己检测 141 | kpt_matpath = os.path.splitext(imagepath)[0]+'.mat' 142 | kpt_txtpath = os.path.splitext(imagepath)[0]+'.txt' 143 | if os.path.exists(kpt_matpath): 144 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T 145 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]) 146 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 147 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68') 148 | elif os.path.exists(kpt_txtpath): 149 | kpt = np.loadtxt(kpt_txtpath) 150 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]) 151 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 152 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68') 153 | else: 154 | bbox, bbox_type = self.face_detector.run(image) 155 | if len(bbox) < 4: 156 | print('no face detected! run original image') 157 | left = 0; right = h-1; top=0; bottom=w-1 158 | else: 159 | left = bbox[0]; right=bbox[2] 160 | top = bbox[1]; bottom=bbox[3] 161 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type) 162 | size = int(old_size * self.scale) 163 | src_pts = np.array([[center[0] - size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 164 | else: 165 | src_pts = np.array([[0, 0], [0, h-1], [w-1, 0]]) 166 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]]) 167 | # self.resolution_inp = 224,目标图像大小 168 | DST_PTS = np.array([[0, 0], [0, self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 169 | # 计算源图像变换到目标图像需要经过怎样的矩阵转换 170 | tform = estimate_transform('similarity', src_pts, DST_PTS) 171 | 172 | image = image / 255. # 0-1区间 173 | 174 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp)) 175 | dst_image = dst_image.transpose(2,0,1) 176 | 177 | return {'image': torch.tensor(dst_image).float(), # 只对面部区域进行了保留 178 | 'imagename': imagename, 179 | 'tform': torch.tensor(tform.params).float(), 180 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(), 181 | } 182 | 183 | 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | testdata_source = datasets.TestData( 189 | source, iscrop=True, size=512, sort=True 190 | ) -------------------------------------------------------------------------------- /decalib/datasets/detectors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch 18 | 19 | class FAN(object): 20 | def __init__(self): 21 | import face_alignment 22 | self.model = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) 23 | 24 | def run(self, image): 25 | ''' 26 | image: 0-255, uint8, rgb, [h, w, 3] 27 | return: detected box list 28 | ''' 29 | out = self.model.get_landmarks(image) 30 | if out is None: 31 | return [0], 'kpt68' 32 | else: 33 | kpt = out[0].squeeze() 34 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]) 35 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 36 | bbox = [left, top, right, bottom] 37 | return bbox, 'kpt68' 38 | 39 | class MTCNN(object): 40 | def __init__(self, device = 'cpu'): 41 | ''' 42 | https://github.com/timesler/facenet-pytorch/blob/master/examples/infer.ipynb 43 | ''' 44 | from facenet_pytorch import MTCNN as mtcnn 45 | self.device = device 46 | self.model = mtcnn(keep_all=True) 47 | def run(self, input): 48 | ''' 49 | image: 0-255, uint8, rgb, [h, w, 3] 50 | return: detected box 51 | ''' 52 | out = self.model.detect(input[None,...]) 53 | if out[0][0] is None: 54 | return [0] 55 | else: 56 | bbox = out[0][0].squeeze() 57 | return bbox, 'bbox' 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /decalib/datasets/ethnicity.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class EthnicityDataset(Dataset): 13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False): 14 | ''' 15 | K must be less than 6 16 | ''' 17 | self.K = K 18 | self.image_size = image_size 19 | self.imagefolder = '/ps/scratch/face2d3d/train' 20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/' 21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/' 22 | # hq: 23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_and_race_per_7000_african_asian_2d_train_list_max_normal_100_ring_5_1_serial.npy' 25 | self.data_lines = np.load(datafile).astype('str') 26 | 27 | self.isTemporal = isTemporal 28 | self.scale = scale #[scale_min, scale_max] 29 | self.trans_scale = trans_scale #[dx, dy] 30 | self.isSingle = isSingle 31 | if isSingle: 32 | self.K = 1 33 | 34 | def __len__(self): 35 | return len(self.data_lines) 36 | 37 | def __getitem__(self, idx): 38 | images_list = []; kpt_list = []; mask_list = [] 39 | for i in range(self.K): 40 | name = self.data_lines[idx, i] 41 | if name[0]=='n': 42 | self.imagefolder = '/ps/scratch/face2d3d/train/' 43 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/' 44 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/' 45 | elif name[0]=='A': 46 | self.imagefolder = '/ps/scratch/face2d3d/race_per_7000/' 47 | self.kptfolder = '/ps/scratch/face2d3d/race_per_7000_annotated_torch7_new/' 48 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/race7000_seg/test_crop_size_400_batch/' 49 | 50 | image_path = os.path.join(self.imagefolder, name + '.jpg') 51 | seg_path = os.path.join(self.segfolder, name + '.npy') 52 | kpt_path = os.path.join(self.kptfolder, name + '.npy') 53 | 54 | image = imread(image_path)/255. 55 | kpt = np.load(kpt_path)[:,:2] 56 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1]) 57 | 58 | ### crop information 59 | tform = self.crop(image, kpt) 60 | ## crop 61 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 62 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size)) 63 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 64 | 65 | # normalized kpt 66 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 67 | 68 | images_list.append(cropped_image.transpose(2,0,1)) 69 | kpt_list.append(cropped_kpt) 70 | mask_list.append(cropped_mask) 71 | 72 | ### 73 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 74 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 75 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3 76 | 77 | if self.isSingle: 78 | images_array = images_array.squeeze() 79 | kpt_array = kpt_array.squeeze() 80 | mask_array = mask_array.squeeze() 81 | 82 | data_dict = { 83 | 'image': images_array, 84 | 'landmark': kpt_array, 85 | 'mask': mask_array 86 | } 87 | 88 | return data_dict 89 | 90 | def crop(self, image, kpt): 91 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 92 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 93 | 94 | h, w, _ = image.shape 95 | old_size = (right - left + bottom - top)/2 96 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 97 | # translate center 98 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale 99 | center = center + trans_scale*old_size # 0.5 100 | 101 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 102 | size = int(old_size*scale) 103 | 104 | # crop image 105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 106 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]]) 107 | tform = estimate_transform('similarity', src_pts, DST_PTS) 108 | 109 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 110 | # # change kpt accordingly 111 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 112 | return tform 113 | 114 | def load_mask(self, maskpath, h, w): 115 | # print(maskpath) 116 | if os.path.isfile(maskpath): 117 | vis_parsing_anno = np.load(maskpath) 118 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 119 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 120 | mask = np.zeros_like(vis_parsing_anno) 121 | # for i in range(1, 16): 122 | mask[vis_parsing_anno>0.5] = 1. 123 | else: 124 | mask = np.ones((h, w)) 125 | return mask 126 | 127 | -------------------------------------------------------------------------------- /decalib/datasets/now.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class NoWDataset(Dataset): 13 | def __init__(self, ring_elements=6, crop_size=224, scale=1.6): 14 | folder = '/ps/scratch/yfeng/other-github/now_evaluation/data/NoW_Dataset' 15 | self.data_path = os.path.join(folder, 'imagepathsvalidation.txt') 16 | with open(self.data_path) as f: 17 | self.data_lines = f.readlines() 18 | 19 | self.imagefolder = os.path.join(folder, 'final_release_version', 'iphone_pictures') 20 | self.bbxfolder = os.path.join(folder, 'final_release_version', 'detected_face') 21 | 22 | # self.data_path = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/test_image_paths_ring_6_elements.npy' 23 | # self.imagepath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/iphone_pictures/' 24 | # self.bbxpath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/detected_face/' 25 | self.crop_size = crop_size 26 | self.scale = scale 27 | 28 | def __len__(self): 29 | return len(self.data_lines) 30 | 31 | def __getitem__(self, index): 32 | imagepath = os.path.join(self.imagefolder, self.data_lines[index].strip()) #+ '.jpg' 33 | bbx_path = os.path.join(self.bbxfolder, self.data_lines[index].strip().replace('.jpg', '.npy')) 34 | bbx_data = np.load(bbx_path, allow_pickle=True, encoding='latin1').item() 35 | # box = np.array([[bbx_data['left'], bbx_data['top']], [bbx_data['right'], bbx_data['bottom']]]).astype('float32') 36 | left = bbx_data['left']; right = bbx_data['right'] 37 | top = bbx_data['top']; bottom = bbx_data['bottom'] 38 | 39 | imagename = imagepath.split('/')[-1].split('.')[0] 40 | image = imread(imagepath)[:,:,:3] 41 | 42 | h, w, _ = image.shape 43 | old_size = (right - left + bottom - top)/2 44 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 45 | size = int(old_size*self.scale) 46 | 47 | # crop image 48 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 49 | DST_PTS = np.array([[0,0], [0,self.crop_size - 1], [self.crop_size - 1, 0]]) 50 | tform = estimate_transform('similarity', src_pts, DST_PTS) 51 | 52 | image = image/255. 53 | dst_image = warp(image, tform.inverse, output_shape=(self.crop_size, self.crop_size)) 54 | dst_image = dst_image.transpose(2,0,1) 55 | return {'image': torch.tensor(dst_image).float(), 56 | 'imagename': self.data_lines[index].strip().replace('.jpg', ''), 57 | # 'tform': tform, 58 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(), 59 | } -------------------------------------------------------------------------------- /decalib/datasets/vggface.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class VGGFace2Dataset(Dataset): 13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False): 14 | ''' 15 | K must be less than 6 16 | ''' 17 | self.K = K 18 | self.image_size = image_size 19 | self.imagefolder = '/ps/scratch/face2d3d/train' 20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7' 21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch' 22 | # hq: 23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_train_list_max_normal_100_ring_5_1_serial.npy' 25 | if isEval: 26 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_val_list_max_normal_100_ring_5_1_serial.npy' 27 | self.data_lines = np.load(datafile).astype('str') 28 | 29 | self.isTemporal = isTemporal 30 | self.scale = scale #[scale_min, scale_max] 31 | self.trans_scale = trans_scale #[dx, dy] 32 | self.isSingle = isSingle 33 | if isSingle: 34 | self.K = 1 35 | 36 | def __len__(self): 37 | return len(self.data_lines) 38 | 39 | def __getitem__(self, idx): 40 | images_list = []; kpt_list = []; mask_list = [] 41 | 42 | random_ind = np.random.permutation(5)[:self.K] 43 | for i in random_ind: 44 | name = self.data_lines[idx, i] 45 | image_path = os.path.join(self.imagefolder, name + '.jpg') 46 | seg_path = os.path.join(self.segfolder, name + '.npy') 47 | kpt_path = os.path.join(self.kptfolder, name + '.npy') 48 | 49 | image = imread(image_path)/255. 50 | kpt = np.load(kpt_path)[:,:2] 51 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1]) 52 | 53 | ### crop information 54 | tform = self.crop(image, kpt) 55 | ## crop 56 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 57 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size)) 58 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 59 | 60 | # normalized kpt 61 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 62 | 63 | images_list.append(cropped_image.transpose(2,0,1)) 64 | kpt_list.append(cropped_kpt) 65 | mask_list.append(cropped_mask) 66 | 67 | ### 68 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 69 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 70 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3 71 | 72 | if self.isSingle: 73 | images_array = images_array.squeeze() 74 | kpt_array = kpt_array.squeeze() 75 | mask_array = mask_array.squeeze() 76 | 77 | data_dict = { 78 | 'image': images_array, 79 | 'landmark': kpt_array, 80 | 'mask': mask_array 81 | } 82 | 83 | return data_dict 84 | 85 | def crop(self, image, kpt): 86 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 87 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 88 | 89 | h, w, _ = image.shape 90 | old_size = (right - left + bottom - top)/2 91 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 92 | # translate center 93 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale 94 | center = center + trans_scale*old_size # 0.5 95 | 96 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 97 | size = int(old_size*scale) 98 | 99 | # crop image 100 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 101 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]]) 102 | tform = estimate_transform('similarity', src_pts, DST_PTS) 103 | 104 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 105 | # # change kpt accordingly 106 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 107 | return tform 108 | 109 | def load_mask(self, maskpath, h, w): 110 | # print(maskpath) 111 | if os.path.isfile(maskpath): 112 | vis_parsing_anno = np.load(maskpath) 113 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 114 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 115 | mask = np.zeros_like(vis_parsing_anno) 116 | # for i in range(1, 16): 117 | mask[vis_parsing_anno>0.5] = 1. 118 | else: 119 | mask = np.ones((h, w)) 120 | return mask 121 | 122 | 123 | 124 | class VGGFace2HQDataset(Dataset): 125 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False): 126 | ''' 127 | K must be less than 6 128 | ''' 129 | self.K = K 130 | self.image_size = image_size 131 | self.imagefolder = '/ps/scratch/face2d3d/train' 132 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7' 133 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch' 134 | # hq: 135 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 136 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 137 | self.data_lines = np.load(datafile).astype('str') 138 | 139 | self.isTemporal = isTemporal 140 | self.scale = scale #[scale_min, scale_max] 141 | self.trans_scale = trans_scale #[dx, dy] 142 | self.isSingle = isSingle 143 | if isSingle: 144 | self.K = 1 145 | 146 | def __len__(self): 147 | return len(self.data_lines) 148 | 149 | def __getitem__(self, idx): 150 | images_list = []; kpt_list = []; mask_list = [] 151 | 152 | for i in range(self.K): 153 | name = self.data_lines[idx, i] 154 | image_path = os.path.join(self.imagefolder, name + '.jpg') 155 | seg_path = os.path.join(self.segfolder, name + '.npy') 156 | kpt_path = os.path.join(self.kptfolder, name + '.npy') 157 | 158 | image = imread(image_path)/255. 159 | kpt = np.load(kpt_path)[:,:2] 160 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1]) 161 | 162 | ### crop information 163 | tform = self.crop(image, kpt) 164 | ## crop 165 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 166 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size)) 167 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 168 | 169 | # normalized kpt 170 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 171 | 172 | images_list.append(cropped_image.transpose(2,0,1)) 173 | kpt_list.append(cropped_kpt) 174 | mask_list.append(cropped_mask) 175 | 176 | ### 177 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 178 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 179 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3 180 | 181 | if self.isSingle: 182 | images_array = images_array.squeeze() 183 | kpt_array = kpt_array.squeeze() 184 | mask_array = mask_array.squeeze() 185 | 186 | data_dict = { 187 | 'image': images_array, 188 | 'landmark': kpt_array, 189 | 'mask': mask_array 190 | } 191 | 192 | return data_dict 193 | 194 | def crop(self, image, kpt): 195 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 196 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 197 | 198 | h, w, _ = image.shape 199 | old_size = (right - left + bottom - top)/2 200 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 201 | # translate center 202 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale 203 | center = center + trans_scale*old_size # 0.5 204 | 205 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 206 | size = int(old_size*scale) 207 | 208 | # crop image 209 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 210 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]]) 211 | tform = estimate_transform('similarity', src_pts, DST_PTS) 212 | 213 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 214 | # # change kpt accordingly 215 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 216 | return tform 217 | 218 | def load_mask(self, maskpath, h, w): 219 | # print(maskpath) 220 | if os.path.isfile(maskpath): 221 | vis_parsing_anno = np.load(maskpath) 222 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 223 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 224 | mask = np.zeros_like(vis_parsing_anno) 225 | # for i in range(1, 16): 226 | mask[vis_parsing_anno>0.5] = 1. 227 | else: 228 | mask = np.ones((h, w)) 229 | return mask -------------------------------------------------------------------------------- /decalib/datasets/vox.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class VoxelDataset(Dataset): 13 | def __init__(self, K, image_size, scale, trans_scale = 0, dataname='vox2', n_train=100000, isTemporal=False, isEval=False, isSingle=False): 14 | self.K = K 15 | self.image_size = image_size 16 | if dataname == 'vox1': 17 | self.kpt_suffix = '.txt' 18 | self.imagefolder = '/ps/project/face2d3d/VoxCeleb/vox1/dev/images_cropped' 19 | self.kptfolder = '/ps/scratch/yfeng/Data/VoxCeleb/vox1/landmark_2d' 20 | 21 | self.face_dict = {} 22 | for person_id in sorted(os.listdir(self.kptfolder)): 23 | for video_id in os.listdir(os.path.join(self.kptfolder, person_id)): 24 | for face_id in os.listdir(os.path.join(self.kptfolder, person_id, video_id)): 25 | if 'txt' in face_id: 26 | continue 27 | key = person_id + '/' + video_id + '/' + face_id 28 | # if key not in self.face_dict.keys(): 29 | # self.face_dict[key] = [] 30 | name_list = os.listdir(os.path.join(self.kptfolder, person_id, video_id, face_id)) 31 | name_list = [name.split['.'][0] for name in name_list] 32 | if len(name_list)0.5] = 1. 162 | else: 163 | mask = np.ones((h, w)) 164 | return mask 165 | -------------------------------------------------------------------------------- /decalib/deca.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import os, sys 17 | import torch 18 | import torchvision 19 | import torch.nn.functional as F 20 | import torch.nn as nn 21 | 22 | import numpy as np 23 | from time import time 24 | from skimage.io import imread 25 | import cv2 26 | import pickle 27 | from .utils.renderer import SRenderY, set_rasterizer 28 | from .models.encoders import ResnetEncoder 29 | from .models.FLAME import FLAME, FLAMETex 30 | from .models.decoders import Generator 31 | from .utils import util 32 | from .utils.rotation_converter import batch_euler2axis 33 | from .utils.tensor_cropper import transform_points 34 | from .datasets import datasets 35 | from .utils.config import cfg 36 | torch.backends.cudnn.benchmark = True 37 | 38 | class DECA(nn.Module): 39 | def __init__(self, config=None, device='cuda'): 40 | super(DECA, self).__init__() 41 | if config is None: 42 | self.cfg = cfg 43 | else: 44 | self.cfg = config 45 | self.device = device 46 | self.image_size = self.cfg.dataset.image_size 47 | self.uv_size = self.cfg.model.uv_size 48 | 49 | self._create_model(self.cfg.model) 50 | self._setup_renderer(self.cfg.model) 51 | 52 | def _setup_renderer(self, model_cfg): 53 | set_rasterizer(self.cfg.rasterizer_type) 54 | self.render = SRenderY(self.image_size, obj_filename=model_cfg.topology_path, uv_size=model_cfg.uv_size, rasterizer_type=self.cfg.rasterizer_type).to(self.device) 55 | # face mask for rendering details 56 | mask = imread(model_cfg.face_eye_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous() 57 | self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) 58 | mask = imread(model_cfg.face_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous() 59 | self.uv_face_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) 60 | # displacement correction 61 | fixed_dis = np.load(model_cfg.fixed_displacement_path) 62 | self.fixed_uv_dis = torch.tensor(fixed_dis).float().to(self.device) 63 | # mean texture 64 | mean_texture = imread(model_cfg.mean_tex_path).astype(np.float32)/255.; mean_texture = torch.from_numpy(mean_texture.transpose(2,0,1))[None,:,:,:].contiguous() 65 | self.mean_texture = F.interpolate(mean_texture, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) 66 | # dense mesh template, for save detail mesh 67 | self.dense_template = np.load(model_cfg.dense_template_path, allow_pickle=True, encoding='latin1').item() 68 | 69 | def _create_model(self, model_cfg): 70 | # set up parameters 71 | self.n_param = model_cfg.n_shape+model_cfg.n_tex+model_cfg.n_exp+model_cfg.n_pose+model_cfg.n_cam+model_cfg.n_light 72 | self.n_detail = model_cfg.n_detail 73 | self.n_cond = model_cfg.n_exp + 3 # exp + jaw pose 74 | self.num_list = [model_cfg.n_shape, model_cfg.n_tex, model_cfg.n_exp, model_cfg.n_pose, model_cfg.n_cam, model_cfg.n_light] 75 | self.param_dict = {i:model_cfg.get('n_' + i) for i in model_cfg.param_list} 76 | 77 | # encoders 78 | self.E_flame = ResnetEncoder(outsize=self.n_param).to(self.device) 79 | self.E_detail = ResnetEncoder(outsize=self.n_detail).to(self.device) 80 | # decoders 81 | self.flame = FLAME(model_cfg).to(self.device) 82 | if model_cfg.use_tex: 83 | self.flametex = FLAMETex(model_cfg).to(self.device) 84 | self.D_detail = Generator(latent_dim=self.n_detail+self.n_cond, out_channels=1, out_scale=model_cfg.max_z, sample_mode = 'bilinear').to(self.device) 85 | # resume model 86 | model_path = self.cfg.pretrained_modelpath 87 | if os.path.exists(model_path): 88 | print(f'trained model found. load {model_path}') 89 | checkpoint = torch.load(model_path) 90 | self.checkpoint = checkpoint 91 | util.copy_state_dict(self.E_flame.state_dict(), checkpoint['E_flame']) 92 | util.copy_state_dict(self.E_detail.state_dict(), checkpoint['E_detail']) 93 | util.copy_state_dict(self.D_detail.state_dict(), checkpoint['D_detail']) 94 | else: 95 | print(f'please check model path: {model_path}') 96 | # eval mode 97 | self.E_flame.eval() 98 | self.E_detail.eval() 99 | self.D_detail.eval() 100 | 101 | def decompose_code(self, code, num_dict): 102 | ''' Convert a flattened parameter vector to a dictionary of parameters 103 | code_dict.keys() = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 104 | ''' 105 | code_dict = {} 106 | start = 0 107 | for key in num_dict: 108 | end = start+int(num_dict[key]) 109 | code_dict[key] = code[:, start:end] 110 | start = end 111 | if key == 'light': 112 | code_dict[key] = code_dict[key].reshape(code_dict[key].shape[0], 9, 3) 113 | return code_dict 114 | 115 | def displacement2normal(self, uv_z, coarse_verts, coarse_normals): 116 | ''' Convert displacement map into detail normal map 117 | ''' 118 | batch_size = uv_z.shape[0] 119 | uv_coarse_vertices = self.render.world2uv(coarse_verts).detach() 120 | uv_coarse_normals = self.render.world2uv(coarse_normals).detach() 121 | 122 | uv_z = uv_z*self.uv_face_eye_mask 123 | uv_detail_vertices = uv_coarse_vertices + uv_z*uv_coarse_normals + self.fixed_uv_dis[None,None,:,:]*uv_coarse_normals.detach() 124 | dense_vertices = uv_detail_vertices.permute(0,2,3,1).reshape([batch_size, -1, 3]) 125 | uv_detail_normals = util.vertex_normals(dense_vertices, self.render.dense_faces.expand(batch_size, -1, -1)) 126 | uv_detail_normals = uv_detail_normals.reshape([batch_size, uv_coarse_vertices.shape[2], uv_coarse_vertices.shape[3], 3]).permute(0,3,1,2) 127 | uv_detail_normals = uv_detail_normals*self.uv_face_eye_mask + uv_coarse_normals*(1.-self.uv_face_eye_mask) 128 | return uv_detail_normals 129 | 130 | def visofp(self, normals): 131 | ''' visibility of keypoints, based on the normal direction 132 | ''' 133 | normals68 = self.flame.seletec_3d68(normals) 134 | vis68 = (normals68[:,:,2:] < 0.1).float() 135 | return vis68 136 | 137 | # @torch.no_grad() 138 | def encode(self, images, use_detail=True): 139 | if use_detail: 140 | # use_detail is for training detail model, need to set coarse model as eval mode 141 | with torch.no_grad(): 142 | parameters = self.E_flame(images) 143 | else: 144 | parameters = self.E_flame(images) 145 | codedict = self.decompose_code(parameters, self.param_dict) 146 | codedict['images'] = images 147 | if use_detail: 148 | detailcode = self.E_detail(images) 149 | codedict['detail'] = detailcode 150 | if self.cfg.model.jaw_type == 'euler': 151 | posecode = codedict['pose'] 152 | euler_jaw_pose = posecode[:,3:].clone() # x for yaw (open mouth), y for pitch (left ang right), z for roll 153 | posecode[:,3:] = batch_euler2axis(euler_jaw_pose) 154 | codedict['pose'] = posecode 155 | codedict['euler_jaw_pose'] = euler_jaw_pose 156 | return codedict 157 | 158 | # @torch.no_grad() 159 | def decode(self, codedict, rendering=True, iddict=None, vis_lmk=True, return_vis=True, use_detail=True, 160 | render_orig=False, original_image=None, tform=None, add_light=True, th=0, 161 | align_ffhq=False, return_ffhq_center=False, ffhq_center=None, 162 | light_type='point', render_norm=False): 163 | # 仍然是裁减之后的图像 164 | images = codedict['images'] 165 | batch_size = images.shape[0] 166 | 167 | ## decode 168 | verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape'], 169 | expression_params=codedict['exp'], 170 | pose_params=codedict['pose']) 171 | 172 | if (align_ffhq and ffhq_center is not None) or return_ffhq_center: 173 | lm_eye_left = landmarks2d[:, 36:42] # left-clockwise 174 | lm_eye_right = landmarks2d[:, 42:48] # left-clockwise 175 | lm_mouth_outer = landmarks2d[:, 48:60] # left-clockwise 176 | 177 | eye_left = torch.mean(lm_eye_left, dim=1) 178 | eye_right = torch.mean(lm_eye_right, dim=1) 179 | eye_avg = (eye_left + eye_right) * 0.5 180 | # eye_to_eye = eye_right - eye_left 181 | mouth_left = lm_mouth_outer[:, 0] 182 | mouth_right = lm_mouth_outer[:, 6] 183 | mouth_avg = (mouth_left + mouth_right) * 0.5 184 | eye_to_mouth = mouth_avg - eye_avg 185 | 186 | center = eye_avg + eye_to_mouth * 0.1 187 | 188 | if return_ffhq_center: 189 | return center 190 | 191 | if align_ffhq: 192 | delta = ffhq_center - center 193 | verts = verts + delta 194 | 195 | if self.cfg.model.use_tex: 196 | albedo = self.flametex(codedict['tex']) 197 | else: 198 | albedo = torch.zeros([batch_size, 3, self.uv_size, self.uv_size], device=images.device) 199 | landmarks3d_world = landmarks3d.clone() 200 | 201 | ## projection 202 | landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:,:,:2]; landmarks2d[:,:,1:] = -landmarks2d[:,:,1:]#; landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2 203 | landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam']); landmarks3d[:,:,1:] = -landmarks3d[:,:,1:] #; landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2 204 | 205 | trans_verts = util.batch_orth_proj(verts, codedict['cam']); trans_verts[:,:,1:] = -trans_verts[:,:,1:] 206 | opdict = { 207 | 'verts': verts, 208 | 'trans_verts': trans_verts, 209 | 'landmarks2d': landmarks2d, 210 | 'landmarks3d': landmarks3d, 211 | 'landmarks3d_world': landmarks3d_world, 212 | } 213 | 214 | ## rendering 215 | if return_vis and render_orig and original_image is not None and tform is not None: 216 | points_scale = [self.image_size, self.image_size] 217 | _, _, h, w = original_image.shape 218 | # import ipdb; ipdb.set_trace() 219 | trans_verts = transform_points(trans_verts, tform, points_scale, [h, w]) 220 | landmarks2d = transform_points(landmarks2d, tform, points_scale, [h, w]) 221 | landmarks3d = transform_points(landmarks3d, tform, points_scale, [h, w]) 222 | images = original_image 223 | else: 224 | h, w = self.image_size, self.image_size 225 | 226 | if rendering: 227 | ops = self.render(verts, trans_verts, albedo, codedict['light'], h=h, w=w, 228 | add_light=add_light, th=th, light_type=light_type, render_norm=render_norm) 229 | ## output 230 | opdict['grid'] = ops['grid'] 231 | opdict['rendered_images'] = ops['images'] 232 | opdict['alpha_images'] = ops['alpha_images'] 233 | opdict['normal_images'] = ops['normal_images'] 234 | opdict['albedo_images'] = ops['albedo_images'] 235 | 236 | if self.cfg.model.use_tex: 237 | opdict['albedo'] = albedo 238 | 239 | return opdict, _ 240 | 241 | def visualize(self, visdict, size=224, dim=2): 242 | ''' 243 | image range should be [0,1] 244 | dim: 2 for horizontal. 1 for vertical 245 | ''' 246 | assert dim == 1 or dim==2 247 | grids = {} 248 | for key in visdict: 249 | _,_,h,w = visdict[key].shape 250 | if dim == 2: 251 | new_h = size; new_w = int(w*size/h) 252 | elif dim == 1: 253 | new_h = int(h*size/w); new_w = size 254 | grids[key] = torchvision.utils.make_grid(F.interpolate(visdict[key], [new_h, new_w]).detach().cpu()) 255 | grid = torch.cat(list(grids.values()), dim) 256 | grid_image = (grid.numpy().transpose(1,2,0).copy()*255)[:,:,[2,1,0]] 257 | grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) 258 | return grid_image 259 | 260 | def save_obj(self, filename, opdict): 261 | ''' 262 | vertices: [nv, 3], tensor 263 | texture: [3, h, w], tensor 264 | ''' 265 | i = 0 266 | vertices = opdict['verts'][i].cpu().numpy() 267 | faces = self.render.faces[0].cpu().numpy() 268 | texture = util.tensor2image(opdict['uv_texture_gt'][i]) 269 | uvcoords = self.render.raw_uvcoords[0].cpu().numpy() 270 | uvfaces = self.render.uvfaces[0].cpu().numpy() 271 | # save coarse mesh, with texture and normal map 272 | normal_map = util.tensor2image(opdict['uv_detail_normals'][i]*0.5 + 0.5) 273 | util.write_obj(filename, vertices, faces, 274 | texture=texture, 275 | uvcoords=uvcoords, 276 | uvfaces=uvfaces, 277 | normal_map=normal_map) 278 | # upsample mesh, save detailed mesh 279 | texture = texture[:,:,[2,1,0]] 280 | normals = opdict['normals'][i].cpu().numpy() 281 | displacement_map = opdict['displacement_map'][i].cpu().numpy().squeeze() 282 | dense_vertices, dense_colors, dense_faces = util.upsample_mesh(vertices, normals, faces, displacement_map, texture, self.dense_template) 283 | util.write_obj(filename.replace('.obj', '_detail.obj'), 284 | dense_vertices, 285 | dense_faces, 286 | colors = dense_colors, 287 | inverse_face_order=True) 288 | 289 | def run(self, imagepath, iscrop=True): 290 | ''' An api for running deca given an image path 291 | ''' 292 | testdata = datasets.TestData(imagepath) 293 | images = testdata[0]['image'].to(self.device)[None,...] 294 | codedict = self.encode(images) 295 | opdict, visdict = self.decode(codedict) 296 | return codedict, opdict, visdict 297 | 298 | def model_dict(self): 299 | return { 300 | 'E_flame': self.E_flame.state_dict(), 301 | 'E_detail': self.E_detail.state_dict(), 302 | 'D_detail': self.D_detail.state_dict() 303 | } 304 | -------------------------------------------------------------------------------- /decalib/models/FLAME.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import pickle 20 | import torch.nn.functional as F 21 | 22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler 23 | 24 | def to_tensor(array, dtype=torch.float32): 25 | if 'torch.tensor' not in str(type(array)): 26 | return torch.tensor(array, dtype=dtype) 27 | def to_np(array, dtype=np.float32): 28 | if 'scipy.sparse' in str(type(array)): 29 | array = array.todense() 30 | return np.array(array, dtype=dtype) 31 | 32 | class Struct(object): 33 | def __init__(self, **kwargs): 34 | for key, val in kwargs.items(): 35 | setattr(self, key, val) 36 | 37 | class FLAME(nn.Module): 38 | """ 39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py 40 | Given flame parameters this class generates a differentiable FLAME function 41 | which outputs the a mesh and 2D/3D facial landmarks 42 | """ 43 | def __init__(self, config): 44 | super(FLAME, self).__init__() 45 | print("creating the FLAME Decoder") 46 | with open(config.flame_model_path, 'rb') as f: 47 | ss = pickle.load(f, encoding='latin1') 48 | flame_model = Struct(**ss) 49 | 50 | self.dtype = torch.float32 51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) 52 | # The vertices of the template model 53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) 54 | # The shape components and expression 55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) 56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2) 57 | self.register_buffer('shapedirs', shapedirs) 58 | # The pose components 59 | num_pose_basis = flame_model.posedirs.shape[-1] 60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T 61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) 62 | # 63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) 64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1 65 | self.register_buffer('parents', parents) 66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) 67 | 68 | # Fixing Eyeball and neck rotation 69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) 70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, 71 | requires_grad=False)) 72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) 73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, 74 | requires_grad=False)) 75 | 76 | # Static and Dynamic Landmark embeddings for FLAME 77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') 78 | lmk_embeddings = lmk_embeddings[()] 79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long()) 80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype)) 81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long()) 82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype)) 83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long()) 84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype)) 85 | 86 | neck_kin_chain = []; NECK_IDX=1 87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) 88 | while curr_idx != -1: 89 | neck_kin_chain.append(curr_idx) 90 | curr_idx = self.parents[curr_idx] 91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) 92 | 93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, 94 | dynamic_lmk_b_coords, 95 | neck_kin_chain, dtype=torch.float32): 96 | """ 97 | Selects the face contour depending on the reletive position of the head 98 | Input: 99 | vertices: N X num_of_vertices X 3 100 | pose: N X full pose 101 | dynamic_lmk_faces_idx: The list of contour face indexes 102 | dynamic_lmk_b_coords: The list of contour barycentric weights 103 | neck_kin_chain: The tree to consider for the relative rotation 104 | dtype: Data type 105 | return: 106 | The contour face indexes and the corresponding barycentric weights 107 | """ 108 | 109 | batch_size = pose.shape[0] 110 | 111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 112 | neck_kin_chain) 113 | rot_mats = batch_rodrigues( 114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 115 | 116 | rel_rot_mat = torch.eye(3, device=pose.device, 117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) 118 | for idx in range(len(neck_kin_chain)): 119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 120 | 121 | y_rot_angle = torch.round( 122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 123 | max=39)).to(dtype=torch.long) 124 | 125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 128 | y_rot_angle = (neg_mask * neg_vals + 129 | (1 - neg_mask) * y_rot_angle) 130 | 131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 132 | 0, y_rot_angle) 133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 134 | 0, y_rot_angle) 135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 136 | 137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): 138 | """ 139 | Calculates landmarks by barycentric interpolation 140 | Input: 141 | vertices: torch.tensor NxVx3, dtype = torch.float32 142 | The tensor of input vertices 143 | faces: torch.tensor (N*F)x3, dtype = torch.long 144 | The faces of the mesh 145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long 146 | The tensor with the indices of the faces used to calculate the 147 | landmarks. 148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 149 | The tensor of barycentric coordinates that are used to interpolate 150 | the landmarks 151 | 152 | Returns: 153 | landmarks: torch.tensor NxLx3, dtype = torch.float32 154 | The coordinates of the landmarks for each mesh in the batch 155 | """ 156 | # Extract the indices of the vertices for each face 157 | # NxLx3 158 | batch_size, num_verts = vertices.shape[:dd2] 159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) 161 | 162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( 163 | device=vertices.device) * num_verts 164 | 165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces] 166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 167 | return landmarks 168 | 169 | def seletec_3d68(self, vertices): 170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), 172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) 173 | return landmarks3d 174 | 175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None): 176 | """ 177 | Input: 178 | shape_params: N X number of shape parameters 179 | expression_params: N X number of expression parameters 180 | pose_params: N X number of pose parameters (6) 181 | return:d 182 | vertices: N X V X 3 183 | landmarks: N X number of landmarks X 3 184 | """ 185 | batch_size = shape_params.shape[0] 186 | if pose_params is None: 187 | pose_params = self.eye_pose.expand(batch_size, -1) 188 | if eye_pose_params is None: 189 | eye_pose_params = self.eye_pose.expand(batch_size, -1) 190 | betas = torch.cat([shape_params, expression_params], dim=1) 191 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1) 192 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 193 | 194 | vertices, _ = lbs(betas, full_pose, template_vertices, 195 | self.shapedirs, self.posedirs, 196 | self.J_regressor, self.parents, 197 | self.lbs_weights, dtype=self.dtype) 198 | 199 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 200 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 201 | 202 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( 203 | full_pose, self.dynamic_lmk_faces_idx, 204 | self.dynamic_lmk_bary_coords, 205 | self.neck_kin_chain, dtype=self.dtype) 206 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) 207 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) 208 | 209 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor, 210 | lmk_faces_idx, 211 | lmk_bary_coords) 212 | bz = vertices.shape[0] 213 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 214 | self.full_lmk_faces_idx.repeat(bz, 1), 215 | self.full_lmk_bary_coords.repeat(bz, 1, 1)) 216 | return vertices, landmarks2d, landmarks3d 217 | 218 | class FLAMETex(nn.Module): 219 | """ 220 | FLAME texture: 221 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 222 | FLAME texture converted from BFM: 223 | https://github.com/TimoBolkart/BFM_to_FLAME 224 | """ 225 | def __init__(self, config): 226 | super(FLAMETex, self).__init__() 227 | if config.tex_type == 'BFM': 228 | mu_key = 'MU' 229 | pc_key = 'PC' 230 | n_pc = 199 231 | tex_path = config.tex_path 232 | tex_space = np.load(tex_path) 233 | texture_mean = tex_space[mu_key].reshape(1, -1) 234 | texture_basis = tex_space[pc_key].reshape(-1, n_pc) 235 | 236 | elif config.tex_type == 'FLAME': 237 | mu_key = 'mean' 238 | pc_key = 'tex_dir' 239 | n_pc = 200 240 | tex_path = config.tex_path # config.flame_tex_path 241 | tex_space = np.load(tex_path) 242 | texture_mean = tex_space[mu_key].reshape(1, -1)/255. 243 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255. 244 | else: 245 | print('texture type ', config.tex_type, 'not exist!') 246 | raise NotImplementedError 247 | 248 | n_tex = config.n_tex 249 | num_components = texture_basis.shape[1] 250 | texture_mean = torch.from_numpy(texture_mean).float()[None,...] 251 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...] 252 | self.register_buffer('texture_mean', texture_mean) 253 | self.register_buffer('texture_basis', texture_basis) 254 | 255 | def forward(self, texcode): 256 | ''' 257 | texcode: [batchsize, n_tex] 258 | texture: [bz, 3, 256, 256], range: 0-1 259 | ''' 260 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1) 261 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2) 262 | texture = F.interpolate(texture, [256, 256]) 263 | texture = texture[:,[2,1,0], :,:] 264 | return texture 265 | -------------------------------------------------------------------------------- /decalib/models/decoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'): 21 | super(Generator, self).__init__() 22 | self.out_scale = out_scale 23 | 24 | self.init_size = 32 // 4 # Initial size before upsampling 25 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) 26 | self.conv_blocks = nn.Sequential( 27 | nn.BatchNorm2d(128), 28 | nn.Upsample(scale_factor=2, mode=sample_mode), #16 29 | nn.Conv2d(128, 128, 3, stride=1, padding=1), 30 | nn.BatchNorm2d(128, 0.8), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Upsample(scale_factor=2, mode=sample_mode), #32 33 | nn.Conv2d(128, 64, 3, stride=1, padding=1), 34 | nn.BatchNorm2d(64, 0.8), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | nn.Upsample(scale_factor=2, mode=sample_mode), #64 37 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 38 | nn.BatchNorm2d(64, 0.8), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | nn.Upsample(scale_factor=2, mode=sample_mode), #128 41 | nn.Conv2d(64, 32, 3, stride=1, padding=1), 42 | nn.BatchNorm2d(32, 0.8), 43 | nn.LeakyReLU(0.2, inplace=True), 44 | nn.Upsample(scale_factor=2, mode=sample_mode), #256 45 | nn.Conv2d(32, 16, 3, stride=1, padding=1), 46 | nn.BatchNorm2d(16, 0.8), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Conv2d(16, out_channels, 3, stride=1, padding=1), 49 | nn.Tanh(), 50 | ) 51 | 52 | def forward(self, noise): 53 | out = self.l1(noise) 54 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 55 | img = self.conv_blocks(out) 56 | return img*self.out_scale -------------------------------------------------------------------------------- /decalib/models/encoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | import torch 19 | import torch.nn.functional as F 20 | from . import resnet 21 | 22 | class ResnetEncoder(nn.Module): 23 | def __init__(self, outsize, last_op=None): 24 | super(ResnetEncoder, self).__init__() 25 | feature_size = 2048 26 | self.encoder = resnet.load_ResNet50Model() #out: 2048 27 | ### regressor 28 | self.layers = nn.Sequential( 29 | nn.Linear(feature_size, 1024), 30 | nn.ReLU(), 31 | nn.Linear(1024, outsize) 32 | ) 33 | self.last_op = last_op 34 | 35 | def forward(self, inputs): 36 | features = self.encoder(inputs) 37 | parameters = self.layers(features) 38 | if self.last_op: 39 | parameters = self.last_op(parameters) 40 | return parameters 41 | -------------------------------------------------------------------------------- /decalib/models/frnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | # from pro_gan_pytorch.PRO_GAN import ProGAN, Generator, Discriminator 5 | import torch.nn.functional as F 6 | import cv2 7 | from torch.autograd import Variable 8 | import math 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, block, layers, num_classes=1000, include_top=True): 88 | self.inplanes = 64 89 | super(ResNet, self).__init__() 90 | self.include_top = include_top 91 | 92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) 96 | 97 | self.layer1 = self._make_layer(block, 64, layers[0]) 98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 101 | self.avgpool = nn.AvgPool2d(7, stride=1) 102 | self.fc = nn.Linear(512 * block.expansion, num_classes) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | nn.BatchNorm2d(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample)) 123 | self.inplanes = planes * block.expansion 124 | for i in range(1, blocks): 125 | layers.append(block(self.inplanes, planes)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | x = self.avgpool(x) 141 | 142 | if not self.include_top: 143 | return x 144 | 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | return x 148 | 149 | def resnet50(**kwargs): 150 | """Constructs a ResNet-50 model. 151 | """ 152 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 153 | return model 154 | 155 | import pickle 156 | def load_state_dict(model, fname): 157 | """ 158 | Set parameters converted from Caffe models authors of VGGFace2 provide. 159 | See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. 160 | Arguments: 161 | model: model 162 | fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. 163 | """ 164 | with open(fname, 'rb') as f: 165 | weights = pickle.load(f, encoding='latin1') 166 | 167 | own_state = model.state_dict() 168 | for name, param in weights.items(): 169 | if name in own_state: 170 | try: 171 | own_state[name].copy_(torch.from_numpy(param)) 172 | except Exception: 173 | raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ 174 | 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) 175 | else: 176 | raise KeyError('unexpected key "{}" in state_dict'.format(name)) 177 | 178 | -------------------------------------------------------------------------------- /decalib/models/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | def rot_mat_to_euler(rot_mats): 27 | # Calculates rotation matrix to euler angles 28 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 29 | 30 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 31 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 32 | return torch.atan2(-rot_mats[:, 2, 0], sy) 33 | 34 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, 35 | dynamic_lmk_b_coords, 36 | neck_kin_chain, dtype=torch.float32): 37 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks 38 | 39 | 40 | To do so, we first compute the rotation of the neck around the y-axis 41 | and then use a pre-computed look-up table to find the faces and the 42 | barycentric coordinates that will be used. 43 | 44 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) 45 | for providing the original TensorFlow implementation and for the LUT. 46 | 47 | Parameters 48 | ---------- 49 | vertices: torch.tensor BxVx3, dtype = torch.float32 50 | The tensor of input vertices 51 | pose: torch.tensor Bx(Jx3), dtype = torch.float32 52 | The current pose of the body model 53 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long 54 | The look-up table from neck rotation to faces 55 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 56 | The look-up table from neck rotation to barycentric coordinates 57 | neck_kin_chain: list 58 | A python list that contains the indices of the joints that form the 59 | kinematic chain of the neck. 60 | dtype: torch.dtype, optional 61 | 62 | Returns 63 | ------- 64 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long 65 | A tensor of size BxL that contains the indices of the faces that 66 | will be used to compute the current dynamic landmarks. 67 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 68 | A tensor of size BxL that contains the indices of the faces that 69 | will be used to compute the current dynamic landmarks. 70 | ''' 71 | 72 | batch_size = vertices.shape[0] 73 | 74 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 75 | neck_kin_chain) 76 | rot_mats = batch_rodrigues( 77 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 78 | 79 | rel_rot_mat = torch.eye(3, device=vertices.device, 80 | dtype=dtype).unsqueeze_(dim=0) 81 | for idx in range(len(neck_kin_chain)): 82 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 83 | 84 | y_rot_angle = torch.round( 85 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 86 | max=39)).to(dtype=torch.long) 87 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 88 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 89 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 90 | y_rot_angle = (neg_mask * neg_vals + 91 | (1 - neg_mask) * y_rot_angle) 92 | 93 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 94 | 0, y_rot_angle) 95 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 96 | 0, y_rot_angle) 97 | 98 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 99 | 100 | 101 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): 102 | ''' Calculates landmarks by barycentric interpolation 103 | 104 | Parameters 105 | ---------- 106 | vertices: torch.tensor BxVx3, dtype = torch.float32 107 | The tensor of input vertices 108 | faces: torch.tensor Fx3, dtype = torch.long 109 | The faces of the mesh 110 | lmk_faces_idx: torch.tensor L, dtype = torch.long 111 | The tensor with the indices of the faces used to calculate the 112 | landmarks. 113 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 114 | The tensor of barycentric coordinates that are used to interpolate 115 | the landmarks 116 | 117 | Returns 118 | ------- 119 | landmarks: torch.tensor BxLx3, dtype = torch.float32 120 | The coordinates of the landmarks for each mesh in the batch 121 | ''' 122 | # Extract the indices of the vertices for each face 123 | # BxLx3 124 | batch_size, num_verts = vertices.shape[:2] 125 | device = vertices.device 126 | 127 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 128 | batch_size, -1, 3) 129 | 130 | lmk_faces += torch.arange( 131 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts 132 | 133 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( 134 | batch_size, -1, 3, 3) 135 | 136 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 137 | return landmarks 138 | 139 | 140 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 141 | lbs_weights, pose2rot=True, dtype=torch.float32): 142 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 143 | 144 | Parameters 145 | ---------- 146 | betas : torch.tensor BxNB 147 | The tensor of shape parameters 148 | pose : torch.tensor Bx(J + 1) * 3 149 | The pose parameters in axis-angle format 150 | v_template torch.tensor BxVx3 151 | The template mesh that will be deformed 152 | shapedirs : torch.tensor 1xNB 153 | The tensor of PCA shape displacements 154 | posedirs : torch.tensor Px(V * 3) 155 | The pose PCA coefficients 156 | J_regressor : torch.tensor JxV 157 | The regressor array that is used to calculate the joints from 158 | the position of the vertices 159 | parents: torch.tensor J 160 | The array that describes the kinematic tree for the model 161 | lbs_weights: torch.tensor N x V x (J + 1) 162 | The linear blend skinning weights that represent how much the 163 | rotation matrix of each part affects each vertex 164 | pose2rot: bool, optional 165 | Flag on whether to convert the input pose tensor to rotation 166 | matrices. The default value is True. If False, then the pose tensor 167 | should already contain rotation matrices and have a size of 168 | Bx(J + 1)x9 169 | dtype: torch.dtype, optional 170 | 171 | Returns 172 | ------- 173 | verts: torch.tensor BxVx3 174 | The vertices of the mesh after applying the shape and pose 175 | displacements. 176 | joints: torch.tensor BxJx3 177 | The joints of the model 178 | ''' 179 | 180 | batch_size = max(betas.shape[0], pose.shape[0]) 181 | device = betas.device 182 | 183 | # Add shape contribution 184 | v_shaped = v_template + blend_shapes(betas, shapedirs) 185 | 186 | # Get the joints 187 | # NxJx3 array 188 | J = vertices2joints(J_regressor, v_shaped) 189 | 190 | # 3. Add pose blend shapes 191 | # N x J x 3 x 3 192 | ident = torch.eye(3, dtype=dtype, device=device) 193 | if pose2rot: 194 | rot_mats = batch_rodrigues( 195 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) 196 | 197 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 198 | # (N x P) x (P, V * 3) -> N x V x 3 199 | pose_offsets = torch.matmul(pose_feature, posedirs) \ 200 | .view(batch_size, -1, 3) 201 | else: 202 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 203 | rot_mats = pose.view(batch_size, -1, 3, 3) 204 | 205 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 206 | posedirs).view(batch_size, -1, 3) 207 | 208 | v_posed = pose_offsets + v_shaped 209 | # 4. Get the global joint location 210 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 211 | 212 | # 5. Do skinning: 213 | # W is N x V x (J + 1) 214 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 215 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 216 | num_joints = J_regressor.shape[0] 217 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ 218 | .view(batch_size, -1, 4, 4) 219 | 220 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], 221 | dtype=dtype, device=device) 222 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 223 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 224 | 225 | verts = v_homo[:, :, :3, 0] 226 | 227 | return verts, J_transformed 228 | 229 | 230 | def vertices2joints(J_regressor, vertices): 231 | ''' Calculates the 3D joint locations from the vertices 232 | 233 | Parameters 234 | ---------- 235 | J_regressor : torch.tensor JxV 236 | The regressor array that is used to calculate the joints from the 237 | position of the vertices 238 | vertices : torch.tensor BxVx3 239 | The tensor of mesh vertices 240 | 241 | Returns 242 | ------- 243 | torch.tensor BxJx3 244 | The location of the joints 245 | ''' 246 | 247 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 248 | 249 | 250 | def blend_shapes(betas, shape_disps): 251 | ''' Calculates the per vertex displacement due to the blend shapes 252 | 253 | 254 | Parameters 255 | ---------- 256 | betas : torch.tensor Bx(num_betas) 257 | Blend shape coefficients 258 | shape_disps: torch.tensor Vx3x(num_betas) 259 | Blend shapes 260 | 261 | Returns 262 | ------- 263 | torch.tensor BxVx3 264 | The per-vertex displacement due to shape deformation 265 | ''' 266 | 267 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 268 | # i.e. Multiply each shape displacement by its corresponding beta and 269 | # then sum them. 270 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 271 | return blend_shape 272 | 273 | 274 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 275 | ''' Calculates the rotation matrices for a batch of rotation vectors 276 | Parameters 277 | ---------- 278 | rot_vecs: torch.tensor Nx3 279 | array of N axis-angle vectors 280 | Returns 281 | ------- 282 | R: torch.tensor Nx3x3 283 | The rotation matrices for the given axis-angle parameters 284 | ''' 285 | 286 | batch_size = rot_vecs.shape[0] 287 | device = rot_vecs.device 288 | 289 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 290 | rot_dir = rot_vecs / angle 291 | 292 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 293 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 294 | 295 | # Bx1 arrays 296 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 297 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 298 | 299 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 300 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 301 | .view((batch_size, 3, 3)) 302 | 303 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 304 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 305 | return rot_mat 306 | 307 | 308 | def transform_mat(R, t): 309 | ''' Creates a batch of transformation matrices 310 | Args: 311 | - R: Bx3x3 array of a batch of rotation matrices 312 | - t: Bx3x1 array of a batch of translation vectors 313 | Returns: 314 | - T: Bx4x4 Transformation matrix 315 | ''' 316 | # No padding left or right, only add an extra row 317 | return torch.cat([F.pad(R, [0, 0, 0, 1]), 318 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 319 | 320 | 321 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 322 | """ 323 | Applies a batch of rigid transformations to the joints 324 | 325 | Parameters 326 | ---------- 327 | rot_mats : torch.tensor BxNx3x3 328 | Tensor of rotation matrices 329 | joints : torch.tensor BxNx3 330 | Locations of joints 331 | parents : torch.tensor BxN 332 | The kinematic tree of each object 333 | dtype : torch.dtype, optional: 334 | The data type of the created tensors, the default is torch.float32 335 | 336 | Returns 337 | ------- 338 | posed_joints : torch.tensor BxNx3 339 | The locations of the joints after applying the pose rotations 340 | rel_transforms : torch.tensor BxNx4x4 341 | The relative (with respect to the root joint) rigid transformations 342 | for all the joints 343 | """ 344 | 345 | joints = torch.unsqueeze(joints, dim=-1) 346 | 347 | rel_joints = joints.clone() 348 | rel_joints[:, 1:] -= joints[:, parents[1:]] 349 | 350 | # transforms_mat = transform_mat( 351 | # rot_mats.view(-1, 3, 3), 352 | # rel_joints.view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4) 353 | transforms_mat = transform_mat( 354 | rot_mats.view(-1, 3, 3), 355 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) 356 | 357 | transform_chain = [transforms_mat[:, 0]] 358 | for i in range(1, parents.shape[0]): 359 | # Subtract the joint location at the rest pose 360 | # No need for rotation, since it's identity when at rest 361 | curr_res = torch.matmul(transform_chain[parents[i]], 362 | transforms_mat[:, i]) 363 | transform_chain.append(curr_res) 364 | 365 | transforms = torch.stack(transform_chain, dim=1) 366 | 367 | # The last column of the transformations contains the posed joints 368 | posed_joints = transforms[:, :, :3, 3] 369 | 370 | # The last column of the transformations contains the posed joints 371 | posed_joints = transforms[:, :, :3, 3] 372 | 373 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 374 | 375 | rel_transforms = transforms - F.pad( 376 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) 377 | 378 | return posed_joints, rel_transforms -------------------------------------------------------------------------------- /decalib/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Soubhik Sanyal 3 | Copyright (c) 2019, Soubhik Sanyal 4 | All rights reserved. 5 | Loads different resnet models 6 | """ 7 | ''' 8 | file: Resnet.py 9 | date: 2018_05_02 10 | author: zhangxiong(1025679612@qq.com) 11 | mark: copied from pytorch source code 12 | ''' 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch 17 | from torch.nn.parameter import Parameter 18 | import torch.optim as optim 19 | import numpy as np 20 | import math 21 | import torchvision 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layers, num_classes=1000): 25 | self.inplanes = 64 26 | super(ResNet, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | self.layer1 = self._make_layer(block, 64, layers[0]) 33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 36 | self.avgpool = nn.AvgPool2d(7, stride=1) 37 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample)) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x1 = self.layer4(x) 74 | 75 | x2 = self.avgpool(x1) 76 | x2 = x2.view(x2.size(0), -1) 77 | # x = self.fc(x) 78 | ## x2: [bz, 2048] for shape 79 | ## x1: [bz, 2048, 7, 7] for texture 80 | return x2 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(planes * 4) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | def conv3x3(in_planes, out_planes, stride=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 123 | padding=1, bias=False) 124 | 125 | class BasicBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BasicBlock, self).__init__() 130 | self.conv1 = conv3x3(inplanes, planes, stride) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.conv2 = conv3x3(planes, planes) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | def copy_parameter_from_resnet(model, resnet_dict): 157 | cur_state_dict = model.state_dict() 158 | # import ipdb; ipdb.set_trace() 159 | for name, param in list(resnet_dict.items())[0:None]: 160 | if name not in cur_state_dict: 161 | # print(name, ' not available in reconstructed resnet') 162 | continue 163 | if isinstance(param, Parameter): 164 | param = param.data 165 | try: 166 | cur_state_dict[name].copy_(param) 167 | except: 168 | # print(name, ' is inconsistent!') 169 | continue 170 | # print('copy resnet state dict finished!') 171 | # import ipdb; ipdb.set_trace() 172 | 173 | def load_ResNet50Model(): 174 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 175 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = False).state_dict()) 176 | return model 177 | 178 | def load_ResNet101Model(): 179 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 180 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict()) 181 | return model 182 | 183 | def load_ResNet152Model(): 184 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 185 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict()) 186 | return model 187 | 188 | # model.load_state_dict(checkpoint['model_state_dict']) 189 | 190 | 191 | ######## Unet 192 | 193 | class DoubleConv(nn.Module): 194 | """(convolution => [BN] => ReLU) * 2""" 195 | 196 | def __init__(self, in_channels, out_channels): 197 | super().__init__() 198 | self.double_conv = nn.Sequential( 199 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 200 | nn.BatchNorm2d(out_channels), 201 | nn.ReLU(inplace=True), 202 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 203 | nn.BatchNorm2d(out_channels), 204 | nn.ReLU(inplace=True) 205 | ) 206 | 207 | def forward(self, x): 208 | return self.double_conv(x) 209 | 210 | 211 | class Down(nn.Module): 212 | """Downscaling with maxpool then double conv""" 213 | 214 | def __init__(self, in_channels, out_channels): 215 | super().__init__() 216 | self.maxpool_conv = nn.Sequential( 217 | nn.MaxPool2d(2), 218 | DoubleConv(in_channels, out_channels) 219 | ) 220 | 221 | def forward(self, x): 222 | return self.maxpool_conv(x) 223 | 224 | 225 | class Up(nn.Module): 226 | """Upscaling then double conv""" 227 | 228 | def __init__(self, in_channels, out_channels, bilinear=True): 229 | super().__init__() 230 | 231 | # if bilinear, use the normal convolutions to reduce the number of channels 232 | if bilinear: 233 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 234 | else: 235 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 236 | 237 | self.conv = DoubleConv(in_channels, out_channels) 238 | 239 | def forward(self, x1, x2): 240 | x1 = self.up(x1) 241 | # input is CHW 242 | diffY = x2.size()[2] - x1.size()[2] 243 | diffX = x2.size()[3] - x1.size()[3] 244 | 245 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 246 | diffY // 2, diffY - diffY // 2]) 247 | # if you have padding issues, see 248 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 249 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 250 | x = torch.cat([x2, x1], dim=1) 251 | return self.conv(x) 252 | 253 | 254 | class OutConv(nn.Module): 255 | def __init__(self, in_channels, out_channels): 256 | super(OutConv, self).__init__() 257 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 258 | 259 | def forward(self, x): 260 | return self.conv(x) -------------------------------------------------------------------------------- /decalib/utils/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Default config for DECA 3 | ''' 4 | from yacs.config import CfgNode as CN 5 | import argparse 6 | import yaml 7 | import os 8 | 9 | cfg = CN() 10 | 11 | abs_deca_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'utils')) 12 | 13 | cfg.deca_dir = abs_deca_dir 14 | cfg.device = 'cuda' 15 | cfg.device_id = '0' 16 | 17 | cfg.pretrained_modelpath = os.path.join(cfg.deca_dir, 'data', 'deca_model.tar') 18 | cfg.output_dir = '' 19 | cfg.rasterizer_type = 'pytorch3d' 20 | # ---------------------------------------------------------------------------- # 21 | # Options for Face model 22 | # ---------------------------------------------------------------------------- # 23 | cfg.model = CN() 24 | cfg.model.topology_path = os.path.join(cfg.deca_dir, 'data', 'head_template.obj') 25 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip 26 | cfg.model.dense_template_path = os.path.join(cfg.deca_dir, 'data', 'texture_data_256.npy') 27 | cfg.model.fixed_displacement_path = os.path.join(cfg.deca_dir, 'data', 'fixed_displacement_256.npy') 28 | cfg.model.flame_model_path = os.path.join(cfg.deca_dir, 'data', 'generic_model.pkl') 29 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.deca_dir, 'data', 'landmark_embedding.npy') 30 | cfg.model.face_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_mask.png') 31 | cfg.model.face_eye_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_eye_mask.png') 32 | cfg.model.mean_tex_path = os.path.join(cfg.deca_dir, 'data', 'mean_texture.jpg') 33 | cfg.model.tex_path = os.path.join(cfg.deca_dir, 'data', 'FLAME_albedo_from_BFM.npz') 34 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM 35 | cfg.model.uv_size = 256 36 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 37 | cfg.model.n_shape = 100 38 | cfg.model.n_tex = 50 39 | cfg.model.n_exp = 50 40 | cfg.model.n_cam = 3 41 | cfg.model.n_pose = 6 42 | cfg.model.n_light = 27 43 | cfg.model.use_tex = True 44 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler. Note that: aa is not stable in the beginning 45 | # face recognition model 46 | cfg.model.fr_model_path = os.path.join(cfg.deca_dir, 'data', 'resnet50_ft_weight.pkl') 47 | 48 | ## details 49 | cfg.model.n_detail = 128 50 | cfg.model.max_z = 0.01 51 | 52 | # ---------------------------------------------------------------------------- # 53 | # Options for Dataset 54 | # ---------------------------------------------------------------------------- # 55 | cfg.dataset = CN() 56 | cfg.dataset.training_data = ['vggface2', 'ethnicity'] 57 | # cfg.dataset.training_data = ['ethnicity'] 58 | cfg.dataset.eval_data = ['aflw2000'] 59 | cfg.dataset.test_data = [''] 60 | cfg.dataset.batch_size = 2 61 | cfg.dataset.K = 4 62 | cfg.dataset.isSingle = False 63 | cfg.dataset.num_workers = 2 64 | cfg.dataset.image_size = 224 65 | cfg.dataset.scale_min = 1.4 66 | cfg.dataset.scale_max = 1.8 67 | cfg.dataset.trans_scale = 0. 68 | 69 | # ---------------------------------------------------------------------------- # 70 | # Options for training 71 | # ---------------------------------------------------------------------------- # 72 | cfg.train = CN() 73 | cfg.train.train_detail = False 74 | cfg.train.max_epochs = 500 75 | cfg.train.max_steps = 1000000 76 | cfg.train.lr = 1e-4 77 | cfg.train.log_dir = 'logs' 78 | cfg.train.log_steps = 10 79 | cfg.train.vis_dir = 'train_images' 80 | cfg.train.vis_steps = 200 81 | cfg.train.write_summary = True 82 | cfg.train.checkpoint_steps = 500 83 | cfg.train.val_steps = 500 84 | cfg.train.val_vis_dir = 'val_images' 85 | cfg.train.eval_steps = 5000 86 | cfg.train.resume = True 87 | 88 | # ---------------------------------------------------------------------------- # 89 | # Options for Losses 90 | # ---------------------------------------------------------------------------- # 91 | cfg.loss = CN() 92 | cfg.loss.lmk = 1.0 93 | cfg.loss.useWlmk = True 94 | cfg.loss.eyed = 1.0 95 | cfg.loss.lipd = 0.5 96 | cfg.loss.photo = 2.0 97 | cfg.loss.useSeg = True 98 | cfg.loss.id = 0.2 99 | cfg.loss.id_shape_only = True 100 | cfg.loss.reg_shape = 1e-04 101 | cfg.loss.reg_exp = 1e-04 102 | cfg.loss.reg_tex = 1e-04 103 | cfg.loss.reg_light = 1. 104 | cfg.loss.reg_jaw_pose = 0. #1. 105 | cfg.loss.use_gender_prior = False 106 | cfg.loss.shape_consistency = True 107 | # loss for detail 108 | cfg.loss.detail_consistency = True 109 | cfg.loss.useConstraint = True 110 | cfg.loss.mrf = 5e-2 111 | cfg.loss.photo_D = 2. 112 | cfg.loss.reg_sym = 0.005 113 | cfg.loss.reg_z = 0.005 114 | cfg.loss.reg_diff = 0.005 115 | 116 | 117 | def get_cfg_defaults(): 118 | """Get a yacs CfgNode object with default values for my_project.""" 119 | # Return a clone so that the defaults will not be altered 120 | # This is for the "local variable" use pattern 121 | return cfg.clone() 122 | 123 | def update_cfg(cfg, cfg_file): 124 | cfg.merge_from_file(cfg_file) 125 | return cfg.clone() 126 | 127 | def parse_args(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--cfg', type=str, help='cfg file path') 130 | parser.add_argument('--mode', type=str, default = 'train', help='deca mode') 131 | 132 | args = parser.parse_args() 133 | print(args, end='\n\n') 134 | 135 | cfg = get_cfg_defaults() 136 | cfg.cfg_file = None 137 | cfg.mode = args.mode 138 | # import ipdb; ipdb.set_trace() 139 | if args.cfg is not None: 140 | cfg_file = args.cfg 141 | cfg = update_cfg(cfg, args.cfg) 142 | cfg.cfg_file = cfg_file 143 | 144 | return cfg 145 | -------------------------------------------------------------------------------- /decalib/utils/rasterizer/INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | from standard_rasterize_cuda import standard_rasterize 3 | # from .rasterizer.standard_rasterize_cuda import standard_rasterize 4 | 5 | in this folder, run 6 | ```python setup.py build_ext -i ``` 7 | 8 | then remember to set --rasterizer_type=standard when runing demos :) 9 | 10 | ## Alg 11 | https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation 12 | 13 | ## Speed Comparison 14 | runtime for raterization only 15 | In PIXIE, number of faces in SMPLX: 20908 16 | 17 | for image size = 1024 18 | pytorch3d: 0.031s 19 | standard: 0.01s 20 | 21 | for image size = 224 22 | pytorch3d: 0.0035s 23 | standard: 0.0014s 24 | 25 | why standard rasterizer is faster than pytorch3d? 26 | Ref: https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu 27 | pytorch3d: for each pixel in image space (each pixel is parallel in cuda), loop through the faces, check if this pixel is in the projection bounding box of the face, then sorting faces according to z, record the face id of closest K faces. 28 | standard rasterization: for each face in mesh (each face is parallel in cuda), loop through pixels in the projection bounding box (normally a very samll number), compare z, record face id of that pixel 29 | 30 | -------------------------------------------------------------------------------- /decalib/utils/rasterizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/decalib/utils/rasterizer/__init__.py -------------------------------------------------------------------------------- /decalib/utils/rasterizer/setup.py: -------------------------------------------------------------------------------- 1 | # To install, run 2 | # python setup.py build_ext -i 3 | # Ref: https://github.com/pytorch/pytorch/blob/11a40410e755b1fe74efe9eaa635e7ba5712846b/test/cpp_extensions/setup.py#L62 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | import os 8 | 9 | # USE_NINJA = os.getenv('USE_NINJA') == '1' 10 | os.environ["CC"] = "gcc-7" 11 | os.environ["CXX"] = "gcc-7" 12 | 13 | USE_NINJA = os.getenv('USE_NINJA') == '1' 14 | 15 | setup( 16 | name='standard_rasterize_cuda', 17 | ext_modules=[ 18 | CUDAExtension('standard_rasterize_cuda', [ 19 | 'standard_rasterize_cuda.cpp', 20 | 'standard_rasterize_cuda_kernel.cu', 21 | ]) 22 | ], 23 | cmdclass={'build_ext': BuildExtension.with_options(use_ninja=USE_NINJA)} 24 | ) 25 | -------------------------------------------------------------------------------- /decalib/utils/rasterizer/standard_rasterize_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | std::vector forward_rasterize_cuda( 6 | at::Tensor face_vertices, 7 | at::Tensor depth_buffer, 8 | at::Tensor triangle_buffer, 9 | at::Tensor baryw_buffer, 10 | int h, 11 | int w); 12 | 13 | std::vector standard_rasterize( 14 | at::Tensor face_vertices, 15 | at::Tensor depth_buffer, 16 | at::Tensor triangle_buffer, 17 | at::Tensor baryw_buffer, 18 | int height, int width 19 | ) { 20 | return forward_rasterize_cuda(face_vertices, depth_buffer, triangle_buffer, baryw_buffer, height, width); 21 | } 22 | 23 | std::vector forward_rasterize_colors_cuda( 24 | at::Tensor face_vertices, 25 | at::Tensor face_colors, 26 | at::Tensor depth_buffer, 27 | at::Tensor triangle_buffer, 28 | at::Tensor images, 29 | int h, 30 | int w); 31 | 32 | std::vector standard_rasterize_colors( 33 | at::Tensor face_vertices, 34 | at::Tensor face_colors, 35 | at::Tensor depth_buffer, 36 | at::Tensor triangle_buffer, 37 | at::Tensor images, 38 | int height, int width 39 | ) { 40 | return forward_rasterize_colors_cuda(face_vertices, face_colors, depth_buffer, triangle_buffer, images, height, width); 41 | } 42 | 43 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 44 | m.def("standard_rasterize", &standard_rasterize, "RASTERIZE (CUDA)"); 45 | m.def("standard_rasterize_colors", &standard_rasterize_colors, "RASTERIZE COLORS (CUDA)"); 46 | } 47 | 48 | // TODO: backward -------------------------------------------------------------------------------- /decalib/utils/rasterizer/standard_rasterize_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | // Ref: https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/cuda/rasterize_cuda_kernel.cu 2 | // https://github.com/YadiraF/face3d/blob/master/face3d/mesh/cython/mesh_core.cpp 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | namespace{ 10 | __device__ __forceinline__ float atomicMin(float* address, float val) 11 | { 12 | int* address_as_i = (int*) address; 13 | int old = *address_as_i, assumed; 14 | do { 15 | assumed = old; 16 | old = atomicCAS(address_as_i, assumed, 17 | __float_as_int(fminf(val, __int_as_float(assumed)))); 18 | } while (assumed != old); 19 | return __int_as_float(old); 20 | } 21 | __device__ __forceinline__ double atomicMin(double* address, double val) 22 | { 23 | unsigned long long int* address_as_i = (unsigned long long int*) address; 24 | unsigned long long int old = *address_as_i, assumed; 25 | do { 26 | assumed = old; 27 | old = atomicCAS(address_as_i, assumed, 28 | __double_as_longlong(fminf(val, __longlong_as_double(assumed)))); 29 | } while (assumed != old); 30 | return __longlong_as_double(old); 31 | } 32 | 33 | template 34 | __device__ __forceinline__ bool check_face_frontside(const scalar_t *face) { 35 | return (face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]); 36 | } 37 | 38 | 39 | template struct point 40 | { 41 | public: 42 | scalar_t x; 43 | scalar_t y; 44 | 45 | __host__ __device__ scalar_t dot(point p) 46 | { 47 | return this->x * p.x + this->y * p.y; 48 | }; 49 | 50 | __host__ __device__ point operator-(point& p) 51 | { 52 | point np; 53 | np.x = this->x - p.x; 54 | np.y = this->y - p.y; 55 | return np; 56 | }; 57 | 58 | __host__ __device__ point operator+(point& p) 59 | { 60 | point np; 61 | np.x = this->x + p.x; 62 | np.y = this->y + p.y; 63 | return np; 64 | }; 65 | 66 | __host__ __device__ point operator*(scalar_t s) 67 | { 68 | point np; 69 | np.x = s * this->x; 70 | np.y = s * this->y; 71 | return np; 72 | }; 73 | }; 74 | 75 | template 76 | __device__ __forceinline__ bool check_pixel_inside(const scalar_t *w) { 77 | return w[0] <= 1 && w[0] >= 0 && w[1] <= 1 && w[1] >= 0 && w[2] <= 1 && w[2] >= 0; 78 | } 79 | 80 | template 81 | __device__ __forceinline__ void barycentric_weight(scalar_t *w, point p, point p0, point p1, point p2) { 82 | 83 | // vectors 84 | point v0, v1, v2; 85 | scalar_t s = p.dot(p); 86 | v0 = p2 - p0; 87 | v1 = p1 - p0; 88 | v2 = p - p0; 89 | 90 | // dot products 91 | scalar_t dot00 = v0.dot(v0); //v0.x * v0.x + v0.y * v0.y //np.dot(v0.T, v0) 92 | scalar_t dot01 = v0.dot(v1); //v0.x * v1.x + v0.y * v1.y //np.dot(v0.T, v1) 93 | scalar_t dot02 = v0.dot(v2); //v0.x * v2.x + v0.y * v2.y //np.dot(v0.T, v2) 94 | scalar_t dot11 = v1.dot(v1); //v1.x * v1.x + v1.y * v1.y //np.dot(v1.T, v1) 95 | scalar_t dot12 = v1.dot(v2); //v1.x * v2.x + v1.y * v2.y//np.dot(v1.T, v2) 96 | 97 | // barycentric coordinates 98 | scalar_t inverDeno; 99 | if(dot00*dot11 - dot01*dot01 == 0) 100 | inverDeno = 0; 101 | else 102 | inverDeno = 1/(dot00*dot11 - dot01*dot01); 103 | 104 | scalar_t u = (dot11*dot02 - dot01*dot12)*inverDeno; 105 | scalar_t v = (dot00*dot12 - dot01*dot02)*inverDeno; 106 | 107 | // weight 108 | w[0] = 1 - u - v; 109 | w[1] = v; 110 | w[2] = u; 111 | } 112 | 113 | // Ref: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/overview-rasterization-algorithm 114 | template 115 | __global__ void forward_rasterize_cuda_kernel( 116 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3] 117 | scalar_t* depth_buffer, 118 | int* triangle_buffer, 119 | scalar_t* baryw_buffer, 120 | int batch_size, int h, int w, 121 | int ntri) { 122 | 123 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 124 | if (i >= batch_size * ntri) { 125 | return; 126 | } 127 | int bn = i/ntri; 128 | const scalar_t* face = &face_vertices[i * 9]; 129 | scalar_t bw[3]; 130 | point p0, p1, p2, p; 131 | 132 | p0.x = face[0]; p0.y=face[1]; 133 | p1.x = face[3]; p1.y=face[4]; 134 | p2.x = face[6]; p2.y=face[7]; 135 | 136 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0); 137 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1); 138 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0); 139 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1); 140 | 141 | for(int y = y_min; y <= y_max; y++) //h 142 | { 143 | for(int x = x_min; x <= x_max; x++) //w 144 | { 145 | p.x = x; p.y = y; 146 | barycentric_weight(bw, p, p0, p1, p2); 147 | // if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face)) 148 | if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) 149 | { 150 | // perspective correct: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/perspective-correct-interpolation-vertex-attributes 151 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]); 152 | // printf("%f %f %f \n", (float)zp, (float)face[2], (float)bw[2]); 153 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp); 154 | if(depth_buffer[bn*h*w + y*w + x] == zp) 155 | { 156 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri); 157 | for(int k=0; k<3; k++){ 158 | baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k]; 159 | } 160 | } 161 | } 162 | } 163 | } 164 | 165 | } 166 | 167 | template 168 | __global__ void forward_rasterize_colors_cuda_kernel( 169 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3] 170 | const scalar_t* __restrict__ face_colors, //[bz, nf, 3, 3] 171 | scalar_t* depth_buffer, 172 | int* triangle_buffer, 173 | scalar_t* images, 174 | int batch_size, int h, int w, 175 | int ntri) { 176 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 177 | if (i >= batch_size * ntri) { 178 | return; 179 | } 180 | int bn = i/ntri; 181 | const scalar_t* face = &face_vertices[i * 9]; 182 | const scalar_t* color = &face_colors[i * 9]; 183 | scalar_t bw[3]; 184 | point p0, p1, p2, p; 185 | 186 | p0.x = face[0]; p0.y=face[1]; 187 | p1.x = face[3]; p1.y=face[4]; 188 | p2.x = face[6]; p2.y=face[7]; 189 | scalar_t cl[3][3]; 190 | for (int num = 0; num < 3; num++) { 191 | for (int dim = 0; dim < 3; dim++) { 192 | cl[num][dim] = color[3 * num + dim]; //[3p,3rgb] 193 | } 194 | } 195 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0); 196 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1); 197 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0); 198 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1); 199 | 200 | for(int y = y_min; y <= y_max; y++) //h 201 | { 202 | for(int x = x_min; x <= x_max; x++) //w 203 | { 204 | p.x = x; p.y = y; 205 | barycentric_weight(bw, p, p0, p1, p2); 206 | if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face)) 207 | // if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) 208 | { 209 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]); 210 | 211 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp); 212 | if(depth_buffer[bn*h*w + y*w + x] == zp) 213 | { 214 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri); 215 | for(int k=0; k<3; k++){ 216 | // baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k]; 217 | images[bn*h*w*3 + y*w*3 + x*3 + k] = bw[0]*cl[0][k] + bw[1]*cl[1][k] + bw[2]*cl[2][k]; 218 | } 219 | // buffers[bn*h*w*2 + y*w*2 + x*2 + 1] = p_depth; 220 | } 221 | } 222 | } 223 | } 224 | 225 | } 226 | 227 | } 228 | 229 | std::vector forward_rasterize_cuda( 230 | at::Tensor face_vertices, 231 | at::Tensor depth_buffer, 232 | at::Tensor triangle_buffer, 233 | at::Tensor baryw_buffer, 234 | int h, 235 | int w){ 236 | 237 | const auto batch_size = face_vertices.size(0); 238 | const auto ntri = face_vertices.size(1); 239 | 240 | // print(channel_size) 241 | const int threads = 512; 242 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1); 243 | 244 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda1", ([&] { 245 | forward_rasterize_cuda_kernel<<>>( 246 | face_vertices.data(), 247 | depth_buffer.data(), 248 | triangle_buffer.data(), 249 | baryw_buffer.data(), 250 | batch_size, h, w, 251 | ntri); 252 | })); 253 | 254 | // better to do it twice (or there will be balck spots in the rendering) 255 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda2", ([&] { 256 | forward_rasterize_cuda_kernel<<>>( 257 | face_vertices.data(), 258 | depth_buffer.data(), 259 | triangle_buffer.data(), 260 | baryw_buffer.data(), 261 | batch_size, h, w, 262 | ntri); 263 | })); 264 | cudaError_t err = cudaGetLastError(); 265 | if (err != cudaSuccess) 266 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err)); 267 | 268 | return {depth_buffer, triangle_buffer, baryw_buffer}; 269 | } 270 | 271 | 272 | std::vector forward_rasterize_colors_cuda( 273 | at::Tensor face_vertices, 274 | at::Tensor face_colors, 275 | at::Tensor depth_buffer, 276 | at::Tensor triangle_buffer, 277 | at::Tensor images, 278 | int h, 279 | int w){ 280 | 281 | const auto batch_size = face_vertices.size(0); 282 | const auto ntri = face_vertices.size(1); 283 | 284 | // print(channel_size) 285 | const int threads = 512; 286 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1); 287 | //initial 288 | 289 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] { 290 | forward_rasterize_colors_cuda_kernel<<>>( 291 | face_vertices.data(), 292 | face_colors.data(), 293 | depth_buffer.data(), 294 | triangle_buffer.data(), 295 | images.data(), 296 | batch_size, h, w, 297 | ntri); 298 | })); 299 | // better to do it twice 300 | // AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] { 301 | // forward_rasterize_colors_cuda_kernel<<>>( 302 | // face_vertices.data(), 303 | // face_colors.data(), 304 | // depth_buffer.data(), 305 | // triangle_buffer.data(), 306 | // images.data(), 307 | // batch_size, h, w, 308 | // ntri); 309 | // })); 310 | cudaError_t err = cudaGetLastError(); 311 | if (err != cudaSuccess) 312 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err)); 313 | 314 | return {depth_buffer, triangle_buffer, images}; 315 | } 316 | 317 | 318 | 319 | 320 | -------------------------------------------------------------------------------- /decalib/utils/rotation_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ''' Rotation Converter 4 | Repre: euler angle(3), angle axis(3), rotation matrix(3x3), quaternion(4) 5 | ref: https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html# 6 | "pi", 7 | "rad2deg", 8 | "deg2rad", 9 | # "angle_axis_to_rotation_matrix", batch_rodrigues 10 | "rotation_matrix_to_angle_axis", 11 | "rotation_matrix_to_quaternion", 12 | "quaternion_to_angle_axis", 13 | # "angle_axis_to_quaternion", 14 | 15 | euler2quat_conversion_sanity_batch 16 | 17 | ref: smplx/lbs 18 | batch_rodrigues: axis angle -> matrix 19 | # 20 | ''' 21 | pi = torch.Tensor([3.14159265358979323846]) 22 | 23 | def rad2deg(tensor): 24 | """Function that converts angles from radians to degrees. 25 | 26 | See :class:`~torchgeometry.RadToDeg` for details. 27 | 28 | Args: 29 | tensor (Tensor): Tensor of arbitrary shape. 30 | 31 | Returns: 32 | Tensor: Tensor with same shape as input. 33 | 34 | Example: 35 | >>> input = tgm.pi * torch.rand(1, 3, 3) 36 | >>> output = tgm.rad2deg(input) 37 | """ 38 | if not torch.is_tensor(tensor): 39 | raise TypeError("Input type is not a torch.Tensor. Got {}" 40 | .format(type(tensor))) 41 | 42 | return 180. * tensor / pi.to(tensor.device).type(tensor.dtype) 43 | 44 | def deg2rad(tensor): 45 | """Function that converts angles from degrees to radians. 46 | 47 | See :class:`~torchgeometry.DegToRad` for details. 48 | 49 | Args: 50 | tensor (Tensor): Tensor of arbitrary shape. 51 | 52 | Returns: 53 | Tensor: Tensor with same shape as input. 54 | 55 | Examples:: 56 | 57 | >>> input = 360. * torch.rand(1, 3, 3) 58 | >>> output = tgm.deg2rad(input) 59 | """ 60 | if not torch.is_tensor(tensor): 61 | raise TypeError("Input type is not a torch.Tensor. Got {}" 62 | .format(type(tensor))) 63 | 64 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180. 65 | 66 | ######### to quaternion 67 | def euler_to_quaternion(r): 68 | x = r[..., 0] 69 | y = r[..., 1] 70 | z = r[..., 2] 71 | 72 | z = z/2.0 73 | y = y/2.0 74 | x = x/2.0 75 | cz = torch.cos(z) 76 | sz = torch.sin(z) 77 | cy = torch.cos(y) 78 | sy = torch.sin(y) 79 | cx = torch.cos(x) 80 | sx = torch.sin(x) 81 | quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device) 82 | quaternion[..., 0] += cx*cy*cz - sx*sy*sz 83 | quaternion[..., 1] += cx*sy*sz + cy*cz*sx 84 | quaternion[..., 2] += cx*cz*sy - sx*cy*sz 85 | quaternion[..., 3] += cx*cy*sz + sx*cz*sy 86 | return quaternion 87 | 88 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): 89 | """Convert 3x4 rotation matrix to 4d quaternion vector 90 | 91 | This algorithm is based on algorithm described in 92 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 93 | 94 | Args: 95 | rotation_matrix (Tensor): the rotation matrix to convert. 96 | 97 | Return: 98 | Tensor: the rotation in quaternion 99 | 100 | Shape: 101 | - Input: :math:`(N, 3, 4)` 102 | - Output: :math:`(N, 4)` 103 | 104 | Example: 105 | >>> input = torch.rand(4, 3, 4) # Nx3x4 106 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 107 | """ 108 | if not torch.is_tensor(rotation_matrix): 109 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 110 | type(rotation_matrix))) 111 | 112 | if len(rotation_matrix.shape) > 3: 113 | raise ValueError( 114 | "Input size must be a three dimensional tensor. Got {}".format( 115 | rotation_matrix.shape)) 116 | # if not rotation_matrix.shape[-2:] == (3, 4): 117 | # raise ValueError( 118 | # "Input size must be a N x 3 x 4 tensor. Got {}".format( 119 | # rotation_matrix.shape)) 120 | 121 | rmat_t = torch.transpose(rotation_matrix, 1, 2) 122 | 123 | mask_d2 = rmat_t[:, 2, 2] < eps 124 | 125 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] 126 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] 127 | 128 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 129 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 130 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 131 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) 132 | t0_rep = t0.repeat(4, 1).t() 133 | 134 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 135 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 136 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 137 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) 138 | t1_rep = t1.repeat(4, 1).t() 139 | 140 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 141 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], 142 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2], 143 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) 144 | t2_rep = t2.repeat(4, 1).t() 145 | 146 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 147 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 148 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 149 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) 150 | t3_rep = t3.repeat(4, 1).t() 151 | 152 | mask_c0 = mask_d2 * mask_d0_d1.float() 153 | mask_c1 = mask_d2 * (1 - mask_d0_d1.float()) 154 | mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1 155 | mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float()) 156 | mask_c0 = mask_c0.view(-1, 1).type_as(q0) 157 | mask_c1 = mask_c1.view(-1, 1).type_as(q1) 158 | mask_c2 = mask_c2.view(-1, 1).type_as(q2) 159 | mask_c3 = mask_c3.view(-1, 1).type_as(q3) 160 | 161 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 162 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa 163 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa 164 | q *= 0.5 165 | return q 166 | 167 | # def angle_axis_to_quaternion(theta): 168 | # batch_size = theta.shape[0] 169 | # l1norm = torch.norm(theta + 1e-8, p=2, dim=1) 170 | # angle = torch.unsqueeze(l1norm, -1) 171 | # normalized = torch.div(theta, angle) 172 | # angle = angle * 0.5 173 | # v_cos = torch.cos(angle) 174 | # v_sin = torch.sin(angle) 175 | # quat = torch.cat([v_cos, v_sin * normalized], dim=1) 176 | # return quat 177 | 178 | def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor: 179 | """Convert an angle axis to a quaternion. 180 | 181 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 182 | 183 | Args: 184 | angle_axis (torch.Tensor): tensor with angle axis. 185 | 186 | Return: 187 | torch.Tensor: tensor with quaternion. 188 | 189 | Shape: 190 | - Input: :math:`(*, 3)` where `*` means, any number of dimensions 191 | - Output: :math:`(*, 4)` 192 | 193 | Example: 194 | >>> angle_axis = torch.rand(2, 4) # Nx4 195 | >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3 196 | """ 197 | if not torch.is_tensor(angle_axis): 198 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 199 | type(angle_axis))) 200 | 201 | if not angle_axis.shape[-1] == 3: 202 | raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}" 203 | .format(angle_axis.shape)) 204 | # unpack input and compute conversion 205 | a0: torch.Tensor = angle_axis[..., 0:1] 206 | a1: torch.Tensor = angle_axis[..., 1:2] 207 | a2: torch.Tensor = angle_axis[..., 2:3] 208 | theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2 209 | 210 | theta: torch.Tensor = torch.sqrt(theta_squared) 211 | half_theta: torch.Tensor = theta * 0.5 212 | 213 | mask: torch.Tensor = theta_squared > 0.0 214 | ones: torch.Tensor = torch.ones_like(half_theta) 215 | 216 | k_neg: torch.Tensor = 0.5 * ones 217 | k_pos: torch.Tensor = torch.sin(half_theta) / theta 218 | k: torch.Tensor = torch.where(mask, k_pos, k_neg) 219 | w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones) 220 | 221 | quaternion: torch.Tensor = torch.zeros_like(angle_axis) 222 | quaternion[..., 0:1] += a0 * k 223 | quaternion[..., 1:2] += a1 * k 224 | quaternion[..., 2:3] += a2 * k 225 | return torch.cat([w, quaternion], dim=-1) 226 | 227 | #### quaternion to 228 | def quaternion_to_rotation_matrix(quat): 229 | """Convert quaternion coefficients to rotation matrix. 230 | Args: 231 | quat: size = [B, 4] 4 <===>(w, x, y, z) 232 | Returns: 233 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 234 | """ 235 | norm_quat = quat 236 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) 237 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] 238 | 239 | B = quat.size(0) 240 | 241 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 242 | wx, wy, wz = w * x, w * y, w * z 243 | xy, xz, yz = x * y, x * z, y * z 244 | 245 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 246 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 247 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 248 | return rotMat 249 | 250 | def quaternion_to_angle_axis(quaternion: torch.Tensor): 251 | """Convert quaternion vector to angle axis of rotation. TODO: CORRECT 252 | 253 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 254 | 255 | Args: 256 | quaternion (torch.Tensor): tensor with quaternions. 257 | 258 | Return: 259 | torch.Tensor: tensor with angle axis of rotation. 260 | 261 | Shape: 262 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions 263 | - Output: :math:`(*, 3)` 264 | 265 | Example: 266 | >>> quaternion = torch.rand(2, 4) # Nx4 267 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 268 | """ 269 | if not torch.is_tensor(quaternion): 270 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 271 | type(quaternion))) 272 | 273 | if not quaternion.shape[-1] == 4: 274 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" 275 | .format(quaternion.shape)) 276 | # unpack input and compute conversion 277 | q1: torch.Tensor = quaternion[..., 1] 278 | q2: torch.Tensor = quaternion[..., 2] 279 | q3: torch.Tensor = quaternion[..., 3] 280 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 281 | 282 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) 283 | cos_theta: torch.Tensor = quaternion[..., 0] 284 | two_theta: torch.Tensor = 2.0 * torch.where( 285 | cos_theta < 0.0, 286 | torch.atan2(-sin_theta, -cos_theta), 287 | torch.atan2(sin_theta, cos_theta)) 288 | 289 | k_pos: torch.Tensor = two_theta / sin_theta 290 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device) 291 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) 292 | 293 | angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3] 294 | angle_axis[..., 0] += q1 * k 295 | angle_axis[..., 1] += q2 * k 296 | angle_axis[..., 2] += q3 * k 297 | return angle_axis 298 | 299 | #### batch converter 300 | def batch_euler2axis(r): 301 | return quaternion_to_angle_axis(euler_to_quaternion(r)) 302 | 303 | def batch_euler2matrix(r): 304 | return quaternion_to_rotation_matrix(euler_to_quaternion(r)) 305 | 306 | def batch_matrix2euler(rot_mats): 307 | # Calculates rotation matrix to euler angles 308 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 309 | ### only y? 310 | # TODO: 311 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 312 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 313 | return torch.atan2(-rot_mats[:, 2, 0], sy) 314 | 315 | def batch_matrix2axis(rot_mats): 316 | return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats)) 317 | 318 | def batch_axis2matrix(theta): 319 | # angle axis to rotation matrix 320 | # theta N x 3 321 | # return quat2mat(quat) 322 | # batch_rodrigues 323 | return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta)) 324 | 325 | def batch_axis2euler(theta): 326 | return batch_matrix2euler(batch_axis2matrix(theta)) 327 | 328 | def batch_axis2euler(r): 329 | return rot_mat_to_euler(batch_rodrigues(r)) 330 | 331 | 332 | def batch_orth_proj(X, camera): 333 | ''' 334 | X is N x num_pquaternion_to_angle_axisoints x 3 335 | ''' 336 | camera = camera.clone().view(-1, 1, 3) 337 | X_trans = X[:, :, :2] + camera[:, :, 1:] 338 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2) 339 | Xn = (camera[:, :, 0:1] * X_trans) 340 | return Xn 341 | 342 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 343 | ''' same as batch_matrix2axis 344 | Calculates the rotation matrices for a batch of rotation vectors 345 | Parameters 346 | ---------- 347 | rot_vecs: torch.tensor Nx3 348 | array of N axis-angle vectors 349 | Returns 350 | ------- 351 | R: torch.tensor Nx3x3 352 | The rotation matrices for the given axis-angle parameters 353 | ''' 354 | 355 | batch_size = rot_vecs.shape[0] 356 | device = rot_vecs.device 357 | 358 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 359 | rot_dir = rot_vecs / angle 360 | 361 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 362 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 363 | 364 | # Bx1 arrays 365 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 366 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 367 | 368 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 369 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 370 | .view((batch_size, 3, 3)) 371 | 372 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 373 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 374 | return rot_mat 375 | -------------------------------------------------------------------------------- /decalib/utils/tensor_cropper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | crop 3 | for torch tensor 4 | Given image, bbox(center, bboxsize) 5 | return: cropped image, tform(used for transform the keypoint accordingly) 6 | only support crop to squared images 7 | ''' 8 | import torch 9 | from kornia.geometry.transform.imgwarp import ( 10 | warp_perspective, get_perspective_transform, warp_affine 11 | ) 12 | 13 | def points2bbox(points, points_scale=None): 14 | if points_scale: 15 | assert points_scale[0]==points_scale[1] 16 | points = points.clone() 17 | points[:,:,:2] = (points[:,:,:2]*0.5 + 0.5)*points_scale[0] 18 | min_coords, _ = torch.min(points, dim=1) 19 | xmin, ymin = min_coords[:, 0], min_coords[:, 1] 20 | max_coords, _ = torch.max(points, dim=1) 21 | xmax, ymax = max_coords[:, 0], max_coords[:, 1] 22 | center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 23 | 24 | width = (xmax - xmin) 25 | height = (ymax - ymin) 26 | # Convert the bounding box to a square box 27 | size = torch.max(width, height).unsqueeze(-1) 28 | return center, size 29 | 30 | def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.): 31 | batch_size = center.shape[0] 32 | trans_scale = (torch.rand([batch_size, 2], device=center.device)*2. -1.) * trans_scale 33 | center = center + trans_scale*bbox_size # 0.5 34 | scale = torch.rand([batch_size,1], device=center.device) * (scale[1] - scale[0]) + scale[0] 35 | size = bbox_size*scale 36 | return center, size 37 | 38 | def crop_tensor(image, center, bbox_size, crop_size, interpolation = 'bilinear', align_corners=False): 39 | ''' for batch image 40 | Args: 41 | image (torch.Tensor): the reference tensor of shape BXHxWXC. 42 | center: [bz, 2] 43 | bboxsize: [bz, 1] 44 | crop_size; 45 | interpolation (str): Interpolation flag. Default: 'bilinear'. 46 | align_corners (bool): mode for grid_generation. Default: False. See 47 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details 48 | Returns: 49 | cropped_image 50 | tform 51 | ''' 52 | dtype = image.dtype 53 | device = image.device 54 | batch_size = image.shape[0] 55 | # points: top-left, top-right, bottom-right, bottom-left 56 | src_pts = torch.zeros([4,2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1).contiguous() 57 | 58 | src_pts[:, 0, :] = center - bbox_size*0.5 # / (self.crop_size - 1) 59 | src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 60 | src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 61 | src_pts[:, 2, :] = center + bbox_size * 0.5 62 | src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5 63 | src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5 64 | 65 | DST_PTS = torch.tensor([[ 66 | [0, 0], 67 | [crop_size - 1, 0], 68 | [crop_size - 1, crop_size - 1], 69 | [0, crop_size - 1], 70 | ]], dtype=dtype, device=device).expand(batch_size, -1, -1) 71 | # estimate transformation between points 72 | dst_trans_src = get_perspective_transform(src_pts, DST_PTS) 73 | # simulate broadcasting 74 | # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1) 75 | 76 | # warp images 77 | cropped_image = warp_affine( 78 | image, dst_trans_src[:, :2, :], (crop_size, crop_size), 79 | flags=interpolation, align_corners=align_corners) 80 | 81 | tform = torch.transpose(dst_trans_src, 2, 1) 82 | # tform = torch.inverse(dst_trans_src) 83 | return cropped_image, tform 84 | 85 | class Cropper(object): 86 | def __init__(self, crop_size, scale=[1,1], trans_scale = 0.): 87 | self.crop_size = crop_size 88 | self.scale = scale 89 | self.trans_scale = trans_scale 90 | 91 | def crop(self, image, points, points_scale=None): 92 | # points to bbox 93 | center, bbox_size = points2bbox(points.clone(), points_scale) 94 | # argument bbox. TODO: add rotation? 95 | center, bbox_size = augment_bbox(center, bbox_size, scale=self.scale, trans_scale=self.trans_scale) 96 | # crop 97 | cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size) 98 | return cropped_image, tform 99 | 100 | def transform_points(self, points, tform, points_scale=None, normalize = True): 101 | points_2d = points[:,:,:2] 102 | 103 | #'input points must use original range' 104 | if points_scale: 105 | assert points_scale[0]==points_scale[1] 106 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0] 107 | 108 | batch_size, n_points, _ = points.shape 109 | trans_points_2d = torch.bmm( 110 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), 111 | tform 112 | ) 113 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1) 114 | if normalize: 115 | trans_points[:,:,:2] = trans_points[:,:,:2]/self.crop_size*2 - 1 116 | return trans_points 117 | 118 | def transform_points(points, tform, points_scale=None, out_scale=None): 119 | points_2d = points[:,:,:2] 120 | 121 | #'input points must use original range' 122 | if points_scale: 123 | assert points_scale[0]==points_scale[1] 124 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0] 125 | # import ipdb; ipdb.set_trace() 126 | 127 | batch_size, n_points, _ = points.shape 128 | trans_points_2d = torch.bmm( 129 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), 130 | tform 131 | ) 132 | if out_scale: # h,w of output image size 133 | trans_points_2d[:,:,0] = trans_points_2d[:,:,0]/out_scale[1]*2 - 1 134 | trans_points_2d[:,:,1] = trans_points_2d[:,:,1]/out_scale[0]*2 - 1 135 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1) 136 | return trans_points -------------------------------------------------------------------------------- /inference_rigface.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.checkpoint 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from diffusers import AutoencoderKL 11 | from diffusers import ( 12 | UniPCMultistepScheduler, 13 | ) 14 | from transformers import CLIPTextModel, CLIPTokenizer 15 | 16 | from rigface.models.pipelineRigFace import RigFacePipeline as RigFacePipelineInference 17 | from rigface.models.unet_ID_2d_condition import UNetID2DConditionModel 18 | from rigface.models.unet_denoising_2d_condition import UNetDenoise2DConditionModel 19 | 20 | 21 | 22 | def parse_args(input_args=None): 23 | parser = argparse.ArgumentParser(description="Inference script.") 24 | 25 | parser.add_argument( 26 | "--pretrained_model_name_or_path", 27 | type=str, 28 | default='stable-diffusion-v1-5/stable-diffusion-v1-5', 29 | required=False, 30 | help="Path to pretrained model or model identifier from huggingface.co/models.", 31 | ) 32 | 33 | parser.add_argument( 34 | "--revision", 35 | type=str, 36 | default=None, 37 | required=False, 38 | help="Revision of pretrained model identifier from huggingface.co/models.", 39 | ) 40 | parser.add_argument( 41 | "--variant", 42 | type=str, 43 | default=None, 44 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 45 | ) 46 | 47 | parser.add_argument("--seed", type=int, default=424, help="A seed for reproducible training.") 48 | 49 | 50 | parser.add_argument( 51 | "--inference_steps", 52 | type=int, 53 | default=50, 54 | ) 55 | 56 | 57 | parser.add_argument( 58 | "--vit_path", 59 | type=str, 60 | default="openai/clip-vit-large-patch14", 61 | ) 62 | 63 | parser.add_argument( 64 | "--vton_unet_path", 65 | type=str, 66 | default='./pre_trained/unet_denoise/checkpoint-70000', 67 | ) 68 | 69 | parser.add_argument( 70 | "--garm_unet_path", 71 | type=str, 72 | default='./pre_trained/unet_id/checkpoint-70000', 73 | ) 74 | parser.add_argument( 75 | "--id_path", 76 | type=str, 77 | default='', 78 | ) 79 | parser.add_argument( 80 | "--bg_path", 81 | type=str, 82 | default='', 83 | ) 84 | parser.add_argument( 85 | "--exp_path", 86 | type=str, 87 | default='', 88 | ) 89 | parser.add_argument( 90 | "--render_path", 91 | type=str, 92 | default='', 93 | ) 94 | parser.add_argument( 95 | "--save_path", 96 | type=str, 97 | default='', 98 | ) 99 | 100 | if input_args is not None: 101 | args = parser.parse_args(input_args) 102 | else: 103 | args = parser.parse_args() 104 | 105 | return args 106 | 107 | 108 | 109 | def make_data(args): 110 | 111 | transform = transforms.ToTensor() 112 | 113 | img_name = args.id_path 114 | bg_name = args.bg_path 115 | render_name = args.render_path 116 | 117 | source = Image.open(img_name) 118 | source = transform(source) 119 | 120 | bg = Image.open(bg_name) 121 | bg = transform(bg) 122 | 123 | render = Image.open(render_name) 124 | # render = render.resize((512, 512)) 125 | render = transform(render) 126 | 127 | return source, bg, render 128 | 129 | 130 | def tokenize_captions(tokenizer, captions, max_length): 131 | 132 | inputs = tokenizer( 133 | captions, 134 | max_length=tokenizer.model_max_length, 135 | padding="max_length", 136 | truncation=True, 137 | return_tensors="pt" 138 | ) 139 | return inputs.input_ids 140 | 141 | 142 | def main(args): 143 | 144 | device = 'cuda' 145 | vton_unet_path = args.vton_unet_path 146 | garm_unet_path = args.garm_unet_path 147 | 148 | vae = AutoencoderKL.from_pretrained( 149 | args.pretrained_model_name_or_path, 150 | subfolder="vae" 151 | ).to(device) 152 | text_encoder = CLIPTextModel.from_pretrained( 153 | args.pretrained_model_name_or_path, 154 | subfolder="text_encoder", 155 | ).to(device) 156 | 157 | tokenizer = CLIPTokenizer.from_pretrained( 158 | args.pretrained_model_name_or_path, 159 | subfolder="tokenizer", 160 | ) 161 | 162 | unet_id = UNetID2DConditionModel.from_pretrained( 163 | garm_unet_path, 164 | # torch_dtype=torch.float16, 165 | use_safetensors=True, 166 | low_cpu_mem_usage=False, 167 | ignore_mismatched_sizes=True 168 | ) 169 | 170 | unet_denoising = UNetDenoise2DConditionModel.from_pretrained( 171 | vton_unet_path, 172 | # torch_dtype=torch.float16, 173 | use_safetensors=True, 174 | low_cpu_mem_usage=False, 175 | ignore_mismatched_sizes=True 176 | ) 177 | 178 | unet_denoising.requires_grad_(False) 179 | unet_id.requires_grad_(False) 180 | vae.requires_grad_(False) 181 | text_encoder.requires_grad_(False) 182 | 183 | weight_dtype = torch.float32 184 | 185 | 186 | pipeline = RigFacePipelineInference.from_pretrained( 187 | args.pretrained_model_name_or_path, 188 | vae=vae, 189 | text_encoder=text_encoder, 190 | tokenizer=tokenizer, 191 | unet_id=unet_id, 192 | unet_denoising=unet_denoising, 193 | safety_checker=None, 194 | revision=args.revision, 195 | variant=args.variant, 196 | torch_dtype=weight_dtype, 197 | ).to(device) 198 | 199 | 200 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 201 | pipeline.set_progress_bar_config(disable=True) 202 | 203 | if args.seed is None: 204 | generator = None 205 | else: 206 | generator = torch.Generator(device=device).manual_seed(args.seed) 207 | 208 | source, bg, rend = make_data(args) 209 | prompt = 'A close up of a person.' 210 | source = source.unsqueeze(0) 211 | bg = bg.unsqueeze(0) 212 | rend = rend.unsqueeze(0) 213 | 214 | prompt_embeds = text_encoder(tokenize_captions(tokenizer, [prompt], 2).to(device))[0] 215 | 216 | 217 | exp = np.load(args.exp_path) 218 | 219 | os.makedirs(args.save_path, exist_ok=True) 220 | tor_exp = torch.from_numpy(exp).unsqueeze(0) 221 | 222 | samples = pipeline( 223 | prompt_embeds=prompt_embeds, 224 | source=source, 225 | bg=bg, 226 | render=rend, 227 | exp=tor_exp, 228 | num_inference_steps=args.inference_steps, 229 | generator=generator, 230 | ).images[0] 231 | samples.save(os.path.join(args.save_path, f'out.png')) 232 | 233 | if __name__ == "__main__": 234 | args = parse_args() 235 | 236 | main(args) 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.3.0 2 | bitsandbytes==0.45.1 3 | diffusers==0.22.0 4 | einops==0.8.0 5 | face_alignment==1.4.1 6 | facenet_pytorch==2.6.0 7 | huggingface_hub==0.28.1 8 | insightface==0.7.3 9 | ipdb==0.13.13 10 | kornia==0.8.0 11 | loguru==0.7.3 12 | numpy==1.23.0 13 | omegaconf==2.3.0 14 | opencv_python==4.11.0.86 15 | opencv_python_headless==4.11.0.86 16 | packaging==24.2 17 | pandas==2.2.3 18 | Pillow==11.1.0 19 | preprocess==2.0.0 20 | PyYAML==6.0.2 21 | PyYAML==6.0.2 22 | safetensors==0.5.2 23 | scipy==1.15.1 24 | setuptools==75.8.0 25 | skimage==0.0 26 | torchfile==0.1.0 27 | tqdm==4.67.1 28 | transformers==4.42.4 29 | wandb==0.19.6 30 | xformers==0.0.29.post2 31 | yacs==0.1.8 32 | -------------------------------------------------------------------------------- /utils/compute_renders.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | import argparse 6 | import torch as th 7 | from torchvision.utils import save_image 8 | 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "."))) 11 | 12 | 13 | from decalib.deca import DECA 14 | from decalib.utils.config import cfg as deca_cfg 15 | from data_utils import get_image_dict 16 | 17 | # Build DECA 18 | deca_cfg.model.use_tex = True 19 | deca_cfg.model.tex_path = "./data/FLAME_texture.npz" 20 | deca_cfg.model.tex_type = "FLAME" 21 | deca = DECA(config=deca_cfg, device="cuda") 22 | 23 | 24 | 25 | def get_render(source, target, modes): 26 | src_dict = get_image_dict(source, 512, True) 27 | tar_dict = get_image_dict(target, 512, True) 28 | # ===================get DECA codes of the target image=============================== 29 | tar_cropped = tar_dict["image"].unsqueeze(0).to("cuda") 30 | imgname = tar_dict["imagename"] 31 | with th.no_grad(): 32 | tar_code = deca.encode(tar_cropped) 33 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda") 34 | # ===================get DECA codes of the source image=============================== 35 | src_cropped = src_dict["image"].unsqueeze(0).to("cuda") 36 | with th.no_grad(): 37 | src_code = deca.encode(src_cropped) 38 | # To align the face when the pose is changing 39 | src_ffhq_center = deca.decode(src_code, return_ffhq_center=True) 40 | tar_ffhq_center = deca.decode(tar_code, return_ffhq_center=True) 41 | 42 | src_tform = src_dict["tform"].unsqueeze(0) 43 | src_tform = th.inverse(src_tform).transpose(1, 2).to("cuda") 44 | src_code["tform"] = src_tform 45 | 46 | tar_tform = tar_dict["tform"].unsqueeze(0) 47 | tar_tform = th.inverse(tar_tform).transpose(1, 2).to("cuda") 48 | tar_code["tform"] = tar_tform 49 | 50 | src_image = src_dict["original_image"].unsqueeze(0).to("cuda") # 平均的参数 51 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda") 52 | 53 | # code 1 means source code, code 2 means target code 54 | code1, code2 = {}, {} 55 | for k in src_code: 56 | code1[k] = src_code[k].clone() 57 | 58 | for k in tar_code: 59 | code2[k] = tar_code[k].clone() 60 | 61 | mode_list = modes.split("+") 62 | # 应该是确定有pose参与,就转换目标为target 63 | if 'pose' in mode_list: 64 | if 'exp' not in mode_list: 65 | code2['exp'] = src_code['exp'] 66 | code2['pose'][:, 3:] = src_code['pose'][:, 3:] 67 | if 'light' not in mode_list: 68 | code2['light'] = src_code['light'] 69 | opdict, _ = deca.decode( 70 | code2, 71 | render_orig=True, 72 | original_image=tar_image, 73 | tform=tar_code["tform"], 74 | align_ffhq=True, 75 | ffhq_center=tar_ffhq_center, 76 | ) 77 | else: 78 | if 'exp' in mode_list: 79 | code1['exp'] = tar_code['exp'] 80 | code1['pose'][:, 3:] = tar_code['pose'][:, 3:] 81 | if 'light' not in mode_list: 82 | code1['light'] = tar_code['light'] 83 | opdict, _ = deca.decode( 84 | code1, 85 | render_orig=True, 86 | original_image=src_image, 87 | tform=src_code["tform"], 88 | align_ffhq=True, 89 | ffhq_center=src_ffhq_center, 90 | ) 91 | 92 | rendered = opdict["rendered_images"].detach() 93 | os.makedirs('results', exist_ok=True) 94 | save_image(rendered[0], f"./results/render_{modes}.png") 95 | 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument( 101 | "--sor_path", 102 | type=str, 103 | default='', 104 | required=False 105 | ) 106 | parser.add_argument( 107 | "--tar_path", 108 | type=str, 109 | default='', 110 | required=False 111 | ) 112 | parser.add_argument( 113 | "--modes", 114 | type=str, 115 | default='', 116 | required=False 117 | ) 118 | 119 | args = parser.parse_args() 120 | get_render(args.sor_path, args.tar_path, args.modes) 121 | print('done') -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | 6 | import numpy as np 7 | import scipy 8 | import scipy.io 9 | import torch 10 | from skimage.io import imread 11 | from skimage.transform import estimate_transform, warp, resize 12 | 13 | from decalib.datasets import detectors 14 | 15 | face_detector = detectors.FAN() 16 | scale = 1.3 17 | resolution_inp = 224 18 | 19 | 20 | 21 | def bbox2point(left, right, top, bottom, type='bbox'): 22 | 23 | if type =='kpt68': 24 | old_size = (right - left + bottom - top) / 2 * 1.1 25 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 26 | elif type =='bbox': 27 | old_size = (right - left + bottom - top ) /2 28 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size *0.12]) 29 | else: 30 | raise NotImplementedError 31 | return old_size, center 32 | 33 | 34 | 35 | 36 | 37 | def get_image_dict(img_path, size, iscrop): 38 | img_name = img_path.split('/')[-1] 39 | im = imread(img_path) 40 | if size is not None: # size = 256 41 | im = (resize(im, (size, size), anti_aliasing=True) * 255.).astype(np.uint8) 42 | # (256, 256, 3) 43 | image = np.array(im) 44 | if len(image.shape) == 2: 45 | image = image[:, :, None].repeat(1, 1, 3) 46 | if len(image.shape) == 3 and image.shape[2] > 3: 47 | image = image[:, :, :3] 48 | 49 | h, w, _ = image.shape 50 | if iscrop: # true 51 | # provide kpt as txt file, or mat file (for AFLW2000) 52 | kpt_matpath = os.path.splitext(img_path)[0] + '.mat' 53 | kpt_txtpath = os.path.splitext(img_path)[0] + '.txt' 54 | if os.path.exists(kpt_matpath): 55 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T 56 | left = np.min(kpt[:, 0]) 57 | right = np.max(kpt[:, 0]) 58 | top = np.min(kpt[:, 1]) 59 | bottom = np.max(kpt[:, 1]) 60 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68') 61 | elif os.path.exists(kpt_txtpath): 62 | kpt = np.loadtxt(kpt_txtpath) 63 | left = np.min(kpt[:, 0]) 64 | right = np.max(kpt[:, 0]) 65 | top = np.min(kpt[:, 1]) 66 | bottom = np.max(kpt[:, 1]) 67 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68') 68 | else: 69 | bbox, bbox_type = face_detector.run(image) 70 | if len(bbox) < 4: 71 | print('no face detected! run original image') 72 | left = 0 73 | right = h - 1 74 | top = 0 75 | bottom = w - 1 76 | else: 77 | left = bbox[0] 78 | right = bbox[2] 79 | top = bbox[1] 80 | bottom = bbox[3] 81 | old_size, center = bbox2point(left, right, top, bottom, type=bbox_type) 82 | size = int(old_size * scale) 83 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2], 84 | [center[0] + size / 2, center[1] - size / 2]]) 85 | else: 86 | src_pts = np.array([[0, 0], [0, h - 1], [w - 1, 0]]) 87 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]]) 88 | DST_PTS = np.array([[0, 0], [0, resolution_inp - 1], [resolution_inp - 1, 0]]) 89 | tform = estimate_transform('similarity', src_pts, DST_PTS) 90 | 91 | image = image / 255. 92 | 93 | dst_image = warp(image, tform.inverse, output_shape=(resolution_inp, resolution_inp)) 94 | dst_image = dst_image.transpose(2, 0, 1) 95 | return {'image': torch.tensor(dst_image).float(), 96 | 'imagename': img_name, 97 | 'tform': torch.tensor(tform.params).float(), 98 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(), 99 | } 100 | -------------------------------------------------------------------------------- /utils/datasets_faceswap.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cv2 4 | from PIL import ImageFile 5 | 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | mean_face_lm5p_256 = np.array([ 10 | [(30.2946+8)*2+16, 51.6963*2], # left eye pupil 11 | [(65.5318+8)*2+16, 51.5014*2], # right eye pupil 12 | [(48.0252+8)*2+16, 71.7366*2], # nose tip 13 | [(33.5493+8)*2+16, 92.3655*2], # left mouth corner 14 | [(62.7299+8)*2+16, 92.2041*2], # right mouth corner 15 | ], dtype=np.float32) 16 | 17 | 18 | 19 | mean_box_lm4p_512 = np.array([ 20 | [80, 80], 21 | [80, 432], 22 | [432, 432], 23 | [432, 80], 24 | ], dtype=np.float32) 25 | 26 | 27 | 28 | def get_box_lm4p(pts): 29 | x1 = np.min(pts[:,0]) 30 | x2 = np.max(pts[:,0]) 31 | y1 = np.min(pts[:,1]) 32 | y2 = np.max(pts[:,1]) 33 | 34 | x_center = (x1+x2)*0.5 35 | y_center = (y1+y2)*0.5 36 | box_size = max(x2-x1, y2-y1) 37 | 38 | x1 = x_center-0.5*box_size 39 | x2 = x_center+0.5*box_size 40 | y1 = y_center-0.5*box_size 41 | y2 = y_center+0.5*box_size 42 | 43 | return np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32) 44 | 45 | 46 | def get_affine_transform(target_face_lm5p, mean_lm5p): 47 | mat_warp = np.zeros((2,3)) 48 | A = np.zeros((4,4)) 49 | B = np.zeros((4)) 50 | for i in range(5): 51 | #sa[0][0] += a[i].x*a[i].x + a[i].y*a[i].y; 52 | A[0][0] += target_face_lm5p[i][0] * target_face_lm5p[i][0] + target_face_lm5p[i][1] * target_face_lm5p[i][1] 53 | #sa[0][2] += a[i].x; 54 | A[0][2] += target_face_lm5p[i][0] 55 | #sa[0][3] += a[i].y; 56 | A[0][3] += target_face_lm5p[i][1] 57 | 58 | #sb[0] += a[i].x*b[i].x + a[i].y*b[i].y; 59 | B[0] += target_face_lm5p[i][0] * mean_lm5p[i][0] + target_face_lm5p[i][1] * mean_lm5p[i][1] 60 | #sb[1] += a[i].x*b[i].y - a[i].y*b[i].x; 61 | B[1] += target_face_lm5p[i][0] * mean_lm5p[i][1] - target_face_lm5p[i][1] * mean_lm5p[i][0] 62 | #sb[2] += b[i].x; 63 | B[2] += mean_lm5p[i][0] 64 | #sb[3] += b[i].y; 65 | B[3] += mean_lm5p[i][1] 66 | 67 | #sa[1][1] = sa[0][0]; 68 | A[1][1] = A[0][0] 69 | #sa[2][1] = sa[1][2] = -sa[0][3]; 70 | A[2][1] = A[1][2] = -A[0][3] 71 | #sa[3][1] = sa[1][3] = sa[2][0] = sa[0][2]; 72 | A[3][1] = A[1][3] = A[2][0] = A[0][2] 73 | #sa[2][2] = sa[3][3] = count; 74 | A[2][2] = A[3][3] = 5 75 | #sa[3][0] = sa[0][3]; 76 | A[3][0] = A[0][3] 77 | 78 | _, mat23 = cv2.solve(A, B, flags=cv2.DECOMP_SVD) 79 | mat_warp[0][0] = mat23[0] 80 | mat_warp[1][1] = mat23[0] 81 | mat_warp[0][1] = -mat23[1] 82 | mat_warp[1][0] = mat23[1] 83 | mat_warp[0][2] = mat23[2] 84 | mat_warp[1][2] = mat23[3] 85 | 86 | return mat_warp 87 | 88 | 89 | 90 | 91 | def transformation_from_points(points1, points2): 92 | points1 = np.float64(np.matrix([[point[0], point[1]] for point in points1])) 93 | points2 = np.float64(np.matrix([[point[0], point[1]] for point in points2])) 94 | 95 | points1 = points1.astype(np.float64) 96 | points2 = points2.astype(np.float64) 97 | c1 = np.mean(points1, axis=0) 98 | c2 = np.mean(points2, axis=0) 99 | points1 -= c1 100 | points2 -= c2 101 | s1 = np.std(points1) 102 | s2 = np.std(points2) 103 | points1 /= s1 104 | points2 /= s2 105 | #points2 = np.array(points2) 106 | #write_pts('pt2.txt', points2) 107 | U, S, Vt = np.linalg.svd(points1.T * points2) 108 | R = (U * Vt).T 109 | return np.array(np.vstack([np.hstack(((s2 / s1) * R,c2.T - (s2 / s1) * R * c1.T)),np.matrix([0., 0., 1.])])[:2]) 110 | 111 | -------------------------------------------------------------------------------- /utils/make_bgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from insightface.app import FaceAnalysis 10 | from torchvision.utils import save_image 11 | 12 | from model import BiSeNet 13 | 14 | device = 'cuda' 15 | checkpoint = './checkpoints' 16 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'), 17 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 18 | app.prepare(ctx_id=0, det_size=(640, 640)) 19 | 20 | 21 | n_classes = 19 22 | net = BiSeNet(n_classes=n_classes) 23 | net.cuda() 24 | model_pth = './third_party_files/79999_iter.pth' 25 | net.load_state_dict(torch.load(model_pth)) 26 | net.eval() 27 | 28 | 29 | 30 | 31 | def keep_background(im, parsing_anno, stride): 32 | # Colors for all 20 parts 33 | part_colors = [[0, 0, 0], [0, 0, 0], [0, 0, 0], 34 | [0, 0, 0], [0, 0, 0], 35 | [0, 0, 0], [0, 0, 0], [0, 0, 0], 36 | [0, 0, 0], [0, 0, 0], 37 | [0, 0, 0], [0, 0, 0], [0, 0, 0], 38 | [0, 0, 0], [0, 0, 0], 39 | [0, 0, 0], [0, 0, 0], [0, 0, 0], 40 | [0, 0, 0], [0, 0, 0], [0, 0, 0], 41 | [0, 0, 0], [0, 0, 0], [0, 0, 0]] 42 | 43 | im = np.array(im) 44 | vis_im = im.copy().astype(np.uint8) 45 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) # [1, 19, 512, 512] 46 | 47 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 48 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 49 | 50 | num_of_class = np.max(vis_parsing_anno) 51 | 52 | 53 | for pi in range(1, num_of_class + 1): 54 | # if pi == 8 or pi == 9 or pi == 14 or pi == 17 or pi == 18: 55 | # continue 56 | index = np.where(vis_parsing_anno == pi) 57 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 58 | 59 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 60 | tmp = vis_parsing_anno_color / 255 61 | mask_channel = 1 - tmp[:, :, 0] 62 | deep_gray = np.full_like(vis_im, (0, 0, 0), dtype=np.uint8) 63 | result_image = np.where(mask_channel[:, :, np.newaxis] == 1, 255, deep_gray) # [0-255] 64 | return result_image 65 | 66 | 67 | def deal_with_one_image(sorpth, tgtpth, modes): 68 | 69 | trans = transforms.ToTensor() 70 | 71 | to_tensor = transforms.Compose([ 72 | transforms.Resize(512), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 75 | ]) 76 | 77 | with torch.no_grad(): 78 | tgtimg, sorimg = Image.open(tgtpth), Image.open(sorpth) 79 | tgtimage, sorimage = tgtimg.resize((512, 512), Image.BILINEAR), sorimg.resize((512, 512), 80 | Image.BILINEAR) 81 | # image = img 82 | tgtimg, sorimg = to_tensor(tgtimage), to_tensor(sorimage) 83 | tgtimg, sorimg = torch.unsqueeze(tgtimg, 0), torch.unsqueeze(sorimg, 0) 84 | tgtimg, sorimg = tgtimg.cuda(), sorimg.cuda() 85 | tgtout, sorout = net(tgtimg)[0], net(sorimg)[0] # [1, 19, 512, 512] 86 | tgtparsing, sorparsing = tgtout.squeeze(0).cpu().numpy().argmax(0), sorout.squeeze( 87 | 0).cpu().numpy().argmax(0) 88 | 89 | tgtbg, sorbg = keep_background(tgtimage, tgtparsing, stride=1), keep_background(sorimage, sorparsing, 90 | stride=1) 91 | tgtbg, sorbg = cv2.cvtColor(tgtbg, cv2.COLOR_RGB2BGR), cv2.cvtColor(sorbg, cv2.COLOR_RGB2BGR) 92 | mode_list = modes.split('+') 93 | 94 | if 'pose' in mode_list: 95 | logical_or = np.bitwise_or(tgtbg, sorbg) 96 | elif 'light' in mode_list or 'exp' in mode_list: 97 | logical_or = sorbg 98 | else: 99 | raise ValueError(f'Unknown mode: {modes}') 100 | 101 | tmp = logical_or / 255 102 | mask_channel = 1 - tmp[:, :, 0] 103 | tmp_sor = cv2.imread(sorpth) 104 | deep_gray = np.full_like(tmp_sor, (127.5, 127.5, 127.5), dtype=np.uint8) 105 | 106 | im_cv = cv2.cvtColor(tmp_sor, cv2.COLOR_RGB2BGR) 107 | result_image = np.where(mask_channel[:, :, np.newaxis] == 1, im_cv, deep_gray) 108 | result_image = trans(result_image) 109 | # # res = bg + im_pts70 110 | os.makedirs('results', exist_ok=True) 111 | save_image(result_image, f"./results/bg_{modes}.png") 112 | 113 | 114 | 115 | 116 | if __name__ == '__main__': 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument( 119 | "--sor_path", 120 | type=str, 121 | default='', 122 | required=False 123 | ) 124 | parser.add_argument( 125 | "--tar_path", 126 | type=str, 127 | default='', 128 | required=False 129 | ) 130 | parser.add_argument( 131 | "--modes", 132 | type=str, 133 | default='', 134 | required=False 135 | ) 136 | 137 | args = parser.parse_args() 138 | deal_with_one_image(args.sor_path, args.tar_path, args.modes) -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | return feat_out, feat_out16, feat_out32 255 | 256 | def init_weight(self): 257 | for ly in self.children(): 258 | if isinstance(ly, nn.Conv2d): 259 | nn.init.kaiming_normal_(ly.weight, a=1) 260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 261 | 262 | def get_params(self): 263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 264 | for name, child in self.named_children(): 265 | child_wd_params, child_nowd_params = child.get_params() 266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 267 | lr_mul_wd_params += child_wd_params 268 | lr_mul_nowd_params += child_nowd_params 269 | else: 270 | wd_params += child_wd_params 271 | nowd_params += child_nowd_params 272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 273 | 274 | 275 | if __name__ == "__main__": 276 | net = BiSeNet(19) 277 | net.cuda() 278 | net.eval() 279 | in_ten = torch.randn(16, 3, 640, 480).cuda() 280 | out, out16, out32 = net(in_ten) 281 | print(out.shape) 282 | 283 | net.get_params() 284 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from insightface.app import FaceAnalysis 8 | from torchvision.utils import save_image 9 | 10 | import datasets_faceswap as datasets_faceswap 11 | 12 | pil2tensor = transforms.Compose([transforms.ToTensor(), transforms.Resize(512)]) 13 | 14 | pil2tensor = transforms.ToTensor() 15 | 16 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'), 17 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 18 | app.prepare(ctx_id=0, det_size=(640, 640)) 19 | 20 | 21 | def get_bbox(dets, crop_ratio): 22 | if crop_ratio > 0: 23 | bbox = dets[0:4] 24 | bbox_size = max(bbox[2] - bbox[0], bbox[2] - bbox[0]) 25 | bbox_x = 0.5 * (bbox[2] + bbox[0]) 26 | bbox_y = 0.5 * (bbox[3] + bbox[1]) 27 | x1 = bbox_x - bbox_size * crop_ratio 28 | x2 = bbox_x + bbox_size * crop_ratio 29 | y1 = bbox_y - bbox_size * crop_ratio 30 | y2 = bbox_y + bbox_size * crop_ratio 31 | bbox_pts4 = np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32) 32 | else: 33 | # original box 34 | bbox = dets[0:4].reshape((2, 2)) 35 | bbox_pts4 = datasets_faceswap.get_box_lm4p(bbox) 36 | return bbox_pts4 37 | 38 | 39 | 40 | def crop_one_image(args): 41 | cur_img_sor_path = args.img_path 42 | im_pil_sor = Image.open(cur_img_sor_path).convert("RGB") 43 | face_info_sor = app.get(cv2.cvtColor(np.array(im_pil_sor), cv2.COLOR_RGB2BGR)) 44 | assert len(face_info_sor) >= 1, 'The input image must contain a face!' 45 | if len(face_info_sor) > 1: 46 | print('The input image contain more than one face, we will only use the maximum face') 47 | face_info_sor = \ 48 | sorted(face_info_sor, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1] 49 | dets_sor= face_info_sor['bbox'] 50 | bbox_pst_sor = get_bbox(dets_sor, crop_ratio=0.75) 51 | 52 | warp_mat_crop_sor = datasets_faceswap.transformation_from_points(bbox_pst_sor, 53 | datasets_faceswap.mean_box_lm4p_512) 54 | im_crop512_sor = cv2.warpAffine(np.array(im_pil_sor), warp_mat_crop_sor, (512, 512), flags=cv2.INTER_LINEAR) 55 | 56 | im_pil_sor = Image.fromarray(im_crop512_sor) 57 | im_pil_sor = pil2tensor(im_pil_sor) 58 | save_image(im_pil_sor, args.save_path) 59 | 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument( 65 | "--img_path", 66 | type=str, 67 | default='', 68 | required=False 69 | ) 70 | parser.add_argument( 71 | "--save_path", 72 | type=str, 73 | default='', 74 | required=False 75 | ) 76 | args = parser.parse_args() 77 | crop_one_image(args) -------------------------------------------------------------------------------- /utils/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /utils/save_exp_coeffs.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from insightface.app import FaceAnalysis 11 | 12 | import datasets_faceswap as datasets_faceswap 13 | import third_party.d3dfr.bfm as bfm 14 | import third_party.model_resnet_d3dfr as model_resnet_d3dfr 15 | 16 | device = 'cpu' 17 | checkpoint = './checkpoints' 18 | 19 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'), 20 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 21 | 22 | app.prepare(ctx_id=0, det_size=(640, 640)) 23 | 24 | pil2tensor = transforms.Compose([ 25 | transforms.Resize((256, 256)), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=0.5, std=0.5)]) 28 | 29 | 30 | net_d3dfr = model_resnet_d3dfr.getd3dfr_res50(os.path.join(checkpoint, 'third_party/d3dfr_res50_nofc.pth')).eval().to(device) 31 | 32 | bfm_facemodel = bfm.BFM(focal=1015*256/224, image_size=256, 33 | bfm_model_path=os.path.join(checkpoint, 'third_party/BFM_model_front.mat')).to(device) 34 | 35 | 36 | def get_landmarks(image): 37 | face_info = app.get(cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)) 38 | if len(face_info) == 0: 39 | return 'error' 40 | face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1] 41 | pts5 = face_info['kps'] 42 | 43 | warp_mat = datasets_faceswap.get_affine_transform(pts5, datasets_faceswap.mean_face_lm5p_256) 44 | drive_im_crop256 = cv2.warpAffine(np.array(image), warp_mat, (256, 256), flags=cv2.INTER_LINEAR) 45 | 46 | drive_im_crop256_pil = Image.fromarray(drive_im_crop256) 47 | image_tar_crop256 = pil2tensor(drive_im_crop256_pil).view(1, 3, 256, 256) 48 | 49 | gt_d3d_coeff = net_d3dfr(image_tar_crop256) 50 | # _, ex_coeff = bfm_facemodel.get_lm68(gt_d3d_coeff) 51 | id_coeff, ex_coeff, tex_coeff, angles, gamma, translation = bfm_facemodel.split_coeff_orderly(gt_d3d_coeff) 52 | 53 | return (ex_coeff, angles, gamma, translation) 54 | 55 | 56 | 57 | def main(sorpth, tarpth, modes): 58 | 59 | mode_list = modes.split('+') 60 | if 'exp' in mode_list: 61 | dstpth = tarpth 62 | elif 'light' in mode_list or 'pose' in mode_list: 63 | dstpth = sorpth 64 | else: 65 | raise ValueError('Unrecognized mode') 66 | with torch.no_grad(): 67 | img = Image.open(dstpth) 68 | res = get_landmarks(img) 69 | if isinstance(res, str): 70 | print('cannot find face on ', dstpth) 71 | return 72 | ex_coeff, angles, gamma, translation = res 73 | os.makedirs('results', exist_ok=True) 74 | np.save(f"./results/exp_{modes}.npy", ex_coeff[0]) 75 | 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument( 81 | "--sor_path", 82 | type=str, 83 | default='', 84 | required=False 85 | ) 86 | parser.add_argument( 87 | "--tar_path", 88 | type=str, 89 | default='', 90 | required=False 91 | ) 92 | parser.add_argument( 93 | "--modes", 94 | type=str, 95 | default='', 96 | required=False 97 | ) 98 | 99 | args = parser.parse_args() 100 | main(args.sor_path, args.tar_path, args.modes) -------------------------------------------------------------------------------- /utils/test_images/id1/bg_pose+exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/bg_pose+exp.png -------------------------------------------------------------------------------- /utils/test_images/id1/exp_pose+exp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/exp_pose+exp.npy -------------------------------------------------------------------------------- /utils/test_images/id1/render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/render.png -------------------------------------------------------------------------------- /utils/test_images/id1/render_exp+pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/render_exp+pose.png -------------------------------------------------------------------------------- /utils/test_images/id1/render_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/render_exp.png -------------------------------------------------------------------------------- /utils/test_images/id1/sor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/sor.png -------------------------------------------------------------------------------- /utils/test_images/id1/tar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id1/tar.png -------------------------------------------------------------------------------- /utils/test_images/id2/bg_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/bg_light.png -------------------------------------------------------------------------------- /utils/test_images/id2/exp_light.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/exp_light.npy -------------------------------------------------------------------------------- /utils/test_images/id2/render_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/render_light.png -------------------------------------------------------------------------------- /utils/test_images/id2/sor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/sor.png -------------------------------------------------------------------------------- /utils/test_images/id2/tar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/RigFace/dd3277108164830b19a8f15339b681c4b230079a/utils/test_images/id2/tar.png --------------------------------------------------------------------------------