37 |
38 |
68 |
69 |
70 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
Abstract
89 |
90 |
91 |
92 | Deep generative models can synthesize photorealistic images of human faces with novel identities.
93 | However, a key challenge to the wide applicability of such techniques is to provide independent control over semantically meaningful parameters: appearance, head pose, face shape, and facial expressions.
94 | In this paper, we propose VariTex - to the best of our knowledge the first method that learns a variational latent feature space of neural face textures, which allows sampling of novel identities.
95 | We combine this generative model with a parametric face model and gain explicit control over head pose and facial expressions.
96 | To generate complete images of human heads, we propose an additive decoder that adds plausible details such as hair.
97 | A novel training scheme enforces a pose-independent latent space and in consequence, allows learning a one-to-many mapping between latent codes and pose-conditioned exterior regions.
98 | The resulting method can generate geometrically consistent images of novel identities under fine-grained control over head pose, face shape, and facial expressions. This facilitates a broad range of downstream tasks, like sampling novel identities, changing the head pose, expression transfer, and more.
99 |
100 |
101 |
102 |
103 |
104 |
105 |
Video
106 |
107 | VIDEO
108 |
109 |
110 |
111 |
112 |
113 |
VariTex Controls
114 |
115 |
116 |
117 |
118 |
119 |
Expressions
120 |
121 |
122 |
123 |
124 |
127 |
128 |
Pose
129 |
130 |
131 |
132 |
133 |
135 |
136 |
137 |
Identity
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
Method
146 |
147 |
148 | The objective of our pipeline is to learn a Generator that can synthesize face images with arbitrary novel
149 | identities whose pose and expressions can be controlled using face model parameters.
150 | During training, we use unlabeled monocular RGB images to learn a smooth latent space
151 | of natural face appearance using a variational encoder.
152 | A latent code sampled from this space is then decoded to a novel face image.
153 | At test time, we use samples drawn from a normal distribution to generate novel face images.
154 | Our variationally generated neural textures can also be stylistically interpolated to generate
155 | intermediate identities.
156 |
157 |
158 |
159 |
160 |
161 |
162 |
Code and Models
163 | Available on
GitHub . Make sure to check out our
demo notebook .
164 |
165 |
166 |
167 |
168 |
Citation
169 |
170 | @inproceedings{buehler2021varitex,
171 | title={VariTex: Variational Neural Face Textures},
172 | author={Marcel C. Buehler and Abhimitra Meka and Gengyan Li and Thabo Beeler and Otmar Hilliges},
173 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
174 | year={2021}
175 | }
176 |
BibTeX
177 |
178 |
179 |
180 |
181 |
199 |
200 |
201 |
202 |
204 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: varitex
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=2_kmp_llvm
8 | - _pytorch_select=0.2=gpu_0
9 | - absl-py=0.11.0=py38h578d9bd_0
10 | - aiohttp=3.7.3=py38h25fe258_0
11 | - appdirs=1.4.4=pyh9f0ad1d_0
12 | - async-timeout=3.0.1=py_1000
13 | - attrs=20.3.0=pyhd3deb0d_0
14 | - bcrypt=3.2.0=py38h1e0a361_1
15 | - blas=1.0=mkl
16 | - blinker=1.4=py_1
17 | - brotlipy=0.7.0=py38h8df0ef7_1001
18 | - c-ares=1.17.1=h36c2ea0_0
19 | - ca-certificates=2023.12.12=h06a4308_0
20 | - cachetools=4.1.1=py_0
21 | - certifi=2024.2.2=py38h06a4308_0
22 | - cffi=1.14.4=py38h261ae71_0
23 | - chardet=3.0.4=py38h924ce5b_1008
24 | - click=7.1.2=pyh9f0ad1d_0
25 | - cryptography=3.3.1=py38h3c74f83_0
26 | - cuda-cudart=12.3.101=hd3aeb46_0
27 | - cuda-cudart_linux-64=12.3.101=h59595ed_0
28 | - cuda-nvrtc=12.3.107=hd3aeb46_0
29 | - cuda-nvtx=12.3.101=h59595ed_0
30 | - cuda-version=12.3=h55a0123_2
31 | - cudnn=8.8.0.121=h264754d_4
32 | - cycler=0.10.0=py38_0
33 | - dbus=1.13.18=hb2f20db_0
34 | - expat=2.2.10=he6710b0_2
35 | - filelock=3.13.1=py38h06a4308_0
36 | - fontconfig=2.13.0=h9420a91_0
37 | - freetype=2.10.4=h5ab3b9f_0
38 | - fs=2.4.11=py38h32f6830_2
39 | - fs.sshfs=1.0.0=pyhd8ed1ab_1
40 | - future=0.18.2=py38h578d9bd_2
41 | - fvcore=0.1.2.post20201218=pyhd8ed1ab_0
42 | - glib=2.66.1=h92f7085_0
43 | - gmp=6.2.1=h295c915_3
44 | - gmpy2=2.1.2=py38heeb90bb_0
45 | - gst-plugins-base=1.14.0=hbbd80ab_1
46 | - gstreamer=1.14.0=hb31296c_0
47 | - hdf5=1.10.6=hb1b8bf9_0
48 | - icu=58.2=he6710b0_3
49 | - idna=2.10=pyh9f0ad1d_0
50 | - imageio=2.9.0=py_0
51 | - importlib-metadata=3.3.0=py38h578d9bd_3
52 | - intel-openmp=2020.2=254
53 | - joblib=1.0.0=pyhd3eb1b0_0
54 | - jpeg=9e=h5eee18b_1
55 | - kiwisolver=1.3.0=py38h2531618_0
56 | - lcms2=2.11=h396b838_0
57 | - ld_impl_linux-64=2.33.1=h53a641e_7
58 | - lerc=3.0=h295c915_0
59 | - libabseil=20230802.1=cxx17_h59595ed_0
60 | - libblas=3.9.0=20_linux64_mkl
61 | - libcblas=3.9.0=20_linux64_mkl
62 | - libcublas=12.3.4.1=hd3aeb46_0
63 | - libcufft=11.0.12.1=hd3aeb46_0
64 | - libcurand=10.3.4.107=hd3aeb46_0
65 | - libcusolver=11.5.4.101=hd3aeb46_0
66 | - libcusparse=12.2.0.103=hd3aeb46_0
67 | - libdeflate=1.17=h5eee18b_1
68 | - libedit=3.1.20191231=h14c3975_1
69 | - libffi=3.3=he6710b0_2
70 | - libgcc-ng=13.2.0=h807b86a_5
71 | - libgfortran-ng=7.3.0=hdf63c60_0
72 | - liblapack=3.9.0=20_linux64_mkl
73 | - libmagma=2.7.2=h173bb3b_2
74 | - libmagma_sparse=2.7.2=h173bb3b_2
75 | - libnvjitlink=12.3.101=hd3aeb46_0
76 | - libpng=1.6.39=h5eee18b_0
77 | - libprotobuf=4.25.1=hf27288f_2
78 | - libsodium=1.0.18=h36c2ea0_1
79 | - libstdcxx-ng=13.2.0=h7e041cc_5
80 | - libtiff=4.5.1=h6a678d5_0
81 | - libtorch=2.1.2=cuda120_h2aa5df7_301
82 | - libuuid=1.0.3=h1bed415_2
83 | - libuv=1.47.0=hd590300_0
84 | - libwebp-base=1.3.2=h5eee18b_0
85 | - libxcb=1.14=h7b6447c_0
86 | - libxml2=2.9.10=hb55368b_3
87 | - libzlib=1.2.13=hd590300_5
88 | - llvm-openmp=17.0.6=h4dfa4b3_0
89 | - lz4-c=1.9.2=heb0550a_3
90 | - magma=2.7.2=h51420fd_2
91 | - markdown=3.3.3=pyh9f0ad1d_0
92 | - matplotlib=3.3.2=h06a4308_0
93 | - matplotlib-base=3.3.2=py38h817c723_0
94 | - mkl=2023.2.0=h84fe81f_50496
95 | - mkl-service=2.4.0=py38h5eee18b_1
96 | - mkl_fft=1.3.8=py38h5eee18b_0
97 | - mkl_random=1.2.4=py38hdb19cb5_0
98 | - mpc=1.1.0=h10f8cd9_1
99 | - mpfr=4.0.2=hb69a4c5_1
100 | - mpmath=1.3.0=py38h06a4308_0
101 | - multidict=5.1.0=py38h27cfd23_2
102 | - nccl=2.20.3.1=h3a97aeb_0
103 | - ncurses=6.2=he6710b0_1
104 | - ninja=1.10.2=py38hff7bd54_0
105 | - oauthlib=3.0.1=py_0
106 | - olefile=0.46=py_0
107 | - openssl=1.1.1w=h7f8727e_0
108 | - packaging=20.8=pyhd3deb0d_0
109 | - paramiko=2.7.2=pyh9f0ad1d_0
110 | - pcre=8.44=he6710b0_0
111 | - pillow=8.0.1=py38he98fc37_0
112 | - pip=20.3.3=py38h06a4308_0
113 | - portalocker=1.7.0=py38h32f6830_1
114 | - property-cached=1.6.4=py_0
115 | - protobuf=4.25.1=py38hf14ab21_0
116 | - pyasn1=0.4.8=py_0
117 | - pyasn1-modules=0.2.7=py_0
118 | - pycparser=2.20=py_2
119 | - pyjwt=2.0.0=pyhd8ed1ab_0
120 | - pynacl=1.4.0=py38h1e0a361_2
121 | - pyopenssl=20.0.1=pyhd8ed1ab_0
122 | - pyparsing=2.4.7=py_0
123 | - pyqt=5.9.2=py38h05f1152_4
124 | - pysocks=1.7.1=py38h924ce5b_2
125 | - python=3.8.5=h7579374_1
126 | - python-dateutil=2.8.1=py_0
127 | - python_abi=3.8=1_cp38
128 | - pytorch=2.1.2=cuda120_py38ha6e2b7b_301
129 | - pytorch3d=0.7.5=cuda120py38h6e72adc_3
130 | - pytz=2021.1=pyhd8ed1ab_0
131 | - qt=5.9.7=h5867ecd_1
132 | - readline=8.0=h7b6447c_0
133 | - requests=2.25.1=pyhd3deb0d_0
134 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0
135 | - rsa=4.6=pyh9f0ad1d_0
136 | - setuptools=51.0.0=py38h06a4308_2
137 | - sip=4.19.13=py38he6710b0_0
138 | - six=1.15.0=py38h06a4308_0
139 | - sleef=3.5.1=h9b69904_2
140 | - sqlite=3.33.0=h62c20be_0
141 | - sympy=1.12=py38h06a4308_0
142 | - tabulate=0.8.7=pyh9f0ad1d_0
143 | - tbb=2021.8.0=hdb19cb5_0
144 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0
145 | - tk=8.6.10=hbc83047_0
146 | - torchvision=0.15.2=cpu_py38h83e0c9b_0
147 | - tornado=6.1=py38h27cfd23_0
148 | - typing-extensions=3.7.4.3=0
149 | - typing_extensions=4.9.0=py38h06a4308_1
150 | - werkzeug=1.0.1=pyh9f0ad1d_0
151 | - wheel=0.36.2=pyhd3eb1b0_0
152 | - xz=5.4.5=h5eee18b_0
153 | - yacs=0.1.6=py_0
154 | - yaml=0.2.5=h516909a_0
155 | - yarl=1.6.3=py38h25fe258_0
156 | - zipp=3.4.0=py_0
157 | - zlib=1.2.13=hd590300_5
158 | - zstd=1.5.5=hfc55251_0
159 | - pip:
160 | - annotated-types==0.6.0
161 | - anyio==4.3.0
162 | - argon2-cffi==20.1.0
163 | - async-generator==1.10
164 | - backcall==0.2.0
165 | - bleach==3.3.0
166 | - configparser==6.0.0
167 | - dataclasses==0.6
168 | - decorator==4.4.2
169 | - defusedxml==0.6.0
170 | - deprecated==1.2.12
171 | - dill==0.3.3
172 | - docker-pycreds==0.4.0
173 | - easydict==1.9
174 | - entrypoints==0.3
175 | - exceptiongroup==1.2.0
176 | - facenet-pytorch==2.5.1
177 | - fastapi==0.109.2
178 | - fsspec==2024.2.0
179 | - gitdb==4.0.11
180 | - gitpython==3.1.42
181 | - google-auth==2.28.1
182 | - google-auth-oauthlib==1.0.0
183 | - grpcio==1.62.0
184 | - h3ds==0.1.1
185 | - h5py==3.10.0
186 | - imageio-ffmpeg==0.4.3
187 | - insightface==0.1.5
188 | - iopath==0.1.2
189 | - ipykernel==5.5.0
190 | - ipython==7.21.0
191 | - ipython-genutils==0.2.0
192 | - ipywidgets==7.6.3
193 | - jedi==0.18.0
194 | - jinja2==2.11.3
195 | - jsonschema==3.2.0
196 | - jupyter==1.0.0
197 | - jupyter-client==6.1.11
198 | - jupyter-console==6.2.0
199 | - jupyter-core==4.7.1
200 | - jupyterlab-pygments==0.1.2
201 | - jupyterlab-widgets==1.0.0
202 | - lightning-bolts==0.7.0
203 | - lightning-utilities==0.10.1
204 | - lpips==0.1.3
205 | - markupsafe==1.1.1
206 | - mistune==0.8.4
207 | - mxnet==1.8.0
208 | - mxnet-cu110==1.8.0.post0
209 | - nbclient==0.5.3
210 | - nbconvert==6.0.7
211 | - nbformat==5.1.2
212 | - nest-asyncio==1.5.1
213 | - networkx==2.5
214 | - notebook==6.2.0
215 | - numpy==1.24.4
216 | - nvidia-htop==1.0.2
217 | - onnx==1.15.0
218 | - opencv-python==4.5.1.48
219 | - pandas==1.2.4
220 | - pandocfilters==1.4.3
221 | - parso==0.8.1
222 | - pathtools==0.1.2
223 | - pexpect==4.8.0
224 | - pickleshare==0.7.5
225 | - plotly==5.0.0
226 | - prometheus-client==0.9.0
227 | - promise==2.3
228 | - prompt-toolkit==3.0.16
229 | - psutil==5.9.8
230 | - ptyprocess==0.7.0
231 | - pydantic==2.6.2
232 | - pydantic-core==2.16.3
233 | - pygments==2.8.0
234 | - pyqt5==5.15.4
235 | - pyqt5-qt5==5.15.2
236 | - pyqt5-sip==12.8.1
237 | - pyrsistent==0.17.3
238 | - python-graphviz==0.8.4
239 | - pytorch-fid==0.2.0
240 | - pytorch-lightning==1.9.5
241 | - pywavelets==1.1.1
242 | - pyyaml==6.0.1
243 | - pyzmq==22.0.3
244 | - qdarkgraystyle==1.0.2
245 | - qdarkstyle==3.0.1
246 | - qtconsole==5.0.2
247 | - qtpy==1.9.0
248 | - scikit-image==0.18.1
249 | - scikit-learn==0.24.1
250 | - scipy==1.6.1
251 | - send2trash==1.5.0
252 | - sentry-sdk==1.40.5
253 | - shortuuid==1.0.11
254 | - smmap==5.0.1
255 | - sniffio==1.3.0
256 | - starlette==0.36.3
257 | - subprocess32==3.5.4
258 | - tenacity==7.0.0
259 | - tensorboard==2.14.0
260 | - tensorboard-data-server==0.7.2
261 | - tensorboardx==2.1
262 | - termcolor==2.4.0
263 | - terminado==0.9.2
264 | - testpath==0.4.4
265 | - threadpoolctl==2.1.0
266 | - tifffile==2021.2.1
267 | - toml==0.10.2
268 | - torchmetrics==1.3.1
269 | - tqdm==4.66.2
270 | - traitlets==5.0.5
271 | - trimesh==3.9.0
272 | - urllib3==1.26.18
273 | - wandb==0.12.2
274 | - wcwidth==0.2.5
275 | - webencodings==0.5.1
276 | - widgetsnbextension==3.5.1
277 | - wrapt==1.12.1
278 | - yaspin==2.5.0
279 |
--------------------------------------------------------------------------------
/mutil/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/mutil/__init__.py
--------------------------------------------------------------------------------
/mutil/bfm2017.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 | import json
4 |
5 | import numpy as np
6 |
7 |
8 | class BFM:
9 | def __init__(self, path):
10 | self._shape_mu = None
11 | self._shape_pcab = None
12 | self._shape_pca_var = None
13 | self._faces = None
14 | self._expression_mu = None
15 | self._expression_pcab = None
16 | self._expression_pca_var = None
17 |
18 | @property
19 | def shape_mu(self):
20 | return np.array(self._shape_mu)
21 |
22 | @property
23 | def shape_pcab(self):
24 | return np.array(self._shape_pcab)
25 |
26 | @property
27 | def shape_pca_var(self):
28 | return np.array(self._shape_pca_var)
29 |
30 | @property
31 | def expression_mu(self):
32 | return np.array(self._expression_mu)
33 |
34 | @property
35 | def expression_pcab(self):
36 | return np.array(self._expression_pcab)
37 |
38 | @property
39 | def expression_pca_var(self):
40 | return np.array(self._expression_pca_var)
41 |
42 | @property
43 | def faces(self):
44 | return np.array(self._faces)
45 |
46 | def generate_vertices(self, shape_para, exp_para):
47 | '''
48 | Args:
49 | shape_para: (n_shape_para, 1)
50 | exp_para: (n_exp_para, 1)
51 | Returns:
52 | vertices: (nver, 3)
53 | '''
54 | raise NotImplementedError("Implement in subclass")
55 |
56 | def get_mean(self):
57 | return self.shape_mu.reshape(-1, 3)
58 |
59 | def apply_Rts(self, verts, R, t, s):
60 | transformed_verts = s * np.matmul(verts, R) + t
61 | return transformed_verts
62 |
63 |
64 | class BFM2017(BFM):
65 | N_COEFF_SHAPE = 199
66 | N_COEFF_EXPRESSION = 100
67 |
68 | def __init__(self, path_model, path_uv=None):
69 | super().__init__(path_model)
70 | import h5py
71 | self.file = h5py.File(path_model, 'r')
72 | self._shape_mu = self.file["shape"]["model"]["mean"]
73 | self._shape_pcab = self.file["shape"]["model"]["pcaBasis"]
74 | self._shape_pca_var = self.file["shape"]["model"]["pcaVariance"]
75 | self._faces = np.transpose(self.file["shape"]["representer"]["cells"])
76 | self._expression_mu = self.file["expression"]["model"]["mean"]
77 | self._expression_pcab = self.file["expression"]["model"]["pcaBasis"]
78 | self._expression_pca_var = self.file["expression"]["model"]["pcaVariance"]
79 |
80 | if "model2017-1_face12_nomouth.h5" in path_model:
81 | # Model without ears and throat
82 | self.N_VERTICES = 28588
83 | self.N_FACES = 56572
84 | else:
85 | # Model with ears and throat
86 | self.N_VERTICES = 53149
87 | self.N_FACES = 105694
88 |
89 | assert 0 == self.faces.min()
90 | assert self.faces.max() == self.N_VERTICES - 1
91 |
92 | if path_uv is not None:
93 | self.vertices_uv, self.faces_uv = self.load_uv(path_uv)
94 |
95 | def load_uv(self, path_uv):
96 | with open(path_uv, 'r') as f:
97 | uv_para = json.load(f)
98 | verts_uvs = np.array(uv_para['textureMapping']['pointData'])
99 | faces_uvs = np.array(uv_para['textureMapping']['triangles'])
100 | return verts_uvs, faces_uvs
101 |
102 | def __delete__(self, instance):
103 | self.file.close()
104 |
105 | def generate_vertices(self, shape_para, exp_para):
106 | if shape_para.shape != (self.N_COEFF_SHAPE, 1):
107 | shape_para = shape_para.reshape(self.N_COEFF_SHAPE, 1)
108 | if exp_para.shape != (self.N_COEFF_EXPRESSION, 1):
109 | exp_para = exp_para.reshape(self.N_COEFF_EXPRESSION, 1)
110 | vertices = self.shape_mu.reshape(-1, 1) + self.shape_pcab @ shape_para + self.expression_pcab @ exp_para
111 | vertices = np.reshape(vertices, [int(3), int(self.N_VERTICES)], 'F').T
112 | return vertices.astype(np.float32)
113 |
114 | def sample(self, n, std_multiplier, variance):
115 | std = np.sqrt(variance) * std_multiplier
116 | std = np.broadcast_to(std, (n, std.shape[0]))
117 | mu = np.zeros_like(std)
118 | samples = np.random.normal(mu, std)
119 | return samples
120 |
121 | def sample_shape(self, n, std_multiplier=1):
122 | return self.sample(n, std_multiplier, self.shape_pca_var)
123 |
124 | def sample_expression(self, n, std_multiplier=1):
125 | return self.sample(n, std_multiplier, self.expression_pca_var)
126 |
127 |
128 | class BFM2017Tensor:
129 | N_VERTICES = 28588
130 | N_FACES = 56572
131 | N_COEFF_SHAPE = 199
132 | N_COEFF_EXPRESSION = 100
133 |
134 | def __init__(self, path_model, path_uv=None, device='cuda', verbose=False):
135 | print("Loading BFM 2017 into GPU... (this can take a while)")
136 | self.device = device
137 | self.file = h5py.File(path_model, 'r')
138 | # This can take a few seconds
139 | self.shape_mu = torch.Tensor(self.file["shape"]["model"]["mean"]).reshape(self.N_VERTICES*3, 1).to(device).float()
140 | self.shape_pcab = torch.Tensor(self.file["shape"]["model"]["pcaBasis"]).reshape(self.N_VERTICES*3, self.N_COEFF_SHAPE).to(device).float()
141 | self.shape_pca_var = torch.Tensor(self.file["shape"]["model"]["pcaVariance"]).reshape(self.N_COEFF_SHAPE).to(device).float()
142 | self.expression_pcab = torch.Tensor(self.file["expression"]["model"]["pcaBasis"]).reshape(self.N_VERTICES*3, self.N_COEFF_EXPRESSION).to(device).float()
143 | self.expression_pca_var = torch.Tensor(self.file["expression"]["model"]["pcaVariance"]).reshape(self.N_COEFF_EXPRESSION).to(device).float()
144 | self.faces = torch.Tensor(np.transpose(self.file["shape"]["representer"]["cells"])).reshape(self.N_FACES, 3).to(device)
145 |
146 | if path_uv is not None:
147 | self.vertices_uv, self.faces_uv = self.load_uv(path_uv)
148 | print("Done")
149 |
150 | def load_uv(self, path_uv):
151 | with open(path_uv, 'r') as f:
152 | uv_para = json.load(f)
153 | verts_uvs = torch.Tensor(uv_para['textureMapping']['pointData']).reshape(-1, 3).to(self.device)
154 | faces_uvs = torch.Tensor(np.array(uv_para['textureMapping']['triangles'])).reshape(-1, 3).to(self.device)
155 | return verts_uvs, faces_uvs
156 |
157 | def __delete__(self, instance):
158 | self.file.close()
159 |
160 | def generate_vertices(self, shape_para, exp_para):
161 | if shape_para.shape != (self.N_COEFF_SHAPE, 1):
162 | shape_para = shape_para.reshape(self.N_COEFF_SHAPE, 1)
163 | if exp_para.shape != (self.N_COEFF_EXPRESSION, 1):
164 | exp_para = exp_para.reshape(self.N_COEFF_EXPRESSION, 1)
165 | vertices = self.shape_mu +\
166 | self.shape_pcab @ shape_para + \
167 | self.expression_pcab @ exp_para
168 | return vertices.reshape(-1, 3).float()
169 |
170 | def sample(self, n, std_multiplier, variance):
171 | std = torch.sqrt(variance) * std_multiplier
172 | std = std.expand((n, std.shape[0]))
173 | q = torch.distributions.Normal(torch.zeros_like(std).to(std.device), std * std_multiplier)
174 | samples = q.rsample()
175 | return samples
176 |
177 | def sample_shape(self, n, std_multiplier=1):
178 | return self.sample(n, std_multiplier, self.shape_pca_var)
179 |
180 | def sample_expression(self, n, std_multiplier=1):
181 | return self.sample(n, std_multiplier, self.expression_pca_var)
182 |
--------------------------------------------------------------------------------
/mutil/data_types.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import PIL.Image
4 |
5 |
6 | def to_np(t):
7 | if isinstance(t, torch.Tensor):
8 | t = t.detach().cpu().numpy()
9 | if isinstance(t, PIL.Image.Image):
10 | t = np.array(t)
11 | return t
--------------------------------------------------------------------------------
/mutil/files.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import os
3 | import shutil
4 | from mutil.str_format import get_time_string
5 |
6 |
7 | def listdir(path, prefix='', postfix='', return_prefix=True,
8 | return_postfix=True, return_abs=False):
9 | """
10 | Lists all files in path that start with prefix and end with postfix.
11 | By default, this function returns all filenames. If you do not want to
12 | return the pre- or postfix, set the corresponding parameters to False.
13 | :param path:
14 | :param prefix:
15 | :param postfix:
16 | :param return_prefix:
17 | :param return_postfix:
18 | :return: list(str)
19 | """
20 | files = os.listdir(path)
21 | filtered_files = filter(
22 | lambda f: f.startswith(prefix) and f.endswith(postfix), files)
23 | return_files = filtered_files
24 | if not return_prefix:
25 | idx_start = len(prefix) - 1
26 | return_files = [f[idx_start:] for f in filtered_files]
27 | if not return_postfix:
28 | idx_end = len(postfix) - 1
29 | return_files = [f[:-idx_end-1] for f in filtered_files]
30 | return_files = set(return_files)
31 | result = list(return_files)
32 | if return_abs:
33 | result = [os.path.join(path, r) for r in result]
34 | return result
35 |
36 |
37 | def mkdir(path):
38 | if not os.path.exists(path):
39 | os.makedirs(path, exist_ok=True)
40 |
41 |
42 | def copy(src, dest, overwrite=False):
43 | if os.path.exists(dest) and overwrite:
44 | if os.path.isdir(dest):
45 | shutil.rmtree(dest)
46 | else:
47 | os.remove(dest)
48 | try:
49 | shutil.copytree(src, dest)
50 | except OSError as e:
51 | # If the error was caused because the source wasn't a directory
52 | if e.errno == errno.ENOTDIR:
53 | shutil.copy(src, dest)
54 | else:
55 | print('Directory not copied. Error: %s' % e)
56 |
57 |
58 | def copy_src(path_from, path_to):
59 | """
60 | Make sure to have everything in path_from folder
61 | There should not be large files (e.g. checkpoints or images)
62 | Args:
63 | path_from:
64 | path_to:
65 |
66 | Returns:
67 |
68 | """
69 | assert os.path.isdir(path_from)
70 | # Collect all files and folders that contain python files
71 | tmp_folder = os.path.join(path_to, 'src/')
72 | mkdir(tmp_folder)
73 |
74 | from_folder = os.path.basename(path_from)
75 | copy(path_from, os.path.join(tmp_folder, from_folder), overwrite=True)
76 | time_str = get_time_string()
77 |
78 | path_archive = os.path.join(path_to, "src_{}".format(time_str))
79 | shutil.make_archive(path_archive, 'zip', tmp_folder)
80 | try:
81 | shutil.rmtree(tmp_folder)
82 | except FileNotFoundError:
83 | # We got a FileNotfound error on the cluster. Maybe some race conditions?
84 | pass
85 | print("Copied folder {} to {}".format(path_from, path_archive))
--------------------------------------------------------------------------------
/mutil/np_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mutil.data_types import to_np
3 |
4 |
5 | def adjust_width(img, target_width, is_uv=False):
6 | assert len(img.shape) == 3, "Are you using a batch dimension?"
7 | if is_uv:
8 | assert img.shape[-1] == 2
9 | n_channels = 2
10 | fill_value = -1.0
11 | else:
12 | assert img.shape[-1] == 3
13 | n_channels = 3
14 | fill_value = 255
15 | diff = target_width - img.shape[0]
16 | pad = np.ones((img.shape[0], diff // 2, n_channels)) * fill_value
17 | img = np.concatenate([pad, img, pad],
18 | 1)
19 | if is_uv:
20 | img = img.astype(np.float)
21 | else:
22 | img = img.astype(np.uint8)
23 | return img
24 |
25 |
26 | def uv_to_color(rendered_uv, output_format="np"):
27 | uv_color = np.concatenate([rendered_uv, np.zeros((*rendered_uv.shape[:2], 1))], -1)
28 | if output_format == "pil":
29 | from PIL import Image
30 | uv_color = (uv_color + 1 ) * 127.5
31 | uv_color = Image.fromarray(uv_color.astype(np.uint8))
32 | return uv_color
33 |
34 |
35 | def center_crop(image, new_height, new_width):
36 | # width, height = image.size # Get dimensions
37 | height, width = image.shape[:2]
38 | center_h = height // 2
39 | center_w = width // 2
40 |
41 | left = center_w - new_width // 2
42 | top = center_h - new_height // 2
43 | right = left + new_width
44 | bottom = top + new_height
45 |
46 | # Crop the center of the image
47 | image = image[top:bottom, left:right]
48 | return image
49 |
50 |
51 | def interpolation(n, latent_from, latent_to, gaussian_correction=True):
52 | """
53 | Linear interpolate for a Gaussian RV
54 | :param n:
55 | :param latent_from: tensor or numpy array with batch dimension
56 | :param latent_to:
57 | :param type:
58 | :return:
59 | """
60 | latent_from, latent_to = to_np(latent_from), to_np(latent_to)
61 | steps = np.linspace(0, 1, n).astype(np.float32)
62 |
63 | for _ in range(len(latent_to.shape)):
64 | steps = np.expand_dims(steps, -1)
65 | if gaussian_correction:
66 | steps = steps / np.sqrt(steps ** 2 + (1 - steps) ** 2) # Variance correction
67 | all_latents = (1 - steps) * latent_from + steps * latent_to
68 | return all_latents
69 |
--------------------------------------------------------------------------------
/mutil/object_dict.py:
--------------------------------------------------------------------------------
1 | class ObjectDict(dict):
2 | def __init__(self, d):
3 | super().__init__()
4 | for k, v in d.items():
5 | setattr(self, k, v)
--------------------------------------------------------------------------------
/mutil/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from mutil.threed_utils import eulerAnglesToRotationMatrix
5 | from torchvision.transforms import transforms
6 |
7 |
8 | def to_tensor(a, device='cpu', dtype=None):
9 | if isinstance(a, torch.Tensor):
10 | return a.to(device)
11 | return torch.tensor(np.array(a, dtype=dtype)).to(device)
12 |
13 |
14 | def tensor2np(t):
15 | if isinstance(t, torch.Tensor):
16 | return t.detach().clone().cpu().numpy()
17 | return t
18 |
19 |
20 | def get_device():
21 | # Set the device
22 | if torch.cuda.is_available():
23 | device = torch.device("cuda:0")
24 | else:
25 | device = torch.device("cpu")
26 | print("WARNING: CPU only, this will be slow!")
27 | return device
28 |
29 |
30 | # Normalize images with the imagenet means and stds
31 | class ImageNetNormalizeTransform(transforms.Normalize):
32 | """
33 | Test case:
34 | img_tensor = transforms.Compose([transforms.ToTensor()])(img)
35 | describe_tensor(img_tensor, title="img tensor to tensor")
36 |
37 | img_tensor = transforms.Compose([
38 | transforms.ToTensor(), # Converts image to range [0, 1]
39 | ImageNetNormalizeTransformForward() # Uses ImageNet means and std
40 | ])(img)
41 | describe_tensor(img_tensor, title="img tensor")
42 |
43 | img_tensor = transforms.Compose([ImageNetNormalizeTransformInverse()])(img_tensor)
44 | describe_tensor(img_tensor, title="img")
45 | """
46 | NORMALIZE_MEAN = (0.485, 0.456, 0.406)
47 | NORMALIZE_STD = (0.229, 0.224, 0.225)
48 |
49 |
50 | class ImageNetNormalizeTransformForward(ImageNetNormalizeTransform):
51 | """
52 | Expects inputs in range [0, 1] (e.g. from transforms.ToTensor)
53 | """
54 | def __init__(self):
55 | super().__init__(mean=self.NORMALIZE_MEAN, std=self.NORMALIZE_STD)
56 |
57 |
58 | class ImageNetNormalizeTransformInverse(ImageNetNormalizeTransform):
59 | """
60 | Expects inputs as after using ImageNetNormalizeTransformForward.
61 | If this is the case, the inputs are returned in the range [0, 1] * scale
62 |
63 | Can handle tensors with and without batch dimension
64 | """
65 | def __init__(self, scale=1.0):
66 | mean = torch.as_tensor(self.NORMALIZE_MEAN)
67 | std = torch.as_tensor(self.NORMALIZE_STD)
68 | std_inv = 1 / (std + 1e-7)
69 | mean_inv = -mean * std_inv
70 | super().__init__(mean=mean_inv, std=std_inv)
71 |
72 | self.scale = scale # Can also be 255
73 |
74 | def __call__(self, tensor):
75 | if len(tensor.shape) == 3: # Image
76 | return super().__call__(tensor.clone()) * self.scale
77 | elif len(tensor.shape) == 4: # With batch dimensions
78 | results = [self(t) for t in tensor]
79 | return torch.stack(results)
80 |
81 |
82 | def theta2rotation_matrix(theta_x=0, theta_y=0, theta_z=0, theta_all=None):
83 | # Angles should be in degrees
84 | if theta_all is not None:
85 | theta = [np.deg2rad(t) for t in theta_all]
86 | else:
87 | theta = [np.deg2rad(theta_x), np.deg2rad(theta_y),
88 | np.deg2rad(-theta_z)] # x looking from bottom, y looking to right, z tilting to left
89 | R = eulerAnglesToRotationMatrix(theta)
90 | R = torch.Tensor(R).float()
91 | return R
92 |
--------------------------------------------------------------------------------
/mutil/renderer.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Union, List, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torchvision
7 | from pytorch3d.ops import interpolate_face_attributes
8 | from pytorch3d.renderer import (
9 | look_at_view_transform,
10 | FoVPerspectiveCameras,
11 | PointLights,
12 | RasterizationSettings,
13 | MeshRenderer,
14 | MeshRasterizer,
15 | SoftPhongShader,
16 | TexturesUV
17 | )
18 |
19 | from mutil.pytorch_utils import to_tensor
20 |
21 |
22 | class Renderer(torch.nn.Module):
23 | def __init__(self, dist=2, elev=0, azimuth=180, fov=40, image_size=256, R=None, T=None, cameras=None, return_format="torch", device='cuda'):
24 | super().__init__()
25 | # If you provide R and T, you don't need dist, elev, azimuth, fov
26 | self.device = device
27 | self.return_format = return_format
28 |
29 | # Data structures and functions for rendering
30 | if cameras is None:
31 | if R is None and T is None:
32 | R, T = look_at_view_transform(dist, elev, azimuth)
33 | cameras = FoVPerspectiveCameras(R=R, T=T, znear=1, zfar=10000, fov=fov, degrees=True, device=device)
34 | # cameras = PerspectiveCameras(R=R, T=T, focal_length=1.6319*10, device=device)
35 |
36 | self.raster_settings = RasterizationSettings(
37 | image_size=image_size,
38 | blur_radius=0.0, # no blur
39 | bin_size=0,
40 | )
41 | # Place lights at the same point as the camera
42 | location = T
43 | if location is None:
44 | location = ((0,0,0),)
45 | lights = PointLights(ambient_color=((0.3, 0.3, 0.3),), diffuse_color=((0.7, 0.7, 0.7),), device=device,
46 | location=location)
47 |
48 | self.mesh_rasterizer = MeshRasterizer(
49 | cameras=cameras,
50 | raster_settings=self.raster_settings
51 | )
52 | self._renderer = MeshRenderer(
53 | rasterizer=self.mesh_rasterizer,
54 | shader=SoftPhongShader(
55 | device=device,
56 | cameras=cameras,
57 | lights=lights
58 | )
59 | )
60 | self.cameras = self.mesh_rasterizer.cameras
61 |
62 | def _flatten(self, a):
63 | return torch.from_numpy(np.array(a)).reshape(-1, 1).to(self.device)
64 |
65 | def format_output(self, image_tensor, return_format=None):
66 | if return_format is None:
67 | return_format = self.return_format
68 |
69 | if return_format == "torch":
70 | return image_tensor
71 | elif return_format == "pil":
72 | if len(image_tensor.shape) == 4:
73 | vis = [self.format_output(t, return_format) for t in image_tensor]
74 | return vis
75 | else:
76 | to_pil = torchvision.transforms.ToPILImage()
77 | vis = image_tensor.detach().cpu()
78 | return to_pil(vis)
79 | elif return_format == "np_raw":
80 | return image_tensor.detach().cpu().permute(0,2,3,1).numpy()
81 | elif return_format == "np":
82 | pil_image = self.format_output(image_tensor, return_format='pil')
83 | if isinstance(pil_image, list):
84 | pil_image = [np.array(img) for img in pil_image]
85 | return np.array(pil_image)
86 |
87 |
88 | class ImagelessTexturesUV(TexturesUV):
89 | def __init__(self,
90 | faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
91 | verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
92 | padding_mode: str = "border",
93 | align_corners: bool = True,
94 | ):
95 | self.device = faces_uvs[0].device
96 | batch_size = faces_uvs.shape[0]
97 | maps = torch.zeros(batch_size, 2, 2, 3).to(self.device) # This is simply to instantiate a texture, but it is not used.
98 | super().__init__(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs, padding_mode=padding_mode, align_corners=align_corners)
99 |
100 | def sample_pixel_uvs(self, fragments, **kwargs) -> torch.Tensor:
101 | """
102 | Copied from super().sample_textures and adapted to output pixel_uvs instead of the sampled texture.
103 |
104 | Args:
105 | fragments:
106 | The outputs of rasterization. From this we use
107 |
108 | - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
109 | of the faces (in the packed representation) which
110 | overlap each pixel in the image.
111 | - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
112 | the barycentric coordianates of each pixel
113 | relative to the faces (in the packed
114 | representation) which overlap the pixel.
115 |
116 | Returns:
117 | texels: tensor of shape (N, H, W, K, C) giving the interpolated
118 | texture for each pixel in the rasterized image.
119 | """
120 | if self.isempty():
121 | faces_verts_uvs = torch.zeros(
122 | (self._N, 3, 2), dtype=torch.float32, device=self.device
123 | )
124 | else:
125 | packing_list = [
126 | i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
127 | ]
128 | faces_verts_uvs = torch.cat(packing_list)
129 | # Each vertex yields 3 triangles with u,v coordinates (N, 3, 2)
130 | # pixel_uvs: (N, H, W, K, 2)
131 | pixel_uvs = interpolate_face_attributes(
132 | fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
133 | )
134 |
135 | N, H_out, W_out, K = fragments.pix_to_face.shape
136 | # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
137 | pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
138 | pixel_uvs = pixel_uvs * 2.0 - 1.0
139 | return pixel_uvs
140 |
141 |
142 | class UVRenderer(Renderer):
143 | def __init__(self, verts_uv, faces_uv, dist=2, elev=0, azimuth=180, fov=40, image_size=256, R=None, T=None, cameras=None):
144 | # obj path should be the path to the obj file with the UV parametrization
145 | super().__init__(dist=dist, elev=elev, azimuth=azimuth, fov=fov, image_size=image_size, R=R, T=T, cameras=cameras)
146 | self.verts_uvs = verts_uv
147 | self.faces_uvs = faces_uv
148 |
149 | def render(self, meshes):
150 | batch_size = len(meshes)
151 | texture = ImagelessTexturesUV(verts_uvs=self.verts_uvs.expand(batch_size, -1, -1), faces_uvs=self.faces_uvs.expand(batch_size, -1, -1))
152 | # Currently only supports one mesh in meshes
153 | fragments = self.mesh_rasterizer(meshes)
154 | rendered_uv = texture.sample_pixel_uvs(fragments)
155 | return self.format_output(rendered_uv)
156 |
157 |
158 | class BFMUVRenderer(Renderer):
159 | def __init__(self, json_path, *args, **kwargs):
160 | # json_path = ".../face12.json"
161 | super().__init__(*args, **kwargs)
162 |
163 | with open(json_path, 'r') as f:
164 | uv_para = json.load(f)
165 | verts_uvs = np.array(uv_para['textureMapping']['pointData'])
166 | faces_uvs = np.array(uv_para['textureMapping']['triangles'])
167 |
168 | verts_uvs = to_tensor(verts_uvs).unsqueeze(0).float()
169 | faces_uvs = to_tensor(faces_uvs).unsqueeze(0).long()
170 | self.texture = ImagelessTexturesUV(verts_uvs=verts_uvs, faces_uvs=faces_uvs).to(self.device)
171 |
172 | def render(self, meshes):
173 | # Currently only supports one mesh in meshes
174 | fragments = self.mesh_rasterizer(meshes)
175 | rendered_uv = self.texture.sample_pixel_uvs(fragments)
176 | rendered_uv = rendered_uv.permute(0, 3, 1, 2) # to CHW
177 | return self.format_output(rendered_uv)
178 |
179 |
--------------------------------------------------------------------------------
/mutil/str_format.py:
--------------------------------------------------------------------------------
1 |
2 | import datetime
3 |
4 |
5 | def get_time_string(format="%Y%m%d_%H%M"):
6 | time_str = datetime.datetime.now().strftime(format)
7 | return time_str
--------------------------------------------------------------------------------
/mutil/threed_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import matplotlib as mpl
3 | import matplotlib.pyplot as plt
4 |
5 | # add path for demo utils functions
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def setup_plots():
11 | mpl.rcParams['savefig.dpi'] = 80
12 | mpl.rcParams['figure.dpi'] = 80
13 |
14 | # Set the device
15 | if torch.cuda.is_available():
16 | device = torch.device("cuda:0")
17 | else:
18 | device = torch.device("cpu")
19 | print("WARNING: CPU only, this will be slow!")
20 |
21 |
22 |
23 | def align_ls(mp_verts, bfm_verts, idx_mp, idx_bfm):
24 | tmp_mp = mp_verts.clone()[idx_mp].reshape(1, -1, 3).float()
25 | tmp_bfm = bfm_verts.clone()[idx_bfm].reshape(1, -1, 3).float()
26 | R, t, s = corresponding_points_alignment(tmp_mp, tmp_bfm, estimate_scale=True)
27 | mp_verts_aligned = s * mp_verts.mm(R[0]) + t
28 | return mp_verts_aligned, R, t, s
29 |
30 |
31 | def unwrap(vertices, faces):
32 | import trimesh
33 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) # Set process to False, o.w. will change vertex order
34 | mesh = mesh.unwrap() # This will change the vertex order
35 | uv = mesh.visual.uv
36 | return uv
37 |
38 |
39 | def plot_mesh(verts, faces, show=False):
40 | import plotly.graph_objects as go
41 | from mutil.pytorch_utils import tensor2np
42 |
43 | verts = tensor2np((verts))
44 | # faces = tensor2np(faces)
45 | x, y, z = verts[:,0], verts[:,1],\
46 | verts[:,2]
47 | # points3d(x,y,z)
48 | # i, j, k = faces[:,0], faces[:,1], faces[:,2]
49 | fig = go.Figure(data=[go.Mesh3d(x=x, y=y, z=z, opacity=0.9)])
50 | if show:
51 | plt.show()
52 | return fig
53 |
54 |
55 | def apply_Rts(vertices, R, t, s, correct_tanslation=False):
56 | assert len(vertices.shape) == 3, "Needs batch size"
57 | bs, n_vertices = vertices.shape[:2]
58 | s = s.view(bs, 1, 1).expand(bs, n_vertices, 3)
59 | t = t.view(bs, 1, 3).expand(bs, n_vertices, 3)
60 | # print('verts mean', vertices.mean(1))
61 | if correct_tanslation:
62 | # This is the center point that we rotate around
63 | center = vertices.mean(-2).unsqueeze(1).expand(vertices.shape)
64 | rotated_vertices = (vertices - center).matmul(R) + center
65 | else:
66 | rotated_vertices = vertices.matmul(R)
67 | # print('rotated', rotated_vertices.mean(1))
68 | transformed_vertices = s * rotated_vertices + t
69 | # print('transformed', transformed_vertices.mean(1))
70 | return transformed_vertices
71 |
72 |
73 | # Checks if a matrix is a valid rotation matrix.
74 | def isRotationMatrix(R, eps=1e-6):
75 | # https://www.learnopencv.com/rotation-matrix-to-euler-angles/
76 | Rt = np.transpose(R)
77 | shouldBeIdentity = np.dot(Rt, R)
78 | I = np.identity(3, dtype=R.dtype)
79 | n = np.linalg.norm(I - shouldBeIdentity)
80 | if n >= eps:
81 | raise Warning("Warning: Rt != R. delta norm: {}".format(n))
82 | return n < eps
83 |
84 |
85 | # Calculates rotation matrix to euler angles
86 | # The result is the same as MATLAB except the order
87 | # of the euler angles ( x and z are swapped ).
88 | def rotationMatrixToEulerAngles(R, eps=1e-6, degrees=False):
89 | # https://www.learnopencv.com/rotation-matrix-to-euler-angles/
90 |
91 | assert isRotationMatrix(R, eps)
92 |
93 | sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
94 |
95 | singular = sy < 1e-6
96 |
97 | if not singular:
98 | x = math.atan2(R[2, 1], R[2, 2])
99 | y = math.atan2(-R[2, 0], sy)
100 | z = math.atan2(R[1, 0], R[0, 0])
101 | else:
102 | x = math.atan2(-R[1, 2], R[1, 1])
103 | y = math.atan2(-R[2, 0], sy)
104 | z = 0
105 | result = np.array([x, y, z])
106 | if degrees:
107 | result = np.rad2deg(result)
108 | return result
109 |
110 |
111 | # Calculates Rotation Matrix given euler angles.
112 | def eulerAnglesToRotationMatrix(theta):
113 | # Angles in radians plz
114 | # theta order: x, y z
115 | # pitch, yaw, roll
116 | R_x = np.array([[1, 0, 0],
117 | [0, math.cos(theta[0]), -math.sin(theta[0])],
118 | [0, math.sin(theta[0]), math.cos(theta[0])]
119 | ])
120 |
121 | R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1])],
122 | [0, 1, 0],
123 | [-math.sin(theta[1]), 0, math.cos(theta[1])]
124 | ])
125 |
126 | R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0],
127 | [math.sin(theta[2]), math.cos(theta[2]), 0],
128 | [0, 0, 1]
129 | ])
130 | R = np.dot(R_z, np.dot(R_y, R_x))
131 |
132 | return R
133 |
134 |
135 | def view_matrix_fps(pitch, yaw, eye, project_eye=False):
136 | """
137 | Following https://www.3dgep.com/understanding-the-view-matrix/#The_View_Matrix
138 | Right-handed coordinate system
139 | if project_eye: the eye will be dotted with the axis, otherwise not.
140 |
141 | """
142 | sp = np.sin(pitch)
143 | sy = np.sin(yaw)
144 | cp = np.cos(pitch)
145 | cy = np.cos(yaw)
146 | x_axis = [cy, 0, -sy]
147 | y_axis = [sy * sp, cp, cy * sp]
148 | z_axis = [sy * cp, -sp, cp * cy]
149 | if project_eye:
150 | translation = [-np.dot(x_axis, eye), -np.dot(y_axis, eye), -np.dot(z_axis, eye)]
151 | else:
152 | translation = eye
153 | V = np.array([x_axis, y_axis, z_axis, translation]) # 4x3
154 | V = V.T # 3x4
155 | V = np.vstack((V, np.array([0, 0, 0, 1.0]).reshape(1, 4)))
156 | return V
157 |
158 |
159 | def fov2focal(fov, image_size, degree=True):
160 | """
161 | Computes focal length from similarity of triangles
162 | Args:
163 | fov: angle of view in degrees
164 | image_size: in pixels or any metric unit
165 | degree: True if fov is in degrees, else false
166 |
167 | Returns: focal length in the same unit as image_size
168 |
169 | """
170 | A = image_size / 2.0 # Half the image size
171 | a = fov / 2.0 # Half the fov angle
172 | if degree:
173 | a = np.deg2rad(a)
174 | f = A / np.tan(a) # numpy expects angles in radians
175 | return f
176 |
177 |
178 | def create_cam2world_matrix(forward_vector, origin):
179 | # Adapted from pigan repo
180 | """Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix."""
181 | def normalize_vec(v):
182 | return v / np.linalg.norm(v)
183 | forward_vector = normalize_vec(forward_vector)
184 | up_vector = np.array([0, 1.0, 0])
185 |
186 | left_vector = normalize_vec(np.cross(up_vector, forward_vector, axis=-1))
187 |
188 | up_vector = normalize_vec(np.cross(forward_vector, left_vector, axis=-1))
189 |
190 | rotation_matrix = np.eye(4)
191 | rotation_matrix[:3, :3] = np.stack((-left_vector, up_vector, -forward_vector), axis=-1)
192 |
193 | translation_matrix = np.eye(4)
194 | translation_matrix[:3, 3] = origin
195 |
196 | cam2world = translation_matrix @ rotation_matrix
197 |
198 | return cam2world
199 |
200 |
201 | def project_mesh(vertices, img, cam, color=[255, 0, 0]):
202 | """
203 |
204 | Args:
205 | vertices: Nx3 np.array
206 | img: PIL.Image (np.array also works)
207 | cam: K, P: K is the 3x3 camera matrix (maps to camera space; focal, princiapl point, skew), P is the 4x4 matrix (R|T) and maps to the vertex space (will be inverted)
208 | color:
209 |
210 | Returns: PIL.Image with projected vertices
211 |
212 | """
213 | from PIL import Image
214 | # Adapted from https://github.com/CrisalixSA/h3ds
215 |
216 | # Expand cam attributes
217 | K, P = cam # K is the 3x3 camera matrix (focal, principal point, skew), P is the 4x4 projection matrix (R|T)
218 | # P projects from camera to world space, so we need the inverse
219 | P_inv = np.linalg.inv(P)
220 |
221 | # Project mesh vertices into 2D
222 | p3d_h = np.hstack((vertices, np.ones((vertices.shape[0], 1)))) # Homogeneous
223 | p2d_h = (K @ P_inv[:3, :] @ p3d_h.T).T # Apply the camera and the world^-1 projection. Transpose result
224 | p2d = p2d_h[:, :-1] / p2d_h[:, -1:] # Divide by Z to project to image plane
225 | # The p2d now contains the indices of the values as pixels
226 |
227 | # Draw p2d to image
228 | img_proj = np.array(img)
229 | p2d = np.clip(p2d, 0, img.width - 1).astype(np.uint32) # Discretize them
230 | img_proj[p2d[:, 1], p2d[:, 0]] = color
231 |
232 | return Image.fromarray(img_proj.astype(np.uint8))
233 |
234 |
235 | def points2homo(points, times=1):
236 | if len(points.shape) == 2:
237 | points_h = np.hstack((points, np.ones((points.shape[0], times)))) # N,4
238 | elif len(points.shape) == 3:
239 | # Has a batch dimension
240 | points_h = np.dstack((points, np.ones((points.shape[0], points.shape[1], times))))
241 | else:
242 | raise Warning("Invalid shape")
243 | return points_h
244 |
245 |
246 | def scaling_transform(s):
247 | scale_transform = np.zeros((4,4))
248 | scale_transform[0,0] = s
249 | scale_transform[1,1] = s
250 | scale_transform[2,2] = s
251 | scale_transform[3,3] = 1
252 | return scale_transform
--------------------------------------------------------------------------------
/varitex/custom_callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/custom_callbacks/__init__.py
--------------------------------------------------------------------------------
/varitex/custom_callbacks/callbacks.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 |
4 | from varitex.data.keys_enum import DataItemKey as DIK
5 | from varitex.data.uv_factory import BFMUVFactory
6 | from varitex.visualization.batch import CombinedVisualizer, SampledVisualizer, UVVisualizer, InterpolationVisualizer, \
7 | NeuralTextureVisualizer
8 |
9 |
10 | class ImageLogCallback(pl.Callback):
11 | """
12 | Logs a variety of visualizations during training.
13 | """
14 |
15 | def __init__(self, opt, use_bfm_gpu=False):
16 | self.MASK_VALUE = opt.uv_mask_value
17 | self.opt = opt
18 |
19 | uv_factory_bfm = BFMUVFactory(opt, use_bfm_gpu)
20 |
21 | self.visualizer_combined = CombinedVisualizer(opt, mask_value=self.MASK_VALUE)
22 | self.visualizer_sampled = SampledVisualizer(opt, n_samples=3, mask_value=self.MASK_VALUE)
23 | self.visualizer_uv = UVVisualizer(opt, bfm_uv_factory=uv_factory_bfm, mask_value=self.MASK_VALUE)
24 | self.visualizer_interpolation = InterpolationVisualizer(opt, mask_value=self.MASK_VALUE)
25 | self.visualizer_neural_texture = NeuralTextureVisualizer(opt)
26 |
27 | def log_image(self, logger, key, vis, step):
28 | if self.opt.logger == "tensorboard":
29 | logger.experiment.add_image(key, vis, step)
30 | elif self.opt.logger == "wandb":
31 | import wandb
32 | logger.experiment.log(
33 | {key: wandb.Image(vis)}
34 | )
35 |
36 | def _combined(self, pl_module, batch, prefix):
37 | vis_combined = self.visualizer_combined.visualize(batch)
38 | self.log_image(pl_module.logger, "{}/combined".format(prefix), vis_combined, pl_module.global_step)
39 |
40 | def _sampled(self, pl_module, batch, batch_idx, prefix):
41 | vis_sampled = self.visualizer_sampled.visualize(batch, batch_idx, pl_module, std_multiplier=3)
42 | self.log_image(pl_module.logger, "{}/outputs_sampled".format(prefix), vis_sampled, pl_module.global_step)
43 |
44 | vis_sampled = self.visualizer_sampled.visualize_unseen(batch, batch_idx, pl_module, std_multiplier=2)
45 | self.log_image(pl_module.logger, "{}/outputs_sampled_gaussian_std2".format(prefix), vis_sampled,
46 | pl_module.global_step)
47 |
48 | def _posed(self, pl_module, batch, batch_idx, prefix):
49 | deg_range = torch.arange(-45, 45 + 1, 30)
50 | vis = self.visualizer_uv.visualize_grid(pl_module, batch, batch_idx, deg_range)
51 | self.log_image(pl_module.logger, "{}/outputs_posed".format(prefix), vis, pl_module.global_step)
52 |
53 | def _interpolated(self, pl_module, batch, batch_idx, prefix, n=5):
54 | batch2 = pl_module.forward_sample_style(batch.copy(), batch_idx, std_multiplier=4) # Random new style code
55 | vis = self.visualizer_interpolation.visualize(pl_module, batch, batch2, n, bidirectional=False,
56 | include_gt=False)
57 | self.log_image(pl_module.logger, "{}/interpolation/random_std2".format(prefix), vis, pl_module.global_step)
58 |
59 | batch2 = batch.copy()
60 | batch2[DIK.STYLE_LATENT] = torch.zeros_like(batch[DIK.STYLE_LATENT]).to(batch[DIK.STYLE_LATENT].device)
61 | vis = self.visualizer_interpolation.visualize(pl_module, batch, batch2, n, bidirectional=False,
62 | include_gt=False)
63 | self.log_image(pl_module.logger, "{}/interpolation/zeros".format(prefix), vis, pl_module.global_step)
64 |
65 | batch2 = batch.copy()
66 | batch2[DIK.STYLE_LATENT] = torch.randn_like(batch[DIK.STYLE_LATENT]).to(batch[DIK.STYLE_LATENT].device)
67 | vis = self.visualizer_interpolation.visualize(pl_module, batch, batch2, n, bidirectional=False,
68 | include_gt=False)
69 | self.log_image(pl_module.logger, "{}/interpolation/standard_gaussian".format(prefix), vis,
70 | pl_module.global_step)
71 |
72 | def _neural_texture(self, pl_module, batch, batch_idx, prefix):
73 | vis = self.visualizer_neural_texture.visualize_interior(batch, batch_idx)
74 | self.log_image(pl_module.logger, "{}/neural_texture/interior".format(prefix), vis,
75 | pl_module.global_step)
76 |
77 | vis = self.visualizer_neural_texture.visualize_interior_sampled(batch, batch_idx)
78 | self.log_image(pl_module.logger, "{}/neural_texture/interior_sampled".format(prefix), vis,
79 | pl_module.global_step)
80 |
81 | vis = self.visualizer_neural_texture.visualize_exterior_sampled(batch, batch_idx)
82 | self.log_image(pl_module.logger, "{}/neural_texture/exterior_sampled".format(prefix), vis,
83 | pl_module.global_step)
84 |
85 | vis = self.visualizer_neural_texture.visualize_enhanced(batch, batch_idx)
86 | self.log_image(pl_module.logger, "{}/neural_texture/enhanced".format(prefix), vis,
87 | pl_module.global_step)
88 |
89 | def log_batch(self, pl_module, batch, batch_idx, prefix):
90 | self._combined(pl_module, batch, prefix)
91 | self._sampled(pl_module, batch, batch_idx, prefix)
92 | self._interpolated(pl_module, batch, batch_idx, prefix)
93 | self._neural_texture(pl_module, batch, batch_idx, prefix)
94 | self._posed(pl_module, batch, batch_idx, prefix)
95 |
96 | def batch2gpu(self, batch):
97 | for k, v in batch.items():
98 | if isinstance(v, torch.Tensor):
99 | batch[k] = v.cuda()
100 | return batch
101 |
102 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
103 | if batch_idx % self.opt.display_freq == 0:
104 | batch = pl_module(batch, batch_idx)
105 | self.log_batch(pl_module, batch, batch_idx, "train")
106 |
107 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
108 | if batch_idx == 1:
109 | batch = pl_module(batch, batch_idx)
110 | self.log_batch(pl_module, batch, batch_idx, "val")
111 |
--------------------------------------------------------------------------------
/varitex/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/data/__init__.py
--------------------------------------------------------------------------------
/varitex/data/augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Random affine transforms used during training.
3 | """
4 | import PIL.Image
5 | import numpy
6 | import torchvision.transforms.functional as F
7 | from torchvision.transforms import RandomAffine
8 |
9 |
10 | class CustomRandomAffine(RandomAffine):
11 | def __init__(self, img_size, flip_p=0, *args, **kwargs):
12 | if not "degrees" in kwargs:
13 | kwargs["degrees"] = 0
14 | super().__init__(*args, **kwargs)
15 | self.ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
16 | self.flip = numpy.random.rand(1) < flip_p
17 |
18 | def to_pil(self, img):
19 | return PIL.Image.fromarray(img)
20 |
21 | def __call__(self, img):
22 | """
23 | img (PIL Image or Tensor): Image to be transformed.
24 |
25 | Returns:
26 | PIL Image or Tensor: Affine transformed image.
27 | """
28 | is_uv = img.shape[-1] == 2
29 | if not isinstance(img, PIL.Image.Image):
30 | if is_uv:
31 | # We convert the UV to a PIL Image such that we can apply the affine transform.
32 | img = numpy.concatenate([img, numpy.zeros((*img.shape[:2], 1))], -1)
33 | img = (img + 1) * 127.5
34 | img = img.astype(numpy.uint8)
35 | img = self.to_pil(img)
36 |
37 | if self.flip:
38 | img = F.hflip(img)
39 |
40 | transformed = F.affine(img, *self.ret,
41 | fill=self.fill)
42 | result = numpy.array(transformed)
43 | if is_uv:
44 | # Convert back to UV range.
45 | result = result[:, :, :2]
46 | result = result / 127.5 - 1
47 | return result
48 |
--------------------------------------------------------------------------------
/varitex/data/custom_dataset.py:
--------------------------------------------------------------------------------
1 | from varitex.data.dataset_specifics import FFHQ
2 |
3 |
4 | class CustomDataset:
5 | data = None
6 | indices = None
7 | N = None
8 |
9 | def __init__(self, opt, split=None, augmentation=None):
10 | self.opt = opt
11 | self.split = split if split is not None else self.opt.dataset_split
12 | self.augmentation = augmentation if augmentation is not None else self.opt.augmentation
13 |
14 | if self.opt.dataset.lower() == 'ffhq':
15 | self.initial_height, self.initial_width = FFHQ.image_height, FFHQ.image_width
16 | if self.augmentation:
17 | self.transform_params = FFHQ.get_transform_params(opt.transform_mode)
18 | else:
19 | raise NotImplementedError("Not implemented dataset '{}'".format(self.opt.dataset))
20 |
--------------------------------------------------------------------------------
/varitex/data/dataset_specifics.py:
--------------------------------------------------------------------------------
1 | from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform
2 |
3 |
4 | class FFHQ:
5 | # We load the inital images with this size. It should be the same as the predicted segmentation masks.
6 | # The original FFHQ images have resolution 1024x1024.
7 | image_width = 512
8 | image_height = 512
9 |
10 | transform_params = dict(
11 | degrees=15,
12 | translate=(0.2, 0.2),
13 | scale=(1, 1.2),
14 | flip_p=0.5,
15 | )
16 | transform_params_light = dict(
17 | degrees=5,
18 | translate=(0.1, 0.1),
19 | scale=(1, 1.2),
20 | flip_p=0.5,
21 | )
22 |
23 | @classmethod
24 | def get_transform_params(cls, mode):
25 | keys = []
26 | if mode == "all":
27 | keys = cls.transform_params.keys()
28 | else:
29 | if "d" in mode:
30 | keys.append("degrees")
31 | if "t" in mode:
32 | keys.append("translate")
33 | if "s" in mode:
34 | keys.append("scale")
35 | if "f" in mode:
36 | keys.append("flip_p")
37 | params = {k: v for k, v in cls.transform_params.items() if k in keys}
38 | return params
39 |
40 |
41 | class Camera:
42 | @staticmethod
43 | def get_camera():
44 | # Camera is at the origin, looking at the negative z axis.
45 | R, T = look_at_view_transform(eye=((0, 0, 0),), at=((0, 0, -1),), up=((0, 1, 0),))
46 | cameras = FoVPerspectiveCameras(device='cuda', R=R, T=T, fov=30, zfar=1000)
47 | return cameras
48 |
--------------------------------------------------------------------------------
/varitex/data/keys_enum.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class DataItemKey(Enum):
5 | IMAGE_IN = 1
6 | IMAGE_ENCODED = 2
7 | STYLE_LATENT = 3
8 | TEXTURE_PRIOR = 4
9 | TEXTURE_PERSON = 5
10 | FACE_FEATUREIMAGE = 6
11 | ADDITIVE_FEATUREIMAGE = 7
12 | IMAGE_OUT = 8
13 | UV_RENDERED = 9
14 | MASK_UV = 10
15 | STYLE_LATENT_MU = 12
16 | STYLE_LATENT_STD = 13
17 | FILENAME = 14
18 | COEFF_SHAPE = 15
19 | COEFF_EXPRESSION = 16
20 | SEGMENTATION_MASK = 17
21 | SEGMENTATION_PREDICTED = 18
22 | FULL_FEATUREIMAGE = 19
23 | MASK_FULL = 21
24 | IMAGE_IN_ENCODE = 26
25 | LATENT_INTERIOR = 27
26 | LATENT_EXTERIOR = 28
27 | R = 90
28 | T = 91
29 | SCALE = 92
30 |
31 | def __str__(self):
32 | return str(self._name_).lower()
33 |
--------------------------------------------------------------------------------
/varitex/data/npy_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from mutil.pytorch_utils import ImageNetNormalizeTransformForward
7 | from scipy import ndimage
8 | from skimage.morphology import convex_hull_image
9 | from torchvision import transforms
10 |
11 | from varitex.data.augmentation import CustomRandomAffine
12 | from varitex.data.keys_enum import DataItemKey as DIK
13 | from varitex.data.custom_dataset import CustomDataset
14 |
15 |
16 | class NPYDataset(CustomDataset):
17 | # Rotation, translation, scale, shape, expression, segmentation, uv
18 | keys = ["R", "t", "s", "sp", "ep", "segmentation", "uv"]
19 |
20 | def __init__(self, opt, *args, **kwargs):
21 | super().__init__(opt, *args, **kwargs)
22 | # We need to call this to set self.__len()__
23 | # h5py can only be read from multiple workers if the file is opened in sub-processes, not threads.
24 | # We need to open it later again for the __getitem()__
25 | self.image_folder = self.opt.image_folder
26 |
27 | self.dataroot_npy = opt.dataroot_npy
28 | self.data = {
29 | key: np.load(os.path.join(self.dataroot_npy, "{}.npy".format(key)), 'r') for key in self.keys
30 | }
31 |
32 | # These are strings
33 | self.data["filename"] = np.load(os.path.join(self.dataroot_npy, "filename.npy"), allow_pickle=True)
34 | self.data["dataset_splits"] = np.load(os.path.join(self.dataroot_npy, "dataset_splits.npz"), allow_pickle=True)
35 |
36 | self.set_indices()
37 |
38 | self.N = len(self.indices)
39 |
40 | assert self.opt.image_h == self.opt.image_w, "Only square images supported!"
41 |
42 | def set_indices(self, indices=None):
43 | # We can provide indices or load them from h5
44 | if indices is None:
45 | if self.split == "all":
46 | self.indices = list(range(self.data["filename"].shape[0]))
47 | elif self.split in self.data["dataset_splits"].keys():
48 | self.indices = self.data["dataset_splits"][self.split][:]
49 | else:
50 | raise Warning("Invalid split: {}".format(self.split))
51 | else:
52 | self.indices = indices
53 |
54 | @classmethod
55 | def _center_crop(cls, image, square_dim):
56 | if square_dim is None:
57 | square_dim = min(image.shape[:2])
58 | new_height, new_width = square_dim, square_dim
59 |
60 | # width, height = image.size # Get dimensions
61 | height, width = image.shape[:2]
62 | center_h = height // 2
63 | center_w = width // 2
64 |
65 | left = center_w - new_width // 2
66 | top = center_h - new_height // 2
67 | right = left + new_width
68 | bottom = top + new_height
69 |
70 | # Crop the center of the image
71 | image = image[top:bottom, left:right]
72 | return image
73 |
74 | @classmethod
75 | def _apply_transforms(cls, ndarray_in, height, width, interpolation_nearest=False, affine_transform=None,
76 | center_crop_dim=None):
77 | # Semantic masks should use 'nearest' interpolation.
78 | interpolation = cv2.INTER_NEAREST if interpolation_nearest else cv2.INTER_LINEAR
79 |
80 | if affine_transform is not None:
81 | ndarray = affine_transform(ndarray_in)
82 | else:
83 | ndarray = ndarray_in
84 |
85 | ndarray = cls._center_crop(ndarray, square_dim=center_crop_dim)
86 |
87 | if ndarray.shape[0] != height or ndarray.shape[1] != width:
88 | ndarray = cv2.resize(ndarray, (width, height), interpolation)
89 | return ndarray
90 |
91 | @classmethod
92 | def preprocess_image(cls, img, height, width, affine_transform=None, center_crop_dim=None):
93 | img = cls._apply_transforms(img, height, width, affine_transform=affine_transform,
94 | center_crop_dim=center_crop_dim)
95 |
96 | all_transforms = [
97 | transforms.ToTensor(), # Converts image to range [0, 1]
98 | ImageNetNormalizeTransformForward() # Uses ImageNet means and std
99 | ]
100 |
101 | img_tensor = transforms.Compose(all_transforms)(img)
102 | return img_tensor # img_tensor can be < 1 and > 1, but is roughly centered at 0
103 |
104 | def preprocess_segmentation(self, segmentation, height, width, affine_transform=None, mask=None,
105 | center_crop_dim=None):
106 | segmentation_np = np.zeros(segmentation.shape)
107 | for region_idx in self.opt.semantic_regions:
108 | segmentation_np[segmentation == region_idx] = 1
109 | # We now have a binary mask with 1s for all regions that we specified
110 |
111 | segmentation = self._apply_transforms(segmentation_np, height, width, interpolation_nearest=True,
112 | affine_transform=affine_transform, center_crop_dim=center_crop_dim)
113 |
114 | segmentation_tensor = torch.from_numpy(segmentation)
115 | segmentation = segmentation_tensor.unsqueeze(0)
116 |
117 | if mask is not None:
118 | segmentation[mask] = 0
119 |
120 | return segmentation.float()
121 |
122 | @classmethod
123 | def preprocess_uv(cls, uv, height, width, affine_transform=None, center_crop_dim=None):
124 | uv = cls._apply_transforms(uv, height, width, interpolation_nearest=True, affine_transform=affine_transform,
125 | center_crop_dim=center_crop_dim)
126 | # same as ToTensor, but it's better to keep this explicit
127 | uv_tensor = torch.from_numpy(uv)
128 |
129 | if not (-1 <= uv_tensor.min() and uv_tensor.max() <= 1):
130 | raise ValueError("UV not in range [-1, 1]! min: {} max: {}".format(uv.min(), uv.max()))
131 |
132 | return uv_tensor.float()
133 |
134 | @classmethod
135 | def preprocess_mask(cls, uv, height, width, affine_transform=None):
136 | # We compute the mask from the UV (invalid values will be marked as -1)
137 | uv = cls.preprocess_uv(uv, height, width, affine_transform)
138 | mask_tensor = torch.logical_and((uv[:, :, 0] != -1), (uv[:, :, 1] != -1))
139 | return mask_tensor.unsqueeze(0)
140 |
141 | @staticmethod
142 | def convex_hull_mask_tensor(uv_mask, segmentations):
143 | mask = uv_mask
144 | for semantic_mask in segmentations:
145 | mask = mask.logical_or(semantic_mask.bool())
146 |
147 | mask_np = mask.squeeze().detach().cpu().numpy()
148 | mask_image = convex_hull_image(mask_np)
149 | mask_image_tensor = torch.from_numpy(mask_image).unsqueeze(0)
150 | return mask_image_tensor
151 |
152 | def full_mask(self, uv_mask, segmentation_mask):
153 | full_mask = uv_mask.logical_or(segmentation_mask)
154 |
155 | mask_image = ndimage.binary_fill_holes(full_mask.squeeze().numpy())
156 | mask_image_tensor = torch.from_numpy(mask_image).unsqueeze(0)
157 | return mask_image_tensor
158 |
159 | def preprocess_expressions(self, frame_id):
160 | return torch.tensor(self.data['ep'][frame_id])
161 |
162 | def _read_image(self, filename, size):
163 | path_image = os.path.join(self.image_folder, "{}.png".format(filename))
164 | if not os.path.exists(path_image):
165 | raise FileNotFoundError("This image has not been found: '{}'".format(path_image))
166 | image_bgr = cv2.imread(path_image)
167 | image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
168 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
169 | return image
170 |
171 | def __getitem__(self, index):
172 | frame_id = self.indices[index] # We keep the dataset splits loaded as indices
173 | filename = self.data["filename"][frame_id]
174 | height, width = self.opt.image_h, self.opt.image_w
175 |
176 | if self.augmentation:
177 | # Create a random affine transform for each iteration, but use the same transform for all in- and outputs
178 | affine_transform = CustomRandomAffine([self.initial_width, self.initial_height], **self.transform_params)
179 | else:
180 | affine_transform = None
181 | try:
182 | image_raw = self._read_image(filename, (self.initial_width, self.initial_height))
183 | uv = self.data["uv"][frame_id].astype(np.float32) # opencv expects float 32
184 | segmentation = self.data["segmentation"][frame_id]
185 |
186 | img_tensor_clean = self.preprocess_image(image_raw.copy(), height, width)
187 | img_tensor = self.preprocess_image(image_raw.copy(), height, width, affine_transform)
188 | uv_tensor = self.preprocess_uv(uv.copy(), height, width, affine_transform)
189 | mask_tensor = self.preprocess_mask(uv.copy(), height, width,
190 | affine_transform) # We compute the mask from the uv
191 |
192 | segmentation_tensor = self.preprocess_segmentation(segmentation.copy(), height, width, affine_transform)
193 |
194 | mask_full_tensor = self.full_mask(mask_tensor, segmentation_tensor)
195 |
196 | if not self.opt.keep_background:
197 | # We remove the background to focus the network capacity on the relevant regions
198 | if self.opt.bg_color == 'black':
199 | img_tensor[~mask_full_tensor.expand_as(img_tensor)] = img_tensor.min()
200 | else:
201 | img_tensor[~mask_full_tensor.expand_as(img_tensor)] = 0.0
202 |
203 | # We need to compute the full mask on the non augmented variant such that we can mask the encoding image
204 | mask_tensor_clean = self.preprocess_mask(uv.copy(), height, width) # We compute the mask from the uv
205 | segmentation_tensor_clean = self.preprocess_segmentation(segmentation.copy(),
206 | height, width)
207 | mask_full_tensor_clean = self.full_mask(mask_tensor_clean, segmentation_tensor_clean)
208 | if self.opt.bg_color == 'black':
209 | img_tensor_clean[~mask_full_tensor_clean.expand_as(img_tensor_clean)] = img_tensor_clean.min()
210 | else:
211 | img_tensor_clean[~mask_full_tensor_clean.expand_as(img_tensor_clean)] = 0
212 | #
213 | return {
214 | DIK.IMAGE_IN_ENCODE: img_tensor_clean,
215 | DIK.IMAGE_IN: img_tensor,
216 | DIK.SEGMENTATION_MASK: segmentation_tensor,
217 | DIK.UV_RENDERED: uv_tensor,
218 | DIK.MASK_UV: mask_tensor,
219 | DIK.MASK_FULL: mask_full_tensor, # Used to mask input before encoding
220 | DIK.FILENAME: filename,
221 | DIK.COEFF_SHAPE: self.data["sp"][frame_id].copy().astype(np.float32),
222 | DIK.COEFF_EXPRESSION: self.data["ep"][frame_id].copy().astype(np.float32),
223 | DIK.R: torch.from_numpy(self.data["R"][frame_id].copy()).float(),
224 | DIK.T: torch.from_numpy(self.data["t"][frame_id].copy()).float(),
225 | DIK.SCALE: torch.from_numpy(self.data["s"][frame_id].copy()).float()
226 | }
227 | except ValueError as e:
228 | print("Value error when processing index {}".format(index))
229 | print(e)
230 | exit(0)
231 |
232 | def __len__(self):
233 | return self.N
234 |
235 | def get_unsqueezed(self, index, device='cuda'):
236 | batch = self[index]
237 | batch_unsqueeze = {}
238 | for k, v in batch.items():
239 | if hasattr(v, "unsqueeze"):
240 | batch_unsqueeze[k] = v.unsqueeze(0).to(device)
241 | elif k == DIK.FILENAME:
242 | batch_unsqueeze[k] = [v]
243 | else:
244 | batch_unsqueeze[k] = torch.Tensor(v).unsqueeze(0).to(device)
245 |
246 | return batch_unsqueeze
247 |
248 | def get_raw_image(self, index):
249 | frame_id = self.indices[index] # We keep the dataset splits loaded as indices
250 | filename = self.data["filename"][frame_id]
251 | image_raw = self._read_image(filename, (self.initial_width, self.initial_height))
252 |
253 | height, width = self.opt.image_h, self.opt.image_w
254 | return self._apply_transforms(image_raw, height, width)
255 |
--------------------------------------------------------------------------------
/varitex/data/uv_factory.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import torch
5 | from mutil.bfm2017 import BFM2017Tensor, BFM2017
6 | from mutil.pytorch_utils import to_tensor, theta2rotation_matrix
7 | from mutil.renderer import UVRenderer
8 | from mutil.threed_utils import apply_Rts
9 | from pytorch3d.structures import Meshes
10 |
11 | from varitex.data.dataset_specifics import Camera
12 | from varitex.modules.custom_module import CustomModule
13 |
14 |
15 | class BFMUVFactory(CustomModule):
16 | """
17 | facemodels_path = ...
18 | path_bfm = os.path.join(facemodels_path, "basel_facemodel/model2017-1_face12_nomouth.h5")
19 | path_uv = os.path.join(facemodels_path, "basel_facemodel/face12.json")
20 | factory = BFMUVFactory(path_bfm, path_uv, image_size=256)
21 | """
22 |
23 | def __init__(self, opt, use_bfm_gpu=False):
24 | super().__init__(opt)
25 | self.to(opt.device)
26 |
27 | # This factory uses the Basel Face Model
28 | self.use_bfm_gpu = use_bfm_gpu # Loads the BFM into GPU for faster mesh generation
29 | if self.use_bfm_gpu:
30 | self.bfm = BFM2017Tensor(self.opt.path_bfm, device=opt.device)
31 | else:
32 | self.bfm = BFM2017(self.opt.path_bfm, self.opt.path_uv)
33 | self.faces_tensor = to_tensor(self.bfm.faces, self.device).unsqueeze(0)
34 |
35 | verts_uv, faces_uv = self.load_verts_faces_uv(opt.path_uv)
36 | self.uv_renderer = UVRenderer(verts_uv, faces_uv, image_size=self.opt.image_h, cameras=Camera.get_camera())
37 |
38 | def unsqueeze_transform(self, R, t, s):
39 | if len(R.shape) == 2:
40 | R, t, s = R.unsqueeze(0), t.unsqueeze(0), s.unsqueeze(0)
41 | if len(t.shape) == 1:
42 | t = t.unsqueeze(0)
43 | if len(s.shape) == 1:
44 | s = s.unsqueeze(0)
45 | R, t, s = R.float().to(self.device), t.float().to(self.device), s.float().to(self.device)
46 | return R, t, s
47 |
48 | def load_verts_faces_uv(self, json_path):
49 | with open(json_path, 'r') as f:
50 | uv_para = json.load(f)
51 | verts_uvs = np.array(uv_para['textureMapping']['pointData'])
52 | faces_uvs = np.array(uv_para['textureMapping']['triangles'])
53 |
54 | verts_uvs = to_tensor(verts_uvs, self.device, dtype=float).unsqueeze(0).float()
55 | faces_uvs = to_tensor(faces_uvs, self.device, dtype=float).unsqueeze(0).long()
56 |
57 | return verts_uvs, faces_uvs
58 |
59 | def generate_vertices(self, sp, ep):
60 | # sp, ep need to have a batch dimension
61 | batch_size = ep.shape[0]
62 | if self.use_bfm_gpu:
63 | vertices = [self.bfm.generate_vertices(sp[i], ep[i]) for i in range(batch_size)]
64 | else:
65 | vertices = [self.bfm.generate_vertices(sp[i].detach().cpu().numpy(), ep[i].detach().cpu().numpy()) for i in
66 | range(batch_size)]
67 | vertices = [to_tensor(v, self.device) for v in vertices]
68 | vertices = torch.stack(vertices)
69 | return vertices
70 |
71 | def get_posed_meshes(self, sp, ep, R, t, s, batch_size, correct_translation):
72 | # Generate the vertices
73 | vertices = self.generate_vertices(sp, ep)
74 | # Match the batch size
75 | faces = self.faces_tensor.expand(batch_size, self.bfm.N_FACES, 3)
76 | vertices = apply_Rts(vertices, R, t, s, correct_tanslation=correct_translation)
77 | meshes = Meshes(verts=vertices, faces=faces)
78 | return meshes
79 |
80 | def getUV(self, R=torch.eye(3).unsqueeze(0), t=torch.Tensor((0, 0, -35)).unsqueeze(0),
81 | s=torch.Tensor((0.1,)).unsqueeze(0), sp=torch.zeros((199,)).unsqueeze(0),
82 | ep=torch.zeros((100,)).unsqueeze(0), correct_translation=True):
83 | R, t, s = self.unsqueeze_transform(R, t, s)
84 | assert len(sp.shape) == 2 and sp.shape[-1] == 199, "Should come in shape (batch_size, 199), but is {}".format(
85 | sp.shape)
86 | assert len(ep.shape) == 2 and ep.shape[-1] == 100, "Should come in shape (batch_size, 100), but is {}".format(
87 | ep.shape)
88 | batch_size = R.shape[0]
89 | meshes = self.get_posed_meshes(sp, ep, R, t, s, batch_size, correct_translation)
90 | uv = self.uv_renderer.render(meshes)
91 | return uv
92 |
93 | def get_sampled_uvs_shape(self, n, std_multiplier=1, ep=np.zeros(100, )):
94 | shapes = self.bfm.sample_shape(n, std_multiplier)
95 | uv_list = [self.getUV(sp, ep) for sp in shapes]
96 | return uv_list
97 |
98 | def get_sampled_uvs_expression(self, n, std_multiplier=1, sp=np.zeros(199, )):
99 | expressions = self.bfm.sample_expression(n, std_multiplier)
100 | uv_list = [self.getUV(sp, ep) for ep in expressions]
101 | return uv_list
102 |
103 | def get_posed_uvs(self, sp=torch.zeros((199,)), ep=torch.zeros((100,)), deg_range=torch.arange(-30, 31, 15),
104 | t=torch.Tensor((0, 0, -35)), s=torch.Tensor((0.1,))):
105 | uv_list = list()
106 | for theta_y in deg_range:
107 | for theta_x in deg_range:
108 | theta = theta_x, theta_y, 0
109 | R = theta2rotation_matrix(theta_all=theta)
110 | uv = self.getUV(sp, ep, R, t, s)
111 | uv_list.append(uv)
112 | return uv_list
113 |
--------------------------------------------------------------------------------
/varitex/demo.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import torch
3 |
4 | # try:
5 | from mutil.object_dict import ObjectDict
6 | from varitex.data.keys_enum import DataItemKey as DIK
7 | from varitex.data.uv_factory import BFMUVFactory
8 | from varitex.modules.pipeline import PipelineModule
9 | from varitex.visualization.batch import CompleteVisualizer
10 | from varitex.options import varitex_default_options
11 | # except ModuleNotFoundError as e:
12 | # print(e)
13 | # print("Have you added VariTex to your pythonpath?")
14 | # print('To fix this error, go to the root path of the repository ".../VariTex/" \n '
15 | # 'and run \n'
16 | # "export PYTHONPATH=$PYTHONPATH:$(pwd)")
17 | # exit()
18 |
19 |
20 | class Demo:
21 | def __init__(self, opt):
22 | default_opt = varitex_default_options()
23 | default_opt.update(opt)
24 | self.opt = ObjectDict(default_opt)
25 |
26 | uv_factory_bfm = BFMUVFactory(opt=self.opt, use_bfm_gpu=self.opt.device == 'cuda')
27 | self.visualizer_complete = CompleteVisualizer(opt=self.opt, bfm_uv_factory=uv_factory_bfm)
28 |
29 | self.pipeline = PipelineModule.load_from_checkpoint(self.opt.checkpoint, opt=self.opt, strict=False).to(
30 | self.opt.device).eval()
31 | self.device = self.pipeline.device
32 |
33 | def run(self, z, sp, ep, theta, t=torch.Tensor([0, -2, -57])):
34 | batch = {
35 | DIK.STYLE_LATENT: z,
36 | DIK.COEFF_SHAPE: sp,
37 | DIK.COEFF_EXPRESSION: ep,
38 | DIK.T: t
39 | }
40 | batch = {k: v.to(self.device) for k, v in batch.items()}
41 |
42 | batch = self.visualizer_complete.visualize_single(self.pipeline, batch, 0, theta_all=theta,
43 | forward_type='style2image')
44 | batch = {k: v.detach().cpu() for k, v in batch.items()}
45 | return batch
46 |
47 | def to_image(self, batch_or_batch_list):
48 | if isinstance(batch_or_batch_list, list):
49 | out = torch.cat([batch_out[DIK.IMAGE_OUT][0] for batch_out in batch_or_batch_list], -1)
50 | elif isinstance(batch_or_batch_list, dict):
51 | out = batch_or_batch_list[DIK.IMAGE_OUT][0]
52 | else:
53 | raise Warning("Invalid type: '{}'".format(type(batch_or_batch_list)))
54 | return self.visualizer_complete.tensor2image(out, return_format='pil')
55 |
56 | def to_video(self, batch_list, path_out, fps=15, quality=9, reverse=False):
57 | assert path_out.endswith(".mp4"), "Path should end with .mp4"
58 | frames = [self.to_image(batch) for batch in batch_list]
59 | if reverse:
60 | frames = frames + frames[::-1]
61 | imageio.mimwrite(path_out, frames, fps=fps, quality=quality)
62 |
63 | def load_shape_expressions(self):
64 | import numpy as np
65 | import os
66 | validation_indices = list(np.load(os.path.join(self.opt.dataroot_npy, "dataset_splits.npz"))["val"])
67 | sp = np.load(os.path.join(self.opt.dataroot_npy, "sp.npy"))[validation_indices]
68 | ep = np.load(os.path.join(self.opt.dataroot_npy, "ep.npy"))[validation_indices]
69 | return torch.Tensor(sp), torch.Tensor(ep)
70 |
--------------------------------------------------------------------------------
/varitex/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/evaluation/__init__.py
--------------------------------------------------------------------------------
/varitex/evaluation/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from mutil.files import mkdir
6 | from mutil.pytorch_utils import theta2rotation_matrix
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 |
10 | from varitex.data.keys_enum import DataItemKey as DIK
11 | from varitex.data.npy_dataset import NPYDataset
12 | from varitex.data.uv_factory import BFMUVFactory
13 | from varitex.modules.pipeline import PipelineModule
14 | from varitex.visualization.batch import Visualizer
15 |
16 |
17 | def get_model(opt):
18 | model = PipelineModule.load_from_checkpoint(opt.checkpoint, opt=opt, strict=False)
19 | model = model.eval()
20 | model = model.cuda()
21 | return model
22 |
23 |
24 | def inference_ffhq(opt, results_folder, n=3000):
25 | """
26 | Runs inference on FFHQ. Save the resulting images in the results_folder.
27 | Also saves the latent codes and distributions.
28 | """
29 | print("Running inference on FFHQ. Using the extracted face model parameters and poses, and predicted latent codes.")
30 | mkdir(results_folder)
31 | visualizer = Visualizer(opt, return_format='pil')
32 |
33 | dataset = NPYDataset(opt, augmentation=False, split='val')
34 | dataloader = iter(DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False))
35 | model = get_model(opt)
36 | visualizer.mask_key = DIK.SEGMENTATION_PREDICTED
37 |
38 | result = list()
39 |
40 | for i, batch in tqdm(enumerate(dataloader)):
41 | batch = model.forward(batch, i, std_multiplier=0)
42 | file_id = batch[DIK.FILENAME][0]
43 | latent_code = batch[DIK.STYLE_LATENT_MU].detach().cpu().numpy()[0]
44 | latent_std = batch[DIK.STYLE_LATENT_STD].detach().cpu().numpy()[0]
45 | result.append([latent_code, latent_std])
46 |
47 | out = visualizer.tensor2image(batch[DIK.IMAGE_OUT][0], batch=batch)
48 | out = visualizer.mask(out, batch, white_bg=False)
49 |
50 | out = visualizer.format_output(out, return_format='pil')
51 | out.save(os.path.join(results_folder, "{}.png".format(file_id)))
52 |
53 | if i >= n:
54 | break
55 |
56 | result = np.array(result)
57 | np.save(os.path.join(results_folder, "latents.npy"), result)
58 | print("Done")
59 |
60 |
61 | def inference_posed(opt, results_folder, path_latent, n=3000):
62 | """
63 | Runs different poses for each latent code extracted from the FFHQ dataset.
64 | Uses random shape and expression parameters.
65 | """
66 |
67 | print("Running inference on FFHQ. Using random face model parameters, various poses, and predicted latent codes. This can be slow.")
68 | visualizer = Visualizer(opt, return_format='pil')
69 | visualizer.mask_key = DIK.SEGMENTATION_PREDICTED
70 |
71 | # We use the latent codes extracted from FFHQ, but they could also be sampled from the extracted distributions.
72 | latents = np.load(path_latent)[:, 0] # We want to use the distribution means directly (index 0)
73 |
74 | mkdir(results_folder)
75 |
76 | model = get_model(opt)
77 |
78 | all_pose_axis = [[0], [1], [0, 1]] # List of all axis: 0 is pitch, 1 is yaw
79 |
80 | uv_factory_bfm = BFMUVFactory(opt, use_bfm_gpu=True)
81 |
82 | s = torch.Tensor([0.1])
83 | t = torch.Tensor([[0, -2, -57]])
84 | results = list()
85 | all_poses = np.arange(-45, 46, 15)
86 |
87 | with torch.no_grad():
88 | for i in tqdm(range(n)):
89 | sp = uv_factory_bfm.bfm.sample_shape(1)
90 | ep = uv_factory_bfm.bfm.sample_expression(1)
91 | latent = torch.Tensor([latents[i]])
92 |
93 | for pose_axis in all_pose_axis:
94 | for pose in all_poses:
95 | theta = torch.Tensor([0, 0, 0])
96 | theta[pose_axis] = pose
97 | R = theta2rotation_matrix(theta_all=theta).unsqueeze(0)
98 | uv_tensor = uv_factory_bfm.getUV(R, t, s, sp, ep, correct_translation=True)
99 |
100 | batch = {
101 | DIK.UV_RENDERED: uv_tensor,
102 | DIK.R: R,
103 | DIK.T: t,
104 | DIK.SCALE: s,
105 | DIK.STYLE_LATENT: latent
106 | }
107 | batch = {k: v.cuda() for k, v in batch.items()}
108 | batch_out = model.forward_latent2image(batch, 0)
109 | out = visualizer.tensor2image(batch_out[DIK.IMAGE_OUT][0], batch=batch_out, white_bg=False)
110 | results.append(out)
111 | out = visualizer.format_output(out, return_format='pil')
112 |
113 | folder_out = os.path.join(results_folder, "axis{}_theta{}".format(pose_axis, pose))
114 | mkdir(folder_out)
115 | out.save(os.path.join(results_folder, folder_out, "{:05d}.png".format(i)))
116 |
117 |
118 | def inference_posed_ffhq(opt, results_folder, n=3000):
119 | """
120 | Runs different poses for each latent code extracted from the FFHQ dataset.
121 | Uses shape and expression parameters from validation set.
122 | """
123 | print("Running inference on FFHQ. Using the extracted face model parameters, various poses, and predicted latent codes. This can be slow.")
124 | mkdir(results_folder)
125 | visualizer = Visualizer(opt, return_format='pil')
126 |
127 | dataset = NPYDataset(opt, augmentation=False, split='val')
128 | dataloader = iter(DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False))
129 | model = get_model(opt)
130 | visualizer.mask_key = DIK.SEGMENTATION_PREDICTED
131 |
132 | n = min(n, 3000)
133 | all_pose_axis = [[0], [1], [0, 1]]
134 |
135 | uv_factory_bfm = BFMUVFactory(opt, use_bfm_gpu=True)
136 | s = torch.Tensor([0.1])
137 | t = torch.Tensor([[0, -2, -57]])
138 |
139 | all_poses = np.arange(-45, 46, 15)
140 |
141 | for i, batch_orig in tqdm(enumerate(dataloader)):
142 | batch_orig = model.forward(batch_orig, i, std_multiplier=0)
143 | file_id = batch_orig[DIK.FILENAME][0]
144 | sp = batch_orig[DIK.COEFF_SHAPE].clone()
145 | ep = batch_orig[DIK.COEFF_EXPRESSION].clone()
146 | latent = batch_orig[DIK.STYLE_LATENT].clone()
147 |
148 | for pose_axis in all_pose_axis:
149 | for pose in all_poses:
150 | theta = torch.Tensor([0, 0, 0])
151 | theta[pose_axis] = pose
152 | R = theta2rotation_matrix(theta_all=theta).unsqueeze(0)
153 | uv_tensor = uv_factory_bfm.getUV(R, t, s, sp, ep, correct_translation=True)
154 |
155 | batch = {
156 | DIK.UV_RENDERED: uv_tensor,
157 | DIK.R: R,
158 | DIK.T: t,
159 | DIK.SCALE: s,
160 | DIK.STYLE_LATENT: latent
161 | }
162 | batch = {k: v.cuda() for k, v in batch.items()}
163 | batch_out = model.forward_latent2image(batch, 0)
164 | out = visualizer.tensor2image(batch_out[DIK.IMAGE_OUT][0], batch=batch_out, white_bg=False)
165 | out = visualizer.format_output(out, return_format='pil')
166 |
167 | folder_out = os.path.join(results_folder, "axis{}_theta{}".format(pose_axis, pose))
168 | mkdir(folder_out)
169 | out.save(os.path.join(results_folder, folder_out, "{}.png".format(file_id)))
170 |
171 | if i >= n:
172 | break
173 |
174 | print("Done")
175 |
--------------------------------------------------------------------------------
/varitex/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | CP=PATH/TO/CHECKPOINT.ckpt
3 | CP=$OP/own/LDS/pretrained/checkpoints/ep44.ckpt
4 | CUDA_VISIBLE_DEVICES=0 python varitex/inference.py --checkpoint $CP --dataset_split val
5 | """
6 | import os
7 |
8 | import pytorch_lightning as pl
9 | from mutil.object_dict import ObjectDict
10 |
11 | try:
12 | from varitex.options import varitex_default_options
13 | from varitex.evaluation import inference
14 | from varitex.options.eval_options import EvalOptions
15 | except ModuleNotFoundError:
16 | print("Have you added VariTex to your pythonpath?")
17 | print('To fix this error, go to the root path of the repository ".../VariTex/" \n '
18 | 'and run \n'
19 | "export PYTHONPATH=$PYTHONPATH:$(pwd)")
20 | exit()
21 |
22 | if __name__ == "__main__":
23 | pl.seed_everything(1234)
24 | opt = EvalOptions().parse().__dict__
25 |
26 | default_opt = varitex_default_options()
27 | default_opt.update(opt)
28 | opt = ObjectDict(default_opt)
29 | assert opt.checkpoint is not None, "Please specify a checkpoint file."
30 |
31 | checkpoint_folder = os.path.dirname(opt.checkpoint)
32 |
33 | # Runs inference on FFHQ
34 | inference.inference_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_ffhq'))
35 |
36 | # # Runs different poses for each sample in FFHQ (shape and expressions extracted from FFHQ)
37 | inference.inference_posed_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_ffhq'))
38 | #
39 | # # Runs different poses for each sample in FFHQ (random shape and expressions).
40 | # latents.npy should contain the latent distribuations predicted from the holdout set, see inference.inference_ffhq(...).
41 | inference.inference_posed(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_random'),
42 | path_latent=os.path.join(opt.path_out, 'inference_ffhq/latents.npy'))
43 |
--------------------------------------------------------------------------------
/varitex/inference_surface.py:
--------------------------------------------------------------------------------
1 | """
2 | CP=PATH/TO/CHECKPOINT.ckpt
3 | CP=$OP/own/LDS/pretrained/checkpoints/ep44.ckpt
4 | CUDA_VISIBLE_DEVICES=0 python varitex/inference.py --checkpoint $CP --dataset_split val
5 | """
6 | import os
7 |
8 | import pytorch_lightning as pl
9 | from mutil.object_dict import ObjectDict
10 |
11 | try:
12 | from varitex.options import varitex_default_options
13 | from varitex.evaluation import inference
14 | from varitex.options.eval_options import EvalOptions
15 | except ModuleNotFoundError:
16 | print("Have you added VariTex to your pythonpath?")
17 | print('To fix this error, go to the root path of the repository ".../VariTex/" \n '
18 | 'and run \n'
19 | "export PYTHONPATH=$PYTHONPATH:$(pwd)")
20 | exit()
21 |
22 | if __name__ == "__main__":
23 | pl.seed_everything(1234)
24 | opt = EvalOptions().parse().__dict__
25 |
26 | default_opt = varitex_default_options()
27 | default_opt.update(opt)
28 | opt = ObjectDict(default_opt)
29 | assert opt.checkpoint is not None, "Please specify a checkpoint file."
30 |
31 | checkpoint_folder = os.path.dirname(opt.checkpoint)
32 |
33 | # Runs inference on FFHQ
34 | inference.inference_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_ffhq'))
35 |
36 | # # Runs different poses for each sample in FFHQ (shape and expressions extracted from FFHQ)
37 | inference.inference_posed_ffhq(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_ffhq'))
38 | #
39 | # # Runs different poses for each sample in FFHQ (random shape and expressions).
40 | # latents.npy should contain the latent distribuations predicted from the holdout set, see inference.inference_ffhq(...).
41 | inference.inference_posed(opt, n=30, results_folder=os.path.join(opt.path_out, 'inference_posed_random'),
42 | path_latent=os.path.join(opt.path_out, 'inference_ffhq/latents.npy'))
43 |
--------------------------------------------------------------------------------
/varitex/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mcbuehler/VariTex/a25979f96d600ed839799450958ace8952ce375a/varitex/modules/__init__.py
--------------------------------------------------------------------------------
/varitex/modules/custom_module.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 |
3 |
4 | class CustomModule(pl.LightningModule):
5 | def __init__(self, opt):
6 | super().__init__()
7 | self.opt = opt
8 | self.to(opt.device)
9 |
--------------------------------------------------------------------------------
/varitex/modules/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pl_bolts.models.autoencoders.components import (
3 | DecoderBlock, Interpolate, resize_conv1x1, resize_conv3x3
4 | )
5 | from torch import nn
6 |
7 | from varitex.data.keys_enum import DataItemKey as DIK
8 | from varitex.modules.custom_module import CustomModule
9 |
10 |
11 | class Decoder(CustomModule):
12 | """
13 | Decoder for the neural face texture.
14 | """
15 |
16 | def __init__(self, opt):
17 | super().__init__(opt)
18 |
19 | latent_dim_face = opt.latent_dim // 2
20 |
21 | input_height = opt.texture_dim
22 | # Resnet-18 variant
23 | self.decoder = SimpleResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim_face, input_height,
24 | nc_texture=opt.texture_nc)
25 |
26 | def forward(self, batch, batch_idx):
27 | style = batch[DIK.LATENT_INTERIOR]
28 | texture = self.decoder(style)
29 | batch[DIK.TEXTURE_PERSON] = texture
30 | return batch
31 |
32 |
33 | class AdditiveDecoder(CustomModule):
34 | """
35 | Additive decoder. Produces the additive feature image.
36 | """
37 |
38 | def __init__(self, opt):
39 | super().__init__(opt)
40 | latent_dim_additive = opt.latent_dim // 2
41 |
42 | input_height = opt.texture_dim
43 | out_dim = opt.texture_nc
44 | condition_dim = opt.texture_nc # We condition on a neural texture for the face interior
45 | # Resnet-18 variant
46 | self.decoder = EarlyConditionedSimpleRestNetDecocer(ConditionedDecoderBlock, [2, 2, 2, 2], latent_dim_additive,
47 | input_height, nc_texture=out_dim,
48 | condition_dim=condition_dim)
49 |
50 | def forward(self, batch, batch_idx):
51 | latent = batch[DIK.LATENT_EXTERIOR]
52 |
53 | # The sampled texture should already have been masked to the face interior
54 | # Errors should not back-propagate through the sampled face interior texture
55 | condition = batch[DIK.FACE_FEATUREIMAGE].detach()
56 | additive_featureimage = self.decoder(latent, condition)
57 |
58 | batch[DIK.ADDITIVE_FEATUREIMAGE] = additive_featureimage # This also has features for the interior
59 | return batch
60 |
61 |
62 | class SimpleResNetDecoder(nn.Module):
63 | """
64 | Resnet in reverse order.
65 | Most code from pl_bolts.models.autoencoders.components.
66 | """
67 |
68 | def __init__(self, block, layers, latent_dim, input_height, nc_texture, first_conv=False, maxpool1=False):
69 | super().__init__()
70 |
71 | self.expansion = block.expansion
72 | self.inplanes = 512 * block.expansion
73 | self.first_conv = first_conv
74 | self.maxpool1 = maxpool1
75 | self.input_height = input_height
76 |
77 | self.upscale_factor = 8
78 |
79 | self.linear = nn.Linear(latent_dim, self.inplanes * 4 * 4)
80 |
81 | self.initial = self._make_layer(block, 256, layers[0], scale=2)
82 |
83 | self.layer0 = self._make_layer(block, 256, layers[0], scale=2)
84 | self.layer1 = self._make_layer(block, 256, layers[0], scale=2)
85 | self.layer2 = self._make_layer(block, 128, layers[1], scale=2)
86 | self.layer3 = self._make_layer(block, 64, layers[2], scale=2)
87 |
88 | if self.input_height == 128:
89 | self.layer4 = self._make_layer(block, 64, layers[3])
90 | elif self.input_height == 256:
91 | self.layer4 = self._make_layer(block, 64, layers[3], scale=2)
92 | else:
93 | raise Warning("Invalid input height: '{}".format(self.input_height))
94 |
95 | if self.first_conv:
96 | self.upscale = Interpolate(scale_factor=2)
97 | self.upscale_factor *= 2
98 | else:
99 | self.upscale = Interpolate(scale_factor=1)
100 |
101 | # interpolate after linear layer using scale factor
102 | self.upscale1 = Interpolate(size=input_height // self.upscale_factor)
103 |
104 | self.conv1 = nn.Conv2d(
105 | 64 * block.expansion, nc_texture, kernel_size=3, stride=1, padding=1, bias=False
106 | )
107 |
108 | def _make_layer(self, block, planes, blocks, scale=1):
109 | """
110 |
111 | Args:
112 | block:
113 | planes: int number of channels
114 | blocks: int number of blocks (e.g. 2)
115 | scale:
116 |
117 | Returns:
118 |
119 | """
120 | upsample = None
121 | if scale != 1 or self.inplanes != planes * block.expansion:
122 | upsample = nn.Sequential(
123 | resize_conv1x1(self.inplanes, planes * block.expansion, scale),
124 | nn.BatchNorm2d(planes * block.expansion),
125 | )
126 |
127 | layers = []
128 | layers.append(block(self.inplanes, planes, scale, upsample))
129 | self.inplanes = planes * block.expansion
130 | for _ in range(1, blocks):
131 | layers.append(block(self.inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | x = self.linear(x)
137 |
138 | x = x.view(x.size(0), 512 * self.expansion, 4, 4)
139 | x = self.initial(x)
140 |
141 | x = self.layer0(x)
142 | x = self.layer1(x)
143 | x = self.layer2(x)
144 | x = self.layer3(x)
145 | x = self.layer4(x)
146 |
147 | x = self.conv1(x)
148 | return x
149 |
150 |
151 | class EarlyConditionedSimpleRestNetDecocer(nn.Module):
152 | """
153 | Modified from pl_bolts.models.autoencoders.components.
154 | """
155 |
156 | def __init__(self, block, layers, latent_dim, input_height, nc_texture, first_conv=False, maxpool1=False,
157 | condition_dim=16, **kwargs):
158 | # TODO: refactor duplicated code
159 | super().__init__()
160 |
161 | self.condition_dim = condition_dim
162 |
163 | self.expansion = block.expansion
164 | self.inplanes = 512 * block.expansion
165 | self.first_conv = first_conv
166 | self.maxpool1 = maxpool1
167 | self.input_height = input_height
168 |
169 | self.linear = nn.Linear(latent_dim, self.inplanes)
170 |
171 | self.initial = self._make_layer(block, 512, layers[0], scale=2, condition_dim=self.condition_dim)
172 |
173 | self.layer0 = self._make_layer(block, 256, layers[0], scale=2, condition_dim=self.condition_dim)
174 | self.layer1 = self._make_layer(block, 256, layers[0], scale=2, condition_dim=self.condition_dim)
175 | self.layer2 = self._make_layer(block, 128, layers[1], scale=2, condition_dim=self.condition_dim)
176 | self.layer3 = self._make_layer(block, 64, layers[2], scale=2, condition_dim=self.condition_dim)
177 |
178 | if self.input_height == 128:
179 | self.layer4 = self._make_layer(block, 64, layers[3], condition_dim=self.condition_dim)
180 | elif self.input_height == 256:
181 | self.layer4 = self._make_layer(block, 64, layers[3], scale=2, condition_dim=self.condition_dim)
182 | else:
183 | raise Warning("Invalid input height: '{}".format(self.input_height))
184 |
185 | if self.first_conv:
186 | self.upscale = Interpolate(scale_factor=2)
187 | self.upscale_factor *= 2
188 |
189 | self.conv1 = nn.Conv2d(
190 | 64 * block.expansion, nc_texture, kernel_size=3, stride=1, padding=1, bias=False
191 | )
192 |
193 | def _make_layer(self, block, planes, blocks, scale=1, condition_dim=16):
194 | """
195 |
196 | Args:
197 | block:
198 | planes: int number of channels
199 | blocks: int number of blocks (e.g. 2)
200 | scale:
201 |
202 | Returns:
203 |
204 | """
205 | upsample = None
206 | if scale != 1 or self.inplanes != planes * block.expansion:
207 | upsample = nn.Sequential(
208 | resize_conv1x1(self.inplanes, planes * block.expansion, scale),
209 | nn.BatchNorm2d(planes * block.expansion),
210 | )
211 |
212 | layers = []
213 | layers.append(block(self.inplanes, planes, scale, upsample, condition_dim=condition_dim))
214 | self.inplanes = planes * block.expansion
215 | for _ in range(1, blocks):
216 | layers.append(block(self.inplanes, planes, condition_dim=condition_dim))
217 |
218 | return nn.Sequential(*layers)
219 |
220 | def forward(self, x, condition):
221 | x = self.linear(x)
222 | # We now have 512 feature maps with the same value in all spatial locations
223 | # self.inplanes changes when creating blocks
224 | x = x.view(x.shape[0], 512, 1, 1).expand(-1, -1, 4, 4)
225 |
226 | in_dict = {"features": x, "condition": condition}
227 | out_dict = self.initial(in_dict)
228 |
229 | out_dict = self.layer0(out_dict)
230 | out_dict = self.layer1(out_dict)
231 | out_dict = self.layer2(out_dict)
232 | out_dict = self.layer3(out_dict)
233 | out_dict = self.layer4(out_dict)
234 | x = out_dict['features']
235 | x = self.conv1(x)
236 | return x
237 |
238 |
239 | class ConditionedDecoderBlock(nn.Module):
240 | """
241 | ResNet block, but convs replaced with resize convs, and channel increase is in
242 | second conv, not first.
243 | Also heavily borrowed from pl_bolts.models.autoencoders.components.
244 | """
245 |
246 | expansion = 1
247 |
248 | def __init__(self, inplanes, planes, scale=1, upsample=None, condition_dim=16):
249 | super().__init__()
250 | self.conv1 = resize_conv3x3(inplanes + condition_dim,
251 | inplanes) # 2 is the feature dimension for the conditioning
252 | self.bn1 = nn.BatchNorm2d(inplanes)
253 | self.relu = nn.ReLU(inplace=True)
254 | self.conv2 = resize_conv3x3(inplanes, planes, scale)
255 | self.bn2 = nn.BatchNorm2d(planes)
256 | self.upsample = upsample
257 |
258 | self.interpolation_mode = "bilinear"
259 |
260 | def forward(self, data):
261 | x = data["features"]
262 | condition = data["condition"]
263 |
264 | condition_scaled = torch.nn.functional.interpolate(condition, x.shape[-2:], mode=self.interpolation_mode)
265 | identity = x
266 |
267 | out = torch.cat([x, condition_scaled], 1) # Along the channel dimension
268 | out = self.conv1(out)
269 | out = self.bn1(out)
270 | out = self.relu(out)
271 |
272 | out = self.conv2(out)
273 | out = self.bn2(out)
274 |
275 | if self.upsample is not None:
276 | identity = self.upsample(x)
277 |
278 | out += identity
279 | out = self.relu(out)
280 |
281 | out_dict = {"features": out, "condition": condition}
282 | return out_dict
283 |
--------------------------------------------------------------------------------
/varitex/modules/discriminator.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.nn.utils.spectral_norm as spectral_norm
10 | from torch.nn import init
11 |
12 |
13 | # Returns a function that creates a normalization function
14 | # that does not condition on semantic map
15 | def get_norm_layer(norm_type='instance'):
16 | # helper function to get # output channels of the previous layer
17 | def get_out_channel(layer):
18 | if hasattr(layer, 'out_channels'):
19 | return getattr(layer, 'out_channels')
20 | return layer.weight.size(0)
21 |
22 | # this function will be returned
23 | def add_norm_layer(layer):
24 | nonlocal norm_type
25 | if norm_type.startswith('spectral'):
26 | layer = spectral_norm(layer)
27 | subnorm_type = norm_type[len('spectral'):]
28 |
29 | if subnorm_type == 'none' or len(subnorm_type) == 0:
30 | return layer
31 |
32 | # remove bias in the previous layer, which is meaningless
33 | # since it has no effect after normalization
34 | if getattr(layer, 'bias', None) is not None:
35 | delattr(layer, 'bias')
36 | layer.register_parameter('bias', None)
37 |
38 | if subnorm_type == 'batch':
39 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
40 | # elif subnorm_type == 'sync_batch':
41 | # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
42 | elif subnorm_type == 'instance':
43 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
44 | else:
45 | raise ValueError('normalization layer %s is not recognized' % subnorm_type)
46 |
47 | return nn.Sequential(layer, norm_layer)
48 |
49 | return add_norm_layer
50 |
51 |
52 | class BaseNetwork(nn.Module):
53 | def __init__(self):
54 | super(BaseNetwork, self).__init__()
55 |
56 | @staticmethod
57 | def modify_commandline_options(parser, is_train):
58 | return parser
59 |
60 | def print_network(self):
61 | if isinstance(self, list):
62 | self = self[0]
63 | num_params = 0
64 | for param in self.parameters():
65 | num_params += param.numel()
66 | print('Network [%s] was created. Total number of parameters: %.1f million. '
67 | 'To see the architecture, do print(network).'
68 | % (type(self).__name__, num_params / 1000000))
69 |
70 | def init_weights(self, init_type='normal', gain=0.02):
71 | def init_func(m):
72 | classname = m.__class__.__name__
73 | if classname.find('BatchNorm2d') != -1:
74 | if hasattr(m, 'weight') and m.weight is not None:
75 | init.normal_(m.weight.data, 1.0, gain)
76 | if hasattr(m, 'bias') and m.bias is not None:
77 | init.constant_(m.bias.data, 0.0)
78 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
79 | if init_type == 'normal':
80 | init.normal_(m.weight.data, 0.0, gain)
81 | elif init_type == 'xavier':
82 | init.xavier_normal_(m.weight.data, gain=gain)
83 | elif init_type == 'xavier_uniform':
84 | init.xavier_uniform_(m.weight.data, gain=1.0)
85 | elif init_type == 'kaiming':
86 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
87 | elif init_type == 'orthogonal':
88 | init.orthogonal_(m.weight.data, gain=gain)
89 | elif init_type == 'none': # uses pytorch's default init method
90 | m.reset_parameters()
91 | else:
92 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
93 | if hasattr(m, 'bias') and m.bias is not None:
94 | init.constant_(m.bias.data, 0.0)
95 |
96 | self.apply(init_func)
97 |
98 | # propagate to children
99 | for m in self.children():
100 | if hasattr(m, 'init_weights'):
101 | m.init_weights(init_type, gain)
102 |
103 |
104 | class MultiscaleDiscriminator(BaseNetwork):
105 |
106 | def __init__(self, opt):
107 | super().__init__()
108 | self.opt = opt
109 |
110 | for i in range(opt.num_discriminator):
111 | subnetD = NLayerDiscriminator(opt)
112 | self.add_module('discriminator_%d' % i, subnetD)
113 |
114 | def downsample(self, input):
115 | return F.avg_pool2d(input, kernel_size=3,
116 | stride=2, padding=[1, 1],
117 | count_include_pad=False)
118 |
119 | # Returns list of lists of discriminator outputs.
120 | # The final result is of size opt.num_D x opt.n_layers_D
121 | def forward(self, input):
122 | result = []
123 | for name, D in self.named_children():
124 | out = D(input)
125 | result.append(out)
126 | input = self.downsample(input)
127 | return result
128 |
129 |
130 | # Defines the PatchGAN discriminator with the specified arguments.
131 | class NLayerDiscriminator(BaseNetwork):
132 |
133 | def __init__(self, opt):
134 | super().__init__()
135 | self.opt = opt
136 |
137 | kw = 4
138 | padw = int(np.ceil((kw - 1.0) / 2))
139 | nf = opt.nc_discriminator
140 | input_nc = 3 # RGB
141 |
142 | norm_layer = get_norm_layer(opt.norm_discriminator)
143 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
144 | nn.LeakyReLU(0.2, False)]]
145 |
146 | for n in range(1, opt.n_layers_discriminator):
147 | nf_prev = nf
148 | nf = min(nf * 2, 512)
149 | stride = 1 if n == opt.n_layers_discriminator - 1 else 2
150 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
151 | stride=stride, padding=padw)),
152 | nn.LeakyReLU(0.2, False)
153 | ]]
154 |
155 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
156 |
157 | # We divide the layers into groups to extract intermediate layer outputs
158 | for n in range(len(sequence)):
159 | self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
160 |
161 | def forward(self, input):
162 | results = [input]
163 | for submodel in self.children():
164 | intermediate_output = submodel(results[-1])
165 | results.append(intermediate_output)
166 | return results[1:]
167 |
--------------------------------------------------------------------------------
/varitex/modules/encoder.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.autoencoders.components import (
2 | resnet18_encoder
3 | )
4 |
5 | from varitex.data.keys_enum import DataItemKey as DIK
6 | from varitex.modules.custom_module import CustomModule
7 |
8 |
9 | class Encoder(CustomModule):
10 | def __init__(self, opt):
11 | super().__init__(opt)
12 | self.encoder = resnet18_encoder(False, False)
13 |
14 | def forward(self, batch, batch_idx):
15 | image = batch[DIK.IMAGE_IN_ENCODE]
16 | encoded = self.encoder(image)
17 | batch[DIK.IMAGE_ENCODED] = encoded
18 | return batch
19 |
--------------------------------------------------------------------------------
/varitex/modules/feature2image.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Sigmoid
2 |
3 | from varitex.data.keys_enum import DataItemKey as DIK
4 | from varitex.modules.custom_module import CustomModule
5 | from varitex.modules.unet import UNet
6 |
7 |
8 | class Feature2ImageRenderer(CustomModule):
9 | def __init__(self, opt):
10 | super().__init__(opt)
11 |
12 | n_input_channels = opt.texture_nc * 2 # We have two feature images: face and additive
13 | self.unet = UNet(output_nc=4, input_channels=n_input_channels, features_start=opt.nc_feature2image,
14 | num_layers=opt.feature2image_num_layers)
15 | self.probability = Sigmoid()
16 |
17 | def forward(self, batch, batch_idx):
18 | texture_enhanced = batch[DIK.FULL_FEATUREIMAGE]
19 | tensor_out = self.unet(texture_enhanced)
20 | mask_out = tensor_out[:, :1] # First channel should be the foreground mask. Note that we keep the dimension
21 | image_out = tensor_out[:, 1:] # RGB image
22 |
23 | # Should be close to 0 for the background
24 | mask_proba = self.probability(mask_out)
25 | batch[DIK.SEGMENTATION_PREDICTED] = mask_proba
26 |
27 | batch[DIK.IMAGE_OUT] = image_out
28 | return batch
29 |
--------------------------------------------------------------------------------
/varitex/modules/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from varitex.data.keys_enum import DataItemKey as DIK
4 | from varitex.modules.decoder import Decoder, AdditiveDecoder
5 | from varitex.modules.encoder import Encoder
6 | from varitex.modules.feature2image import Feature2ImageRenderer
7 | from varitex.modules.custom_module import CustomModule
8 |
9 |
10 | class Generator(CustomModule):
11 | def __init__(self, opt):
12 | super().__init__(opt)
13 |
14 | self.MASK_VALUE = opt.uv_mask_value if hasattr(opt, "uv_mask_value") else 0
15 |
16 | enc_out_dim = 512
17 | latent_dim = opt.latent_dim
18 |
19 | self.fc_mu = torch.nn.Linear(enc_out_dim, latent_dim)
20 | self.fc_var = torch.nn.Linear(enc_out_dim, latent_dim)
21 | self.encoder = Encoder(opt)
22 | self.decoder = Decoder(opt)
23 |
24 | self.decoder_exterior = AdditiveDecoder(opt)
25 | self.texture2image = Feature2ImageRenderer(opt)
26 |
27 | def forward(self, batch, batch_idx, std_multiplier=1):
28 | batch = self.forward_encode(batch, batch_idx) # Only encoding, not yet a distribution
29 | batch = self.forward_encoded2latent_distribution(batch) # Compute mu and std
30 | batch = self.forward_sample_style(batch, batch_idx, std_multiplier=std_multiplier) # Sample a latent code
31 | batch = self.forward_latent2image(batch, batch_idx) # Decoders for face and exterior, followed by rendering
32 | return batch
33 |
34 | def forward_encode(self, batch, batch_idx):
35 | # Only encode the image, adds DIK.IMAGE_ENCODED.
36 | # This is not yet the latent distribution.
37 | batch = self.encoder(batch, batch_idx)
38 | return batch
39 |
40 | def forward_encoded2latent_distribution(self, batch):
41 | # Computes the latent distribution from the encoded image.
42 | mu, log_var = self.fc_mu(batch[DIK.IMAGE_ENCODED]), self.fc_var(batch[DIK.IMAGE_ENCODED])
43 | std = torch.exp(log_var / 2)
44 |
45 | batch[DIK.STYLE_LATENT_MU] = mu
46 | batch[DIK.STYLE_LATENT_STD] = std
47 | return batch
48 |
49 | def forward_sample_style(self, batch, batch_idx, std_multiplier=1):
50 | # Sample the latent code z from the given distribution.
51 | mu, std = batch[DIK.STYLE_LATENT_MU], batch[DIK.STYLE_LATENT_STD]
52 |
53 | q = torch.distributions.Normal(mu, std * std_multiplier)
54 | z = q.rsample()
55 | batch[DIK.STYLE_LATENT] = z
56 | return batch
57 |
58 | def forward_latent2image(self, batch, batch_idx):
59 | # Given a latent code, render an image.
60 | batch = self.forward_latent2featureimage(batch, batch_idx)
61 | # Neural rendering
62 | batch = self.texture2image(batch, batch_idx)
63 | # Note that we do not mask the output. The network should learn to produce 0 values for the background.
64 | return batch
65 |
66 | def forward_latent2featureimage(self, batch, batch_idx):
67 | """
68 | Given the full latent code DIK.STYLE_LATENT, process both interior and exterior region and generate
69 | both feature images: the face feature image and the additive feature image.
70 | """
71 | n = self.opt.latent_dim // 2
72 | z_interior = batch[DIK.STYLE_LATENT][:, :n]
73 | z_exterior = batch[DIK.STYLE_LATENT][:, n:]
74 |
75 | batch[DIK.LATENT_EXTERIOR] = z_exterior
76 | batch[DIK.LATENT_INTERIOR] = z_interior
77 |
78 | batch = self.forward_latent2texture_interior(batch, batch_idx)
79 | batch = self.forward_latent2additive_featureimage(batch, batch_idx)
80 | batch = self.forward_merge_textures(batch, batch_idx)
81 | return batch
82 |
83 | def forward_latent2texture_interior(self, batch, batch_idx):
84 | batch = self.forward_decoder_interior(batch, batch_idx)
85 | # Sampling texture using the UV map.
86 | batch = self.sample_texture(batch)
87 | return batch
88 |
89 | def forward_latent2additive_featureimage(self, batch, batch_idx):
90 | batch = self.decoder_exterior(batch, batch_idx)
91 | return batch
92 |
93 | def forward_merge_textures(self, batch, batch_idx):
94 | batch[DIK.FULL_FEATUREIMAGE] = torch.cat(
95 | [batch[DIK.FACE_FEATUREIMAGE], batch[DIK.ADDITIVE_FEATUREIMAGE]], 1)
96 | return batch
97 |
98 | def forward_decoder_interior(self, batch, batch_idx):
99 | batch = self.decoder(batch, batch_idx)
100 | return batch
101 |
102 | def sample_texture(self, batch):
103 | uv_texture = batch[DIK.TEXTURE_PERSON]
104 | uv_map = batch[DIK.UV_RENDERED]
105 |
106 | texture_sampled = torch.nn.functional.grid_sample(uv_texture, uv_map, mode='bilinear',
107 | padding_mode='border', align_corners=False)
108 |
109 | # Grid sample yields the same value for not rendered region
110 | # We mask that region
111 | batch[DIK.MASK_UV] = torch.logical_or(batch[DIK.UV_RENDERED][:, :, :, 1] != -1,
112 | batch[DIK.UV_RENDERED][:, :, :, 0] != -1).unsqueeze(1)
113 | mask = batch[DIK.MASK_UV].expand_as(texture_sampled)
114 | texture_sampled[~mask] = self.MASK_VALUE
115 |
116 | batch[DIK.FACE_FEATUREIMAGE] = texture_sampled
117 | return batch
118 |
--------------------------------------------------------------------------------
/varitex/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | from mutil.pytorch_utils import ImageNetNormalizeTransformInverse
4 | from torch.nn import functional as F
5 |
6 |
7 | def kl_divergence(mu, std):
8 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
9 | q = torch.distributions.Normal(mu, std)
10 | kl = torch.distributions.kl.kl_divergence(p, q).mean(1)
11 | return kl
12 |
13 |
14 | def reconstruction_loss(image_fake, image_real):
15 | # Simple L1 for now
16 | return torch.mean(torch.abs(image_fake - image_real))
17 |
18 |
19 | def l2_loss(image_fake, image_real):
20 | return torch.mean((image_fake - image_real) ** 2)
21 |
22 |
23 | class VGG16(torch.nn.Module):
24 | def __init__(self, path_weights):
25 | super().__init__()
26 | vgg16 = torchvision.models.vgg16(pretrained=False)
27 | import h5py
28 | with h5py.File(path_weights, 'r') as f:
29 | state_dict = self.get_keras_mapping(f)
30 | vgg16.load_state_dict(state_dict, strict=False) # Ignore the missing keys for the classifier
31 | vgg_pretrained_features = vgg16.features
32 |
33 | # 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
34 | self.slice1 = torch.nn.Sequential()
35 | self.slice2 = torch.nn.Sequential()
36 | self.slice3 = torch.nn.Sequential()
37 | self.slice4 = torch.nn.Sequential()
38 |
39 | self.slice1.add_module(str(0), vgg_pretrained_features[0])
40 | for i in range(1, 3):
41 | self.slice2.add_module(str(i), vgg_pretrained_features[i])
42 | for i in range(3, 22):
43 | self.slice3.add_module(str(i), vgg_pretrained_features[i])
44 | for i in range(22, 31):
45 | self.slice4.add_module(str(i), vgg_pretrained_features[i])
46 |
47 | def get_keras_mapping(self, f):
48 | # VGG 16 config D
49 | state_dict = {
50 | "features.0.weight": f["conv1_1"]["conv1_1_1"]["kernel:0"],
51 | "features.0.bias": f["conv1_1"]["conv1_1_1"]["bias:0"],
52 |
53 | "features.2.weight": f["conv1_2"]["conv1_2_1"]["kernel:0"],
54 | "features.2.bias": f["conv1_2"]["conv1_2_1"]["bias:0"],
55 |
56 | "features.5.weight": f["conv2_1"]["conv2_1_1"]["kernel:0"],
57 | "features.5.bias": f["conv2_1"]["conv2_1_1"]["bias:0"],
58 |
59 | "features.7.weight": f["conv2_2"]["conv2_2_1"]["kernel:0"],
60 | "features.7.bias": f["conv2_2"]["conv2_2_1"]["bias:0"],
61 |
62 | "features.10.weight": f["conv3_1"]["conv3_1_1"]["kernel:0"],
63 | "features.10.bias": f["conv3_1"]["conv3_1_1"]["bias:0"],
64 |
65 | "features.12.weight": f["conv3_2"]["conv3_2_1"]["kernel:0"],
66 | "features.12.bias": f["conv3_2"]["conv3_2_1"]["bias:0"],
67 |
68 | "features.14.weight": f["conv3_3"]["conv3_3_1"]["kernel:0"],
69 | "features.14.bias": f["conv3_3"]["conv3_3_1"]["bias:0"],
70 |
71 | "features.17.weight": f["conv4_1"]["conv4_1_1"]["kernel:0"],
72 | "features.17.bias": f["conv4_1"]["conv4_1_1"]["bias:0"],
73 |
74 | "features.19.weight": f["conv4_2"]["conv4_2_1"]["kernel:0"],
75 | "features.19.bias": f["conv4_2"]["conv4_2_1"]["bias:0"],
76 |
77 | "features.21.weight": f["conv4_3"]["conv4_3_1"]["kernel:0"],
78 | "features.21.bias": f["conv4_3"]["conv4_3_1"]["bias:0"],
79 |
80 | "features.24.weight": f["conv5_1"]["conv5_1_1"]["kernel:0"],
81 | "features.24.bias": f["conv5_1"]["conv5_1_1"]["bias:0"],
82 |
83 | "features.26.weight": f["conv5_2"]["conv5_2_1"]["kernel:0"],
84 | "features.26.bias": f["conv5_2"]["conv5_2_1"]["bias:0"],
85 |
86 | "features.28.weight": f["conv5_3"]["conv5_3_1"]["kernel:0"],
87 | "features.28.bias": f["conv5_3"]["conv5_3_1"]["bias:0"],
88 | }
89 | # keras: [3, 3, 3, 64])
90 | # pytorch: torch.Size([64, 3, 3, 3]).
91 | # Keras stores weights in the order (kernel_size, kernel_size, input_dim, output_dim),
92 | # but pytorch expects (output_dim, input_dim, kernel_size, kernel_size)
93 | # We need to transpose them https://discuss.pytorch.org/t/how-to-convert-keras-model-to-pytorch-and-run-inference-in-c-correctly/93451/3
94 | state_dict = {k: torch.Tensor(v[:].transpose()) for k, v in state_dict.items()}
95 | return state_dict
96 |
97 | def forward(self, x):
98 | x = F.interpolate(x, size=224, mode="bilinear")
99 | h_conv1_1 = self.slice1(x)
100 | h_conv1_2 = self.slice2(h_conv1_1)
101 | h_conv3_2 = self.slice3(h_conv1_2)
102 | h_conv4_2 = self.slice4(h_conv3_2)
103 | out = [h_conv1_1, h_conv1_2, h_conv3_2, h_conv4_2]
104 | return out
105 |
106 |
107 | class VGG19(torch.nn.Module):
108 | def __init__(self):
109 | super().__init__()
110 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
111 | self.slice1 = torch.nn.Sequential()
112 | self.slice2 = torch.nn.Sequential()
113 | self.slice3 = torch.nn.Sequential()
114 | self.slice4 = torch.nn.Sequential()
115 | self.slice5 = torch.nn.Sequential()
116 | for x in range(2):
117 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
118 | for x in range(2, 7):
119 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
120 | for x in range(7, 12):
121 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
122 | for x in range(12, 21):
123 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
124 | for x in range(21, 30):
125 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
126 |
127 | def forward(self, X):
128 | h_relu1 = self.slice1(X)
129 | h_relu2 = self.slice2(h_relu1)
130 | h_relu3 = self.slice3(h_relu2)
131 | h_relu4 = self.slice4(h_relu3)
132 | h_relu5 = self.slice5(h_relu4)
133 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
134 | return out
135 |
136 |
137 | # Perceptual loss that uses a pretrained VGG network
138 | class ImageNetVGG19Loss(torch.nn.Module):
139 | def __init__(self):
140 | super(ImageNetVGG19Loss, self).__init__()
141 | self.vgg = VGG19()
142 | # if gpu_ids:
143 | # self.vgg.cuda()
144 | self.criterion = torch.nn.L1Loss()
145 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
146 |
147 | for param in self.parameters():
148 | param.requires_grad = False
149 |
150 | def forward(self, fake, real):
151 | x_vgg, y_vgg = self.vgg(fake), self.vgg(real)
152 | loss = 0
153 | for i in range(len(x_vgg)):
154 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
155 | # loss = self.criterion(x_vgg[i], y_vgg[i].detach())
156 | # print(list(x_vgg[i].shape), loss.data, (self.weights[i] * loss).data, self.weights[i])
157 | return loss
158 |
159 |
160 | # Perceptual loss that uses a pretrained VGG network
161 | class FaceRecognitionVGG16Loss(torch.nn.Module):
162 | def __init__(self, path_weights):
163 | super().__init__()
164 | self.vgg = VGG16(path_weights)
165 | self.criterion = torch.nn.MSELoss()
166 | self.weights = [0.25, 0.25, 0.25, 0.25]
167 |
168 | self.unnormalize = ImageNetNormalizeTransformInverse()
169 | # Mean values of face images from the VGGFace paper
170 | self.normalize_values = torch.Tensor((93.5940, 104.7624, 129.1863)).unsqueeze(0).unsqueeze(2).unsqueeze(2)
171 |
172 | for param in self.parameters():
173 | param.requires_grad = False
174 |
175 | def preprocess(self, tensor):
176 | tensor = self.unnormalize(tensor) # Is now in the range [0, 1] (if not under-/ overshooting)
177 | self.normalize_values = self.normalize_values.to(tensor.device)
178 | tensor = ((tensor * 255) - self.normalize_values.expand_as(tensor))
179 | # tensor = tensor - self.normalize_values
180 | # tensor =
181 | return tensor / 127.5 # Should we do this?
182 |
183 | def forward(self, fake, real):
184 | fake = self.preprocess(fake)
185 | real = self.preprocess(real)
186 | x_vgg, y_vgg = self.vgg(fake), self.vgg(real)
187 |
188 | loss = 0
189 | for i in range(len(x_vgg)):
190 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
191 | # loss = self.criterion(x_vgg[i], y_vgg[i].detach())
192 | # print(list(x_vgg[i].shape), loss.data, (self.weights[i] * loss).data, self.weights[i])
193 | return loss
194 |
195 |
196 | """
197 | Code below:
198 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
199 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
200 | """
201 |
202 |
203 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
204 | # When LSGAN is used, it is basically same as MSELoss,
205 | # but it abstracts away the need to create the target label tensor
206 | # that has the same size as the input
207 | class GANLoss(torch.nn.Module):
208 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
209 | tensor=torch.FloatTensor, opt=None):
210 | super(GANLoss, self).__init__()
211 | self.real_label = target_real_label
212 | self.fake_label = target_fake_label
213 | self.real_label_tensor = None
214 | self.fake_label_tensor = None
215 | self.zero_tensor = None
216 | self.Tensor = tensor
217 | self.gan_mode = gan_mode
218 | self.opt = opt
219 | if gan_mode == 'ls':
220 | pass
221 | elif gan_mode == 'original':
222 | pass
223 | elif gan_mode == 'w':
224 | pass
225 | elif gan_mode == 'hinge':
226 | pass
227 | else:
228 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
229 |
230 | def get_target_tensor(self, input, target_is_real):
231 | if target_is_real:
232 | if self.real_label_tensor is None:
233 | self.real_label_tensor = self.Tensor(1).to(input.device).fill_(self.real_label)
234 | self.real_label_tensor.requires_grad_(False)
235 | return self.real_label_tensor.expand_as(input)
236 | else:
237 | if self.fake_label_tensor is None:
238 | self.fake_label_tensor = self.Tensor(1).to(input.device).fill_(self.fake_label)
239 | self.fake_label_tensor.requires_grad_(False)
240 | return self.fake_label_tensor.expand_as(input)
241 |
242 | def get_zero_tensor(self, input):
243 | if self.zero_tensor is None:
244 | self.zero_tensor = self.Tensor(1).fill_(0)
245 | self.zero_tensor.requires_grad_(False)
246 | return self.zero_tensor.expand_as(input).to(input.device)
247 |
248 | def loss(self, input, target_is_real, for_discriminator=True):
249 | if self.gan_mode == 'original': # cross entropy loss
250 | target_tensor = self.get_target_tensor(input, target_is_real)
251 | loss = F.binary_cross_entropy_with_logits(input, target_tensor)
252 | return loss
253 | elif self.gan_mode == 'ls':
254 | target_tensor = self.get_target_tensor(input, target_is_real)
255 | return F.mse_loss(input, target_tensor)
256 | elif self.gan_mode == 'hinge':
257 | if for_discriminator:
258 | if target_is_real:
259 | minval = torch.min(input - 1, self.get_zero_tensor(input))
260 | loss = -torch.mean(minval)
261 | else:
262 | minval = torch.min(-input - 1, self.get_zero_tensor(input))
263 | loss = -torch.mean(minval)
264 | else:
265 | assert target_is_real, "The generator's hinge loss must be aiming for real"
266 | loss = -torch.mean(input)
267 | return loss
268 | else:
269 | # wgan
270 | if target_is_real:
271 | return -input.mean()
272 | else:
273 | return input.mean()
274 |
275 | def __call__(self, input, target_is_real, for_discriminator=True):
276 | # computing loss is a bit complicated because |input| may not be
277 | # a tensor, but list of tensors in case of multiscale discriminator
278 | if isinstance(input, list):
279 | loss = 0
280 | for pred_i in input:
281 | if isinstance(pred_i, list):
282 | pred_i = pred_i[-1]
283 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
284 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
285 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
286 | loss += new_loss
287 | return loss / len(input)
288 | else:
289 | return self.loss(input, target_is_real, for_discriminator)
290 |
--------------------------------------------------------------------------------
/varitex/modules/metrics.py:
--------------------------------------------------------------------------------
1 | import lpips
2 | import torch
3 | from mutil.pytorch_utils import ImageNetNormalizeTransformInverse
4 | from torchmetrics.functional.image import peak_signal_noise_ratio as psnr
5 | from torchmetrics.functional.image import structural_similarity_index_measure as ssim
6 |
7 |
8 | class AbstractMetric(torch.nn.Module):
9 | scale = 255
10 |
11 | def __init__(self):
12 | super().__init__()
13 | self.unnormalize = ImageNetNormalizeTransformInverse(scale=self.scale)
14 |
15 |
16 | class PSNR(AbstractMetric):
17 | scale = 255
18 |
19 | def forward(self, preds, target):
20 | preds = self.unnormalize(preds)
21 | target = self.unnormalize(target)
22 | return psnr(preds, target, data_range=self.scale)
23 |
24 |
25 | class SSIM(AbstractMetric):
26 | scale = 255
27 |
28 | def forward(self, preds, target):
29 | preds = self.unnormalize(preds)
30 | target = self.unnormalize(target)
31 | return ssim(preds, target, data_range=self.scale)
32 |
33 |
34 | class LPIPS(AbstractMetric):
35 | scale = 1
36 |
37 | def __init__(self):
38 | super().__init__()
39 | self.metric = lpips.LPIPS(net='alex', verbose=False)
40 |
41 | def forward(self, preds, target):
42 | preds = self.unnormalize(preds)
43 | target = self.unnormalize(target)
44 | # Might need a .mean()
45 | return self.metric(preds, target, normalize=True).mean() # With normalize, lpips expects the range [0, 1]
46 |
--------------------------------------------------------------------------------
/varitex/modules/pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import BCELoss
3 |
4 | from varitex.data.keys_enum import DataItemKey as DIK
5 | from varitex.modules.discriminator import MultiscaleDiscriminator
6 | from varitex.modules.generator import Generator
7 | from varitex.modules.custom_module import CustomModule
8 | from varitex.modules.loss import ImageNetVGG19Loss, kl_divergence, l2_loss, GANLoss
9 | from varitex.modules.metrics import PSNR, SSIM, LPIPS
10 |
11 |
12 | class PipelineModule(CustomModule):
13 |
14 | def __init__(self, opt):
15 | super().__init__(opt)
16 |
17 | # Includes the full generation pipeline: encoder, texture decoder, additive decoder, neural renderer
18 | self.generator = Generator(opt)
19 |
20 | if getattr(self.opt, "lambda_gan", 0) > 0:
21 | self.discriminator = MultiscaleDiscriminator(opt)
22 | self.criterion_gan = GANLoss(opt.gan_mode)
23 | self.criterion_discriminator_features = torch.nn.L1Loss()
24 | if getattr(self.opt, "lambda_vgg", 0) > 0:
25 | self.loss_vgg = ImageNetVGG19Loss()
26 | if getattr(self.opt, "lambda_segmentation", 0) > 0:
27 | self.criterion_segmentation = BCELoss()
28 |
29 | self.metric_psnr = PSNR()
30 | self.metric_ssim = SSIM()
31 | self.metric_lpips = LPIPS()
32 |
33 | def to_device(self, o, device='cuda'):
34 | if isinstance(o, list):
35 | o = [self.to_device(o_i, device) for o_i in o]
36 | elif isinstance(o, dict):
37 | o = {k: self.to_device(v, device) for k, v in o.items()}
38 | elif isinstance(o, torch.Tensor):
39 | o = o.to(device)
40 | return o
41 |
42 | def forward(self, batch, batch_idx, std_multiplier=1):
43 | batch = self.to_device(batch, self.opt.device)
44 | batch = self.generator(batch, batch_idx, std_multiplier)
45 | return batch
46 |
47 | def training_step(self, batch, batch_idx, optimizer_idx=0):
48 | if optimizer_idx == 0:
49 | loss = self._generator_step(batch, batch_idx)
50 | elif optimizer_idx == 1:
51 | loss = self._discriminator_step(batch, batch_idx)
52 | else:
53 | raise Warning("Invalid optimizer index: {}".format(optimizer_idx))
54 | return loss
55 |
56 | def validation_step(self, batch, batch_idx, std_multiplier=1):
57 | batch = self.forward(batch, batch_idx, std_multiplier=std_multiplier)
58 | fake = batch[DIK.IMAGE_OUT]
59 | real = batch[DIK.IMAGE_IN]
60 |
61 | psnr = self.metric_psnr(fake, real)
62 | ssim = self.metric_ssim(fake, real)
63 | lpips = self.metric_lpips(fake, real)
64 |
65 | self.log_dict({
66 | "val/psnr": psnr,
67 | "val/ssim": ssim,
68 | "val/lpips": lpips
69 | })
70 |
71 | # Below methods simply forward the calls to the generator
72 | def forward_encode(self, batch, batch_idx):
73 | return self.generator.forward_encode(batch, batch_idx)
74 |
75 | def forward_sample_style(self, *args, **kwargs):
76 | return self.generator.forward_sample_style(*args, **kwargs)
77 |
78 | def forward_latent2texture(self, batch, batch_idx):
79 | return self.generator.forward_latent2featureimage(batch, batch_idx)
80 |
81 | def forward_texture2image(self, *args, **kwargs):
82 | return self.generator.texture2image(*args, **kwargs)
83 |
84 | def forward_latent2image(self, *args, **kwargs):
85 | return self.generator.forward_latent2image(*args, **kwargs)
86 |
87 | def forward_interior2image(self, batch, batch_idx):
88 | batch = self.generator.sample_texture(batch)
89 | batch = self.generator.forward_latent2additive_featureimage(batch, batch_idx)
90 | batch = self.generator.forward_merge_textures(batch, batch_idx)
91 | batch = self.forward_texture2image(batch, batch_idx)
92 | return batch
93 |
94 | def configure_optimizers(self):
95 | # We use one optimizer for the generator and one for the discriminator.
96 | optimizers = list()
97 | # Important: Should have index 0
98 | optimizers.append(torch.optim.Adam(self.generator.parameters(), lr=self.opt.lr))
99 | if getattr(self.opt, "lambda_gan", 0) > 0:
100 | # Needs index 1
101 | optimizers.append(torch.optim.Adam(self.discriminator.parameters(),
102 | lr=self.opt.lr_discriminator))
103 | return optimizers, []
104 |
105 | def _generator_step(self, batch, batch_idx):
106 | batch = self.forward(batch, batch_idx)
107 |
108 | loss_gan = 0
109 | loss_gan_features = 0
110 | loss_l2 = 0
111 | loss_vgg = 0
112 | loss_segmentation = 0
113 | loss_rgb_texture = 0
114 |
115 | image_out = batch[DIK.IMAGE_OUT]
116 | image_in = batch[DIK.IMAGE_IN]
117 |
118 | loss_kl = kl_divergence(batch[DIK.STYLE_LATENT_MU], batch[DIK.STYLE_LATENT_STD]).mean()
119 |
120 | if getattr(self.opt, "lambda_gan", 0) > 0:
121 | pred_fake, pred_real = self._forward_discriminate(image_out, image_in)
122 |
123 | loss_gan = self.criterion_gan(pred_fake, True,
124 | for_discriminator=False)
125 |
126 | # Feature loss
127 | num_D = len(pred_fake)
128 | GAN_Feat_loss = torch.FloatTensor(1).fill_(0).to(image_out.device)
129 | for i in range(num_D): # for each discriminator
130 | # last output is the final prediction, so we exclude it
131 | num_intermediate_outputs = len(pred_fake[i]) - 1
132 | for j in range(num_intermediate_outputs): # for each layer output
133 | unweighted_loss = self.criterion_discriminator_features(
134 | pred_fake[i][j], pred_real[i][j].detach())
135 | GAN_Feat_loss += unweighted_loss
136 | loss_gan_features = GAN_Feat_loss / num_D
137 |
138 | if self.opt.lambda_l2 > 0:
139 | loss_l2 = l2_loss(image_out, image_in)
140 |
141 | if self.opt.lambda_vgg > 0:
142 | loss_vgg = self.loss_vgg(image_out.clone(), image_in.clone())
143 |
144 | if self.opt.lambda_segmentation > 0:
145 | loss_segmentation = self.criterion_segmentation(batch[DIK.SEGMENTATION_PREDICTED],
146 | batch[DIK.SEGMENTATION_MASK])
147 |
148 | if self.opt.lambda_rgb_texture > 0:
149 | texture = batch[DIK.FACE_FEATUREIMAGE].clone()[:, :3] # First three dimensions should be RGB only
150 | masked_image_in = image_in.clone()
151 | masked_image_in *= batch[DIK.MASK_UV].expand_as(image_in)
152 | texture *= batch[DIK.MASK_UV].expand_as(texture)
153 | loss_rgb_texture = l2_loss(texture, masked_image_in)
154 |
155 | loss_unweighted = loss_gan + loss_gan_features + loss_l2 + loss_kl + loss_vgg + loss_segmentation + loss_rgb_texture
156 |
157 | loss = self.opt.lambda_gan * loss_gan + \
158 | self.opt.lambda_discriminator_features * loss_gan_features + \
159 | self.opt.lambda_l2 * loss_l2 + \
160 | self.opt.lambda_kl * loss_kl + \
161 | self.opt.lambda_vgg * loss_vgg + \
162 | self.opt.lambda_segmentation * loss_segmentation + \
163 | self.opt.lambda_rgb_texture * loss_rgb_texture
164 |
165 | data_log = {
166 | "train/generator": loss_gan,
167 | "train/gan_features": loss_gan_features,
168 | "train/reconstruction_l2": loss_l2,
169 | "train/kl": loss_kl,
170 | "train/vgg_l1": loss_vgg,
171 | "train/segmentation": loss_segmentation,
172 | "train/rgb_texture": loss_rgb_texture,
173 | "train/loss": loss_unweighted
174 | }
175 | # Filter out zero losses
176 | data_log = {k: v.clone().detach() for k, v in data_log.items() if v != 0}
177 | self.log_dict(data_log)
178 | return loss
179 |
180 | def _discriminator_step(self, batch, batch_idx):
181 | image_real = batch[DIK.IMAGE_IN]
182 | with torch.no_grad():
183 | batch = self.forward(batch, batch_idx)
184 | image_fake = batch[DIK.IMAGE_OUT].detach()
185 | image_fake.requires_grad = True
186 |
187 | fake_pred, real_pred = self._forward_discriminate(image_fake, image_real)
188 |
189 | real_loss = self.criterion_gan(real_pred, True,
190 | for_discriminator=True)
191 | fake_loss = self.criterion_gan(fake_pred, False,
192 | for_discriminator=True)
193 | loss_unweighted = (real_loss + fake_loss) / 2
194 |
195 | loss = self.opt.lambda_gan * loss_unweighted
196 |
197 | self.log_dict({
198 | "train/discriminator": loss_unweighted
199 | })
200 | return loss
201 |
202 | def _forward_discriminate(self, fake_image, real_image):
203 | # This method is from SPADE: https://github.com/NVlabs/SPADE
204 | # In Batch Normalization, the fake and real images are
205 | # recommended to be in the same batch to avoid disparate
206 | # statistics in fake and real images.
207 | # So both fake and real images are fed to D all at once.
208 | fake_and_real = torch.cat([fake_image, real_image], dim=0)
209 | discriminator_out = self.discriminator(fake_and_real) # len(2); one per discriminator
210 |
211 | # Take the prediction of fake and real images from the combined batch
212 | def divide_pred(pred):
213 | # the prediction contains the intermediate outputs of multiscale GAN,
214 | # so it's usually a list
215 | if type(pred) == list:
216 | fake = []
217 | real = []
218 | for p in pred:
219 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
220 | real.append([tensor[tensor.size(0) // 2:] for tensor in p])
221 | else:
222 | fake = pred[:pred.size(0) // 2]
223 | real = pred[pred.size(0) // 2:]
224 |
225 | return fake, real
226 |
227 | pred_fake, pred_real = divide_pred(discriminator_out)
228 |
229 | return pred_fake, pred_real
230 |
--------------------------------------------------------------------------------
/varitex/modules/unet.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is from pl_bolts.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class UNet(nn.Module):
10 | """
11 | Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
12 |