├── .gitignore ├── README.md ├── demo.ipynb ├── docs └── index.html ├── environment.yml ├── mutil ├── __init__.py ├── bfm2017.py ├── data_types.py ├── files.py ├── np_util.py ├── object_dict.py ├── pytorch_utils.py ├── renderer.py ├── str_format.py └── threed_utils.py └── varitex ├── custom_callbacks ├── __init__.py └── callbacks.py ├── data ├── __init__.py ├── augmentation.py ├── custom_dataset.py ├── dataset_specifics.py ├── keys_enum.py ├── npy_dataset.py └── uv_factory.py ├── demo.py ├── evaluation ├── __init__.py └── inference.py ├── inference.py ├── inference_surface.py ├── modules ├── __init__.py ├── custom_module.py ├── decoder.py ├── discriminator.py ├── encoder.py ├── feature2image.py ├── generator.py ├── loss.py ├── metrics.py ├── pipeline.py └── unet.py ├── options ├── __init__.py ├── base_options.py ├── eval_options.py └── train_options.py ├── train.py └── visualization ├── __init__.py └── batch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea 132 | lightning_logs 133 | res 134 | cfg_*.py 135 | pretrained -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VariTex: Variational Neural Face Textures 2 | 3 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 4 | ![Python 3.6](https://img.shields.io/badge/python-3.8.5-green.svg) 5 | 6 | ![Teaser](https://dataset.ait.ethz.ch/downloads/varitex/teaser.png) 7 | 8 | This is the official repository of the paper: 9 | 10 | > **VariTex: Variational Neural Face Textures**
11 | > [Marcel C. Bühler](https://ait.ethz.ch/people/buehler), [Abhimitra Meka](https://www.meka.page/), [Gengyan Li](https://ait.ethz.ch/people/lig/), [Thabo Beeler](https://thabobeeler.com/), and [Otmar Hilliges](https://ait.ethz.ch/people/hilliges/).
12 | > **Abstract:** *Deep generative models have recently demonstrated the ability to synthesize photorealistic images of human faces with novel identities. A key challenge to the wide applicability of such techniques is to provide independent control over semantically meaningful parameters: appearance, head pose, face shape, and facial expressions. In this paper, we propose VariTex - to the best of our knowledge the first method that learns a variational latent feature space of neural face textures, which allows sampling of novel identities. We combine this generative model with a parametric face model and gain explicit control over head pose and facial expressions. To generate images of complete human heads, we propose an additive decoder that generates plausible additional details such as hair. A novel training scheme enforces a pose independent latent space and in consequence, allows learning of a one-to-many mapping between latent codes and pose-conditioned exterior regions. The resulting method can generate geometrically consistent images of novel identities allowing fine-grained control over head pose, face shape, and facial expressions, facilitating a broad range of downstream tasks, like sampling novel identities, re-posing, expression transfer, and more.* 13 | 14 | 15 | # Code and Models 16 | 17 | ## Code, Environment 18 | - [ ] Clone repository: `git clone https://github.com/mcbuehler/VariTex.git` 19 | - [ ] Create environment: `conda env create -f environment.yml` and activate it `conda activate varitex`. 20 | 21 | 22 | ## Data 23 | We train on the [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset) and we use the [Basel Face Model 2017](https://faces.dmi.unibas.ch/bfm/bfm2017.html) (BFM). Please download the following: 24 | 25 | - [ ] FFHQ: Follow the instructions in the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset) to obtain the images (.png). Download "Aligned and cropped images at 1024×1024". 26 | - [ ] Preprocessed dataset: [Download](https://dataset.ait.ethz.ch/downloads/varitex/preprocessed_dataset.zip) (~15 GB) and unzip. 27 | - [ ] Basel Face Model: Request the [model](https://faces.dmi.unibas.ch/bfm/bfm2017.html) ("model2017-1_face12_nomouth.h5"") and download the [UV parameterization](https://github.com/unibas-gravis/parametric-face-image-generator/blob/master/data/regions/face12.json). 28 | - [ ] Pretrained models: [Download](https://dataset.ait.ethz.ch/downloads/varitex/pretrained.zip) and unzip. 29 | - [ ] Move the downloaded files to the correct locations (see below) 30 | 31 | Environment variables should point to your data, facemodel, and (optional) output folder: `export DP=; export FP=; export OP=`. 32 | We assume the following folder structure. 33 | * `$DP/FFHQ/images`: Folder with *.png files from FFHQ 34 | * `$DP/FFHQ/preprocessed_dataset`: Folder with the preprocessed datasets. Should contain .npy files "R", "t", "s", "sp", "ep", "segmentation", "uv", "filename", and a .npz file "dataset_splits". 35 | * `$FP/basel_facemodel/`: Folder where the BFM model files are located. Should contain "model2017-1_face12_nomouth.h5" and "face12.json". 36 | 37 | ## Using the Pretrained Model 38 | 39 | Make sure you have downloaded the pretrained model (link above). 40 | Define the checkpoint file: `export CP=.ckpt` 41 | 42 | #### Demo Notebook 43 | Run the notebook `CUDA_VISIBLE_DEVICES=0 jupyter notebook` and open `demo.ipynb`. 44 | 45 | #### Inference Script 46 | The inference script runs three different modes on the FFHQ dataset: 47 | 1. Inference on the extracted geometries and original pose (`inference.inference_ffhq`) 48 | 2. Inference with extracted geometries and multiple poses (`inference.inference_posed_ffhq`) 49 | 3. Inference with random geometries and poses (`inference.inference_posed`) 50 | 51 | You can adjust the number of samples with the parameter `n`. 52 | 53 | `CUDA_VISIBLE_DEVICES=0 python varitex/inference.py --checkpoint $CP --dataset_split val`. 54 | 55 | 56 | ## Training 57 | Run `CUDA_VISIBLE_DEVICES=0 python varitex/train.py`. 58 | 59 | If you wish, you can set a variety of input parameters. Please see `varitex.options`. 60 | 61 | A GPU with 24 GB VMem should support batch size 7. If your GPU has only 12 GB, please use a lower batch size. 62 | 63 | Training should converge after 44 epochs, which takes roughly 72 hours on a NVIDIA Quadro RTX 6000/8000 GPU. 64 | 65 | ## Implementation Details 66 | 67 | The VariTex architecture consists of several components (in `varitex/modules`). We pass on a dictionary from one component to the next. The following table lists the classes / methods with their corresponding added tensors. 68 | 69 | 70 | | Class / Method | Adds... | 71 | |--- |--- 72 | | varitex.data.hdf_dataset.NPYDataset | IMAGE_IN, IMAGE_IN_ENCODE, SEGMENTATION_MASK, UV_RENDERED | 73 | | varitex.modules.encoder.Encoder | IMAGE_ENCODED | 74 | | varitex.modules.generator.Generator.forward_encoded2latent_distribution | STYLE_LATENT_MU, STYLE_LATENT_STD | 75 | | varitex.modules.generator.Generator.forward_sample_style | STYLE_LATENT | 76 | | varitex.modules.generator.Generator.forward_latent2featureimage | LATENT_INTERIOR, LATENT_EXTERIOR | 77 | | varitex.modules.decoder.Decoder | TEXTURE_PERSON | 78 | | varitex.modules.generator.Generator.sample_texture | FACE_FEATUREIMAGE | 79 | | varitex.modules.decoder.AdditiveDecoder | ADDITIVE_FEATUREIMAGE | 80 | | varitex.modules.generator.Generator.forward_merge_textures | FULL_FEATUREIMAGE | 81 | | varitex.modules.feature2image.Feature2ImageRenderer | IMAGE_OUT, SEGMENTATION_PREDICTED | 82 | 83 | ## Acknowledgements 84 | We implement our pipeline in [Lightning](https://www.pytorchlightning.ai/) and use the [SPADE](https://github.com/NVlabs/SPADE) discriminator. The neural rendering is inspired by [Neural Voice Puppetry](https://github.com/keetsky/NeuralVoicePuppetry). We found the [pytorch3d](https://github.com/facebookresearch/pytorch3d) renderer very helpful. 85 | 86 | 87 | ## License 88 | Copyright belongs to the authors. 89 | All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**) 90 | 91 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | VariTex: Variational Neural Face Textures 6 | 7 | 8 | 10 | 11 | 15 | 16 | 17 | 18 | 25 | 26 | 27 | 36 |
37 | 38 |
39 |
40 |
41 |
42 |
43 |

VariTex:
Variational Neural Face Textures

44 | 45 |
46 | 47 |
48 |
49 | Marcel C. Bühler1    51 | Abhimitra Meka2    52 | Gengyan Li1,2    53 | Thabo Beeler2    54 | Otmar Hilliges1 55 |
56 |
57 |
58 | 1AIT, 59 | ETH Zurich    60 | 2Google 61 |
62 |
63 |
64 | 65 |
66 |
67 |
68 | 69 |
70 |
71 |
Paper | 72 | Supplementary | 73 | GitHub | 74 | Demo | 75 | Youtube | 76 | Blog 77 |
78 |
79 |
80 | 81 |
82 | 83 |
84 |
85 | 86 |
87 | 88 |

Abstract

89 | 90 |

91 |

92 | Deep generative models can synthesize photorealistic images of human faces with novel identities. 93 | However, a key challenge to the wide applicability of such techniques is to provide independent control over semantically meaningful parameters: appearance, head pose, face shape, and facial expressions. 94 | In this paper, we propose VariTex - to the best of our knowledge the first method that learns a variational latent feature space of neural face textures, which allows sampling of novel identities. 95 | We combine this generative model with a parametric face model and gain explicit control over head pose and facial expressions. 96 | To generate complete images of human heads, we propose an additive decoder that adds plausible details such as hair. 97 | A novel training scheme enforces a pose-independent latent space and in consequence, allows learning a one-to-many mapping between latent codes and pose-conditioned exterior regions. 98 | The resulting method can generate geometrically consistent images of novel identities under fine-grained control over head pose, face shape, and facial expressions. This facilitates a broad range of downstream tasks, like sampling novel identities, changing the head pose, expression transfer, and more.

99 |
100 | 101 |
102 |
103 | 104 |
105 |

Video

106 |
107 | 108 |
109 |
110 |
111 |
112 | 113 |

VariTex Controls

114 |
115 |
116 | 119 |

Expressions

120 |
121 |
122 | 128 |

Pose

129 |
130 |
131 | 137 |

Identity

138 |
139 |
140 |
141 | 142 |
143 | 144 |
145 |

Method

146 | 147 |

148 | The objective of our pipeline is to learn a Generator that can synthesize face images with arbitrary novel 149 | identities whose pose and expressions can be controlled using face model parameters. 150 | During training, we use unlabeled monocular RGB images to learn a smooth latent space 151 | of natural face appearance using a variational encoder. 152 | A latent code sampled from this space is then decoded to a novel face image. 153 | At test time, we use samples drawn from a normal distribution to generate novel face images. 154 | Our variationally generated neural textures can also be stylistically interpolated to generate 155 | intermediate identities. 156 |

157 |
158 | 159 |
160 | 161 |
162 |

Code and Models

163 | Available on GitHub. Make sure to check out our demo notebook. 164 |
165 |
166 | 167 |
168 |

Citation

169 |
170 | @inproceedings{buehler2021varitex,
171 |   title={VariTex: Variational Neural Face Textures},
172 |   author={Marcel C. Buehler and Abhimitra Meka and Gengyan Li and Thabo Beeler and Otmar Hilliges},
173 |   booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
174 |   year={2021}
175 | }        
176 | BibTeX 177 |
178 | 179 |
180 | 181 |
182 |

Acknowledgements

183 | 184 |

185 | We thank 186 | Xucong Zhang, 187 | Emre Aksan, 188 | Thomas Langerak, 189 | Xu Chen, 190 | Mohamad Shahbazi, 191 | Velko Vechev, 192 | Yue Li, 193 | and Arvind Somasundaram for their contributions; 194 | Ayush Tewari for the StyleRig visuals; and the anonymous reviewers. 195 | This project has received funding from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation program grant agreement No 717054. 196 |

197 | 198 |
199 |
200 | 201 | 202 | 204 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: varitex 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_kmp_llvm 8 | - _pytorch_select=0.2=gpu_0 9 | - absl-py=0.11.0=py38h578d9bd_0 10 | - aiohttp=3.7.3=py38h25fe258_0 11 | - appdirs=1.4.4=pyh9f0ad1d_0 12 | - async-timeout=3.0.1=py_1000 13 | - attrs=20.3.0=pyhd3deb0d_0 14 | - bcrypt=3.2.0=py38h1e0a361_1 15 | - blas=1.0=mkl 16 | - blinker=1.4=py_1 17 | - brotlipy=0.7.0=py38h8df0ef7_1001 18 | - c-ares=1.17.1=h36c2ea0_0 19 | - ca-certificates=2023.12.12=h06a4308_0 20 | - cachetools=4.1.1=py_0 21 | - certifi=2024.2.2=py38h06a4308_0 22 | - cffi=1.14.4=py38h261ae71_0 23 | - chardet=3.0.4=py38h924ce5b_1008 24 | - click=7.1.2=pyh9f0ad1d_0 25 | - cryptography=3.3.1=py38h3c74f83_0 26 | - cuda-cudart=12.3.101=hd3aeb46_0 27 | - cuda-cudart_linux-64=12.3.101=h59595ed_0 28 | - cuda-nvrtc=12.3.107=hd3aeb46_0 29 | - cuda-nvtx=12.3.101=h59595ed_0 30 | - cuda-version=12.3=h55a0123_2 31 | - cudnn=8.8.0.121=h264754d_4 32 | - cycler=0.10.0=py38_0 33 | - dbus=1.13.18=hb2f20db_0 34 | - expat=2.2.10=he6710b0_2 35 | - filelock=3.13.1=py38h06a4308_0 36 | - fontconfig=2.13.0=h9420a91_0 37 | - freetype=2.10.4=h5ab3b9f_0 38 | - fs=2.4.11=py38h32f6830_2 39 | - fs.sshfs=1.0.0=pyhd8ed1ab_1 40 | - future=0.18.2=py38h578d9bd_2 41 | - fvcore=0.1.2.post20201218=pyhd8ed1ab_0 42 | - glib=2.66.1=h92f7085_0 43 | - gmp=6.2.1=h295c915_3 44 | - gmpy2=2.1.2=py38heeb90bb_0 45 | - gst-plugins-base=1.14.0=hbbd80ab_1 46 | - gstreamer=1.14.0=hb31296c_0 47 | - hdf5=1.10.6=hb1b8bf9_0 48 | - icu=58.2=he6710b0_3 49 | - idna=2.10=pyh9f0ad1d_0 50 | - imageio=2.9.0=py_0 51 | - importlib-metadata=3.3.0=py38h578d9bd_3 52 | - intel-openmp=2020.2=254 53 | - joblib=1.0.0=pyhd3eb1b0_0 54 | - jpeg=9e=h5eee18b_1 55 | - kiwisolver=1.3.0=py38h2531618_0 56 | - lcms2=2.11=h396b838_0 57 | - ld_impl_linux-64=2.33.1=h53a641e_7 58 | - lerc=3.0=h295c915_0 59 | - libabseil=20230802.1=cxx17_h59595ed_0 60 | - libblas=3.9.0=20_linux64_mkl 61 | - libcblas=3.9.0=20_linux64_mkl 62 | - libcublas=12.3.4.1=hd3aeb46_0 63 | - libcufft=11.0.12.1=hd3aeb46_0 64 | - libcurand=10.3.4.107=hd3aeb46_0 65 | - libcusolver=11.5.4.101=hd3aeb46_0 66 | - libcusparse=12.2.0.103=hd3aeb46_0 67 | - libdeflate=1.17=h5eee18b_1 68 | - libedit=3.1.20191231=h14c3975_1 69 | - libffi=3.3=he6710b0_2 70 | - libgcc-ng=13.2.0=h807b86a_5 71 | - libgfortran-ng=7.3.0=hdf63c60_0 72 | - liblapack=3.9.0=20_linux64_mkl 73 | - libmagma=2.7.2=h173bb3b_2 74 | - libmagma_sparse=2.7.2=h173bb3b_2 75 | - libnvjitlink=12.3.101=hd3aeb46_0 76 | - libpng=1.6.39=h5eee18b_0 77 | - libprotobuf=4.25.1=hf27288f_2 78 | - libsodium=1.0.18=h36c2ea0_1 79 | - libstdcxx-ng=13.2.0=h7e041cc_5 80 | - libtiff=4.5.1=h6a678d5_0 81 | - libtorch=2.1.2=cuda120_h2aa5df7_301 82 | - libuuid=1.0.3=h1bed415_2 83 | - libuv=1.47.0=hd590300_0 84 | - libwebp-base=1.3.2=h5eee18b_0 85 | - libxcb=1.14=h7b6447c_0 86 | - libxml2=2.9.10=hb55368b_3 87 | - libzlib=1.2.13=hd590300_5 88 | - llvm-openmp=17.0.6=h4dfa4b3_0 89 | - lz4-c=1.9.2=heb0550a_3 90 | - magma=2.7.2=h51420fd_2 91 | - markdown=3.3.3=pyh9f0ad1d_0 92 | - matplotlib=3.3.2=h06a4308_0 93 | - matplotlib-base=3.3.2=py38h817c723_0 94 | - mkl=2023.2.0=h84fe81f_50496 95 | - mkl-service=2.4.0=py38h5eee18b_1 96 | - mkl_fft=1.3.8=py38h5eee18b_0 97 | - mkl_random=1.2.4=py38hdb19cb5_0 98 | - mpc=1.1.0=h10f8cd9_1 99 | - mpfr=4.0.2=hb69a4c5_1 100 | - mpmath=1.3.0=py38h06a4308_0 101 | - multidict=5.1.0=py38h27cfd23_2 102 | - nccl=2.20.3.1=h3a97aeb_0 103 | - ncurses=6.2=he6710b0_1 104 | - ninja=1.10.2=py38hff7bd54_0 105 | - oauthlib=3.0.1=py_0 106 | - olefile=0.46=py_0 107 | - openssl=1.1.1w=h7f8727e_0 108 | - packaging=20.8=pyhd3deb0d_0 109 | - paramiko=2.7.2=pyh9f0ad1d_0 110 | - pcre=8.44=he6710b0_0 111 | - pillow=8.0.1=py38he98fc37_0 112 | - pip=20.3.3=py38h06a4308_0 113 | - portalocker=1.7.0=py38h32f6830_1 114 | - property-cached=1.6.4=py_0 115 | - protobuf=4.25.1=py38hf14ab21_0 116 | - pyasn1=0.4.8=py_0 117 | - pyasn1-modules=0.2.7=py_0 118 | - pycparser=2.20=py_2 119 | - pyjwt=2.0.0=pyhd8ed1ab_0 120 | - pynacl=1.4.0=py38h1e0a361_2 121 | - pyopenssl=20.0.1=pyhd8ed1ab_0 122 | - pyparsing=2.4.7=py_0 123 | - pyqt=5.9.2=py38h05f1152_4 124 | - pysocks=1.7.1=py38h924ce5b_2 125 | - python=3.8.5=h7579374_1 126 | - python-dateutil=2.8.1=py_0 127 | - python_abi=3.8=1_cp38 128 | - pytorch=2.1.2=cuda120_py38ha6e2b7b_301 129 | - pytorch3d=0.7.5=cuda120py38h6e72adc_3 130 | - pytz=2021.1=pyhd8ed1ab_0 131 | - qt=5.9.7=h5867ecd_1 132 | - readline=8.0=h7b6447c_0 133 | - requests=2.25.1=pyhd3deb0d_0 134 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 135 | - rsa=4.6=pyh9f0ad1d_0 136 | - setuptools=51.0.0=py38h06a4308_2 137 | - sip=4.19.13=py38he6710b0_0 138 | - six=1.15.0=py38h06a4308_0 139 | - sleef=3.5.1=h9b69904_2 140 | - sqlite=3.33.0=h62c20be_0 141 | - sympy=1.12=py38h06a4308_0 142 | - tabulate=0.8.7=pyh9f0ad1d_0 143 | - tbb=2021.8.0=hdb19cb5_0 144 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 145 | - tk=8.6.10=hbc83047_0 146 | - torchvision=0.15.2=cpu_py38h83e0c9b_0 147 | - tornado=6.1=py38h27cfd23_0 148 | - typing-extensions=3.7.4.3=0 149 | - typing_extensions=4.9.0=py38h06a4308_1 150 | - werkzeug=1.0.1=pyh9f0ad1d_0 151 | - wheel=0.36.2=pyhd3eb1b0_0 152 | - xz=5.4.5=h5eee18b_0 153 | - yacs=0.1.6=py_0 154 | - yaml=0.2.5=h516909a_0 155 | - yarl=1.6.3=py38h25fe258_0 156 | - zipp=3.4.0=py_0 157 | - zlib=1.2.13=hd590300_5 158 | - zstd=1.5.5=hfc55251_0 159 | - pip: 160 | - annotated-types==0.6.0 161 | - anyio==4.3.0 162 | - argon2-cffi==20.1.0 163 | - async-generator==1.10 164 | - backcall==0.2.0 165 | - bleach==3.3.0 166 | - configparser==6.0.0 167 | - dataclasses==0.6 168 | - decorator==4.4.2 169 | - defusedxml==0.6.0 170 | - deprecated==1.2.12 171 | - dill==0.3.3 172 | - docker-pycreds==0.4.0 173 | - easydict==1.9 174 | - entrypoints==0.3 175 | - exceptiongroup==1.2.0 176 | - facenet-pytorch==2.5.1 177 | - fastapi==0.109.2 178 | - fsspec==2024.2.0 179 | - gitdb==4.0.11 180 | - gitpython==3.1.42 181 | - google-auth==2.28.1 182 | - google-auth-oauthlib==1.0.0 183 | - grpcio==1.62.0 184 | - h3ds==0.1.1 185 | - h5py==3.10.0 186 | - imageio-ffmpeg==0.4.3 187 | - insightface==0.1.5 188 | - iopath==0.1.2 189 | - ipykernel==5.5.0 190 | - ipython==7.21.0 191 | - ipython-genutils==0.2.0 192 | - ipywidgets==7.6.3 193 | - jedi==0.18.0 194 | - jinja2==2.11.3 195 | - jsonschema==3.2.0 196 | - jupyter==1.0.0 197 | - jupyter-client==6.1.11 198 | - jupyter-console==6.2.0 199 | - jupyter-core==4.7.1 200 | - jupyterlab-pygments==0.1.2 201 | - jupyterlab-widgets==1.0.0 202 | - lightning-bolts==0.7.0 203 | - lightning-utilities==0.10.1 204 | - lpips==0.1.3 205 | - markupsafe==1.1.1 206 | - mistune==0.8.4 207 | - mxnet==1.8.0 208 | - mxnet-cu110==1.8.0.post0 209 | - nbclient==0.5.3 210 | - nbconvert==6.0.7 211 | - nbformat==5.1.2 212 | - nest-asyncio==1.5.1 213 | - networkx==2.5 214 | - notebook==6.2.0 215 | - numpy==1.24.4 216 | - nvidia-htop==1.0.2 217 | - onnx==1.15.0 218 | - opencv-python==4.5.1.48 219 | - pandas==1.2.4 220 | - pandocfilters==1.4.3 221 | - parso==0.8.1 222 | - pathtools==0.1.2 223 | - pexpect==4.8.0 224 | - pickleshare==0.7.5 225 | - plotly==5.0.0 226 | - prometheus-client==0.9.0 227 | - promise==2.3 228 | - prompt-toolkit==3.0.16 229 | - psutil==5.9.8 230 | - ptyprocess==0.7.0 231 | - pydantic==2.6.2 232 | - pydantic-core==2.16.3 233 | - pygments==2.8.0 234 | - pyqt5==5.15.4 235 | - pyqt5-qt5==5.15.2 236 | - pyqt5-sip==12.8.1 237 | - pyrsistent==0.17.3 238 | - python-graphviz==0.8.4 239 | - pytorch-fid==0.2.0 240 | - pytorch-lightning==1.9.5 241 | - pywavelets==1.1.1 242 | - pyyaml==6.0.1 243 | - pyzmq==22.0.3 244 | - qdarkgraystyle==1.0.2 245 | - qdarkstyle==3.0.1 246 | - qtconsole==5.0.2 247 | - qtpy==1.9.0 248 | - scikit-image==0.18.1 249 | - scikit-learn==0.24.1 250 | - scipy==1.6.1 251 | - send2trash==1.5.0 252 | - sentry-sdk==1.40.5 253 | - shortuuid==1.0.11 254 | - smmap==5.0.1 255 | - sniffio==1.3.0 256 | - starlette==0.36.3 257 | - subprocess32==3.5.4 258 | - tenacity==7.0.0 259 | - tensorboard==2.14.0 260 | - tensorboard-data-server==0.7.2 261 | - tensorboardx==2.1 262 | - termcolor==2.4.0 263 | - terminado==0.9.2 264 | - testpath==0.4.4 265 | - threadpoolctl==2.1.0 266 | - tifffile==2021.2.1 267 | - toml==0.10.2 268 | - torchmetrics==1.3.1 269 | - tqdm==4.66.2 270 | - traitlets==5.0.5 271 | - trimesh==3.9.0 272 | - urllib3==1.26.18 273 | - wandb==0.12.2 274 | - wcwidth==0.2.5 275 | - webencodings==0.5.1 276 | - widgetsnbextension==3.5.1 277 | - wrapt==1.12.1 278 | - yaspin==2.5.0 279 | -------------------------------------------------------------------------------- /mutil/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/mutil/__init__.py -------------------------------------------------------------------------------- /mutil/bfm2017.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | import json 4 | 5 | import numpy as np 6 | 7 | 8 | class BFM: 9 | def __init__(self, path): 10 | self._shape_mu = None 11 | self._shape_pcab = None 12 | self._shape_pca_var = None 13 | self._faces = None 14 | self._expression_mu = None 15 | self._expression_pcab = None 16 | self._expression_pca_var = None 17 | 18 | @property 19 | def shape_mu(self): 20 | return np.array(self._shape_mu) 21 | 22 | @property 23 | def shape_pcab(self): 24 | return np.array(self._shape_pcab) 25 | 26 | @property 27 | def shape_pca_var(self): 28 | return np.array(self._shape_pca_var) 29 | 30 | @property 31 | def expression_mu(self): 32 | return np.array(self._expression_mu) 33 | 34 | @property 35 | def expression_pcab(self): 36 | return np.array(self._expression_pcab) 37 | 38 | @property 39 | def expression_pca_var(self): 40 | return np.array(self._expression_pca_var) 41 | 42 | @property 43 | def faces(self): 44 | return np.array(self._faces) 45 | 46 | def generate_vertices(self, shape_para, exp_para): 47 | ''' 48 | Args: 49 | shape_para: (n_shape_para, 1) 50 | exp_para: (n_exp_para, 1) 51 | Returns: 52 | vertices: (nver, 3) 53 | ''' 54 | raise NotImplementedError("Implement in subclass") 55 | 56 | def get_mean(self): 57 | return self.shape_mu.reshape(-1, 3) 58 | 59 | def apply_Rts(self, verts, R, t, s): 60 | transformed_verts = s * np.matmul(verts, R) + t 61 | return transformed_verts 62 | 63 | 64 | class BFM2017(BFM): 65 | N_COEFF_SHAPE = 199 66 | N_COEFF_EXPRESSION = 100 67 | 68 | def __init__(self, path_model, path_uv=None): 69 | super().__init__(path_model) 70 | import h5py 71 | self.file = h5py.File(path_model, 'r') 72 | self._shape_mu = self.file["shape"]["model"]["mean"] 73 | self._shape_pcab = self.file["shape"]["model"]["pcaBasis"] 74 | self._shape_pca_var = self.file["shape"]["model"]["pcaVariance"] 75 | self._faces = np.transpose(self.file["shape"]["representer"]["cells"]) 76 | self._expression_mu = self.file["expression"]["model"]["mean"] 77 | self._expression_pcab = self.file["expression"]["model"]["pcaBasis"] 78 | self._expression_pca_var = self.file["expression"]["model"]["pcaVariance"] 79 | 80 | if "model2017-1_face12_nomouth.h5" in path_model: 81 | # Model without ears and throat 82 | self.N_VERTICES = 28588 83 | self.N_FACES = 56572 84 | else: 85 | # Model with ears and throat 86 | self.N_VERTICES = 53149 87 | self.N_FACES = 105694 88 | 89 | assert 0 == self.faces.min() 90 | assert self.faces.max() == self.N_VERTICES - 1 91 | 92 | if path_uv is not None: 93 | self.vertices_uv, self.faces_uv = self.load_uv(path_uv) 94 | 95 | def load_uv(self, path_uv): 96 | with open(path_uv, 'r') as f: 97 | uv_para = json.load(f) 98 | verts_uvs = np.array(uv_para['textureMapping']['pointData']) 99 | faces_uvs = np.array(uv_para['textureMapping']['triangles']) 100 | return verts_uvs, faces_uvs 101 | 102 | def __delete__(self, instance): 103 | self.file.close() 104 | 105 | def generate_vertices(self, shape_para, exp_para): 106 | if shape_para.shape != (self.N_COEFF_SHAPE, 1): 107 | shape_para = shape_para.reshape(self.N_COEFF_SHAPE, 1) 108 | if exp_para.shape != (self.N_COEFF_EXPRESSION, 1): 109 | exp_para = exp_para.reshape(self.N_COEFF_EXPRESSION, 1) 110 | vertices = self.shape_mu.reshape(-1, 1) + self.shape_pcab @ shape_para + self.expression_pcab @ exp_para 111 | vertices = np.reshape(vertices, [int(3), int(self.N_VERTICES)], 'F').T 112 | return vertices.astype(np.float32) 113 | 114 | def sample(self, n, std_multiplier, variance): 115 | std = np.sqrt(variance) * std_multiplier 116 | std = np.broadcast_to(std, (n, std.shape[0])) 117 | mu = np.zeros_like(std) 118 | samples = np.random.normal(mu, std) 119 | return samples 120 | 121 | def sample_shape(self, n, std_multiplier=1): 122 | return self.sample(n, std_multiplier, self.shape_pca_var) 123 | 124 | def sample_expression(self, n, std_multiplier=1): 125 | return self.sample(n, std_multiplier, self.expression_pca_var) 126 | 127 | 128 | class BFM2017Tensor: 129 | N_VERTICES = 28588 130 | N_FACES = 56572 131 | N_COEFF_SHAPE = 199 132 | N_COEFF_EXPRESSION = 100 133 | 134 | def __init__(self, path_model, path_uv=None, device='cuda', verbose=False): 135 | print("Loading BFM 2017 into GPU... (this can take a while)") 136 | self.device = device 137 | self.file = h5py.File(path_model, 'r') 138 | # This can take a few seconds 139 | self.shape_mu = torch.Tensor(self.file["shape"]["model"]["mean"]).reshape(self.N_VERTICES*3, 1).to(device).float() 140 | self.shape_pcab = torch.Tensor(self.file["shape"]["model"]["pcaBasis"]).reshape(self.N_VERTICES*3, self.N_COEFF_SHAPE).to(device).float() 141 | self.shape_pca_var = torch.Tensor(self.file["shape"]["model"]["pcaVariance"]).reshape(self.N_COEFF_SHAPE).to(device).float() 142 | self.expression_pcab = torch.Tensor(self.file["expression"]["model"]["pcaBasis"]).reshape(self.N_VERTICES*3, self.N_COEFF_EXPRESSION).to(device).float() 143 | self.expression_pca_var = torch.Tensor(self.file["expression"]["model"]["pcaVariance"]).reshape(self.N_COEFF_EXPRESSION).to(device).float() 144 | self.faces = torch.Tensor(np.transpose(self.file["shape"]["representer"]["cells"])).reshape(self.N_FACES, 3).to(device) 145 | 146 | if path_uv is not None: 147 | self.vertices_uv, self.faces_uv = self.load_uv(path_uv) 148 | print("Done") 149 | 150 | def load_uv(self, path_uv): 151 | with open(path_uv, 'r') as f: 152 | uv_para = json.load(f) 153 | verts_uvs = torch.Tensor(uv_para['textureMapping']['pointData']).reshape(-1, 3).to(self.device) 154 | faces_uvs = torch.Tensor(np.array(uv_para['textureMapping']['triangles'])).reshape(-1, 3).to(self.device) 155 | return verts_uvs, faces_uvs 156 | 157 | def __delete__(self, instance): 158 | self.file.close() 159 | 160 | def generate_vertices(self, shape_para, exp_para): 161 | if shape_para.shape != (self.N_COEFF_SHAPE, 1): 162 | shape_para = shape_para.reshape(self.N_COEFF_SHAPE, 1) 163 | if exp_para.shape != (self.N_COEFF_EXPRESSION, 1): 164 | exp_para = exp_para.reshape(self.N_COEFF_EXPRESSION, 1) 165 | vertices = self.shape_mu +\ 166 | self.shape_pcab @ shape_para + \ 167 | self.expression_pcab @ exp_para 168 | return vertices.reshape(-1, 3).float() 169 | 170 | def sample(self, n, std_multiplier, variance): 171 | std = torch.sqrt(variance) * std_multiplier 172 | std = std.expand((n, std.shape[0])) 173 | q = torch.distributions.Normal(torch.zeros_like(std).to(std.device), std * std_multiplier) 174 | samples = q.rsample() 175 | return samples 176 | 177 | def sample_shape(self, n, std_multiplier=1): 178 | return self.sample(n, std_multiplier, self.shape_pca_var) 179 | 180 | def sample_expression(self, n, std_multiplier=1): 181 | return self.sample(n, std_multiplier, self.expression_pca_var) 182 | -------------------------------------------------------------------------------- /mutil/data_types.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import PIL.Image 4 | 5 | 6 | def to_np(t): 7 | if isinstance(t, torch.Tensor): 8 | t = t.detach().cpu().numpy() 9 | if isinstance(t, PIL.Image.Image): 10 | t = np.array(t) 11 | return t -------------------------------------------------------------------------------- /mutil/files.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import shutil 4 | from mutil.str_format import get_time_string 5 | 6 | 7 | def listdir(path, prefix='', postfix='', return_prefix=True, 8 | return_postfix=True, return_abs=False): 9 | """ 10 | Lists all files in path that start with prefix and end with postfix. 11 | By default, this function returns all filenames. If you do not want to 12 | return the pre- or postfix, set the corresponding parameters to False. 13 | :param path: 14 | :param prefix: 15 | :param postfix: 16 | :param return_prefix: 17 | :param return_postfix: 18 | :return: list(str) 19 | """ 20 | files = os.listdir(path) 21 | filtered_files = filter( 22 | lambda f: f.startswith(prefix) and f.endswith(postfix), files) 23 | return_files = filtered_files 24 | if not return_prefix: 25 | idx_start = len(prefix) - 1 26 | return_files = [f[idx_start:] for f in filtered_files] 27 | if not return_postfix: 28 | idx_end = len(postfix) - 1 29 | return_files = [f[:-idx_end-1] for f in filtered_files] 30 | return_files = set(return_files) 31 | result = list(return_files) 32 | if return_abs: 33 | result = [os.path.join(path, r) for r in result] 34 | return result 35 | 36 | 37 | def mkdir(path): 38 | if not os.path.exists(path): 39 | os.makedirs(path, exist_ok=True) 40 | 41 | 42 | def copy(src, dest, overwrite=False): 43 | if os.path.exists(dest) and overwrite: 44 | if os.path.isdir(dest): 45 | shutil.rmtree(dest) 46 | else: 47 | os.remove(dest) 48 | try: 49 | shutil.copytree(src, dest) 50 | except OSError as e: 51 | # If the error was caused because the source wasn't a directory 52 | if e.errno == errno.ENOTDIR: 53 | shutil.copy(src, dest) 54 | else: 55 | print('Directory not copied. Error: %s' % e) 56 | 57 | 58 | def copy_src(path_from, path_to): 59 | """ 60 | Make sure to have everything in path_from folder 61 | There should not be large files (e.g. checkpoints or images) 62 | Args: 63 | path_from: 64 | path_to: 65 | 66 | Returns: 67 | 68 | """ 69 | assert os.path.isdir(path_from) 70 | # Collect all files and folders that contain python files 71 | tmp_folder = os.path.join(path_to, 'src/') 72 | mkdir(tmp_folder) 73 | 74 | from_folder = os.path.basename(path_from) 75 | copy(path_from, os.path.join(tmp_folder, from_folder), overwrite=True) 76 | time_str = get_time_string() 77 | 78 | path_archive = os.path.join(path_to, "src_{}".format(time_str)) 79 | shutil.make_archive(path_archive, 'zip', tmp_folder) 80 | try: 81 | shutil.rmtree(tmp_folder) 82 | except FileNotFoundError: 83 | # We got a FileNotfound error on the cluster. Maybe some race conditions? 84 | pass 85 | print("Copied folder {} to {}".format(path_from, path_archive)) -------------------------------------------------------------------------------- /mutil/np_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mutil.data_types import to_np 3 | 4 | 5 | def adjust_width(img, target_width, is_uv=False): 6 | assert len(img.shape) == 3, "Are you using a batch dimension?" 7 | if is_uv: 8 | assert img.shape[-1] == 2 9 | n_channels = 2 10 | fill_value = -1.0 11 | else: 12 | assert img.shape[-1] == 3 13 | n_channels = 3 14 | fill_value = 255 15 | diff = target_width - img.shape[0] 16 | pad = np.ones((img.shape[0], diff // 2, n_channels)) * fill_value 17 | img = np.concatenate([pad, img, pad], 18 | 1) 19 | if is_uv: 20 | img = img.astype(np.float) 21 | else: 22 | img = img.astype(np.uint8) 23 | return img 24 | 25 | 26 | def uv_to_color(rendered_uv, output_format="np"): 27 | uv_color = np.concatenate([rendered_uv, np.zeros((*rendered_uv.shape[:2], 1))], -1) 28 | if output_format == "pil": 29 | from PIL import Image 30 | uv_color = (uv_color + 1 ) * 127.5 31 | uv_color = Image.fromarray(uv_color.astype(np.uint8)) 32 | return uv_color 33 | 34 | 35 | def center_crop(image, new_height, new_width): 36 | # width, height = image.size # Get dimensions 37 | height, width = image.shape[:2] 38 | center_h = height // 2 39 | center_w = width // 2 40 | 41 | left = center_w - new_width // 2 42 | top = center_h - new_height // 2 43 | right = left + new_width 44 | bottom = top + new_height 45 | 46 | # Crop the center of the image 47 | image = image[top:bottom, left:right] 48 | return image 49 | 50 | 51 | def interpolation(n, latent_from, latent_to, gaussian_correction=True): 52 | """ 53 | Linear interpolate for a Gaussian RV 54 | :param n: 55 | :param latent_from: tensor or numpy array with batch dimension 56 | :param latent_to: 57 | :param type: 58 | :return: 59 | """ 60 | latent_from, latent_to = to_np(latent_from), to_np(latent_to) 61 | steps = np.linspace(0, 1, n).astype(np.float32) 62 | 63 | for _ in range(len(latent_to.shape)): 64 | steps = np.expand_dims(steps, -1) 65 | if gaussian_correction: 66 | steps = steps / np.sqrt(steps ** 2 + (1 - steps) ** 2) # Variance correction 67 | all_latents = (1 - steps) * latent_from + steps * latent_to 68 | return all_latents 69 | -------------------------------------------------------------------------------- /mutil/object_dict.py: -------------------------------------------------------------------------------- 1 | class ObjectDict(dict): 2 | def __init__(self, d): 3 | super().__init__() 4 | for k, v in d.items(): 5 | setattr(self, k, v) -------------------------------------------------------------------------------- /mutil/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from mutil.threed_utils import eulerAnglesToRotationMatrix 5 | from torchvision.transforms import transforms 6 | 7 | 8 | def to_tensor(a, device='cpu', dtype=None): 9 | if isinstance(a, torch.Tensor): 10 | return a.to(device) 11 | return torch.tensor(np.array(a, dtype=dtype)).to(device) 12 | 13 | 14 | def tensor2np(t): 15 | if isinstance(t, torch.Tensor): 16 | return t.detach().clone().cpu().numpy() 17 | return t 18 | 19 | 20 | def get_device(): 21 | # Set the device 22 | if torch.cuda.is_available(): 23 | device = torch.device("cuda:0") 24 | else: 25 | device = torch.device("cpu") 26 | print("WARNING: CPU only, this will be slow!") 27 | return device 28 | 29 | 30 | # Normalize images with the imagenet means and stds 31 | class ImageNetNormalizeTransform(transforms.Normalize): 32 | """ 33 | Test case: 34 | img_tensor = transforms.Compose([transforms.ToTensor()])(img) 35 | describe_tensor(img_tensor, title="img tensor to tensor") 36 | 37 | img_tensor = transforms.Compose([ 38 | transforms.ToTensor(), # Converts image to range [0, 1] 39 | ImageNetNormalizeTransformForward() # Uses ImageNet means and std 40 | ])(img) 41 | describe_tensor(img_tensor, title="img tensor") 42 | 43 | img_tensor = transforms.Compose([ImageNetNormalizeTransformInverse()])(img_tensor) 44 | describe_tensor(img_tensor, title="img") 45 | """ 46 | NORMALIZE_MEAN = (0.485, 0.456, 0.406) 47 | NORMALIZE_STD = (0.229, 0.224, 0.225) 48 | 49 | 50 | class ImageNetNormalizeTransformForward(ImageNetNormalizeTransform): 51 | """ 52 | Expects inputs in range [0, 1] (e.g. from transforms.ToTensor) 53 | """ 54 | def __init__(self): 55 | super().__init__(mean=self.NORMALIZE_MEAN, std=self.NORMALIZE_STD) 56 | 57 | 58 | class ImageNetNormalizeTransformInverse(ImageNetNormalizeTransform): 59 | """ 60 | Expects inputs as after using ImageNetNormalizeTransformForward. 61 | If this is the case, the inputs are returned in the range [0, 1] * scale 62 | 63 | Can handle tensors with and without batch dimension 64 | """ 65 | def __init__(self, scale=1.0): 66 | mean = torch.as_tensor(self.NORMALIZE_MEAN) 67 | std = torch.as_tensor(self.NORMALIZE_STD) 68 | std_inv = 1 / (std + 1e-7) 69 | mean_inv = -mean * std_inv 70 | super().__init__(mean=mean_inv, std=std_inv) 71 | 72 | self.scale = scale # Can also be 255 73 | 74 | def __call__(self, tensor): 75 | if len(tensor.shape) == 3: # Image 76 | return super().__call__(tensor.clone()) * self.scale 77 | elif len(tensor.shape) == 4: # With batch dimensions 78 | results = [self(t) for t in tensor] 79 | return torch.stack(results) 80 | 81 | 82 | def theta2rotation_matrix(theta_x=0, theta_y=0, theta_z=0, theta_all=None): 83 | # Angles should be in degrees 84 | if theta_all is not None: 85 | theta = [np.deg2rad(t) for t in theta_all] 86 | else: 87 | theta = [np.deg2rad(theta_x), np.deg2rad(theta_y), 88 | np.deg2rad(-theta_z)] # x looking from bottom, y looking to right, z tilting to left 89 | R = eulerAnglesToRotationMatrix(theta) 90 | R = torch.Tensor(R).float() 91 | return R 92 | -------------------------------------------------------------------------------- /mutil/renderer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Union, List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from pytorch3d.ops import interpolate_face_attributes 8 | from pytorch3d.renderer import ( 9 | look_at_view_transform, 10 | FoVPerspectiveCameras, 11 | PointLights, 12 | RasterizationSettings, 13 | MeshRenderer, 14 | MeshRasterizer, 15 | SoftPhongShader, 16 | TexturesUV 17 | ) 18 | 19 | from mutil.pytorch_utils import to_tensor 20 | 21 | 22 | class Renderer(torch.nn.Module): 23 | def __init__(self, dist=2, elev=0, azimuth=180, fov=40, image_size=256, R=None, T=None, cameras=None, return_format="torch", device='cuda'): 24 | super().__init__() 25 | # If you provide R and T, you don't need dist, elev, azimuth, fov 26 | self.device = device 27 | self.return_format = return_format 28 | 29 | # Data structures and functions for rendering 30 | if cameras is None: 31 | if R is None and T is None: 32 | R, T = look_at_view_transform(dist, elev, azimuth) 33 | cameras = FoVPerspectiveCameras(R=R, T=T, znear=1, zfar=10000, fov=fov, degrees=True, device=device) 34 | # cameras = PerspectiveCameras(R=R, T=T, focal_length=1.6319*10, device=device) 35 | 36 | self.raster_settings = RasterizationSettings( 37 | image_size=image_size, 38 | blur_radius=0.0, # no blur 39 | bin_size=0, 40 | ) 41 | # Place lights at the same point as the camera 42 | location = T 43 | if location is None: 44 | location = ((0,0,0),) 45 | lights = PointLights(ambient_color=((0.3, 0.3, 0.3),), diffuse_color=((0.7, 0.7, 0.7),), device=device, 46 | location=location) 47 | 48 | self.mesh_rasterizer = MeshRasterizer( 49 | cameras=cameras, 50 | raster_settings=self.raster_settings 51 | ) 52 | self._renderer = MeshRenderer( 53 | rasterizer=self.mesh_rasterizer, 54 | shader=SoftPhongShader( 55 | device=device, 56 | cameras=cameras, 57 | lights=lights 58 | ) 59 | ) 60 | self.cameras = self.mesh_rasterizer.cameras 61 | 62 | def _flatten(self, a): 63 | return torch.from_numpy(np.array(a)).reshape(-1, 1).to(self.device) 64 | 65 | def format_output(self, image_tensor, return_format=None): 66 | if return_format is None: 67 | return_format = self.return_format 68 | 69 | if return_format == "torch": 70 | return image_tensor 71 | elif return_format == "pil": 72 | if len(image_tensor.shape) == 4: 73 | vis = [self.format_output(t, return_format) for t in image_tensor] 74 | return vis 75 | else: 76 | to_pil = torchvision.transforms.ToPILImage() 77 | vis = image_tensor.detach().cpu() 78 | return to_pil(vis) 79 | elif return_format == "np_raw": 80 | return image_tensor.detach().cpu().permute(0,2,3,1).numpy() 81 | elif return_format == "np": 82 | pil_image = self.format_output(image_tensor, return_format='pil') 83 | if isinstance(pil_image, list): 84 | pil_image = [np.array(img) for img in pil_image] 85 | return np.array(pil_image) 86 | 87 | 88 | class ImagelessTexturesUV(TexturesUV): 89 | def __init__(self, 90 | faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], 91 | verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], 92 | padding_mode: str = "border", 93 | align_corners: bool = True, 94 | ): 95 | self.device = faces_uvs[0].device 96 | batch_size = faces_uvs.shape[0] 97 | maps = torch.zeros(batch_size, 2, 2, 3).to(self.device) # This is simply to instantiate a texture, but it is not used. 98 | super().__init__(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs, padding_mode=padding_mode, align_corners=align_corners) 99 | 100 | def sample_pixel_uvs(self, fragments, **kwargs) -> torch.Tensor: 101 | """ 102 | Copied from super().sample_textures and adapted to output pixel_uvs instead of the sampled texture. 103 | 104 | Args: 105 | fragments: 106 | The outputs of rasterization. From this we use 107 | 108 | - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices 109 | of the faces (in the packed representation) which 110 | overlap each pixel in the image. 111 | - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying 112 | the barycentric coordianates of each pixel 113 | relative to the faces (in the packed 114 | representation) which overlap the pixel. 115 | 116 | Returns: 117 | texels: tensor of shape (N, H, W, K, C) giving the interpolated 118 | texture for each pixel in the rasterized image. 119 | """ 120 | if self.isempty(): 121 | faces_verts_uvs = torch.zeros( 122 | (self._N, 3, 2), dtype=torch.float32, device=self.device 123 | ) 124 | else: 125 | packing_list = [ 126 | i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list()) 127 | ] 128 | faces_verts_uvs = torch.cat(packing_list) 129 | # Each vertex yields 3 triangles with u,v coordinates (N, 3, 2) 130 | # pixel_uvs: (N, H, W, K, 2) 131 | pixel_uvs = interpolate_face_attributes( 132 | fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs 133 | ) 134 | 135 | N, H_out, W_out, K = fragments.pix_to_face.shape 136 | # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) 137 | pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) 138 | pixel_uvs = pixel_uvs * 2.0 - 1.0 139 | return pixel_uvs 140 | 141 | 142 | class UVRenderer(Renderer): 143 | def __init__(self, verts_uv, faces_uv, dist=2, elev=0, azimuth=180, fov=40, image_size=256, R=None, T=None, cameras=None): 144 | # obj path should be the path to the obj file with the UV parametrization 145 | super().__init__(dist=dist, elev=elev, azimuth=azimuth, fov=fov, image_size=image_size, R=R, T=T, cameras=cameras) 146 | self.verts_uvs = verts_uv 147 | self.faces_uvs = faces_uv 148 | 149 | def render(self, meshes): 150 | batch_size = len(meshes) 151 | texture = ImagelessTexturesUV(verts_uvs=self.verts_uvs.expand(batch_size, -1, -1), faces_uvs=self.faces_uvs.expand(batch_size, -1, -1)) 152 | # Currently only supports one mesh in meshes 153 | fragments = self.mesh_rasterizer(meshes) 154 | rendered_uv = texture.sample_pixel_uvs(fragments) 155 | return self.format_output(rendered_uv) 156 | 157 | 158 | class BFMUVRenderer(Renderer): 159 | def __init__(self, json_path, *args, **kwargs): 160 | # json_path = ".../face12.json" 161 | super().__init__(*args, **kwargs) 162 | 163 | with open(json_path, 'r') as f: 164 | uv_para = json.load(f) 165 | verts_uvs = np.array(uv_para['textureMapping']['pointData']) 166 | faces_uvs = np.array(uv_para['textureMapping']['triangles']) 167 | 168 | verts_uvs = to_tensor(verts_uvs).unsqueeze(0).float() 169 | faces_uvs = to_tensor(faces_uvs).unsqueeze(0).long() 170 | self.texture = ImagelessTexturesUV(verts_uvs=verts_uvs, faces_uvs=faces_uvs).to(self.device) 171 | 172 | def render(self, meshes): 173 | # Currently only supports one mesh in meshes 174 | fragments = self.mesh_rasterizer(meshes) 175 | rendered_uv = self.texture.sample_pixel_uvs(fragments) 176 | rendered_uv = rendered_uv.permute(0, 3, 1, 2) # to CHW 177 | return self.format_output(rendered_uv) 178 | 179 | -------------------------------------------------------------------------------- /mutil/str_format.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | 4 | 5 | def get_time_string(format="%Y%m%d_%H%M"): 6 | time_str = datetime.datetime.now().strftime(format) 7 | return time_str -------------------------------------------------------------------------------- /mutil/threed_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import matplotlib as mpl 3 | import matplotlib.pyplot as plt 4 | 5 | # add path for demo utils functions 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def setup_plots(): 11 | mpl.rcParams['savefig.dpi'] = 80 12 | mpl.rcParams['figure.dpi'] = 80 13 | 14 | # Set the device 15 | if torch.cuda.is_available(): 16 | device = torch.device("cuda:0") 17 | else: 18 | device = torch.device("cpu") 19 | print("WARNING: CPU only, this will be slow!") 20 | 21 | 22 | 23 | def align_ls(mp_verts, bfm_verts, idx_mp, idx_bfm): 24 | tmp_mp = mp_verts.clone()[idx_mp].reshape(1, -1, 3).float() 25 | tmp_bfm = bfm_verts.clone()[idx_bfm].reshape(1, -1, 3).float() 26 | R, t, s = corresponding_points_alignment(tmp_mp, tmp_bfm, estimate_scale=True) 27 | mp_verts_aligned = s * mp_verts.mm(R[0]) + t 28 | return mp_verts_aligned, R, t, s 29 | 30 | 31 | def unwrap(vertices, faces): 32 | import trimesh 33 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) # Set process to False, o.w. will change vertex order 34 | mesh = mesh.unwrap() # This will change the vertex order 35 | uv = mesh.visual.uv 36 | return uv 37 | 38 | 39 | def plot_mesh(verts, faces, show=False): 40 | import plotly.graph_objects as go 41 | from mutil.pytorch_utils import tensor2np 42 | 43 | verts = tensor2np((verts)) 44 | # faces = tensor2np(faces) 45 | x, y, z = verts[:,0], verts[:,1],\ 46 | verts[:,2] 47 | # points3d(x,y,z) 48 | # i, j, k = faces[:,0], faces[:,1], faces[:,2] 49 | fig = go.Figure(data=[go.Mesh3d(x=x, y=y, z=z, opacity=0.9)]) 50 | if show: 51 | plt.show() 52 | return fig 53 | 54 | 55 | def apply_Rts(vertices, R, t, s, correct_tanslation=False): 56 | assert len(vertices.shape) == 3, "Needs batch size" 57 | bs, n_vertices = vertices.shape[:2] 58 | s = s.view(bs, 1, 1).expand(bs, n_vertices, 3) 59 | t = t.view(bs, 1, 3).expand(bs, n_vertices, 3) 60 | # print('verts mean', vertices.mean(1)) 61 | if correct_tanslation: 62 | # This is the center point that we rotate around 63 | center = vertices.mean(-2).unsqueeze(1).expand(vertices.shape) 64 | rotated_vertices = (vertices - center).matmul(R) + center 65 | else: 66 | rotated_vertices = vertices.matmul(R) 67 | # print('rotated', rotated_vertices.mean(1)) 68 | transformed_vertices = s * rotated_vertices + t 69 | # print('transformed', transformed_vertices.mean(1)) 70 | return transformed_vertices 71 | 72 | 73 | # Checks if a matrix is a valid rotation matrix. 74 | def isRotationMatrix(R, eps=1e-6): 75 | # https://www.learnopencv.com/rotation-matrix-to-euler-angles/ 76 | Rt = np.transpose(R) 77 | shouldBeIdentity = np.dot(Rt, R) 78 | I = np.identity(3, dtype=R.dtype) 79 | n = np.linalg.norm(I - shouldBeIdentity) 80 | if n >= eps: 81 | raise Warning("Warning: Rt != R. delta norm: {}".format(n)) 82 | return n < eps 83 | 84 | 85 | # Calculates rotation matrix to euler angles 86 | # The result is the same as MATLAB except the order 87 | # of the euler angles ( x and z are swapped ). 88 | def rotationMatrixToEulerAngles(R, eps=1e-6, degrees=False): 89 | # https://www.learnopencv.com/rotation-matrix-to-euler-angles/ 90 | 91 | assert isRotationMatrix(R, eps) 92 | 93 | sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) 94 | 95 | singular = sy < 1e-6 96 | 97 | if not singular: 98 | x = math.atan2(R[2, 1], R[2, 2]) 99 | y = math.atan2(-R[2, 0], sy) 100 | z = math.atan2(R[1, 0], R[0, 0]) 101 | else: 102 | x = math.atan2(-R[1, 2], R[1, 1]) 103 | y = math.atan2(-R[2, 0], sy) 104 | z = 0 105 | result = np.array([x, y, z]) 106 | if degrees: 107 | result = np.rad2deg(result) 108 | return result 109 | 110 | 111 | # Calculates Rotation Matrix given euler angles. 112 | def eulerAnglesToRotationMatrix(theta): 113 | # Angles in radians plz 114 | # theta order: x, y z 115 | # pitch, yaw, roll 116 | R_x = np.array([[1, 0, 0], 117 | [0, math.cos(theta[0]), -math.sin(theta[0])], 118 | [0, math.sin(theta[0]), math.cos(theta[0])] 119 | ]) 120 | 121 | R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1])], 122 | [0, 1, 0], 123 | [-math.sin(theta[1]), 0, math.cos(theta[1])] 124 | ]) 125 | 126 | R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0], 127 | [math.sin(theta[2]), math.cos(theta[2]), 0], 128 | [0, 0, 1] 129 | ]) 130 | R = np.dot(R_z, np.dot(R_y, R_x)) 131 | 132 | return R 133 | 134 | 135 | def view_matrix_fps(pitch, yaw, eye, project_eye=False): 136 | """ 137 | Following https://www.3dgep.com/understanding-the-view-matrix/#The_View_Matrix 138 | Right-handed coordinate system 139 | if project_eye: the eye will be dotted with the axis, otherwise not. 140 | 141 | """ 142 | sp = np.sin(pitch) 143 | sy = np.sin(yaw) 144 | cp = np.cos(pitch) 145 | cy = np.cos(yaw) 146 | x_axis = [cy, 0, -sy] 147 | y_axis = [sy * sp, cp, cy * sp] 148 | z_axis = [sy * cp, -sp, cp * cy] 149 | if project_eye: 150 | translation = [-np.dot(x_axis, eye), -np.dot(y_axis, eye), -np.dot(z_axis, eye)] 151 | else: 152 | translation = eye 153 | V = np.array([x_axis, y_axis, z_axis, translation]) # 4x3 154 | V = V.T # 3x4 155 | V = np.vstack((V, np.array([0, 0, 0, 1.0]).reshape(1, 4))) 156 | return V 157 | 158 | 159 | def fov2focal(fov, image_size, degree=True): 160 | """ 161 | Computes focal length from similarity of triangles 162 | Args: 163 | fov: angle of view in degrees 164 | image_size: in pixels or any metric unit 165 | degree: True if fov is in degrees, else false 166 | 167 | Returns: focal length in the same unit as image_size 168 | 169 | """ 170 | A = image_size / 2.0 # Half the image size 171 | a = fov / 2.0 # Half the fov angle 172 | if degree: 173 | a = np.deg2rad(a) 174 | f = A / np.tan(a) # numpy expects angles in radians 175 | return f 176 | 177 | 178 | def create_cam2world_matrix(forward_vector, origin): 179 | # Adapted from pigan repo 180 | """Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.""" 181 | def normalize_vec(v): 182 | return v / np.linalg.norm(v) 183 | forward_vector = normalize_vec(forward_vector) 184 | up_vector = np.array([0, 1.0, 0]) 185 | 186 | left_vector = normalize_vec(np.cross(up_vector, forward_vector, axis=-1)) 187 | 188 | up_vector = normalize_vec(np.cross(forward_vector, left_vector, axis=-1)) 189 | 190 | rotation_matrix = np.eye(4) 191 | rotation_matrix[:3, :3] = np.stack((-left_vector, up_vector, -forward_vector), axis=-1) 192 | 193 | translation_matrix = np.eye(4) 194 | translation_matrix[:3, 3] = origin 195 | 196 | cam2world = translation_matrix @ rotation_matrix 197 | 198 | return cam2world 199 | 200 | 201 | def project_mesh(vertices, img, cam, color=[255, 0, 0]): 202 | """ 203 | 204 | Args: 205 | vertices: Nx3 np.array 206 | img: PIL.Image (np.array also works) 207 | cam: K, P: K is the 3x3 camera matrix (maps to camera space; focal, princiapl point, skew), P is the 4x4 matrix (R|T) and maps to the vertex space (will be inverted) 208 | color: 209 | 210 | Returns: PIL.Image with projected vertices 211 | 212 | """ 213 | from PIL import Image 214 | # Adapted from https://github.com/CrisalixSA/h3ds 215 | 216 | # Expand cam attributes 217 | K, P = cam # K is the 3x3 camera matrix (focal, principal point, skew), P is the 4x4 projection matrix (R|T) 218 | # P projects from camera to world space, so we need the inverse 219 | P_inv = np.linalg.inv(P) 220 | 221 | # Project mesh vertices into 2D 222 | p3d_h = np.hstack((vertices, np.ones((vertices.shape[0], 1)))) # Homogeneous 223 | p2d_h = (K @ P_inv[:3, :] @ p3d_h.T).T # Apply the camera and the world^-1 projection. Transpose result 224 | p2d = p2d_h[:, :-1] / p2d_h[:, -1:] # Divide by Z to project to image plane 225 | # The p2d now contains the indices of the values as pixels 226 | 227 | # Draw p2d to image 228 | img_proj = np.array(img) 229 | p2d = np.clip(p2d, 0, img.width - 1).astype(np.uint32) # Discretize them 230 | img_proj[p2d[:, 1], p2d[:, 0]] = color 231 | 232 | return Image.fromarray(img_proj.astype(np.uint8)) 233 | 234 | 235 | def points2homo(points, times=1): 236 | if len(points.shape) == 2: 237 | points_h = np.hstack((points, np.ones((points.shape[0], times)))) # N,4 238 | elif len(points.shape) == 3: 239 | # Has a batch dimension 240 | points_h = np.dstack((points, np.ones((points.shape[0], points.shape[1], times)))) 241 | else: 242 | raise Warning("Invalid shape") 243 | return points_h 244 | 245 | 246 | def scaling_transform(s): 247 | scale_transform = np.zeros((4,4)) 248 | scale_transform[0,0] = s 249 | scale_transform[1,1] = s 250 | scale_transform[2,2] = s 251 | scale_transform[3,3] = 1 252 | return scale_transform -------------------------------------------------------------------------------- /varitex/custom_callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/custom_callbacks/__init__.py -------------------------------------------------------------------------------- /varitex/custom_callbacks/callbacks.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | 4 | from varitex.data.keys_enum import DataItemKey as DIK 5 | from varitex.data.uv_factory import BFMUVFactory 6 | from varitex.visualization.batch import CombinedVisualizer, SampledVisualizer, UVVisualizer, InterpolationVisualizer, \ 7 | NeuralTextureVisualizer 8 | 9 | 10 | class ImageLogCallback(pl.Callback): 11 | """ 12 | Logs a variety of visualizations during training. 13 | """ 14 | 15 | def __init__(self, opt, use_bfm_gpu=False): 16 | self.MASK_VALUE = opt.uv_mask_value 17 | self.opt = opt 18 | 19 | uv_factory_bfm = BFMUVFactory(opt, use_bfm_gpu) 20 | 21 | self.visualizer_combined = CombinedVisualizer(opt, mask_value=self.MASK_VALUE) 22 | self.visualizer_sampled = SampledVisualizer(opt, n_samples=3, mask_value=self.MASK_VALUE) 23 | self.visualizer_uv = UVVisualizer(opt, bfm_uv_factory=uv_factory_bfm, mask_value=self.MASK_VALUE) 24 | self.visualizer_interpolation = InterpolationVisualizer(opt, mask_value=self.MASK_VALUE) 25 | self.visualizer_neural_texture = NeuralTextureVisualizer(opt) 26 | 27 | def log_image(self, logger, key, vis, step): 28 | if self.opt.logger == "tensorboard": 29 | logger.experiment.add_image(key, vis, step) 30 | elif self.opt.logger == "wandb": 31 | import wandb 32 | logger.experiment.log( 33 | {key: wandb.Image(vis)} 34 | ) 35 | 36 | def _combined(self, pl_module, batch, prefix): 37 | vis_combined = self.visualizer_combined.visualize(batch) 38 | self.log_image(pl_module.logger, "{}/combined".format(prefix), vis_combined, pl_module.global_step) 39 | 40 | def _sampled(self, pl_module, batch, batch_idx, prefix): 41 | vis_sampled = self.visualizer_sampled.visualize(batch, batch_idx, pl_module, std_multiplier=3) 42 | self.log_image(pl_module.logger, "{}/outputs_sampled".format(prefix), vis_sampled, pl_module.global_step) 43 | 44 | vis_sampled = self.visualizer_sampled.visualize_unseen(batch, batch_idx, pl_module, std_multiplier=2) 45 | self.log_image(pl_module.logger, "{}/outputs_sampled_gaussian_std2".format(prefix), vis_sampled, 46 | pl_module.global_step) 47 | 48 | def _posed(self, pl_module, batch, batch_idx, prefix): 49 | deg_range = torch.arange(-45, 45 + 1, 30) 50 | vis = self.visualizer_uv.visualize_grid(pl_module, batch, batch_idx, deg_range) 51 | self.log_image(pl_module.logger, "{}/outputs_posed".format(prefix), vis, pl_module.global_step) 52 | 53 | def _interpolated(self, pl_module, batch, batch_idx, prefix, n=5): 54 | batch2 = pl_module.forward_sample_style(batch.copy(), batch_idx, std_multiplier=4) # Random new style code 55 | vis = self.visualizer_interpolation.visualize(pl_module, batch, batch2, n, bidirectional=False, 56 | include_gt=False) 57 | self.log_image(pl_module.logger, "{}/interpolation/random_std2".format(prefix), vis, pl_module.global_step) 58 | 59 | batch2 = batch.copy() 60 | batch2[DIK.STYLE_LATENT] = torch.zeros_like(batch[DIK.STYLE_LATENT]).to(batch[DIK.STYLE_LATENT].device) 61 | vis = self.visualizer_interpolation.visualize(pl_module, batch, batch2, n, bidirectional=False, 62 | include_gt=False) 63 | self.log_image(pl_module.logger, "{}/interpolation/zeros".format(prefix), vis, pl_module.global_step) 64 | 65 | batch2 = batch.copy() 66 | batch2[DIK.STYLE_LATENT] = torch.randn_like(batch[DIK.STYLE_LATENT]).to(batch[DIK.STYLE_LATENT].device) 67 | vis = self.visualizer_interpolation.visualize(pl_module, batch, batch2, n, bidirectional=False, 68 | include_gt=False) 69 | self.log_image(pl_module.logger, "{}/interpolation/standard_gaussian".format(prefix), vis, 70 | pl_module.global_step) 71 | 72 | def _neural_texture(self, pl_module, batch, batch_idx, prefix): 73 | vis = self.visualizer_neural_texture.visualize_interior(batch, batch_idx) 74 | self.log_image(pl_module.logger, "{}/neural_texture/interior".format(prefix), vis, 75 | pl_module.global_step) 76 | 77 | vis = self.visualizer_neural_texture.visualize_interior_sampled(batch, batch_idx) 78 | self.log_image(pl_module.logger, "{}/neural_texture/interior_sampled".format(prefix), vis, 79 | pl_module.global_step) 80 | 81 | vis = self.visualizer_neural_texture.visualize_exterior_sampled(batch, batch_idx) 82 | self.log_image(pl_module.logger, "{}/neural_texture/exterior_sampled".format(prefix), vis, 83 | pl_module.global_step) 84 | 85 | vis = self.visualizer_neural_texture.visualize_enhanced(batch, batch_idx) 86 | self.log_image(pl_module.logger, "{}/neural_texture/enhanced".format(prefix), vis, 87 | pl_module.global_step) 88 | 89 | def log_batch(self, pl_module, batch, batch_idx, prefix): 90 | self._combined(pl_module, batch, prefix) 91 | self._sampled(pl_module, batch, batch_idx, prefix) 92 | self._interpolated(pl_module, batch, batch_idx, prefix) 93 | self._neural_texture(pl_module, batch, batch_idx, prefix) 94 | self._posed(pl_module, batch, batch_idx, prefix) 95 | 96 | def batch2gpu(self, batch): 97 | for k, v in batch.items(): 98 | if isinstance(v, torch.Tensor): 99 | batch[k] = v.cuda() 100 | return batch 101 | 102 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 103 | if batch_idx % self.opt.display_freq == 0: 104 | batch = pl_module(batch, batch_idx) 105 | self.log_batch(pl_module, batch, batch_idx, "train") 106 | 107 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 108 | if batch_idx == 1: 109 | batch = pl_module(batch, batch_idx) 110 | self.log_batch(pl_module, batch, batch_idx, "val") 111 | -------------------------------------------------------------------------------- /varitex/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/data/__init__.py -------------------------------------------------------------------------------- /varitex/data/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Random affine transforms used during training. 3 | """ 4 | import PIL.Image 5 | import numpy 6 | import torchvision.transforms.functional as F 7 | from torchvision.transforms import RandomAffine 8 | 9 | 10 | class CustomRandomAffine(RandomAffine): 11 | def __init__(self, img_size, flip_p=0, *args, **kwargs): 12 | if not "degrees" in kwargs: 13 | kwargs["degrees"] = 0 14 | super().__init__(*args, **kwargs) 15 | self.ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) 16 | self.flip = numpy.random.rand(1) < flip_p 17 | 18 | def to_pil(self, img): 19 | return PIL.Image.fromarray(img) 20 | 21 | def __call__(self, img): 22 | """ 23 | img (PIL Image or Tensor): Image to be transformed. 24 | 25 | Returns: 26 | PIL Image or Tensor: Affine transformed image. 27 | """ 28 | is_uv = img.shape[-1] == 2 29 | if not isinstance(img, PIL.Image.Image): 30 | if is_uv: 31 | # We convert the UV to a PIL Image such that we can apply the affine transform. 32 | img = numpy.concatenate([img, numpy.zeros((*img.shape[:2], 1))], -1) 33 | img = (img + 1) * 127.5 34 | img = img.astype(numpy.uint8) 35 | img = self.to_pil(img) 36 | 37 | if self.flip: 38 | img = F.hflip(img) 39 | 40 | transformed = F.affine(img, *self.ret, 41 | fill=self.fill) 42 | result = numpy.array(transformed) 43 | if is_uv: 44 | # Convert back to UV range. 45 | result = result[:, :, :2] 46 | result = result / 127.5 - 1 47 | return result 48 | -------------------------------------------------------------------------------- /varitex/data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | from varitex.data.dataset_specifics import FFHQ 2 | 3 | 4 | class CustomDataset: 5 | data = None 6 | indices = None 7 | N = None 8 | 9 | def __init__(self, opt, split=None, augmentation=None): 10 | self.opt = opt 11 | self.split = split if split is not None else self.opt.dataset_split 12 | self.augmentation = augmentation if augmentation is not None else self.opt.augmentation 13 | 14 | if self.opt.dataset.lower() == 'ffhq': 15 | self.initial_height, self.initial_width = FFHQ.image_height, FFHQ.image_width 16 | if self.augmentation: 17 | self.transform_params = FFHQ.get_transform_params(opt.transform_mode) 18 | else: 19 | raise NotImplementedError("Not implemented dataset '{}'".format(self.opt.dataset)) 20 | -------------------------------------------------------------------------------- /varitex/data/dataset_specifics.py: -------------------------------------------------------------------------------- 1 | from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform 2 | 3 | 4 | class FFHQ: 5 | # We load the inital images with this size. It should be the same as the predicted segmentation masks. 6 | # The original FFHQ images have resolution 1024x1024. 7 | image_width = 512 8 | image_height = 512 9 | 10 | transform_params = dict( 11 | degrees=15, 12 | translate=(0.2, 0.2), 13 | scale=(1, 1.2), 14 | flip_p=0.5, 15 | ) 16 | transform_params_light = dict( 17 | degrees=5, 18 | translate=(0.1, 0.1), 19 | scale=(1, 1.2), 20 | flip_p=0.5, 21 | ) 22 | 23 | @classmethod 24 | def get_transform_params(cls, mode): 25 | keys = [] 26 | if mode == "all": 27 | keys = cls.transform_params.keys() 28 | else: 29 | if "d" in mode: 30 | keys.append("degrees") 31 | if "t" in mode: 32 | keys.append("translate") 33 | if "s" in mode: 34 | keys.append("scale") 35 | if "f" in mode: 36 | keys.append("flip_p") 37 | params = {k: v for k, v in cls.transform_params.items() if k in keys} 38 | return params 39 | 40 | 41 | class Camera: 42 | @staticmethod 43 | def get_camera(): 44 | # Camera is at the origin, looking at the negative z axis. 45 | R, T = look_at_view_transform(eye=((0, 0, 0),), at=((0, 0, -1),), up=((0, 1, 0),)) 46 | cameras = FoVPerspectiveCameras(device='cuda', R=R, T=T, fov=30, zfar=1000) 47 | return cameras 48 | -------------------------------------------------------------------------------- /varitex/data/keys_enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class DataItemKey(Enum): 5 | IMAGE_IN = 1 6 | IMAGE_ENCODED = 2 7 | STYLE_LATENT = 3 8 | TEXTURE_PRIOR = 4 9 | TEXTURE_PERSON = 5 10 | FACE_FEATUREIMAGE = 6 11 | ADDITIVE_FEATUREIMAGE = 7 12 | IMAGE_OUT = 8 13 | UV_RENDERED = 9 14 | MASK_UV = 10 15 | STYLE_LATENT_MU = 12 16 | STYLE_LATENT_STD = 13 17 | FILENAME = 14 18 | COEFF_SHAPE = 15 19 | COEFF_EXPRESSION = 16 20 | SEGMENTATION_MASK = 17 21 | SEGMENTATION_PREDICTED = 18 22 | FULL_FEATUREIMAGE = 19 23 | MASK_FULL = 21 24 | IMAGE_IN_ENCODE = 26 25 | LATENT_INTERIOR = 27 26 | LATENT_EXTERIOR = 28 27 | R = 90 28 | T = 91 29 | SCALE = 92 30 | 31 | def __str__(self): 32 | return str(self._name_).lower() 33 | -------------------------------------------------------------------------------- /varitex/data/npy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from mutil.pytorch_utils import ImageNetNormalizeTransformForward 7 | from scipy import ndimage 8 | from skimage.morphology import convex_hull_image 9 | from torchvision import transforms 10 | 11 | from varitex.data.augmentation import CustomRandomAffine 12 | from varitex.data.keys_enum import DataItemKey as DIK 13 | from varitex.data.custom_dataset import CustomDataset 14 | 15 | 16 | class NPYDataset(CustomDataset): 17 | # Rotation, translation, scale, shape, expression, segmentation, uv 18 | keys = ["R", "t", "s", "sp", "ep", "segmentation", "uv"] 19 | 20 | def __init__(self, opt, *args, **kwargs): 21 | super().__init__(opt, *args, **kwargs) 22 | # We need to call this to set self.__len()__ 23 | # h5py can only be read from multiple workers if the file is opened in sub-processes, not threads. 24 | # We need to open it later again for the __getitem()__ 25 | self.image_folder = self.opt.image_folder 26 | 27 | self.dataroot_npy = opt.dataroot_npy 28 | self.data = { 29 | key: np.load(os.path.join(self.dataroot_npy, "{}.npy".format(key)), 'r') for key in self.keys 30 | } 31 | 32 | # These are strings 33 | self.data["filename"] = np.load(os.path.join(self.dataroot_npy, "filename.npy"), allow_pickle=True) 34 | self.data["dataset_splits"] = np.load(os.path.join(self.dataroot_npy, "dataset_splits.npz"), allow_pickle=True) 35 | 36 | self.set_indices() 37 | 38 | self.N = len(self.indices) 39 | 40 | assert self.opt.image_h == self.opt.image_w, "Only square images supported!" 41 | 42 | def set_indices(self, indices=None): 43 | # We can provide indices or load them from h5 44 | if indices is None: 45 | if self.split == "all": 46 | self.indices = list(range(self.data["filename"].shape[0])) 47 | elif self.split in self.data["dataset_splits"].keys(): 48 | self.indices = self.data["dataset_splits"][self.split][:] 49 | else: 50 | raise Warning("Invalid split: {}".format(self.split)) 51 | else: 52 | self.indices = indices 53 | 54 | @classmethod 55 | def _center_crop(cls, image, square_dim): 56 | if square_dim is None: 57 | square_dim = min(image.shape[:2]) 58 | new_height, new_width = square_dim, square_dim 59 | 60 | # width, height = image.size # Get dimensions 61 | height, width = image.shape[:2] 62 | center_h = height // 2 63 | center_w = width // 2 64 | 65 | left = center_w - new_width // 2 66 | top = center_h - new_height // 2 67 | right = left + new_width 68 | bottom = top + new_height 69 | 70 | # Crop the center of the image 71 | image = image[top:bottom, left:right] 72 | return image 73 | 74 | @classmethod 75 | def _apply_transforms(cls, ndarray_in, height, width, interpolation_nearest=False, affine_transform=None, 76 | center_crop_dim=None): 77 | # Semantic masks should use 'nearest' interpolation. 78 | interpolation = cv2.INTER_NEAREST if interpolation_nearest else cv2.INTER_LINEAR 79 | 80 | if affine_transform is not None: 81 | ndarray = affine_transform(ndarray_in) 82 | else: 83 | ndarray = ndarray_in 84 | 85 | ndarray = cls._center_crop(ndarray, square_dim=center_crop_dim) 86 | 87 | if ndarray.shape[0] != height or ndarray.shape[1] != width: 88 | ndarray = cv2.resize(ndarray, (width, height), interpolation) 89 | return ndarray 90 | 91 | @classmethod 92 | def preprocess_image(cls, img, height, width, affine_transform=None, center_crop_dim=None): 93 | img = cls._apply_transforms(img, height, width, affine_transform=affine_transform, 94 | center_crop_dim=center_crop_dim) 95 | 96 | all_transforms = [ 97 | transforms.ToTensor(), # Converts image to range [0, 1] 98 | ImageNetNormalizeTransformForward() # Uses ImageNet means and std 99 | ] 100 | 101 | img_tensor = transforms.Compose(all_transforms)(img) 102 | return img_tensor # img_tensor can be < 1 and > 1, but is roughly centered at 0 103 | 104 | def preprocess_segmentation(self, segmentation, height, width, affine_transform=None, mask=None, 105 | center_crop_dim=None): 106 | segmentation_np = np.zeros(segmentation.shape) 107 | for region_idx in self.opt.semantic_regions: 108 | segmentation_np[segmentation == region_idx] = 1 109 | # We now have a binary mask with 1s for all regions that we specified 110 | 111 | segmentation = self._apply_transforms(segmentation_np, height, width, interpolation_nearest=True, 112 | affine_transform=affine_transform, center_crop_dim=center_crop_dim) 113 | 114 | segmentation_tensor = torch.from_numpy(segmentation) 115 | segmentation = segmentation_tensor.unsqueeze(0) 116 | 117 | if mask is not None: 118 | segmentation[mask] = 0 119 | 120 | return segmentation.float() 121 | 122 | @classmethod 123 | def preprocess_uv(cls, uv, height, width, affine_transform=None, center_crop_dim=None): 124 | uv = cls._apply_transforms(uv, height, width, interpolation_nearest=True, affine_transform=affine_transform, 125 | center_crop_dim=center_crop_dim) 126 | # same as ToTensor, but it's better to keep this explicit 127 | uv_tensor = torch.from_numpy(uv) 128 | 129 | if not (-1 <= uv_tensor.min() and uv_tensor.max() <= 1): 130 | raise ValueError("UV not in range [-1, 1]! min: {} max: {}".format(uv.min(), uv.max())) 131 | 132 | return uv_tensor.float() 133 | 134 | @classmethod 135 | def preprocess_mask(cls, uv, height, width, affine_transform=None): 136 | # We compute the mask from the UV (invalid values will be marked as -1) 137 | uv = cls.preprocess_uv(uv, height, width, affine_transform) 138 | mask_tensor = torch.logical_and((uv[:, :, 0] != -1), (uv[:, :, 1] != -1)) 139 | return mask_tensor.unsqueeze(0) 140 | 141 | @staticmethod 142 | def convex_hull_mask_tensor(uv_mask, segmentations): 143 | mask = uv_mask 144 | for semantic_mask in segmentations: 145 | mask = mask.logical_or(semantic_mask.bool()) 146 | 147 | mask_np = mask.squeeze().detach().cpu().numpy() 148 | mask_image = convex_hull_image(mask_np) 149 | mask_image_tensor = torch.from_numpy(mask_image).unsqueeze(0) 150 | return mask_image_tensor 151 | 152 | def full_mask(self, uv_mask, segmentation_mask): 153 | full_mask = uv_mask.logical_or(segmentation_mask) 154 | 155 | mask_image = ndimage.binary_fill_holes(full_mask.squeeze().numpy()) 156 | mask_image_tensor = torch.from_numpy(mask_image).unsqueeze(0) 157 | return mask_image_tensor 158 | 159 | def preprocess_expressions(self, frame_id): 160 | return torch.tensor(self.data['ep'][frame_id]) 161 | 162 | def _read_image(self, filename, size): 163 | path_image = os.path.join(self.image_folder, "{}.png".format(filename)) 164 | if not os.path.exists(path_image): 165 | raise FileNotFoundError("This image has not been found: '{}'".format(path_image)) 166 | image_bgr = cv2.imread(path_image) 167 | image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) 168 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 169 | return image 170 | 171 | def __getitem__(self, index): 172 | frame_id = self.indices[index] # We keep the dataset splits loaded as indices 173 | filename = self.data["filename"][frame_id] 174 | height, width = self.opt.image_h, self.opt.image_w 175 | 176 | if self.augmentation: 177 | # Create a random affine transform for each iteration, but use the same transform for all in- and outputs 178 | affine_transform = CustomRandomAffine([self.initial_width, self.initial_height], **self.transform_params) 179 | else: 180 | affine_transform = None 181 | try: 182 | image_raw = self._read_image(filename, (self.initial_width, self.initial_height)) 183 | uv = self.data["uv"][frame_id].astype(np.float32) # opencv expects float 32 184 | segmentation = self.data["segmentation"][frame_id] 185 | 186 | img_tensor_clean = self.preprocess_image(image_raw.copy(), height, width) 187 | img_tensor = self.preprocess_image(image_raw.copy(), height, width, affine_transform) 188 | uv_tensor = self.preprocess_uv(uv.copy(), height, width, affine_transform) 189 | mask_tensor = self.preprocess_mask(uv.copy(), height, width, 190 | affine_transform) # We compute the mask from the uv 191 | 192 | segmentation_tensor = self.preprocess_segmentation(segmentation.copy(), height, width, affine_transform) 193 | 194 | mask_full_tensor = self.full_mask(mask_tensor, segmentation_tensor) 195 | 196 | if not self.opt.keep_background: 197 | # We remove the background to focus the network capacity on the relevant regions 198 | if self.opt.bg_color == 'black': 199 | img_tensor[~mask_full_tensor.expand_as(img_tensor)] = img_tensor.min() 200 | else: 201 | img_tensor[~mask_full_tensor.expand_as(img_tensor)] = 0.0 202 | 203 | # We need to compute the full mask on the non augmented variant such that we can mask the encoding image 204 | mask_tensor_clean = self.preprocess_mask(uv.copy(), height, width) # We compute the mask from the uv 205 | segmentation_tensor_clean = self.preprocess_segmentation(segmentation.copy(), 206 | height, width) 207 | mask_full_tensor_clean = self.full_mask(mask_tensor_clean, segmentation_tensor_clean) 208 | if self.opt.bg_color == 'black': 209 | img_tensor_clean[~mask_full_tensor_clean.expand_as(img_tensor_clean)] = img_tensor_clean.min() 210 | else: 211 | img_tensor_clean[~mask_full_tensor_clean.expand_as(img_tensor_clean)] = 0 212 | # 213 | return { 214 | DIK.IMAGE_IN_ENCODE: img_tensor_clean, 215 | DIK.IMAGE_IN: img_tensor, 216 | DIK.SEGMENTATION_MASK: segmentation_tensor, 217 | DIK.UV_RENDERED: uv_tensor, 218 | DIK.MASK_UV: mask_tensor, 219 | DIK.MASK_FULL: mask_full_tensor, # Used to mask input before encoding 220 | DIK.FILENAME: filename, 221 | DIK.COEFF_SHAPE: self.data["sp"][frame_id].copy().astype(np.float32), 222 | DIK.COEFF_EXPRESSION: self.data["ep"][frame_id].copy().astype(np.float32), 223 | DIK.R: torch.from_numpy(self.data["R"][frame_id].copy()).float(), 224 | DIK.T: torch.from_numpy(self.data["t"][frame_id].copy()).float(), 225 | DIK.SCALE: torch.from_numpy(self.data["s"][frame_id].copy()).float() 226 | } 227 | except ValueError as e: 228 | print("Value error when processing index {}".format(index)) 229 | print(e) 230 | exit(0) 231 | 232 | def __len__(self): 233 | return self.N 234 | 235 | def get_unsqueezed(self, index, device='cuda'): 236 | batch = self[index] 237 | batch_unsqueeze = {} 238 | for k, v in batch.items(): 239 | if hasattr(v, "unsqueeze"): 240 | batch_unsqueeze[k] = v.unsqueeze(0).to(device) 241 | elif k == DIK.FILENAME: 242 | batch_unsqueeze[k] = [v] 243 | else: 244 | batch_unsqueeze[k] = torch.Tensor(v).unsqueeze(0).to(device) 245 | 246 | return batch_unsqueeze 247 | 248 | def get_raw_image(self, index): 249 | frame_id = self.indices[index] # We keep the dataset splits loaded as indices 250 | filename = self.data["filename"][frame_id] 251 | image_raw = self._read_image(filename, (self.initial_width, self.initial_height)) 252 | 253 | height, width = self.opt.image_h, self.opt.image_w 254 | return self._apply_transforms(image_raw, height, width) 255 | -------------------------------------------------------------------------------- /varitex/data/uv_factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import torch 5 | from mutil.bfm2017 import BFM2017Tensor, BFM2017 6 | from mutil.pytorch_utils import to_tensor, theta2rotation_matrix 7 | from mutil.renderer import UVRenderer 8 | from mutil.threed_utils import apply_Rts 9 | from pytorch3d.structures import Meshes 10 | 11 | from varitex.data.dataset_specifics import Camera 12 | from varitex.modules.custom_module import CustomModule 13 | 14 | 15 | class BFMUVFactory(CustomModule): 16 | """ 17 | facemodels_path = ... 18 | path_bfm = os.path.join(facemodels_path, "basel_facemodel/model2017-1_face12_nomouth.h5") 19 | path_uv = os.path.join(facemodels_path, "basel_facemodel/face12.json") 20 | factory = BFMUVFactory(path_bfm, path_uv, image_size=256) 21 | """ 22 | 23 | def __init__(self, opt, use_bfm_gpu=False): 24 | super().__init__(opt) 25 | self.to(opt.device) 26 | 27 | # This factory uses the Basel Face Model 28 | self.use_bfm_gpu = use_bfm_gpu # Loads the BFM into GPU for faster mesh generation 29 | if self.use_bfm_gpu: 30 | self.bfm = BFM2017Tensor(self.opt.path_bfm, device=opt.device) 31 | else: 32 | self.bfm = BFM2017(self.opt.path_bfm, self.opt.path_uv) 33 | self.faces_tensor = to_tensor(self.bfm.faces, self.device).unsqueeze(0) 34 | 35 | verts_uv, faces_uv = self.load_verts_faces_uv(opt.path_uv) 36 | self.uv_renderer = UVRenderer(verts_uv, faces_uv, image_size=self.opt.image_h, cameras=Camera.get_camera()) 37 | 38 | def unsqueeze_transform(self, R, t, s): 39 | if len(R.shape) == 2: 40 | R, t, s = R.unsqueeze(0), t.unsqueeze(0), s.unsqueeze(0) 41 | if len(t.shape) == 1: 42 | t = t.unsqueeze(0) 43 | if len(s.shape) == 1: 44 | s = s.unsqueeze(0) 45 | R, t, s = R.float().to(self.device), t.float().to(self.device), s.float().to(self.device) 46 | return R, t, s 47 | 48 | def load_verts_faces_uv(self, json_path): 49 | with open(json_path, 'r') as f: 50 | uv_para = json.load(f) 51 | verts_uvs = np.array(uv_para['textureMapping']['pointData']) 52 | faces_uvs = np.array(uv_para['textureMapping']['triangles']) 53 | 54 | verts_uvs = to_tensor(verts_uvs, self.device, dtype=float).unsqueeze(0).float() 55 | faces_uvs = to_tensor(faces_uvs, self.device, dtype=float).unsqueeze(0).long() 56 | 57 | return verts_uvs, faces_uvs 58 | 59 | def generate_vertices(self, sp, ep): 60 | # sp, ep need to have a batch dimension 61 | batch_size = ep.shape[0] 62 | if self.use_bfm_gpu: 63 | vertices = [self.bfm.generate_vertices(sp[i], ep[i]) for i in range(batch_size)] 64 | else: 65 | vertices = [self.bfm.generate_vertices(sp[i].detach().cpu().numpy(), ep[i].detach().cpu().numpy()) for i in 66 | range(batch_size)] 67 | vertices = [to_tensor(v, self.device) for v in vertices] 68 | vertices = torch.stack(vertices) 69 | return vertices 70 | 71 | def get_posed_meshes(self, sp, ep, R, t, s, batch_size, correct_translation): 72 | # Generate the vertices 73 | vertices = self.generate_vertices(sp, ep) 74 | # Match the batch size 75 | faces = self.faces_tensor.expand(batch_size, self.bfm.N_FACES, 3) 76 | vertices = apply_Rts(vertices, R, t, s, correct_tanslation=correct_translation) 77 | meshes = Meshes(verts=vertices, faces=faces) 78 | return meshes 79 | 80 | def getUV(self, R=torch.eye(3).unsqueeze(0), t=torch.Tensor((0, 0, -35)).unsqueeze(0), 81 | s=torch.Tensor((0.1,)).unsqueeze(0), sp=torch.zeros((199,)).unsqueeze(0), 82 | ep=torch.zeros((100,)).unsqueeze(0), correct_translation=True): 83 | R, t, s = self.unsqueeze_transform(R, t, s) 84 | assert len(sp.shape) == 2 and sp.shape[-1] == 199, "Should come in shape (batch_size, 199), but is {}".format( 85 | sp.shape) 86 | assert len(ep.shape) == 2 and ep.shape[-1] == 100, "Should come in shape (batch_size, 100), but is {}".format( 87 | ep.shape) 88 | batch_size = R.shape[0] 89 | meshes = self.get_posed_meshes(sp, ep, R, t, s, batch_size, correct_translation) 90 | uv = self.uv_renderer.render(meshes) 91 | return uv 92 | 93 | def get_sampled_uvs_shape(self, n, std_multiplier=1, ep=np.zeros(100, )): 94 | shapes = self.bfm.sample_shape(n, std_multiplier) 95 | uv_list = [self.getUV(sp, ep) for sp in shapes] 96 | return uv_list 97 | 98 | def get_sampled_uvs_expression(self, n, std_multiplier=1, sp=np.zeros(199, )): 99 | expressions = self.bfm.sample_expression(n, std_multiplier) 100 | uv_list = [self.getUV(sp, ep) for ep in expressions] 101 | return uv_list 102 | 103 | def get_posed_uvs(self, sp=torch.zeros((199,)), ep=torch.zeros((100,)), deg_range=torch.arange(-30, 31, 15), 104 | t=torch.Tensor((0, 0, -35)), s=torch.Tensor((0.1,))): 105 | uv_list = list() 106 | for theta_y in deg_range: 107 | for theta_x in deg_range: 108 | theta = theta_x, theta_y, 0 109 | R = theta2rotation_matrix(theta_all=theta) 110 | uv = self.getUV(sp, ep, R, t, s) 111 | uv_list.append(uv) 112 | return uv_list 113 | -------------------------------------------------------------------------------- /varitex/demo.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import torch 3 | 4 | # try: 5 | from mutil.object_dict import ObjectDict 6 | from varitex.data.keys_enum import DataItemKey as DIK 7 | from varitex.data.uv_factory import BFMUVFactory 8 | from varitex.modules.pipeline import PipelineModule 9 | from varitex.visualization.batch import CompleteVisualizer 10 | from varitex.options import varitex_default_options 11 | # except ModuleNotFoundError as e: 12 | # print(e) 13 | # print("Have you added VariTex to your pythonpath?") 14 | # print('To fix this error, go to the root path of the repository ".../VariTex/" \n ' 15 | # 'and run \n' 16 | # "export PYTHONPATH=$PYTHONPATH:$(pwd)") 17 | # exit() 18 | 19 | 20 | class Demo: 21 | def __init__(self, opt): 22 | default_opt = varitex_default_options() 23 | default_opt.update(opt) 24 | self.opt = ObjectDict(default_opt) 25 | 26 | uv_factory_bfm = BFMUVFactory(opt=self.opt, use_bfm_gpu=self.opt.device == 'cuda') 27 | self.visualizer_complete = CompleteVisualizer(opt=self.opt, bfm_uv_factory=uv_factory_bfm) 28 | 29 | self.pipeline = PipelineModule.load_from_checkpoint(self.opt.checkpoint, opt=self.opt, strict=False).to( 30 | self.opt.device).eval() 31 | self.device = self.pipeline.device 32 | 33 | def run(self, z, sp, ep, theta, t=torch.Tensor([0, -2, -57])): 34 | batch = { 35 | DIK.STYLE_LATENT: z, 36 | DIK.COEFF_SHAPE: sp, 37 | DIK.COEFF_EXPRESSION: ep, 38 | DIK.T: t 39 | } 40 | batch = {k: v.to(self.device) for k, v in batch.items()} 41 | 42 | batch = self.visualizer_complete.visualize_single(self.pipeline, batch, 0, theta_all=theta, 43 | forward_type='style2image') 44 | batch = {k: v.detach().cpu() for k, v in batch.items()} 45 | return batch 46 | 47 | def to_image(self, batch_or_batch_list): 48 | if isinstance(batch_or_batch_list, list): 49 | out = torch.cat([batch_out[DIK.IMAGE_OUT][0] for batch_out in batch_or_batch_list], -1) 50 | elif isinstance(batch_or_batch_list, dict): 51 | out = batch_or_batch_list[DIK.IMAGE_OUT][0] 52 | else: 53 | raise Warning("Invalid type: '{}'".format(type(batch_or_batch_list))) 54 | return self.visualizer_complete.tensor2image(out, return_format='pil') 55 | 56 | def to_video(self, batch_list, path_out, fps=15, quality=9, reverse=False): 57 | assert path_out.endswith(".mp4"), "Path should end with .mp4" 58 | frames = [self.to_image(batch) for batch in batch_list] 59 | if reverse: 60 | frames = frames + frames[::-1] 61 | imageio.mimwrite(path_out, frames, fps=fps, quality=quality) 62 | 63 | def load_shape_expressions(self): 64 | import numpy as np 65 | import os 66 | validation_indices = list(np.load(os.path.join(self.opt.dataroot_npy, "dataset_splits.npz"))["val"]) 67 | sp = np.load(os.path.join(self.opt.dataroot_npy, "sp.npy"))[validation_indices] 68 | ep = np.load(os.path.join(self.opt.dataroot_npy, "ep.npy"))[validation_indices] 69 | return torch.Tensor(sp), torch.Tensor(ep) 70 | -------------------------------------------------------------------------------- /varitex/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/evaluation/__init__.py -------------------------------------------------------------------------------- /varitex/evaluation/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from mutil.files import mkdir 6 | from mutil.pytorch_utils import theta2rotation_matrix 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from varitex.data.keys_enum import DataItemKey as DIK 11 | from varitex.data.npy_dataset import NPYDataset 12 | from varitex.data.uv_factory import BFMUVFactory 13 | from varitex.modules.pipeline import PipelineModule 14 | from varitex.visualization.batch import Visualizer 15 | 16 | 17 | def get_model(opt): 18 | model = PipelineModule.load_from_checkpoint(opt.checkpoint, opt=opt, strict=False) 19 | model = model.eval() 20 | model = model.cuda() 21 | return model 22 | 23 | 24 | def inference_ffhq(opt, results_folder, n=3000): 25 | """ 26 | Runs inference on FFHQ. Save the resulting images in the results_folder. 27 | Also saves the latent codes and distributions. 28 | """ 29 | print("Running inference on FFHQ. Using the extracted face model parameters and poses, and predicted latent codes.") 30 | mkdir(results_folder) 31 | visualizer = Visualizer(opt, return_format='pil') 32 | 33 | dataset = NPYDataset(opt, augmentation=False, split='val') 34 | dataloader = iter(DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)) 35 | model = get_model(opt) 36 | visualizer.mask_key = DIK.SEGMENTATION_PREDICTED 37 | 38 | result = list() 39 | 40 | for i, batch in tqdm(enumerate(dataloader)): 41 | batch = model.forward(batch, i, std_multiplier=0) 42 | file_id = batch[DIK.FILENAME][0] 43 | latent_code = batch[DIK.STYLE_LATENT_MU].detach().cpu().numpy()[0] 44 | latent_std = batch[DIK.STYLE_LATENT_STD].detach().cpu().numpy()[0] 45 | result.append([latent_code, latent_std]) 46 | 47 | out = visualizer.tensor2image(batch[DIK.IMAGE_OUT][0], batch=batch) 48 | out = visualizer.mask(out, batch, white_bg=False) 49 | 50 | out = visualizer.format_output(out, return_format='pil') 51 | out.save(os.path.join(results_folder, "{}.png".format(file_id))) 52 | 53 | if i >= n: 54 | break 55 | 56 | result = np.array(result) 57 | np.save(os.path.join(results_folder, "latents.npy"), result) 58 | print("Done") 59 | 60 | 61 | def inference_posed(opt, results_folder, path_latent, n=3000): 62 | """ 63 | Runs different poses for each latent code extracted from the FFHQ dataset. 64 | Uses random shape and expression parameters. 65 | """ 66 | 67 | print("Running inference on FFHQ. Using random face model parameters, various poses, and predicted latent codes. This can be slow.") 68 | visualizer = Visualizer(opt, return_format='pil') 69 | visualizer.mask_key = DIK.SEGMENTATION_PREDICTED 70 | 71 | # We use the latent codes extracted from FFHQ, but they could also be sampled from the extracted distributions. 72 | latents = np.load(path_latent)[:, 0] # We want to use the distribution means directly (index 0) 73 | 74 | mkdir(results_folder) 75 | 76 | model = get_model(opt) 77 | 78 | all_pose_axis = [[0], [1], [0, 1]] # List of all axis: 0 is pitch, 1 is yaw 79 | 80 | uv_factory_bfm = BFMUVFactory(opt, use_bfm_gpu=True) 81 | 82 | s = torch.Tensor([0.1]) 83 | t = torch.Tensor([[0, -2, -57]]) 84 | results = list() 85 | all_poses = np.arange(-45, 46, 15) 86 | 87 | with torch.no_grad(): 88 | for i in tqdm(range(n)): 89 | sp = uv_factory_bfm.bfm.sample_shape(1) 90 | ep = uv_factory_bfm.bfm.sample_expression(1) 91 | latent = torch.Tensor([latents[i]]) 92 | 93 | for pose_axis in all_pose_axis: 94 | for pose in all_poses: 95 | theta = torch.Tensor([0, 0, 0]) 96 | theta[pose_axis] = pose 97 | R = theta2rotation_matrix(theta_all=theta).unsqueeze(0) 98 | uv_tensor = uv_factory_bfm.getUV(R, t, s, sp, ep, correct_translation=True) 99 | 100 | batch = { 101 | DIK.UV_RENDERED: uv_tensor, 102 | DIK.R: R, 103 | DIK.T: t, 104 | DIK.SCALE: s, 105 | DIK.STYLE_LATENT: latent 106 | } 107 | batch = {k: v.cuda() for k, v in batch.items()} 108 | batch_out = model.forward_latent2image(batch, 0) 109 | out = visualizer.tensor2image(batch_out[DIK.IMAGE_OUT][0], batch=batch_out, white_bg=False) 110 | results.append(out) 111 | out = visualizer.format_output(out, return_format='pil') 112 | 113 | folder_out = os.path.join(results_folder, "axis{}_theta{}".format(pose_axis, pose)) 114 | mkdir(folder_out) 115 | out.save(os.path.join(results_folder, folder_out, "{:05d}.png".format(i))) 116 | 117 | 118 | def inference_posed_ffhq(opt, results_folder, n=3000): 119 | """ 120 | Runs different poses for each latent code extracted from the FFHQ dataset. 121 | Uses shape and expression parameters from validation set. 122 | """ 123 | print("Running inference on FFHQ. Using the extracted face model parameters, various poses, and predicted latent codes. This can be slow.") 124 | mkdir(results_folder) 125 | visualizer = Visualizer(opt, return_format='pil') 126 | 127 | dataset = NPYDataset(opt, augmentation=False, split='val') 128 | dataloader = iter(DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)) 129 | model = get_model(opt) 130 | visualizer.mask_key = DIK.SEGMENTATION_PREDICTED 131 | 132 | n = min(n, 3000) 133 | all_pose_axis = [[0], [1], [0, 1]] 134 | 135 | uv_factory_bfm = BFMUVFactory(opt, use_bfm_gpu=True) 136 | s = torch.Tensor([0.1]) 137 | t = torch.Tensor([[0, -2, -57]]) 138 | 139 | all_poses = np.arange(-45, 46, 15) 140 | 141 | for i, batch_orig in tqdm(enumerate(dataloader)): 142 | batch_orig = model.forward(batch_orig, i, std_multiplier=0) 143 | file_id = batch_orig[DIK.FILENAME][0] 144 | sp = batch_orig[DIK.COEFF_SHAPE].clone() 145 | ep = batch_orig[DIK.COEFF_EXPRESSION].clone() 146 | latent = batch_orig[DIK.STYLE_LATENT].clone() 147 | 148 | for pose_axis in all_pose_axis: 149 | for pose in all_poses: 150 | theta = torch.Tensor([0, 0, 0]) 151 | theta[pose_axis] = pose 152 | R = theta2rotation_matrix(theta_all=theta).unsqueeze(0) 153 | uv_tensor = uv_factory_bfm.getUV(R, t, s, sp, ep, correct_translation=True) 154 | 155 | batch = { 156 | DIK.UV_RENDERED: uv_tensor, 157 | DIK.R: R, 158 | DIK.T: t, 159 | DIK.SCALE: s, 160 | DIK.STYLE_LATENT: latent 161 | } 162 | batch = {k: v.cuda() for k, v in batch.items()} 163 | batch_out = model.forward_latent2image(batch, 0) 164 | out = visualizer.tensor2image(batch_out[DIK.IMAGE_OUT][0], batch=batch_out, white_bg=False) 165 | out = visualizer.format_output(out, return_format='pil') 166 | 167 | folder_out = os.path.join(results_folder, "axis{}_theta{}".format(pose_axis, pose)) 168 | mkdir(folder_out) 169 | out.save(os.path.join(results_folder, folder_out, "{}.png".format(file_id))) 170 | 171 | if i >= n: 172 | break 173 | 174 | print("Done") 175 | -------------------------------------------------------------------------------- /varitex/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | CP=PATH/TO/CHECKPOINT.ckpt 3 | CP=$OP/own/LDS/pretrained/checkpoints/ep44.ckpt 4 | CUDA_VISIBLE_DEVICES=0 python varitex/inference.py --checkpoint $CP --dataset_split val 5 | """ 6 | import os 7 | 8 | import pytorch_lightning as pl 9 | from mutil.object_dict import ObjectDict 10 | 11 | try: 12 | from varitex.options import varitex_default_options 13 | from varitex.evaluation import inference 14 | from varitex.options.eval_options import EvalOptions 15 | except ModuleNotFoundError: 16 | print("Have you added VariTex to your pythonpath?") 17 | print('To fix this error, go to the root path of the repository ".../VariTex/" \n ' 18 | 'and run \n' 19 | "export PYTHONPATH=$PYTHONPATH:$(pwd)") 20 | exit() 21 | 22 | if __name__ == "__main__": 23 | pl.seed_everything(1234) 24 | opt = EvalOptions().parse().__dict__ 25 | 26 | default_opt = varitex_default_options() 27 | default_opt.update(opt) 28 | opt = ObjectDict(default_opt) 29 | assert opt.checkpoint is not None, "Please specify a checkpoint file." 30 | 31 | checkpoint_folder = os.path.dirname(opt.checkpoint) 32 | 33 | # Runs inference on FFHQ 34 | inference.inference_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_ffhq')) 35 | 36 | # # Runs different poses for each sample in FFHQ (shape and expressions extracted from FFHQ) 37 | inference.inference_posed_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_ffhq')) 38 | # 39 | # # Runs different poses for each sample in FFHQ (random shape and expressions). 40 | # latents.npy should contain the latent distribuations predicted from the holdout set, see inference.inference_ffhq(...). 41 | inference.inference_posed(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_random'), 42 | path_latent=os.path.join(opt.path_out, 'inference_ffhq/latents.npy')) 43 | -------------------------------------------------------------------------------- /varitex/inference_surface.py: -------------------------------------------------------------------------------- 1 | """ 2 | CP=PATH/TO/CHECKPOINT.ckpt 3 | CP=$OP/own/LDS/pretrained/checkpoints/ep44.ckpt 4 | CUDA_VISIBLE_DEVICES=0 python varitex/inference.py --checkpoint $CP --dataset_split val 5 | """ 6 | import os 7 | 8 | import pytorch_lightning as pl 9 | from mutil.object_dict import ObjectDict 10 | 11 | try: 12 | from varitex.options import varitex_default_options 13 | from varitex.evaluation import inference 14 | from varitex.options.eval_options import EvalOptions 15 | except ModuleNotFoundError: 16 | print("Have you added VariTex to your pythonpath?") 17 | print('To fix this error, go to the root path of the repository ".../VariTex/" \n ' 18 | 'and run \n' 19 | "export PYTHONPATH=$PYTHONPATH:$(pwd)") 20 | exit() 21 | 22 | if __name__ == "__main__": 23 | pl.seed_everything(1234) 24 | opt = EvalOptions().parse().__dict__ 25 | 26 | default_opt = varitex_default_options() 27 | default_opt.update(opt) 28 | opt = ObjectDict(default_opt) 29 | assert opt.checkpoint is not None, "Please specify a checkpoint file." 30 | 31 | checkpoint_folder = os.path.dirname(opt.checkpoint) 32 | 33 | # Runs inference on FFHQ 34 | inference.inference_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_ffhq')) 35 | 36 | # # Runs different poses for each sample in FFHQ (shape and expressions extracted from FFHQ) 37 | inference.inference_posed_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_ffhq')) 38 | # 39 | # # Runs different poses for each sample in FFHQ (random shape and expressions). 40 | # latents.npy should contain the latent distribuations predicted from the holdout set, see inference.inference_ffhq(...). 41 | inference.inference_posed(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_random'), 42 | path_latent=os.path.join(opt.path_out, 'inference_ffhq/latents.npy')) 43 | -------------------------------------------------------------------------------- /varitex/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/modules/__init__.py -------------------------------------------------------------------------------- /varitex/modules/custom_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | 4 | class CustomModule(pl.LightningModule): 5 | def __init__(self, opt): 6 | super().__init__() 7 | self.opt = opt 8 | self.to(opt.device) 9 | -------------------------------------------------------------------------------- /varitex/modules/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pl_bolts.models.autoencoders.components import ( 3 | DecoderBlock, Interpolate, resize_conv1x1, resize_conv3x3 4 | ) 5 | from torch import nn 6 | 7 | from varitex.data.keys_enum import DataItemKey as DIK 8 | from varitex.modules.custom_module import CustomModule 9 | 10 | 11 | class Decoder(CustomModule): 12 | """ 13 | Decoder for the neural face texture. 14 | """ 15 | 16 | def __init__(self, opt): 17 | super().__init__(opt) 18 | 19 | latent_dim_face = opt.latent_dim // 2 20 | 21 | input_height = opt.texture_dim 22 | # Resnet-18 variant 23 | self.decoder = SimpleResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim_face, input_height, 24 | nc_texture=opt.texture_nc) 25 | 26 | def forward(self, batch, batch_idx): 27 | style = batch[DIK.LATENT_INTERIOR] 28 | texture = self.decoder(style) 29 | batch[DIK.TEXTURE_PERSON] = texture 30 | return batch 31 | 32 | 33 | class AdditiveDecoder(CustomModule): 34 | """ 35 | Additive decoder. Produces the additive feature image. 36 | """ 37 | 38 | def __init__(self, opt): 39 | super().__init__(opt) 40 | latent_dim_additive = opt.latent_dim // 2 41 | 42 | input_height = opt.texture_dim 43 | out_dim = opt.texture_nc 44 | condition_dim = opt.texture_nc # We condition on a neural texture for the face interior 45 | # Resnet-18 variant 46 | self.decoder = EarlyConditionedSimpleRestNetDecocer(ConditionedDecoderBlock, [2, 2, 2, 2], latent_dim_additive, 47 | input_height, nc_texture=out_dim, 48 | condition_dim=condition_dim) 49 | 50 | def forward(self, batch, batch_idx): 51 | latent = batch[DIK.LATENT_EXTERIOR] 52 | 53 | # The sampled texture should already have been masked to the face interior 54 | # Errors should not back-propagate through the sampled face interior texture 55 | condition = batch[DIK.FACE_FEATUREIMAGE].detach() 56 | additive_featureimage = self.decoder(latent, condition) 57 | 58 | batch[DIK.ADDITIVE_FEATUREIMAGE] = additive_featureimage # This also has features for the interior 59 | return batch 60 | 61 | 62 | class SimpleResNetDecoder(nn.Module): 63 | """ 64 | Resnet in reverse order. 65 | Most code from pl_bolts.models.autoencoders.components. 66 | """ 67 | 68 | def __init__(self, block, layers, latent_dim, input_height, nc_texture, first_conv=False, maxpool1=False): 69 | super().__init__() 70 | 71 | self.expansion = block.expansion 72 | self.inplanes = 512 * block.expansion 73 | self.first_conv = first_conv 74 | self.maxpool1 = maxpool1 75 | self.input_height = input_height 76 | 77 | self.upscale_factor = 8 78 | 79 | self.linear = nn.Linear(latent_dim, self.inplanes * 4 * 4) 80 | 81 | self.initial = self._make_layer(block, 256, layers[0], scale=2) 82 | 83 | self.layer0 = self._make_layer(block, 256, layers[0], scale=2) 84 | self.layer1 = self._make_layer(block, 256, layers[0], scale=2) 85 | self.layer2 = self._make_layer(block, 128, layers[1], scale=2) 86 | self.layer3 = self._make_layer(block, 64, layers[2], scale=2) 87 | 88 | if self.input_height == 128: 89 | self.layer4 = self._make_layer(block, 64, layers[3]) 90 | elif self.input_height == 256: 91 | self.layer4 = self._make_layer(block, 64, layers[3], scale=2) 92 | else: 93 | raise Warning("Invalid input height: '{}".format(self.input_height)) 94 | 95 | if self.first_conv: 96 | self.upscale = Interpolate(scale_factor=2) 97 | self.upscale_factor *= 2 98 | else: 99 | self.upscale = Interpolate(scale_factor=1) 100 | 101 | # interpolate after linear layer using scale factor 102 | self.upscale1 = Interpolate(size=input_height // self.upscale_factor) 103 | 104 | self.conv1 = nn.Conv2d( 105 | 64 * block.expansion, nc_texture, kernel_size=3, stride=1, padding=1, bias=False 106 | ) 107 | 108 | def _make_layer(self, block, planes, blocks, scale=1): 109 | """ 110 | 111 | Args: 112 | block: 113 | planes: int number of channels 114 | blocks: int number of blocks (e.g. 2) 115 | scale: 116 | 117 | Returns: 118 | 119 | """ 120 | upsample = None 121 | if scale != 1 or self.inplanes != planes * block.expansion: 122 | upsample = nn.Sequential( 123 | resize_conv1x1(self.inplanes, planes * block.expansion, scale), 124 | nn.BatchNorm2d(planes * block.expansion), 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(self.inplanes, planes, scale, upsample)) 129 | self.inplanes = planes * block.expansion 130 | for _ in range(1, blocks): 131 | layers.append(block(self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | x = self.linear(x) 137 | 138 | x = x.view(x.size(0), 512 * self.expansion, 4, 4) 139 | x = self.initial(x) 140 | 141 | x = self.layer0(x) 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.conv1(x) 148 | return x 149 | 150 | 151 | class EarlyConditionedSimpleRestNetDecocer(nn.Module): 152 | """ 153 | Modified from pl_bolts.models.autoencoders.components. 154 | """ 155 | 156 | def __init__(self, block, layers, latent_dim, input_height, nc_texture, first_conv=False, maxpool1=False, 157 | condition_dim=16, **kwargs): 158 | # TODO: refactor duplicated code 159 | super().__init__() 160 | 161 | self.condition_dim = condition_dim 162 | 163 | self.expansion = block.expansion 164 | self.inplanes = 512 * block.expansion 165 | self.first_conv = first_conv 166 | self.maxpool1 = maxpool1 167 | self.input_height = input_height 168 | 169 | self.linear = nn.Linear(latent_dim, self.inplanes) 170 | 171 | self.initial = self._make_layer(block, 512, layers[0], scale=2, condition_dim=self.condition_dim) 172 | 173 | self.layer0 = self._make_layer(block, 256, layers[0], scale=2, condition_dim=self.condition_dim) 174 | self.layer1 = self._make_layer(block, 256, layers[0], scale=2, condition_dim=self.condition_dim) 175 | self.layer2 = self._make_layer(block, 128, layers[1], scale=2, condition_dim=self.condition_dim) 176 | self.layer3 = self._make_layer(block, 64, layers[2], scale=2, condition_dim=self.condition_dim) 177 | 178 | if self.input_height == 128: 179 | self.layer4 = self._make_layer(block, 64, layers[3], condition_dim=self.condition_dim) 180 | elif self.input_height == 256: 181 | self.layer4 = self._make_layer(block, 64, layers[3], scale=2, condition_dim=self.condition_dim) 182 | else: 183 | raise Warning("Invalid input height: '{}".format(self.input_height)) 184 | 185 | if self.first_conv: 186 | self.upscale = Interpolate(scale_factor=2) 187 | self.upscale_factor *= 2 188 | 189 | self.conv1 = nn.Conv2d( 190 | 64 * block.expansion, nc_texture, kernel_size=3, stride=1, padding=1, bias=False 191 | ) 192 | 193 | def _make_layer(self, block, planes, blocks, scale=1, condition_dim=16): 194 | """ 195 | 196 | Args: 197 | block: 198 | planes: int number of channels 199 | blocks: int number of blocks (e.g. 2) 200 | scale: 201 | 202 | Returns: 203 | 204 | """ 205 | upsample = None 206 | if scale != 1 or self.inplanes != planes * block.expansion: 207 | upsample = nn.Sequential( 208 | resize_conv1x1(self.inplanes, planes * block.expansion, scale), 209 | nn.BatchNorm2d(planes * block.expansion), 210 | ) 211 | 212 | layers = [] 213 | layers.append(block(self.inplanes, planes, scale, upsample, condition_dim=condition_dim)) 214 | self.inplanes = planes * block.expansion 215 | for _ in range(1, blocks): 216 | layers.append(block(self.inplanes, planes, condition_dim=condition_dim)) 217 | 218 | return nn.Sequential(*layers) 219 | 220 | def forward(self, x, condition): 221 | x = self.linear(x) 222 | # We now have 512 feature maps with the same value in all spatial locations 223 | # self.inplanes changes when creating blocks 224 | x = x.view(x.shape[0], 512, 1, 1).expand(-1, -1, 4, 4) 225 | 226 | in_dict = {"features": x, "condition": condition} 227 | out_dict = self.initial(in_dict) 228 | 229 | out_dict = self.layer0(out_dict) 230 | out_dict = self.layer1(out_dict) 231 | out_dict = self.layer2(out_dict) 232 | out_dict = self.layer3(out_dict) 233 | out_dict = self.layer4(out_dict) 234 | x = out_dict['features'] 235 | x = self.conv1(x) 236 | return x 237 | 238 | 239 | class ConditionedDecoderBlock(nn.Module): 240 | """ 241 | ResNet block, but convs replaced with resize convs, and channel increase is in 242 | second conv, not first. 243 | Also heavily borrowed from pl_bolts.models.autoencoders.components. 244 | """ 245 | 246 | expansion = 1 247 | 248 | def __init__(self, inplanes, planes, scale=1, upsample=None, condition_dim=16): 249 | super().__init__() 250 | self.conv1 = resize_conv3x3(inplanes + condition_dim, 251 | inplanes) # 2 is the feature dimension for the conditioning 252 | self.bn1 = nn.BatchNorm2d(inplanes) 253 | self.relu = nn.ReLU(inplace=True) 254 | self.conv2 = resize_conv3x3(inplanes, planes, scale) 255 | self.bn2 = nn.BatchNorm2d(planes) 256 | self.upsample = upsample 257 | 258 | self.interpolation_mode = "bilinear" 259 | 260 | def forward(self, data): 261 | x = data["features"] 262 | condition = data["condition"] 263 | 264 | condition_scaled = torch.nn.functional.interpolate(condition, x.shape[-2:], mode=self.interpolation_mode) 265 | identity = x 266 | 267 | out = torch.cat([x, condition_scaled], 1) # Along the channel dimension 268 | out = self.conv1(out) 269 | out = self.bn1(out) 270 | out = self.relu(out) 271 | 272 | out = self.conv2(out) 273 | out = self.bn2(out) 274 | 275 | if self.upsample is not None: 276 | identity = self.upsample(x) 277 | 278 | out += identity 279 | out = self.relu(out) 280 | 281 | out_dict = {"features": out, "condition": condition} 282 | return out_dict 283 | -------------------------------------------------------------------------------- /varitex/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.utils.spectral_norm as spectral_norm 10 | from torch.nn import init 11 | 12 | 13 | # Returns a function that creates a normalization function 14 | # that does not condition on semantic map 15 | def get_norm_layer(norm_type='instance'): 16 | # helper function to get # output channels of the previous layer 17 | def get_out_channel(layer): 18 | if hasattr(layer, 'out_channels'): 19 | return getattr(layer, 'out_channels') 20 | return layer.weight.size(0) 21 | 22 | # this function will be returned 23 | def add_norm_layer(layer): 24 | nonlocal norm_type 25 | if norm_type.startswith('spectral'): 26 | layer = spectral_norm(layer) 27 | subnorm_type = norm_type[len('spectral'):] 28 | 29 | if subnorm_type == 'none' or len(subnorm_type) == 0: 30 | return layer 31 | 32 | # remove bias in the previous layer, which is meaningless 33 | # since it has no effect after normalization 34 | if getattr(layer, 'bias', None) is not None: 35 | delattr(layer, 'bias') 36 | layer.register_parameter('bias', None) 37 | 38 | if subnorm_type == 'batch': 39 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 40 | # elif subnorm_type == 'sync_batch': 41 | # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 42 | elif subnorm_type == 'instance': 43 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 44 | else: 45 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 46 | 47 | return nn.Sequential(layer, norm_layer) 48 | 49 | return add_norm_layer 50 | 51 | 52 | class BaseNetwork(nn.Module): 53 | def __init__(self): 54 | super(BaseNetwork, self).__init__() 55 | 56 | @staticmethod 57 | def modify_commandline_options(parser, is_train): 58 | return parser 59 | 60 | def print_network(self): 61 | if isinstance(self, list): 62 | self = self[0] 63 | num_params = 0 64 | for param in self.parameters(): 65 | num_params += param.numel() 66 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 67 | 'To see the architecture, do print(network).' 68 | % (type(self).__name__, num_params / 1000000)) 69 | 70 | def init_weights(self, init_type='normal', gain=0.02): 71 | def init_func(m): 72 | classname = m.__class__.__name__ 73 | if classname.find('BatchNorm2d') != -1: 74 | if hasattr(m, 'weight') and m.weight is not None: 75 | init.normal_(m.weight.data, 1.0, gain) 76 | if hasattr(m, 'bias') and m.bias is not None: 77 | init.constant_(m.bias.data, 0.0) 78 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 79 | if init_type == 'normal': 80 | init.normal_(m.weight.data, 0.0, gain) 81 | elif init_type == 'xavier': 82 | init.xavier_normal_(m.weight.data, gain=gain) 83 | elif init_type == 'xavier_uniform': 84 | init.xavier_uniform_(m.weight.data, gain=1.0) 85 | elif init_type == 'kaiming': 86 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 87 | elif init_type == 'orthogonal': 88 | init.orthogonal_(m.weight.data, gain=gain) 89 | elif init_type == 'none': # uses pytorch's default init method 90 | m.reset_parameters() 91 | else: 92 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 93 | if hasattr(m, 'bias') and m.bias is not None: 94 | init.constant_(m.bias.data, 0.0) 95 | 96 | self.apply(init_func) 97 | 98 | # propagate to children 99 | for m in self.children(): 100 | if hasattr(m, 'init_weights'): 101 | m.init_weights(init_type, gain) 102 | 103 | 104 | class MultiscaleDiscriminator(BaseNetwork): 105 | 106 | def __init__(self, opt): 107 | super().__init__() 108 | self.opt = opt 109 | 110 | for i in range(opt.num_discriminator): 111 | subnetD = NLayerDiscriminator(opt) 112 | self.add_module('discriminator_%d' % i, subnetD) 113 | 114 | def downsample(self, input): 115 | return F.avg_pool2d(input, kernel_size=3, 116 | stride=2, padding=[1, 1], 117 | count_include_pad=False) 118 | 119 | # Returns list of lists of discriminator outputs. 120 | # The final result is of size opt.num_D x opt.n_layers_D 121 | def forward(self, input): 122 | result = [] 123 | for name, D in self.named_children(): 124 | out = D(input) 125 | result.append(out) 126 | input = self.downsample(input) 127 | return result 128 | 129 | 130 | # Defines the PatchGAN discriminator with the specified arguments. 131 | class NLayerDiscriminator(BaseNetwork): 132 | 133 | def __init__(self, opt): 134 | super().__init__() 135 | self.opt = opt 136 | 137 | kw = 4 138 | padw = int(np.ceil((kw - 1.0) / 2)) 139 | nf = opt.nc_discriminator 140 | input_nc = 3 # RGB 141 | 142 | norm_layer = get_norm_layer(opt.norm_discriminator) 143 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 144 | nn.LeakyReLU(0.2, False)]] 145 | 146 | for n in range(1, opt.n_layers_discriminator): 147 | nf_prev = nf 148 | nf = min(nf * 2, 512) 149 | stride = 1 if n == opt.n_layers_discriminator - 1 else 2 150 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 151 | stride=stride, padding=padw)), 152 | nn.LeakyReLU(0.2, False) 153 | ]] 154 | 155 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 156 | 157 | # We divide the layers into groups to extract intermediate layer outputs 158 | for n in range(len(sequence)): 159 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 160 | 161 | def forward(self, input): 162 | results = [input] 163 | for submodel in self.children(): 164 | intermediate_output = submodel(results[-1]) 165 | results.append(intermediate_output) 166 | return results[1:] 167 | -------------------------------------------------------------------------------- /varitex/modules/encoder.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.autoencoders.components import ( 2 | resnet18_encoder 3 | ) 4 | 5 | from varitex.data.keys_enum import DataItemKey as DIK 6 | from varitex.modules.custom_module import CustomModule 7 | 8 | 9 | class Encoder(CustomModule): 10 | def __init__(self, opt): 11 | super().__init__(opt) 12 | self.encoder = resnet18_encoder(False, False) 13 | 14 | def forward(self, batch, batch_idx): 15 | image = batch[DIK.IMAGE_IN_ENCODE] 16 | encoded = self.encoder(image) 17 | batch[DIK.IMAGE_ENCODED] = encoded 18 | return batch 19 | -------------------------------------------------------------------------------- /varitex/modules/feature2image.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Sigmoid 2 | 3 | from varitex.data.keys_enum import DataItemKey as DIK 4 | from varitex.modules.custom_module import CustomModule 5 | from varitex.modules.unet import UNet 6 | 7 | 8 | class Feature2ImageRenderer(CustomModule): 9 | def __init__(self, opt): 10 | super().__init__(opt) 11 | 12 | n_input_channels = opt.texture_nc * 2 # We have two feature images: face and additive 13 | self.unet = UNet(output_nc=4, input_channels=n_input_channels, features_start=opt.nc_feature2image, 14 | num_layers=opt.feature2image_num_layers) 15 | self.probability = Sigmoid() 16 | 17 | def forward(self, batch, batch_idx): 18 | texture_enhanced = batch[DIK.FULL_FEATUREIMAGE] 19 | tensor_out = self.unet(texture_enhanced) 20 | mask_out = tensor_out[:, :1] # First channel should be the foreground mask. Note that we keep the dimension 21 | image_out = tensor_out[:, 1:] # RGB image 22 | 23 | # Should be close to 0 for the background 24 | mask_proba = self.probability(mask_out) 25 | batch[DIK.SEGMENTATION_PREDICTED] = mask_proba 26 | 27 | batch[DIK.IMAGE_OUT] = image_out 28 | return batch 29 | -------------------------------------------------------------------------------- /varitex/modules/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from varitex.data.keys_enum import DataItemKey as DIK 4 | from varitex.modules.decoder import Decoder, AdditiveDecoder 5 | from varitex.modules.encoder import Encoder 6 | from varitex.modules.feature2image import Feature2ImageRenderer 7 | from varitex.modules.custom_module import CustomModule 8 | 9 | 10 | class Generator(CustomModule): 11 | def __init__(self, opt): 12 | super().__init__(opt) 13 | 14 | self.MASK_VALUE = opt.uv_mask_value if hasattr(opt, "uv_mask_value") else 0 15 | 16 | enc_out_dim = 512 17 | latent_dim = opt.latent_dim 18 | 19 | self.fc_mu = torch.nn.Linear(enc_out_dim, latent_dim) 20 | self.fc_var = torch.nn.Linear(enc_out_dim, latent_dim) 21 | self.encoder = Encoder(opt) 22 | self.decoder = Decoder(opt) 23 | 24 | self.decoder_exterior = AdditiveDecoder(opt) 25 | self.texture2image = Feature2ImageRenderer(opt) 26 | 27 | def forward(self, batch, batch_idx, std_multiplier=1): 28 | batch = self.forward_encode(batch, batch_idx) # Only encoding, not yet a distribution 29 | batch = self.forward_encoded2latent_distribution(batch) # Compute mu and std 30 | batch = self.forward_sample_style(batch, batch_idx, std_multiplier=std_multiplier) # Sample a latent code 31 | batch = self.forward_latent2image(batch, batch_idx) # Decoders for face and exterior, followed by rendering 32 | return batch 33 | 34 | def forward_encode(self, batch, batch_idx): 35 | # Only encode the image, adds DIK.IMAGE_ENCODED. 36 | # This is not yet the latent distribution. 37 | batch = self.encoder(batch, batch_idx) 38 | return batch 39 | 40 | def forward_encoded2latent_distribution(self, batch): 41 | # Computes the latent distribution from the encoded image. 42 | mu, log_var = self.fc_mu(batch[DIK.IMAGE_ENCODED]), self.fc_var(batch[DIK.IMAGE_ENCODED]) 43 | std = torch.exp(log_var / 2) 44 | 45 | batch[DIK.STYLE_LATENT_MU] = mu 46 | batch[DIK.STYLE_LATENT_STD] = std 47 | return batch 48 | 49 | def forward_sample_style(self, batch, batch_idx, std_multiplier=1): 50 | # Sample the latent code z from the given distribution. 51 | mu, std = batch[DIK.STYLE_LATENT_MU], batch[DIK.STYLE_LATENT_STD] 52 | 53 | q = torch.distributions.Normal(mu, std * std_multiplier) 54 | z = q.rsample() 55 | batch[DIK.STYLE_LATENT] = z 56 | return batch 57 | 58 | def forward_latent2image(self, batch, batch_idx): 59 | # Given a latent code, render an image. 60 | batch = self.forward_latent2featureimage(batch, batch_idx) 61 | # Neural rendering 62 | batch = self.texture2image(batch, batch_idx) 63 | # Note that we do not mask the output. The network should learn to produce 0 values for the background. 64 | return batch 65 | 66 | def forward_latent2featureimage(self, batch, batch_idx): 67 | """ 68 | Given the full latent code DIK.STYLE_LATENT, process both interior and exterior region and generate 69 | both feature images: the face feature image and the additive feature image. 70 | """ 71 | n = self.opt.latent_dim // 2 72 | z_interior = batch[DIK.STYLE_LATENT][:, :n] 73 | z_exterior = batch[DIK.STYLE_LATENT][:, n:] 74 | 75 | batch[DIK.LATENT_EXTERIOR] = z_exterior 76 | batch[DIK.LATENT_INTERIOR] = z_interior 77 | 78 | batch = self.forward_latent2texture_interior(batch, batch_idx) 79 | batch = self.forward_latent2additive_featureimage(batch, batch_idx) 80 | batch = self.forward_merge_textures(batch, batch_idx) 81 | return batch 82 | 83 | def forward_latent2texture_interior(self, batch, batch_idx): 84 | batch = self.forward_decoder_interior(batch, batch_idx) 85 | # Sampling texture using the UV map. 86 | batch = self.sample_texture(batch) 87 | return batch 88 | 89 | def forward_latent2additive_featureimage(self, batch, batch_idx): 90 | batch = self.decoder_exterior(batch, batch_idx) 91 | return batch 92 | 93 | def forward_merge_textures(self, batch, batch_idx): 94 | batch[DIK.FULL_FEATUREIMAGE] = torch.cat( 95 | [batch[DIK.FACE_FEATUREIMAGE], batch[DIK.ADDITIVE_FEATUREIMAGE]], 1) 96 | return batch 97 | 98 | def forward_decoder_interior(self, batch, batch_idx): 99 | batch = self.decoder(batch, batch_idx) 100 | return batch 101 | 102 | def sample_texture(self, batch): 103 | uv_texture = batch[DIK.TEXTURE_PERSON] 104 | uv_map = batch[DIK.UV_RENDERED] 105 | 106 | texture_sampled = torch.nn.functional.grid_sample(uv_texture, uv_map, mode='bilinear', 107 | padding_mode='border', align_corners=False) 108 | 109 | # Grid sample yields the same value for not rendered region 110 | # We mask that region 111 | batch[DIK.MASK_UV] = torch.logical_or(batch[DIK.UV_RENDERED][:, :, :, 1] != -1, 112 | batch[DIK.UV_RENDERED][:, :, :, 0] != -1).unsqueeze(1) 113 | mask = batch[DIK.MASK_UV].expand_as(texture_sampled) 114 | texture_sampled[~mask] = self.MASK_VALUE 115 | 116 | batch[DIK.FACE_FEATUREIMAGE] = texture_sampled 117 | return batch 118 | -------------------------------------------------------------------------------- /varitex/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | from mutil.pytorch_utils import ImageNetNormalizeTransformInverse 4 | from torch.nn import functional as F 5 | 6 | 7 | def kl_divergence(mu, std): 8 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 9 | q = torch.distributions.Normal(mu, std) 10 | kl = torch.distributions.kl.kl_divergence(p, q).mean(1) 11 | return kl 12 | 13 | 14 | def reconstruction_loss(image_fake, image_real): 15 | # Simple L1 for now 16 | return torch.mean(torch.abs(image_fake - image_real)) 17 | 18 | 19 | def l2_loss(image_fake, image_real): 20 | return torch.mean((image_fake - image_real) ** 2) 21 | 22 | 23 | class VGG16(torch.nn.Module): 24 | def __init__(self, path_weights): 25 | super().__init__() 26 | vgg16 = torchvision.models.vgg16(pretrained=False) 27 | import h5py 28 | with h5py.File(path_weights, 'r') as f: 29 | state_dict = self.get_keras_mapping(f) 30 | vgg16.load_state_dict(state_dict, strict=False) # Ignore the missing keys for the classifier 31 | vgg_pretrained_features = vgg16.features 32 | 33 | # 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 34 | self.slice1 = torch.nn.Sequential() 35 | self.slice2 = torch.nn.Sequential() 36 | self.slice3 = torch.nn.Sequential() 37 | self.slice4 = torch.nn.Sequential() 38 | 39 | self.slice1.add_module(str(0), vgg_pretrained_features[0]) 40 | for i in range(1, 3): 41 | self.slice2.add_module(str(i), vgg_pretrained_features[i]) 42 | for i in range(3, 22): 43 | self.slice3.add_module(str(i), vgg_pretrained_features[i]) 44 | for i in range(22, 31): 45 | self.slice4.add_module(str(i), vgg_pretrained_features[i]) 46 | 47 | def get_keras_mapping(self, f): 48 | # VGG 16 config D 49 | state_dict = { 50 | "features.0.weight": f["conv1_1"]["conv1_1_1"]["kernel:0"], 51 | "features.0.bias": f["conv1_1"]["conv1_1_1"]["bias:0"], 52 | 53 | "features.2.weight": f["conv1_2"]["conv1_2_1"]["kernel:0"], 54 | "features.2.bias": f["conv1_2"]["conv1_2_1"]["bias:0"], 55 | 56 | "features.5.weight": f["conv2_1"]["conv2_1_1"]["kernel:0"], 57 | "features.5.bias": f["conv2_1"]["conv2_1_1"]["bias:0"], 58 | 59 | "features.7.weight": f["conv2_2"]["conv2_2_1"]["kernel:0"], 60 | "features.7.bias": f["conv2_2"]["conv2_2_1"]["bias:0"], 61 | 62 | "features.10.weight": f["conv3_1"]["conv3_1_1"]["kernel:0"], 63 | "features.10.bias": f["conv3_1"]["conv3_1_1"]["bias:0"], 64 | 65 | "features.12.weight": f["conv3_2"]["conv3_2_1"]["kernel:0"], 66 | "features.12.bias": f["conv3_2"]["conv3_2_1"]["bias:0"], 67 | 68 | "features.14.weight": f["conv3_3"]["conv3_3_1"]["kernel:0"], 69 | "features.14.bias": f["conv3_3"]["conv3_3_1"]["bias:0"], 70 | 71 | "features.17.weight": f["conv4_1"]["conv4_1_1"]["kernel:0"], 72 | "features.17.bias": f["conv4_1"]["conv4_1_1"]["bias:0"], 73 | 74 | "features.19.weight": f["conv4_2"]["conv4_2_1"]["kernel:0"], 75 | "features.19.bias": f["conv4_2"]["conv4_2_1"]["bias:0"], 76 | 77 | "features.21.weight": f["conv4_3"]["conv4_3_1"]["kernel:0"], 78 | "features.21.bias": f["conv4_3"]["conv4_3_1"]["bias:0"], 79 | 80 | "features.24.weight": f["conv5_1"]["conv5_1_1"]["kernel:0"], 81 | "features.24.bias": f["conv5_1"]["conv5_1_1"]["bias:0"], 82 | 83 | "features.26.weight": f["conv5_2"]["conv5_2_1"]["kernel:0"], 84 | "features.26.bias": f["conv5_2"]["conv5_2_1"]["bias:0"], 85 | 86 | "features.28.weight": f["conv5_3"]["conv5_3_1"]["kernel:0"], 87 | "features.28.bias": f["conv5_3"]["conv5_3_1"]["bias:0"], 88 | } 89 | # keras: [3, 3, 3, 64]) 90 | # pytorch: torch.Size([64, 3, 3, 3]). 91 | # Keras stores weights in the order (kernel_size, kernel_size, input_dim, output_dim), 92 | # but pytorch expects (output_dim, input_dim, kernel_size, kernel_size) 93 | # We need to transpose them https://discuss.pytorch.org/t/how-to-convert-keras-model-to-pytorch-and-run-inference-in-c-correctly/93451/3 94 | state_dict = {k: torch.Tensor(v[:].transpose()) for k, v in state_dict.items()} 95 | return state_dict 96 | 97 | def forward(self, x): 98 | x = F.interpolate(x, size=224, mode="bilinear") 99 | h_conv1_1 = self.slice1(x) 100 | h_conv1_2 = self.slice2(h_conv1_1) 101 | h_conv3_2 = self.slice3(h_conv1_2) 102 | h_conv4_2 = self.slice4(h_conv3_2) 103 | out = [h_conv1_1, h_conv1_2, h_conv3_2, h_conv4_2] 104 | return out 105 | 106 | 107 | class VGG19(torch.nn.Module): 108 | def __init__(self): 109 | super().__init__() 110 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 111 | self.slice1 = torch.nn.Sequential() 112 | self.slice2 = torch.nn.Sequential() 113 | self.slice3 = torch.nn.Sequential() 114 | self.slice4 = torch.nn.Sequential() 115 | self.slice5 = torch.nn.Sequential() 116 | for x in range(2): 117 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 118 | for x in range(2, 7): 119 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 120 | for x in range(7, 12): 121 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 122 | for x in range(12, 21): 123 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 124 | for x in range(21, 30): 125 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 126 | 127 | def forward(self, X): 128 | h_relu1 = self.slice1(X) 129 | h_relu2 = self.slice2(h_relu1) 130 | h_relu3 = self.slice3(h_relu2) 131 | h_relu4 = self.slice4(h_relu3) 132 | h_relu5 = self.slice5(h_relu4) 133 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 134 | return out 135 | 136 | 137 | # Perceptual loss that uses a pretrained VGG network 138 | class ImageNetVGG19Loss(torch.nn.Module): 139 | def __init__(self): 140 | super(ImageNetVGG19Loss, self).__init__() 141 | self.vgg = VGG19() 142 | # if gpu_ids: 143 | # self.vgg.cuda() 144 | self.criterion = torch.nn.L1Loss() 145 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 146 | 147 | for param in self.parameters(): 148 | param.requires_grad = False 149 | 150 | def forward(self, fake, real): 151 | x_vgg, y_vgg = self.vgg(fake), self.vgg(real) 152 | loss = 0 153 | for i in range(len(x_vgg)): 154 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 155 | # loss = self.criterion(x_vgg[i], y_vgg[i].detach()) 156 | # print(list(x_vgg[i].shape), loss.data, (self.weights[i] * loss).data, self.weights[i]) 157 | return loss 158 | 159 | 160 | # Perceptual loss that uses a pretrained VGG network 161 | class FaceRecognitionVGG16Loss(torch.nn.Module): 162 | def __init__(self, path_weights): 163 | super().__init__() 164 | self.vgg = VGG16(path_weights) 165 | self.criterion = torch.nn.MSELoss() 166 | self.weights = [0.25, 0.25, 0.25, 0.25] 167 | 168 | self.unnormalize = ImageNetNormalizeTransformInverse() 169 | # Mean values of face images from the VGGFace paper 170 | self.normalize_values = torch.Tensor((93.5940, 104.7624, 129.1863)).unsqueeze(0).unsqueeze(2).unsqueeze(2) 171 | 172 | for param in self.parameters(): 173 | param.requires_grad = False 174 | 175 | def preprocess(self, tensor): 176 | tensor = self.unnormalize(tensor) # Is now in the range [0, 1] (if not under-/ overshooting) 177 | self.normalize_values = self.normalize_values.to(tensor.device) 178 | tensor = ((tensor * 255) - self.normalize_values.expand_as(tensor)) 179 | # tensor = tensor - self.normalize_values 180 | # tensor = 181 | return tensor / 127.5 # Should we do this? 182 | 183 | def forward(self, fake, real): 184 | fake = self.preprocess(fake) 185 | real = self.preprocess(real) 186 | x_vgg, y_vgg = self.vgg(fake), self.vgg(real) 187 | 188 | loss = 0 189 | for i in range(len(x_vgg)): 190 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 191 | # loss = self.criterion(x_vgg[i], y_vgg[i].detach()) 192 | # print(list(x_vgg[i].shape), loss.data, (self.weights[i] * loss).data, self.weights[i]) 193 | return loss 194 | 195 | 196 | """ 197 | Code below: 198 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 199 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 200 | """ 201 | 202 | 203 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 204 | # When LSGAN is used, it is basically same as MSELoss, 205 | # but it abstracts away the need to create the target label tensor 206 | # that has the same size as the input 207 | class GANLoss(torch.nn.Module): 208 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 209 | tensor=torch.FloatTensor, opt=None): 210 | super(GANLoss, self).__init__() 211 | self.real_label = target_real_label 212 | self.fake_label = target_fake_label 213 | self.real_label_tensor = None 214 | self.fake_label_tensor = None 215 | self.zero_tensor = None 216 | self.Tensor = tensor 217 | self.gan_mode = gan_mode 218 | self.opt = opt 219 | if gan_mode == 'ls': 220 | pass 221 | elif gan_mode == 'original': 222 | pass 223 | elif gan_mode == 'w': 224 | pass 225 | elif gan_mode == 'hinge': 226 | pass 227 | else: 228 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 229 | 230 | def get_target_tensor(self, input, target_is_real): 231 | if target_is_real: 232 | if self.real_label_tensor is None: 233 | self.real_label_tensor = self.Tensor(1).to(input.device).fill_(self.real_label) 234 | self.real_label_tensor.requires_grad_(False) 235 | return self.real_label_tensor.expand_as(input) 236 | else: 237 | if self.fake_label_tensor is None: 238 | self.fake_label_tensor = self.Tensor(1).to(input.device).fill_(self.fake_label) 239 | self.fake_label_tensor.requires_grad_(False) 240 | return self.fake_label_tensor.expand_as(input) 241 | 242 | def get_zero_tensor(self, input): 243 | if self.zero_tensor is None: 244 | self.zero_tensor = self.Tensor(1).fill_(0) 245 | self.zero_tensor.requires_grad_(False) 246 | return self.zero_tensor.expand_as(input).to(input.device) 247 | 248 | def loss(self, input, target_is_real, for_discriminator=True): 249 | if self.gan_mode == 'original': # cross entropy loss 250 | target_tensor = self.get_target_tensor(input, target_is_real) 251 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 252 | return loss 253 | elif self.gan_mode == 'ls': 254 | target_tensor = self.get_target_tensor(input, target_is_real) 255 | return F.mse_loss(input, target_tensor) 256 | elif self.gan_mode == 'hinge': 257 | if for_discriminator: 258 | if target_is_real: 259 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 260 | loss = -torch.mean(minval) 261 | else: 262 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 263 | loss = -torch.mean(minval) 264 | else: 265 | assert target_is_real, "The generator's hinge loss must be aiming for real" 266 | loss = -torch.mean(input) 267 | return loss 268 | else: 269 | # wgan 270 | if target_is_real: 271 | return -input.mean() 272 | else: 273 | return input.mean() 274 | 275 | def __call__(self, input, target_is_real, for_discriminator=True): 276 | # computing loss is a bit complicated because |input| may not be 277 | # a tensor, but list of tensors in case of multiscale discriminator 278 | if isinstance(input, list): 279 | loss = 0 280 | for pred_i in input: 281 | if isinstance(pred_i, list): 282 | pred_i = pred_i[-1] 283 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 284 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 285 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 286 | loss += new_loss 287 | return loss / len(input) 288 | else: 289 | return self.loss(input, target_is_real, for_discriminator) 290 | -------------------------------------------------------------------------------- /varitex/modules/metrics.py: -------------------------------------------------------------------------------- 1 | import lpips 2 | import torch 3 | from mutil.pytorch_utils import ImageNetNormalizeTransformInverse 4 | from torchmetrics.functional.image import peak_signal_noise_ratio as psnr 5 | from torchmetrics.functional.image import structural_similarity_index_measure as ssim 6 | 7 | 8 | class AbstractMetric(torch.nn.Module): 9 | scale = 255 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.unnormalize = ImageNetNormalizeTransformInverse(scale=self.scale) 14 | 15 | 16 | class PSNR(AbstractMetric): 17 | scale = 255 18 | 19 | def forward(self, preds, target): 20 | preds = self.unnormalize(preds) 21 | target = self.unnormalize(target) 22 | return psnr(preds, target, data_range=self.scale) 23 | 24 | 25 | class SSIM(AbstractMetric): 26 | scale = 255 27 | 28 | def forward(self, preds, target): 29 | preds = self.unnormalize(preds) 30 | target = self.unnormalize(target) 31 | return ssim(preds, target, data_range=self.scale) 32 | 33 | 34 | class LPIPS(AbstractMetric): 35 | scale = 1 36 | 37 | def __init__(self): 38 | super().__init__() 39 | self.metric = lpips.LPIPS(net='alex', verbose=False) 40 | 41 | def forward(self, preds, target): 42 | preds = self.unnormalize(preds) 43 | target = self.unnormalize(target) 44 | # Might need a .mean() 45 | return self.metric(preds, target, normalize=True).mean() # With normalize, lpips expects the range [0, 1] 46 | -------------------------------------------------------------------------------- /varitex/modules/pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import BCELoss 3 | 4 | from varitex.data.keys_enum import DataItemKey as DIK 5 | from varitex.modules.discriminator import MultiscaleDiscriminator 6 | from varitex.modules.generator import Generator 7 | from varitex.modules.custom_module import CustomModule 8 | from varitex.modules.loss import ImageNetVGG19Loss, kl_divergence, l2_loss, GANLoss 9 | from varitex.modules.metrics import PSNR, SSIM, LPIPS 10 | 11 | 12 | class PipelineModule(CustomModule): 13 | 14 | def __init__(self, opt): 15 | super().__init__(opt) 16 | 17 | # Includes the full generation pipeline: encoder, texture decoder, additive decoder, neural renderer 18 | self.generator = Generator(opt) 19 | 20 | if getattr(self.opt, "lambda_gan", 0) > 0: 21 | self.discriminator = MultiscaleDiscriminator(opt) 22 | self.criterion_gan = GANLoss(opt.gan_mode) 23 | self.criterion_discriminator_features = torch.nn.L1Loss() 24 | if getattr(self.opt, "lambda_vgg", 0) > 0: 25 | self.loss_vgg = ImageNetVGG19Loss() 26 | if getattr(self.opt, "lambda_segmentation", 0) > 0: 27 | self.criterion_segmentation = BCELoss() 28 | 29 | self.metric_psnr = PSNR() 30 | self.metric_ssim = SSIM() 31 | self.metric_lpips = LPIPS() 32 | 33 | def to_device(self, o, device='cuda'): 34 | if isinstance(o, list): 35 | o = [self.to_device(o_i, device) for o_i in o] 36 | elif isinstance(o, dict): 37 | o = {k: self.to_device(v, device) for k, v in o.items()} 38 | elif isinstance(o, torch.Tensor): 39 | o = o.to(device) 40 | return o 41 | 42 | def forward(self, batch, batch_idx, std_multiplier=1): 43 | batch = self.to_device(batch, self.opt.device) 44 | batch = self.generator(batch, batch_idx, std_multiplier) 45 | return batch 46 | 47 | def training_step(self, batch, batch_idx, optimizer_idx=0): 48 | if optimizer_idx == 0: 49 | loss = self._generator_step(batch, batch_idx) 50 | elif optimizer_idx == 1: 51 | loss = self._discriminator_step(batch, batch_idx) 52 | else: 53 | raise Warning("Invalid optimizer index: {}".format(optimizer_idx)) 54 | return loss 55 | 56 | def validation_step(self, batch, batch_idx, std_multiplier=1): 57 | batch = self.forward(batch, batch_idx, std_multiplier=std_multiplier) 58 | fake = batch[DIK.IMAGE_OUT] 59 | real = batch[DIK.IMAGE_IN] 60 | 61 | psnr = self.metric_psnr(fake, real) 62 | ssim = self.metric_ssim(fake, real) 63 | lpips = self.metric_lpips(fake, real) 64 | 65 | self.log_dict({ 66 | "val/psnr": psnr, 67 | "val/ssim": ssim, 68 | "val/lpips": lpips 69 | }) 70 | 71 | # Below methods simply forward the calls to the generator 72 | def forward_encode(self, batch, batch_idx): 73 | return self.generator.forward_encode(batch, batch_idx) 74 | 75 | def forward_sample_style(self, *args, **kwargs): 76 | return self.generator.forward_sample_style(*args, **kwargs) 77 | 78 | def forward_latent2texture(self, batch, batch_idx): 79 | return self.generator.forward_latent2featureimage(batch, batch_idx) 80 | 81 | def forward_texture2image(self, *args, **kwargs): 82 | return self.generator.texture2image(*args, **kwargs) 83 | 84 | def forward_latent2image(self, *args, **kwargs): 85 | return self.generator.forward_latent2image(*args, **kwargs) 86 | 87 | def forward_interior2image(self, batch, batch_idx): 88 | batch = self.generator.sample_texture(batch) 89 | batch = self.generator.forward_latent2additive_featureimage(batch, batch_idx) 90 | batch = self.generator.forward_merge_textures(batch, batch_idx) 91 | batch = self.forward_texture2image(batch, batch_idx) 92 | return batch 93 | 94 | def configure_optimizers(self): 95 | # We use one optimizer for the generator and one for the discriminator. 96 | optimizers = list() 97 | # Important: Should have index 0 98 | optimizers.append(torch.optim.Adam(self.generator.parameters(), lr=self.opt.lr)) 99 | if getattr(self.opt, "lambda_gan", 0) > 0: 100 | # Needs index 1 101 | optimizers.append(torch.optim.Adam(self.discriminator.parameters(), 102 | lr=self.opt.lr_discriminator)) 103 | return optimizers, [] 104 | 105 | def _generator_step(self, batch, batch_idx): 106 | batch = self.forward(batch, batch_idx) 107 | 108 | loss_gan = 0 109 | loss_gan_features = 0 110 | loss_l2 = 0 111 | loss_vgg = 0 112 | loss_segmentation = 0 113 | loss_rgb_texture = 0 114 | 115 | image_out = batch[DIK.IMAGE_OUT] 116 | image_in = batch[DIK.IMAGE_IN] 117 | 118 | loss_kl = kl_divergence(batch[DIK.STYLE_LATENT_MU], batch[DIK.STYLE_LATENT_STD]).mean() 119 | 120 | if getattr(self.opt, "lambda_gan", 0) > 0: 121 | pred_fake, pred_real = self._forward_discriminate(image_out, image_in) 122 | 123 | loss_gan = self.criterion_gan(pred_fake, True, 124 | for_discriminator=False) 125 | 126 | # Feature loss 127 | num_D = len(pred_fake) 128 | GAN_Feat_loss = torch.FloatTensor(1).fill_(0).to(image_out.device) 129 | for i in range(num_D): # for each discriminator 130 | # last output is the final prediction, so we exclude it 131 | num_intermediate_outputs = len(pred_fake[i]) - 1 132 | for j in range(num_intermediate_outputs): # for each layer output 133 | unweighted_loss = self.criterion_discriminator_features( 134 | pred_fake[i][j], pred_real[i][j].detach()) 135 | GAN_Feat_loss += unweighted_loss 136 | loss_gan_features = GAN_Feat_loss / num_D 137 | 138 | if self.opt.lambda_l2 > 0: 139 | loss_l2 = l2_loss(image_out, image_in) 140 | 141 | if self.opt.lambda_vgg > 0: 142 | loss_vgg = self.loss_vgg(image_out.clone(), image_in.clone()) 143 | 144 | if self.opt.lambda_segmentation > 0: 145 | loss_segmentation = self.criterion_segmentation(batch[DIK.SEGMENTATION_PREDICTED], 146 | batch[DIK.SEGMENTATION_MASK]) 147 | 148 | if self.opt.lambda_rgb_texture > 0: 149 | texture = batch[DIK.FACE_FEATUREIMAGE].clone()[:, :3] # First three dimensions should be RGB only 150 | masked_image_in = image_in.clone() 151 | masked_image_in *= batch[DIK.MASK_UV].expand_as(image_in) 152 | texture *= batch[DIK.MASK_UV].expand_as(texture) 153 | loss_rgb_texture = l2_loss(texture, masked_image_in) 154 | 155 | loss_unweighted = loss_gan + loss_gan_features + loss_l2 + loss_kl + loss_vgg + loss_segmentation + loss_rgb_texture 156 | 157 | loss = self.opt.lambda_gan * loss_gan + \ 158 | self.opt.lambda_discriminator_features * loss_gan_features + \ 159 | self.opt.lambda_l2 * loss_l2 + \ 160 | self.opt.lambda_kl * loss_kl + \ 161 | self.opt.lambda_vgg * loss_vgg + \ 162 | self.opt.lambda_segmentation * loss_segmentation + \ 163 | self.opt.lambda_rgb_texture * loss_rgb_texture 164 | 165 | data_log = { 166 | "train/generator": loss_gan, 167 | "train/gan_features": loss_gan_features, 168 | "train/reconstruction_l2": loss_l2, 169 | "train/kl": loss_kl, 170 | "train/vgg_l1": loss_vgg, 171 | "train/segmentation": loss_segmentation, 172 | "train/rgb_texture": loss_rgb_texture, 173 | "train/loss": loss_unweighted 174 | } 175 | # Filter out zero losses 176 | data_log = {k: v.clone().detach() for k, v in data_log.items() if v != 0} 177 | self.log_dict(data_log) 178 | return loss 179 | 180 | def _discriminator_step(self, batch, batch_idx): 181 | image_real = batch[DIK.IMAGE_IN] 182 | with torch.no_grad(): 183 | batch = self.forward(batch, batch_idx) 184 | image_fake = batch[DIK.IMAGE_OUT].detach() 185 | image_fake.requires_grad = True 186 | 187 | fake_pred, real_pred = self._forward_discriminate(image_fake, image_real) 188 | 189 | real_loss = self.criterion_gan(real_pred, True, 190 | for_discriminator=True) 191 | fake_loss = self.criterion_gan(fake_pred, False, 192 | for_discriminator=True) 193 | loss_unweighted = (real_loss + fake_loss) / 2 194 | 195 | loss = self.opt.lambda_gan * loss_unweighted 196 | 197 | self.log_dict({ 198 | "train/discriminator": loss_unweighted 199 | }) 200 | return loss 201 | 202 | def _forward_discriminate(self, fake_image, real_image): 203 | # This method is from SPADE: https://github.com/NVlabs/SPADE 204 | # In Batch Normalization, the fake and real images are 205 | # recommended to be in the same batch to avoid disparate 206 | # statistics in fake and real images. 207 | # So both fake and real images are fed to D all at once. 208 | fake_and_real = torch.cat([fake_image, real_image], dim=0) 209 | discriminator_out = self.discriminator(fake_and_real) # len(2); one per discriminator 210 | 211 | # Take the prediction of fake and real images from the combined batch 212 | def divide_pred(pred): 213 | # the prediction contains the intermediate outputs of multiscale GAN, 214 | # so it's usually a list 215 | if type(pred) == list: 216 | fake = [] 217 | real = [] 218 | for p in pred: 219 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) 220 | real.append([tensor[tensor.size(0) // 2:] for tensor in p]) 221 | else: 222 | fake = pred[:pred.size(0) // 2] 223 | real = pred[pred.size(0) // 2:] 224 | 225 | return fake, real 226 | 227 | pred_fake, pred_real = divide_pred(discriminator_out) 228 | 229 | return pred_fake, pred_real 230 | -------------------------------------------------------------------------------- /varitex/modules/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from pl_bolts. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class UNet(nn.Module): 10 | """ 11 | Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation 12 | `_ 13 | 14 | Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox 15 | 16 | Implemented by: 17 | 18 | - `Annika Brundyn `_ 19 | - `Akshay Kulkarni `_ 20 | 21 | Args: 22 | output_nc: Number of output classes required 23 | input_channels: Number of channels in input images (default 3) 24 | num_layers: Number of layers in each side of U-net (default 5) 25 | features_start: Number of features in first layer (default 64) 26 | bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | output_nc: int, 32 | input_channels: int = 3, 33 | num_layers: int = 5, 34 | features_start: int = 64, 35 | bilinear: bool = False 36 | ): 37 | super().__init__() 38 | self.num_layers = num_layers 39 | 40 | layers = [DoubleConv(input_channels, features_start)] 41 | 42 | feats = features_start 43 | for _ in range(num_layers - 1): 44 | layers.append(Down(feats, feats * 2)) 45 | feats *= 2 46 | 47 | for _ in range(num_layers - 1): 48 | layers.append(Up(feats, feats // 2, bilinear)) 49 | feats //= 2 50 | 51 | layers.append(nn.Conv2d(feats, output_nc, kernel_size=1)) 52 | 53 | self.layers = nn.ModuleList(layers) 54 | 55 | def forward(self, x): 56 | xi = [self.layers[0](x)] 57 | # Down path 58 | for layer in self.layers[1:self.num_layers]: 59 | xi.append(layer(xi[-1])) 60 | # Up path 61 | for i, layer in enumerate(self.layers[self.num_layers:-1]): 62 | xi[-1] = layer(xi[-1], xi[-2 - i]) 63 | x_out = self.layers[-1](xi[-1]) 64 | return x_out 65 | 66 | 67 | class DoubleConv(nn.Module): 68 | """ 69 | [ Conv2d => BatchNorm (optional) => ReLU ] x 2 70 | """ 71 | 72 | def __init__(self, in_ch: int, out_ch: int): 73 | super().__init__() 74 | self.net = nn.Sequential( 75 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), 76 | nn.BatchNorm2d(out_ch), 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 79 | nn.BatchNorm2d(out_ch), 80 | nn.ReLU(inplace=True) 81 | ) 82 | 83 | def forward(self, x): 84 | return self.net(x) 85 | 86 | 87 | class Down(nn.Module): 88 | """ 89 | Downscale with MaxPool => DoubleConvolution block 90 | """ 91 | 92 | def __init__(self, in_ch: int, out_ch: int): 93 | super().__init__() 94 | self.net = nn.Sequential( 95 | nn.MaxPool2d(kernel_size=2, stride=2), 96 | DoubleConv(in_ch, out_ch) 97 | ) 98 | 99 | def forward(self, x): 100 | return self.net(x) 101 | 102 | 103 | class Up(nn.Module): 104 | """ 105 | Upsampling (by either bilinear interpolation or transpose convolutions) 106 | followed by concatenation of feature map from contracting path, followed by DoubleConv. 107 | """ 108 | 109 | def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): 110 | super().__init__() 111 | self.upsample = None 112 | if bilinear: 113 | self.upsample = nn.Sequential( 114 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), 115 | nn.Conv2d(in_ch, in_ch // 2, kernel_size=1), 116 | ) 117 | else: 118 | self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2) 119 | 120 | self.conv = DoubleConv(in_ch, out_ch) 121 | 122 | def forward(self, x1, x2): 123 | x1 = self.upsample(x1) 124 | 125 | # Pad x1 to the size of x2 126 | diff_h = x2.shape[2] - x1.shape[2] 127 | diff_w = x2.shape[3] - x1.shape[3] 128 | 129 | x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) 130 | 131 | # Concatenate along the channels axis 132 | x = torch.cat([x2, x1], dim=1) 133 | return self.conv(x) 134 | -------------------------------------------------------------------------------- /varitex/options/__init__.py: -------------------------------------------------------------------------------- 1 | def varitex_default_options(): 2 | opt = { 3 | "dataset": "FFHQ", 4 | "image_h": 256, 5 | "image_w": 256, 6 | "latent_dim": 256, 7 | "texture_dim": 256, 8 | "texture_nc": 16, 9 | "nc_feature2image": 64, 10 | "feature2image_num_layers": 5, 11 | "nc_decoder": 32, 12 | "semantic_regions": list(range(1, 16))} 13 | return opt -------------------------------------------------------------------------------- /varitex/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from types import SimpleNamespace 5 | 6 | 7 | class BaseOptions: 8 | initialized = False 9 | 10 | def initialize(self, parser): 11 | parser.add_argument('--dataroot_npy', default=os.path.join(os.getenv("DP"), 'FFHQ/preprocessed_dataset'), 12 | help='Path to the folder with the preprocessed datasets. Should contain .npy files "R", "t", "s", "sp", "ep", "segmentation", "uv", "filename", and a .npz file "dataset_splits".') 13 | parser.add_argument('--image_folder', default=os.path.join(os.getenv("DP"), 'FFHQ/images'), 14 | help='Path to the folder that contains *.png images.') 15 | parser.add_argument('--path_out', default=os.path.join(os.getenv("OP", ""), 'varitex'), 16 | help='Path to the folder where the outputs should be saved.') 17 | parser.add_argument('--path_bfm', 18 | default=os.path.join(os.getenv("FP"), "basel_facemodel/model2017-1_face12_nomouth.h5"), 19 | help='Basel face model (face only). Please use "model2017-1_face12_nomouth.h5"') 20 | parser.add_argument('--path_uv', default=os.path.join(os.getenv("FP"), "basel_facemodel/face12.json"), 21 | help='UV parameterization. Download from "https://github.com/unibas-gravis/parametric-face-image-generator/blob/master/data/regions/face12.json".') 22 | parser.add_argument('--device', default='cuda', help='') 23 | parser.add_argument('--dataset', default='FFHQ', help='') 24 | parser.add_argument('--keep_background', action='store_true', 25 | help="If True the dataloader won't remove the background and the model will also generate the background.") 26 | parser.add_argument('--bg_color', type=str, default='black', 27 | help="Defines how to fill the masked regions. Only if keep_background=False. Check dataloader for details.") 28 | parser.add_argument('--transform_mode', type=str, default='all', 29 | help='string with letters in {s, t, s, f}. d: rotate, t: translate, s: scale, f: flip') 30 | parser.add_argument('--logger', type=str, default='wandb', help='tensorboard | wandb') 31 | 32 | parser.add_argument('--checkpoint', type=str, default=None, help="Path to checkpoint file.") 33 | parser.add_argument('--dataset_split', type=str, default='train', help='all | train | val') 34 | parser.add_argument('--texture_nc', type=int, default=16, help='# features in neural texture') 35 | parser.add_argument('--texture_dim', type=int, default=256, help='Height and width of square neural texture') 36 | parser.add_argument('--image_h', type=int, default=256, help='image height') 37 | parser.add_argument('--image_w', type=int, default=256, help='image width') 38 | parser.add_argument('--batch_size', type=int, default=7) 39 | parser.add_argument('--uv_mask_value', type=float, default=0.0, 40 | help='What values to use for invalid uv regions') 41 | 42 | parser.add_argument('--latent_dim', type=int, default=256, 43 | help='Dimension of the full latent code z before splitting into z_face and z_additive.') 44 | parser.add_argument('--nc_feature2image', type=int, default=64, 45 | help="# feature channels in the Feature2Image renderer.") 46 | parser.add_argument('--feature2image_num_layers', type=int, default=5, 47 | help="Number of leves in the Feature2Image renderer.") 48 | 49 | parser.add_argument('--nc_decoder', type=int, default=32, help="# feature channels in texture decoder.") 50 | parser.add_argument('--semantic_regions', type=int, nargs='+', default=list(range(1, 16)), 51 | help="Defines region indices that should be considered as foreground. You can find a label list here: https://github.com/switchablenorms/CelebAMask-HQ/blob/master/face_parsing/Data_preprocessing/g_mask.py") 52 | 53 | parser.add_argument('--experiment_name', default='default', help='Experiment name for logger') 54 | parser.add_argument('--project', default="varitex", help="Project name for wandb logger") 55 | parser.add_argument('--debug', action="store_true", 56 | help="Enable debug mode, i.e., running a fast_dev_run in the Trainer.") 57 | self.initialized = True 58 | return parser 59 | 60 | def parse(self): 61 | # initialize parser with basic options 62 | if not self.initialized: 63 | self.parser = argparse.ArgumentParser() 64 | self.parser = self.initialize(self.parser) 65 | 66 | opt = self.parser.parse_args() 67 | self.print_options(opt) 68 | return opt 69 | 70 | def print_options(self, opt): 71 | message = '' 72 | message += '----------------- Options ---------------\n' 73 | for k, v in sorted(vars(opt).items()): 74 | comment = '' 75 | default = self.parser.get_default(k) 76 | if v != default: 77 | comment = '\t[default: %s]' % str(default) 78 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 79 | message += '----------------- End -------------------' 80 | print(message) 81 | 82 | @staticmethod 83 | def load_from_json(path_opt): 84 | if not os.path.exists(path_opt): 85 | raise Warning("opt.json not found: '{}'".format(path_opt)) 86 | 87 | with open(path_opt, 'r') as f: 88 | # Load and overwrite from a stored opt file 89 | opt = json.load(f, object_hook=lambda d: SimpleNamespace(**d)) 90 | print("Loaded options from '{}'".format(path_opt)) 91 | return opt 92 | 93 | @classmethod 94 | def update_from_json(cls, opt, path_opt, keep_keys=('checkpoint',)): 95 | # We load from a checkpoint, so let's load the opt as well 96 | to_keep = {k: getattr(opt, k) for k in keep_keys} 97 | opt_new = cls.load_from_json(path_opt) 98 | for k in opt.__dict__: 99 | if getattr(opt_new, k, None) is None: 100 | setattr(opt_new, k, getattr(opt, k)) 101 | # We need to set this again 102 | for k, v in to_keep.items(): 103 | setattr(opt_new, k, v) 104 | return opt_new 105 | -------------------------------------------------------------------------------- /varitex/options/eval_options.py: -------------------------------------------------------------------------------- 1 | from varitex.options.base_options import BaseOptions 2 | 3 | 4 | class EvalOptions(BaseOptions): 5 | 6 | def initialize(self, parser): 7 | parser = super().initialize(parser) 8 | parser.add_argument('--path_opt', default=None, help='json with options.') 9 | return parser 10 | -------------------------------------------------------------------------------- /varitex/options/train_options.py: -------------------------------------------------------------------------------- 1 | from varitex.options.base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | 6 | def initialize(self, parser): 7 | parser = super().initialize(parser) 8 | parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate for generator.") 9 | # Loss term weights: 10 | parser.add_argument('--lambda_l2', type=float, default=1.0) 11 | parser.add_argument('--lambda_kl', type=float, default=0.1) 12 | parser.add_argument('--lambda_vgg', type=float, default=2.0) 13 | parser.add_argument('--lambda_segmentation', type=float, default=1.0) 14 | parser.add_argument('--lambda_rgb_texture', type=float, default=0.0) 15 | 16 | parser.add_argument('--gan_mode', type=str, default='ls') 17 | parser.add_argument('--lambda_gan', type=float, default=1) 18 | parser.add_argument('--lambda_discriminator_features', type=float, default=1) 19 | parser.add_argument('--nc_discriminator', type=int, default=64) 20 | parser.add_argument('--num_discriminator', type=int, default=2, 21 | help='number of discriminators to be used in multiscale') 22 | parser.add_argument('--n_layers_discriminator', type=int, default=4, 23 | help='# layers in each discriminator') 24 | parser.add_argument('--norm_discriminator', type=str, default='spectralinstance') 25 | parser.add_argument('--lr_discriminator', type=float, default=1e-3) 26 | 27 | parser.add_argument('--num_workers', type=int, default=1, help='Number of workers in dataloader') 28 | parser.add_argument('--display_freq', type=int, default=1000, help='Display images every X iterations') 29 | parser.add_argument('--max_epochs', type=int, default=44, help='Should converge in 44 epochs.') 30 | return parser 31 | -------------------------------------------------------------------------------- /varitex/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | import wandb 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 9 | from torch.utils.data import DataLoader 10 | 11 | from varitex.custom_callbacks.callbacks import ImageLogCallback 12 | from varitex.data.npy_dataset import NPYDataset 13 | from varitex.modules.pipeline import PipelineModule 14 | from varitex.options.train_options import TrainOptions 15 | from mutil.files import copy_src, mkdir 16 | 17 | 18 | if __name__ == "__main__": 19 | pl.seed_everything(1234) 20 | 21 | opt = TrainOptions().parse() 22 | if opt.checkpoint is not None: 23 | # We load from a checkpoint, so let's load the opt as well 24 | path_checkpoint = opt.checkpoint 25 | opt_new = TrainOptions.load_from_json(os.path.join(os.path.dirname(opt.checkpoint), "../opt.json")) 26 | for k in opt.__dict__: 27 | # Overwrite options from the json with current options 28 | if getattr(opt_new, k, None) is None: 29 | setattr(opt_new, k, getattr(opt, k)) 30 | # We need to set this again 31 | opt_new.checkpoint = path_checkpoint 32 | opt = opt_new 33 | 34 | if opt.dataset_split == "all": 35 | # The dataset has no splits or we want to use the full dataset. 36 | dataset = NPYDataset(opt, split="all", augmentation=True) 37 | dataloader = DataLoader(dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, shuffle=True) 38 | do_validation = False 39 | else: 40 | # Separate dataloaders for train and validation. 41 | train_dataset, val_dataset = NPYDataset(opt, split="train", augmentation=True), NPYDataset(opt, split="val", 42 | augmentation=False) 43 | train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, 44 | shuffle=True) 45 | val_dataloader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, shuffle=False) 46 | do_validation = True 47 | 48 | pipeline = PipelineModule(opt) 49 | gpus = torch.cuda.device_count() 50 | print("Using {} GPU".format(gpus)) 51 | print("Writing results to {}".format(opt.path_out)) 52 | mkdir(opt.path_out) 53 | 54 | if opt.logger == "wandb": 55 | wandb.login() 56 | logger = pl.loggers.WandbLogger(save_dir=opt.path_out, name=opt.experiment_name, project=opt.project) 57 | logger.log_hyperparams(opt) 58 | logger.watch(pipeline) 59 | elif opt.logger == "tensorboard": 60 | logger = TensorBoardLogger( 61 | save_dir=opt.path_out, name=opt.experiment_name 62 | ) 63 | else: 64 | logger = None 65 | 66 | trainer = pl.Trainer(logger, gpus=gpus, max_epochs=opt.max_epochs, default_root_dir=opt.path_out, 67 | limit_val_batches=0.25, callbacks=[ImageLogCallback(opt), ModelCheckpoint()], 68 | fast_dev_run=opt.debug, 69 | resume_from_checkpoint=opt.checkpoint, 70 | ) 71 | 72 | if not opt.debug: 73 | # We keep a copy of the current source code and opt config 74 | src_path = os.path.dirname(os.path.realpath(__file__)) 75 | copy_src(path_from=src_path, 76 | path_to=opt.path_out) 77 | with open(os.path.join(opt.path_out, "opt.json"), 'w') as f: 78 | json.dump(opt.__dict__, f) 79 | 80 | if do_validation: 81 | trainer.fit(pipeline, train_dataloader, val_dataloader) 82 | else: 83 | trainer.fit(pipeline, dataloader) 84 | -------------------------------------------------------------------------------- /varitex/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/visualization/__init__.py -------------------------------------------------------------------------------- /varitex/visualization/batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms 4 | from torch.nn.functional import grid_sample 5 | from torchvision.utils import make_grid 6 | 7 | from mutil.data_types import to_np 8 | from mutil.np_util import interpolation 9 | from mutil.pytorch_utils import ImageNetNormalizeTransformInverse, to_tensor, theta2rotation_matrix 10 | from varitex.data.keys_enum import DataItemKey as DIK 11 | 12 | 13 | class Visualizer: 14 | def __init__(self, opt, n_samples=1, mask_value=0, return_format='torch', mask_key=None, device='cuda'): 15 | self.opt = opt 16 | self.dataset = self.opt.dataset 17 | self.n_samples = n_samples 18 | self.mask_value = mask_value 19 | self.unnormalize = ImageNetNormalizeTransformInverse() 20 | self.return_format = return_format 21 | self.mask_key = mask_key 22 | self.device = device 23 | 24 | def uv2rgb(self, uv_image): 25 | assert len(uv_image.shape) == 3, "Invalid shape, should have 3 channels." 26 | uv_image = uv_image.clone().permute(2, 0, 1) # HWC to CHW 27 | uv_image_rgb = [uv_image, -1 * torch.ones((1, uv_image.shape[1], uv_image.shape[2])).to(uv_image.device)] 28 | uv_image_rgb = torch.cat(uv_image_rgb, 0) 29 | uv_image_rgb = (uv_image_rgb + 1) / 2 30 | return uv_image_rgb 31 | 32 | def tensor2image(self, image_tensor, mask=None, clamp=True, batch=None, white_bg=True, return_format=None): 33 | image = self.unnormalize(image_tensor) 34 | if batch is not None: 35 | image = self.mask(image, batch, white_bg=white_bg) 36 | if mask is not None: 37 | image[~mask] = self.mask_value 38 | if clamp: 39 | image = image.clamp(0, 1) 40 | out = image.detach().cpu() 41 | if return_format is not None: 42 | out = self.format_output(out, return_format) 43 | return out 44 | 45 | def format_output(self, vis, return_format=None): 46 | if return_format is None: 47 | return_format = self.return_format 48 | 49 | if return_format == "torch": 50 | return vis 51 | elif return_format == "pil": 52 | to_pil = torchvision.transforms.ToPILImage() 53 | vis = vis.detach().cpu() 54 | return to_pil(vis) 55 | elif return_format == "np": 56 | return np.array(self.format_output(vis, return_format='pil')) 57 | raise Warning("Invalid return format: {}".format(return_format)) 58 | 59 | def _sample(self, batch, std_multiplier): 60 | q = torch.distributions.Normal(batch[DIK.STYLE_LATENT_MU], batch[DIK.STYLE_LATENT_STD] * std_multiplier) 61 | z = q.rsample() 62 | batch[DIK.STYLE_LATENT] = z 63 | return batch 64 | 65 | def detach(self, o, to_cpu=False): 66 | if isinstance(o, list): 67 | o = [self.detach(o_i, to_cpu) for o_i in o] 68 | elif isinstance(o, dict): 69 | o = {k: self.detach(v, to_cpu) for k, v in o.items()} 70 | elif isinstance(o, torch.Tensor): 71 | o = o.clone().detach() 72 | if to_cpu: 73 | o = o.cpu() 74 | return o 75 | 76 | def _debug_view(self, vis): 77 | import matplotlib.pyplot as plt 78 | vis = self.format_output(vis, 'np') 79 | plt.imshow(np.array(vis)) 80 | plt.show() 81 | 82 | def _placeholder(self, like): 83 | return torch.zeros_like(like, device=like.device) 84 | 85 | def mask(self, img_out, batch, t=0.7, white_bg=False): 86 | if self.mask_key is not None: 87 | mask = batch[self.mask_key][0].expand_as(img_out).to(img_out.device) 88 | mask[mask < t] = 0 89 | img_out = img_out * mask 90 | if white_bg: 91 | img_out[~mask.bool()] = 1 92 | return img_out 93 | 94 | 95 | class CombinedVisualizer(Visualizer): 96 | def zoom(self, image, factor=3): 97 | assert len(image.shape) == 3, "No batch dim plz" 98 | h, w = self.opt.image_h, self.opt.image_w 99 | center_y, center_x = h // factor, w // factor 100 | image_cropped = image[:, center_y:center_y + h // factor, center_x:center_x + w // factor] 101 | zoomed = torch.nn.functional.interpolate(image_cropped.unsqueeze(0), image.shape[-2:], mode='bilinear').clamp(0, 102 | 1) 103 | return zoomed.squeeze() 104 | 105 | def visualize(self, batch): 106 | batch = self.detach(batch) 107 | image_in = self.tensor2image(batch[DIK.IMAGE_IN][0]) 108 | image_in_encode = self.tensor2image(batch[DIK.IMAGE_IN_ENCODE][0]) 109 | uv_image = self.uv2rgb(batch[DIK.UV_RENDERED][0]) 110 | segmentation_pred = batch[DIK.SEGMENTATION_PREDICTED][0][0] # zero is background 111 | segmentation_pred = segmentation_pred.expand_as(image_in) 112 | segmentation_gt = batch[DIK.SEGMENTATION_MASK][0].expand_as(image_in) 113 | image_out = self.tensor2image(batch[DIK.IMAGE_OUT][0]) 114 | zoomed_out = self.zoom(image_out) 115 | zoomed_in = self.zoom(image_in) 116 | 117 | vis = [uv_image, segmentation_gt, image_in, zoomed_in, 118 | image_in_encode, segmentation_pred, image_out, zoomed_out 119 | ] 120 | vis = self.detach(vis, to_cpu=True) 121 | vis = make_grid(vis, nrow=4) 122 | vis = self.format_output(vis) 123 | return vis 124 | 125 | 126 | class SampledVisualizer(Visualizer): 127 | def visualize(self, batch, batch_idx, pipeline, std_multiplier=1): 128 | batch = self.detach(batch) 129 | images_out = list() 130 | for s_i in range(self.n_samples): 131 | batch2 = batch.copy() 132 | batch2 = self._sample(batch2, std_multiplier) 133 | batch2 = pipeline.forward_latent2image(batch2, batch_idx) 134 | img_out = batch2[DIK.IMAGE_OUT] 135 | img_out = img_out[0] 136 | img_out = self.tensor2image(img_out) 137 | 138 | if self.mask_key is not None: 139 | img_out = img_out * batch[self.mask_key][0].expand_as(img_out).cpu() 140 | images_out.append(img_out) 141 | images_out = self.detach(images_out, to_cpu=True) 142 | vis = make_grid(images_out, nrow=len(images_out)) 143 | return self.format_output(vis) 144 | 145 | def visualize_unseen(self, batch, batch_idx, pipeline, std_multiplier=1): 146 | batch = self.detach(batch) 147 | images_out = list() 148 | for s_i in range(self.n_samples): 149 | batch2 = batch.copy() 150 | batch2[DIK.STYLE_LATENT] = torch.randn_like(batch2[DIK.STYLE_LATENT]).to( 151 | batch2[DIK.STYLE_LATENT].device) * std_multiplier 152 | batch2 = pipeline.forward_latent2image(batch2, batch_idx) 153 | img_out = batch2[DIK.IMAGE_OUT] 154 | img_out = img_out[0] 155 | img_out = self.tensor2image(img_out) 156 | 157 | if self.mask_key is not None: 158 | img_out = img_out * batch[self.mask_key][0].expand_as(img_out).cpu() 159 | images_out.append(img_out) 160 | images_out = self.detach(images_out, to_cpu=True) 161 | vis = make_grid(images_out, nrow=len(images_out)) 162 | return self.format_output(vis) 163 | 164 | def visualize_grid(self, batch, batch_idx, pipeline, std_multipliers=(1, 2, 3, 4)): 165 | sampled_vis = list() 166 | for multiplier in std_multipliers: 167 | vis = self.visualize(batch, 0, pipeline, std_multiplier=multiplier) 168 | sampled_vis.append(vis) 169 | vis = torch.cat(sampled_vis, -2) 170 | vis = torch.clamp(vis, 0, 1) 171 | return self.format_output(vis) 172 | 173 | def visualize_new_sample(self, batch, batch_idx, pipeline, std_multiplier): 174 | batch = pipeline.forward_sample_style(batch.copy(), batch_idx, std_multiplier) 175 | batch = pipeline.forward_latent2image(batch, batch_idx) 176 | vis = self.tensor2image(batch[DIK.IMAGE_OUT][0], batch=batch) 177 | return self.format_output(vis) 178 | 179 | 180 | class UVVisualizer(Visualizer): 181 | 182 | def __init__(self, *args, bfm_uv_factory, **kwargs): 183 | super().__init__(*args, **kwargs) 184 | self.bfm_uv_factory = bfm_uv_factory 185 | 186 | def run(self, pipeline, batch, batch_idx, return_format=None, forward_type="style2image"): 187 | """ 188 | 189 | Args: 190 | pipeline: 191 | batch: 192 | batch_idx: 193 | return_format: 194 | forward_type: include "fullbatch" if the function should return a tuple img, batch 195 | 196 | Returns: 197 | 198 | """ 199 | if "interior2image" in forward_type: 200 | batch = pipeline.forward_interior2image(batch, batch_idx) 201 | elif "style2image" in forward_type: 202 | batch = pipeline.forward_latent2image(batch, batch_idx) 203 | elif "texture2image" in forward_type: 204 | batch = pipeline.forward_texture2image(batch, batch_idx) 205 | else: 206 | raise Warning("invalid forward type") 207 | img_out = batch[DIK.IMAGE_OUT][0] 208 | img_out = self.tensor2image(img_out) 209 | img_out = self.mask(img_out, batch, white_bg=False) 210 | # img_out = self.uv2rgb(batch[DIK.UV_RENDERED][0]) 211 | if return_format is not None: 212 | img_out = self.format_output(img_out, return_format=return_format) 213 | if "fullbatch" in forward_type: 214 | return img_out, batch 215 | else: 216 | return img_out 217 | 218 | def get_neutral_t(self, batch): 219 | t = torch.Tensor((0, 0, batch[DIK.T][0, 2])).expand_as(batch[DIK.T]) 220 | return t 221 | 222 | def visualize_grid(self, pipeline, batch, batch_idx, deg_range, return_format=None): 223 | result = list() 224 | idx = 0 225 | 226 | for theta_y in deg_range: 227 | for theta_x in deg_range: 228 | batch2 = batch.copy() 229 | theta = [theta_x, theta_y, 0] 230 | R = theta2rotation_matrix(theta_all=theta).to(batch2[DIK.R].device).unsqueeze(0) 231 | t = self.get_neutral_t(batch) 232 | uv = self.bfm_uv_factory.getUV(R=R, t=t, s=batch2[DIK.SCALE], sp=batch2[DIK.COEFF_SHAPE], 233 | ep=batch2[DIK.COEFF_EXPRESSION]) 234 | 235 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch[DIK.UV_RENDERED]).to(batch[DIK.UV_RENDERED].device) 236 | img_out = self.run(pipeline, batch2, idx) 237 | result.append(img_out) 238 | 239 | vis = make_grid(result, nrow=int(np.sqrt(len(result)))) 240 | return self.format_output(vis, return_format=return_format) 241 | 242 | def visualize_row(self, pipeline, batch, batch_idx, deg_range, axis=1, return_format=None): 243 | if not isinstance(axis, list): 244 | axis = [axis] 245 | result = list() 246 | idx = 0 247 | if 1 in axis: 248 | deg_range = -1 * deg_range 249 | for theta_d in deg_range: 250 | batch2 = batch.copy() 251 | batch_size = batch2[DIK.R].shape[0] 252 | theta = [0, 0, 0] 253 | for i in range(len(axis)): 254 | theta[axis[i]] = theta_d 255 | R = theta2rotation_matrix(theta_all=theta).to(batch2[DIK.R].device).unsqueeze(0) 256 | t = self.get_neutral_t(batch) 257 | uv = self.bfm_uv_factory.getUV(R=R, t=t, s=batch2[DIK.SCALE], sp=batch2[DIK.COEFF_SHAPE], 258 | ep=batch2[DIK.COEFF_EXPRESSION]) 259 | 260 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch[DIK.UV_RENDERED]).to(batch[DIK.UV_RENDERED].device) 261 | img_out = self.run(pipeline, batch2, idx) 262 | result.append(img_out) 263 | 264 | nrow = len(result) if 1 in axis else 1 265 | vis = make_grid(result, nrow=nrow) 266 | return self.format_output(vis, return_format=return_format) 267 | 268 | def visualize_show(self, pipeline, batch, batch_idx, n_y, n_x, pose_range_x=30, pose_range_y=30): 269 | theta_x = torch.linspace(-pose_range_x, pose_range_x, n_x) if pose_range_x > 0 else [0] 270 | theta_y = torch.linspace(-pose_range_y, pose_range_y, n_y) if pose_range_y > 0 else [0] 271 | top = [[theta_x[0], theta_y[i], 0] for i in range(n_y - 1)] if n_y > 0 else [] 272 | down = [[theta_x[i], theta_y[-1], 0] for i in range(n_x - 1)] if n_x > 0 else [] 273 | bottom = [[theta_x[-1], theta_y[-i - 1], 0] for i in range(n_y - 1)] if n_y > 0 else [] 274 | up = [[theta_x[-i - 1], theta_y[0], 0] for i in range(n_x - 1)] if n_x > 0 else [] 275 | all = top + down + bottom + up 276 | 277 | all = [theta2rotation_matrix(theta_all=theta) for theta in all] 278 | 279 | result = list() 280 | for i, R in enumerate(all): 281 | batch2 = batch.copy() 282 | t = batch2[DIK.T] 283 | uv = self.bfm_uv_factory.getUV(sp=batch[DIK.COEFF_SHAPE], ep=batch[DIK.COEFF_EXPRESSION], R=R, t=t) 284 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch2[DIK.UV_RENDERED]).to(batch2[DIK.UV_RENDERED].device) 285 | img_out = self.run(pipeline, batch2, batch_idx, return_format='np') 286 | result.append(img_out) 287 | return result 288 | 289 | def visualize_show_expression(self, pipeline, batch, batch_idx, n, expression_list): 290 | result = list() 291 | dev = batch[DIK.STYLE_LATENT].device 292 | R = torch.eye(3).unsqueeze(0) 293 | ep_prev = expression_list[0] 294 | expression_list = list(expression_list) + [expression_list[0]] 295 | for i in range(1, len(expression_list)): 296 | for ep in interpolation(n, ep_prev, expression_list[i]): 297 | ep = torch.Tensor(ep).unsqueeze(0).to(dev) 298 | R = R.to(dev) 299 | batch2 = batch.copy() 300 | uv = self.bfm_uv_factory.getUV(sp=batch[DIK.COEFF_SHAPE], ep=ep, R=R, t=batch[DIK.T], 301 | s=batch[DIK.SCALE]) 302 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch2[DIK.UV_RENDERED]).to(batch2[DIK.UV_RENDERED].device) 303 | img_out = self.run(pipeline, batch2, batch_idx, return_format='np') 304 | result.append(img_out) 305 | ep_prev = expression_list[i] 306 | return result 307 | 308 | def visualize_single(self, pipeline, batch, batch_idx, theta_all=None, R=None, forward_type='style2image'): 309 | batch2 = self.detach(batch) 310 | if R is None: 311 | R = theta2rotation_matrix(theta_all=theta_all) 312 | uv = self.bfm_uv_factory.getUV(sp=batch[DIK.COEFF_SHAPE], ep=batch[DIK.COEFF_EXPRESSION], R=R, t=batch[DIK.T]) 313 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch2[DIK.UV_RENDERED]).to(batch2[DIK.UV_RENDERED].device) 314 | out = self.run(pipeline, batch2, batch_idx, return_format='pil', forward_type=forward_type) 315 | return out 316 | 317 | def visualize_row_expressions(self, pipeline, batch, batch_idx, 318 | list_ep, return_format=None): 319 | result = list() 320 | for i in range(len(list_ep)): 321 | batch2 = batch.copy() 322 | t = torch.Tensor((0, 0, batch[DIK.T][0, 2])).expand_as(batch[DIK.T]) 323 | R = torch.eye(3).unsqueeze(0) 324 | ep = list_ep[i] 325 | uv = self.bfm_uv_factory.getUV(R=R, t=t, s=batch2[DIK.SCALE], sp=batch2[DIK.COEFF_SHAPE], ep=ep) 326 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch[DIK.UV_RENDERED]).to(batch[DIK.UV_RENDERED].device) 327 | img_out = self.run(pipeline, batch2, batch_idx) 328 | result.append(img_out) 329 | 330 | nrow = len(result) 331 | vis = make_grid(result, nrow=nrow) 332 | return self.format_output(vis, return_format=return_format) 333 | 334 | def visualize_row_shapes(self, pipeline, batch, batch_idx, 335 | list_sp, return_format=None): 336 | result = list() 337 | for i in range(len(list_sp)): 338 | batch2 = batch.copy() 339 | t = torch.Tensor((0, 0, batch[DIK.T][0, 2])).expand_as(batch[DIK.T]) 340 | R = torch.eye(3).unsqueeze(0) 341 | sp = list_sp[i] 342 | uv = self.bfm_uv_factory.getUV(R=R, t=t, s=batch2[DIK.SCALE], ep=batch2[DIK.COEFF_EXPRESSION], sp=sp) 343 | batch2[DIK.UV_RENDERED] = uv.expand_as(batch[DIK.UV_RENDERED]).to(batch[DIK.UV_RENDERED].device) 344 | img_out = self.run(pipeline, batch2, batch_idx) 345 | result.append(img_out) 346 | 347 | nrow = len(result) 348 | vis = make_grid(result, nrow=nrow) 349 | return self.format_output(vis, return_format=return_format) 350 | 351 | 352 | class InterpolationVisualizer(Visualizer): 353 | 354 | def run(self, pipeline, batch, latent_from, latent_to, n): 355 | result = list() 356 | # linear interpolation 357 | all_latents = interpolation(n, latent_from=latent_from, latent_to=latent_to) 358 | all_latents = to_tensor(all_latents, batch[DIK.STYLE_LATENT].device) 359 | for latent in all_latents: 360 | batch2 = batch.copy() 361 | batch2[DIK.STYLE_LATENT] = latent.reshape(batch[DIK.STYLE_LATENT].shape) 362 | batch2 = pipeline.forward_latent2image(batch2, 0) 363 | img_out = batch2[DIK.IMAGE_OUT][0] 364 | img_out = self.tensor2image(img_out) 365 | 366 | if self.mask_key is not None: 367 | img_out = img_out * batch[self.mask_key][0].expand_as(img_out).cpu() 368 | result.append(img_out) 369 | return result 370 | 371 | def visualize(self, pipeline, batch_from, batch_to, n, bidirectional=True, include_gt=True): 372 | latent_from = to_np(batch_from[DIK.STYLE_LATENT]) 373 | latent_to = to_np(batch_to[DIK.STYLE_LATENT]) 374 | result = self.run(pipeline, batch_from, latent_from, latent_to, n) 375 | 376 | if include_gt: 377 | n = n + 2 378 | result = [self.tensor2image(batch_from[DIK.IMAGE_IN][0])] + result + [ 379 | self.tensor2image(batch_to[DIK.IMAGE_IN][0])] 380 | 381 | if bidirectional: 382 | result_backward = self.run(pipeline, batch_to, latent_from, latent_to, n) 383 | 384 | if include_gt: 385 | result_backward = [self.tensor2image(batch_from[DIK.IMAGE_IN][0])] + result_backward + [ 386 | self.tensor2image(batch_to[DIK.IMAGE_IN][0])] 387 | 388 | result = result + result_backward 389 | 390 | vis = make_grid(result, nrow=n) 391 | return self.format_output(vis) 392 | 393 | 394 | class NeuralTextureVisualizer(Visualizer): 395 | def __init__(self, *args, **kwargs): 396 | super().__init__(*args, **kwargs) 397 | 398 | def run(self, texture, return_format=None): 399 | n_channels = texture.shape[0] 400 | result = list() 401 | for c in range(n_channels): 402 | texture_c = texture[c].unsqueeze(0) 403 | result.append(texture_c) 404 | 405 | vis = make_grid(result, nrow=int(np.sqrt(n_channels)), scale_each=True, normalize=True) 406 | vis = self.format_output(vis, return_format=return_format) 407 | return vis 408 | 409 | def visualize_interior(self, batch, batch_idx, return_format=None): 410 | batch = self.detach(batch) 411 | texture = batch[DIK.TEXTURE_PERSON][0].clone() 412 | return self.run(texture, return_format) 413 | 414 | def visualize_interior_sampled(self, batch, batch_idx, return_format=None): 415 | batch = self.detach(batch) 416 | texture = batch[DIK.FACE_FEATUREIMAGE][0].clone() 417 | return self.run(texture, return_format) 418 | 419 | def visualize_exterior_sampled(self, batch, batch_id, return_format=None): 420 | batch = self.detach(batch) 421 | texture = batch[DIK.ADDITIVE_FEATUREIMAGE][0].clone() 422 | return self.run(texture, return_format) 423 | 424 | def visualize_enhanced(self, batch, batch_id, return_format=None): 425 | batch = self.detach(batch) 426 | texture = batch[DIK.ADDITIVE_FEATUREIMAGE][0].clone() 427 | return self.run(texture, return_format) 428 | 429 | 430 | class CompleteVisualizer(Visualizer): 431 | 432 | def __init__(self, *args, bfm_uv_factory, **kwargs): 433 | super().__init__(*args, **kwargs) 434 | self.bfm_uv_factory = bfm_uv_factory 435 | 436 | def run(self, pipeline, batch, batch_idx, return_format=None, forward_type="style2image"): 437 | if forward_type == "interior2image": 438 | batch = pipeline.forward_interior2image(batch, batch_idx) 439 | elif forward_type == "style2image": 440 | batch = pipeline.forward_latent2image(batch, batch_idx) 441 | elif forward_type == "texture2image": 442 | batch = pipeline.forward_texture2image(batch, batch_idx) 443 | else: 444 | raise Warning("invalid forward type") 445 | return batch 446 | 447 | def visualize_single(self, pipeline, batch, batch_idx, theta_all=None, R=None, forward_type='style2image', 448 | correct_translation=True): 449 | batch2 = batch 450 | if R is None: 451 | R = theta2rotation_matrix(theta_all=theta_all).to(self.device) 452 | batch2[DIK.R] = R.unsqueeze(0) 453 | uv = self.bfm_uv_factory.getUV(sp=batch[DIK.COEFF_SHAPE], ep=batch[DIK.COEFF_EXPRESSION], R=R, t=batch[DIK.T], 454 | correct_translation=correct_translation) 455 | batch2[DIK.UV_RENDERED] = uv.expand(batch2[DIK.COEFF_SHAPE].shape[0], -1, -1, -1).to(self.device) 456 | batch2 = self.run(pipeline, batch2, batch_idx, return_format='pil', forward_type=forward_type) 457 | return batch2 458 | --------------------------------------------------------------------------------