├── .gitignore ├── README.md ├── configs.py ├── datasets.py ├── environment.yml ├── finetune.py ├── finetune_loop.py ├── generators ├── .DS_Store ├── generators.py ├── renderers │ ├── .DS_Store │ ├── manifold_renderer.py │ └── math_utils_torch.py └── representations │ ├── .DS_Store │ └── gram.py ├── optimization.py ├── preprocess_dataset.py ├── rendering_using_finetuned_model.py ├── samples ├── cats │ ├── 00000005_001.png │ ├── 00000009_013.png │ └── poses │ │ ├── 00000005_001_pose.npy │ │ └── 00000009_013_pose.npy └── faces │ ├── 000656.png │ ├── 000990.png │ ├── 097665.png │ ├── R5.png │ ├── mask256 │ ├── 000656.png │ ├── 000990.png │ ├── 097665.png │ └── R5.png │ └── poses │ ├── 000656.mat │ ├── 000990.mat │ ├── 097665.mat │ └── R5.mat └── utils └── arcface ├── __init__.py ├── iresnet.py └── mobilefacenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | pretrained_models/ 3 | experiments/ 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeRFInvertor: High Fidelity NeRF-GAN Inversion for Single-shot Real Image Animation, CVPR'23 2 | 3 | https://github.com/YuYin1/NeRFInvertor/assets/24871206/00e94eeb-efc5-4a61-8570-f2784c741685 4 | 5 | This is an official pytorch implementation of our NeRFInvertor paper: 6 | 7 | Y. Yin, K. Ghasedi, H. Wu, J. Yang, X. Tong, Y. Fu, **NeRFInvertor: High Fidelity NeRF-GAN Inversion for Single-shot Real Image Animation**, IEEE Computer Vision and Pattern Recognition (CVPR), 2023. 8 | 9 | 10 | ### [[Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Yin_NeRFInvertor_High_Fidelity_NeRF-GAN_Inversion_for_Single-Shot_Real_Image_Animation_CVPR_2023_paper.pdf)] [[ArXiv](https://arxiv.org/abs/2211.17235)] [[Project Page](https://yuyin1.github.io/NeRFInvertor_Homepage/)] 11 | 12 | Abstract: _Nerf-based Generative models (NeRF-GANs) have shown impressive capacity in generating high-quality images with consistent 3D geometry. In this paper, we propose a universal method to surgically fine-tune these NeRF-GANs in order to achieve high-fidelity animation of real subjects only by a single image. Given the optimized latent code for an out-of-domain real image, we employ 2D loss functions on the rendered image to reduce the identity gap. Furthermore, our method leverages explicit and implicit 3D regularizations using the in-domain neighborhood samples around the optimized latent code to remove geometrical and visual artifacts._ 13 | 14 | 15 | ## Recent Updates 16 | **2023.06.01:** Inversion of [GRAM](https://github.com/microsoft/GRAM/) 17 | 18 | **TODO:** 19 | - Inversion of [EG3D](https://github.com/NVlabs/eg3d) 20 | - Inversion of [AnifaceGAN](https://yuewuhkust.github.io/AniFaceGAN/) 21 | 22 | ## Requirements 23 | - Currently only Linux is supported. 24 | - 64-bit Python 3.8 installation or newer. We recommend using Anaconda3. 25 | - One or more high-end NVIDIA GPUs, NVIDIA drivers, and CUDA toolkit 10.1 or newer. We recommend using Tesla V100 GPUs with 32 GB memory for training to reproduce the results in the paper. 26 | 27 | ## Installation 28 | Clone the repository and set up a conda environment with all dependencies as follows: 29 | ``` 30 | git clone https://github.com/YuYin1/NeRFInvertor.git 31 | cd NeRFInvertor 32 | conda env create -f environment.yml 33 | source activate nerfinvertor 34 | ``` 35 | 36 | ## Preparation 37 | We provide various auxiliary models needed for NeRF-GAN inversion task. This includes the NeRF-based generators and pre-trained models used for loss computation. 38 | ### Pretrained NeRF-GANs 39 | |Model|Dataset|Resolution|Download| 40 | |:----:|:----:|:-------:|:-----------:| 41 | | GRAM | FFHQ | 256x256 | [Github link](https://github.com/microsoft/GRAM/tree/main/pretrained_models/FFHQ_default) | 42 | | GRAM | Cats | 256x256 | [Github link](https://github.com/microsoft/GRAM/tree/main/pretrained_models/CATS_default) | 43 | | EG3D | FFHQ | 256x256 | [Github link](https://github.com/NVlabs/eg3d/blob/main/docs/models.md) | 44 | | AnifaceGAN | FFHQ | 512x512 | [Github link](https://yuewuhkust.github.io/AniFaceGAN/) | 45 | | arcface |--|--| [Github link](https://drive.google.com/file/d/16t4yUyLlecyYR810WgxVTZF9qtB7NpLo/view?usp=sharing) | 46 | 47 | Models are summarized at [Github link](https://drive.google.com/drive/folders/16DM2qXGnmzY77V9XHBv_o9Ty4BU7q5Kc?usp=sharing). 48 | 49 | ### Prepare Dataset 50 | - Sample dataset: We provide some [sample images](https://github.com/YuYin1/NeRFInvertor/tree/main/samples). 51 | ``` 52 | NeRFInvertor/ 53 | │ 54 | └─── samples/ 55 | │ 56 | └─── faces/ 57 | │ 58 | └─── *.png # original 256x256 images 59 | | 60 | └─── poses/ # estimated face poses 61 | | 62 | └─── *.mat 63 | │ 64 | └─── mask256/ # mask of faces 65 | | 66 | └─── *.png 67 | ``` 68 | - FFHQ or CelebA-HQ: We additionally provide [FFHQ (google drive)](https://drive.google.com/file/d/16qt3imMo1gsAvWrO9T1GqNuCWjxxwxBd/view?usp=sharing) and [CelebA-HQ (google drive)](https://drive.google.com/file/d/16ch326M_9bXQB1I_NpP5Mfb7PY-KJ9fv/view?usp=sharing) datasets for training and evaluation. The dataset includes face images, masks, and face poses. Noted that the face poses is estimated by [Deep3DFaceRecon](https://github.com/sicxu/Deep3DFaceRecon_pytorch). The datasets have the following structure: 69 | ``` 70 | datasets/ 71 | │ 72 | └─── ffhq/ 73 | │ 74 | └─── *.png # original 256x256 images 75 | | 76 | └─── poses/ # estimated face poses 77 | | 78 | └─── *.mat 79 | │ 80 | └─── mask256/ # mask of faces 81 | | 82 | └─── *.png 83 | │ 84 | └─── celebahq/ 85 | ... 86 | ``` 87 | 88 | ### Pretrained NeRFInvertor for sample images 89 | We provide [pretrained NeRFInvertor](https://drive.google.com/drive/folders/16RntgRTD09iqWtIWewrISrOYqdhY05ja?usp=sharing) (i.e., fine-tuned models) for each [samples](https://github.com/YuYin1/NeRFInvertor/tree/main/samples). The folder includes optimized latent codes, fine-tuned models, and inference results (i.e., rendering outputs). 90 | 91 | 92 | ## Inversion 93 | ### Optimize latent codes 94 | In order to invert a real image and edit it you should first align and crop it to the correct size. 95 | Use --name=image_name.png to invert a specific image, otherwise, the following commond will invert all images in img_dir 96 | ``` 97 | python optimization.py \ 98 | --generator_file='pretrained_models/gram/FFHQ_default/generator.pth' \ 99 | --output_dir='experiments/gram/optimization' \ 100 | --data_img_dir='samples/faces/' \ 101 | --data_pose_dir='samples/faces/poses/' \ 102 | --config='FACES_default' \ 103 | --max_iter=1000 104 | ``` 105 | 106 | ### Finetune NeRFGANs 107 | ``` 108 | CUDA_VISIBLE_DEVICES=0,1 python finetune.py \ 109 | --target_names='R1.png+R2.png' \ 110 | --config='FACES_finetune' \ 111 | --output_dir='experiments/gram/finetuned_model/' \ 112 | --data_img_dir='samples/faces/' \ 113 | --data_pose_dir='samples/faces/poses/' \ 114 | --data_emd_dir='experiments/gram/optimization/' \ 115 | --pretrain_model='pretrained_models/gram/FFHQ_default/generator.pth' \ 116 | --load_mask \ 117 | --regulizer_alpha=5 \ 118 | --lambda_id=0.1 \ 119 | --lambda_reg_rgbBefAggregation 10 \ 120 | --lambda_bg_sigma 10 121 | ``` 122 | 123 | ## Inference 124 | ### Rendering results for finetuned models 125 | ``` 126 | CUDA_VISIBLE_DEVICES=0 python rendering_using_finetuned_model.py \ 127 | --generator_file='experiments/gram/finetuned_model/000990/generator.pth' \ 128 | --target_name='000990' \ 129 | --output_dir='experiments/gram/rendering_results/' \ 130 | --data_img_dir='samples/faces/' \ 131 | --data_pose_dir='samples/faces/poses/' \ 132 | --data_emd_dir='experiments/gram/optimization/' \ 133 | --config='FACES_finetune' \ 134 | --image_size 256 \ 135 | --gen_video 136 | ``` 137 | 138 | 148 | 149 | 150 | ## Acknowledgements 151 | This repository structure is based on [GRAM](https://github.com/microsoft/GRAM/) and [PTI](https://github.com/danielroich/PTI) repositories. We thank the authors for their excellent work. 152 | 153 | ## Contact 154 | If you have any questions, please contact Yu Yin (yin.yu1@northeastern.edu). 155 | 156 | ## Citation 157 | @inproceedings{yin2023nerfinvertor, 158 | title={NeRFInvertor: High Fidelity NeRF-GAN Inversion for Single-shot Real Image Animation}, 159 | author={Yin, Yu and Ghasedi, Kamran and Wu, HsiangTao and Yang, Jiaolong and Tong, Xin and Fu, Yun}, 160 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 161 | pages={8539--8548}, 162 | year={2023} 163 | } 164 | 165 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | FACES_finetune = { 4 | 'global': { 5 | 'img_size': 256, 6 | 'batch_size': 1, # batchsize per GPU. We use 8 GPUs by default so that the effective batchsize for an iteration is 4*8=32 7 | 'z_dist': 'gaussian', 8 | }, 9 | 'optimizer': { 10 | 'gen_lr': 2e-5, 11 | 'disc_lr': 2e-4, 12 | 'betas': (0, 0.9), 13 | 'grad_clip': 1., 14 | }, 15 | 'process': { 16 | 'class': 'Gan3DProcess', 17 | 'kwargs': { 18 | 'batch_split': 4, 19 | 'real_pos_lambda': 15., 20 | 'r1_lambda': 1., 21 | 'pos_lambda': 15., 22 | } 23 | }, 24 | 'generator': { 25 | 'class': 'GramGenerator', 26 | 'kwargs': { 27 | 'z_dim': 256, 28 | 'img_size': 256, 29 | 'h_stddev': 0.3, 30 | 'v_stddev': 0.155, 31 | 'h_mean': math.pi*0.5, 32 | 'v_mean': math.pi*0.5, 33 | 'sample_dist': 'gaussian', 34 | }, 35 | 'representation': { 36 | 'class': 'gram', 37 | 'kwargs': { 38 | 'hidden_dim': 256, 39 | 'sigma_clamp_mode': 'softplus', 40 | 'rgb_clamp_mode': 'widen_sigmoid', 41 | 'hidden_dim_sample': 128, 42 | 'layer_num_sample': 3, 43 | 'center': (0, 0, -1.5), 44 | 'init_radius': 0, 45 | }, 46 | }, 47 | 'renderer': { 48 | 'class': 'manifold_renderer', 49 | 'kwargs': { 50 | 'num_samples': 64, 51 | 'num_manifolds': 24, 52 | 'levels_start': 23, 53 | 'levels_end': 8, 54 | 'last_back': False, 55 | 'white_back': False, 56 | 'background': True, 57 | } 58 | } 59 | }, 60 | 'dataset': { 61 | 'class': 'FACES_finetune', 62 | 'kwargs': { 63 | 'img_size': 256, 64 | 'real_pose': True, 65 | } 66 | }, 67 | 'camera': { 68 | 'fov': 12, 69 | 'ray_start': 0.88, 70 | 'ray_end': 1.12, 71 | } 72 | } 73 | 74 | 75 | CATS_finetune = { 76 | 'global': { 77 | 'img_size': 256, 78 | 'batch_size': 1, 79 | 'z_dist': 'gaussian', 80 | }, 81 | 'optimizer': { 82 | 'gen_lr': 2e-5, 83 | 'disc_lr': 2e-4, 84 | 'betas': (0, 0.9), 85 | 'grad_clip': 1., 86 | }, 87 | 'process': { 88 | 'class': 'Gan3DProcess', 89 | 'kwargs': { 90 | 'batch_split': 2, 91 | 'real_pos_lambda': 30., 92 | 'r1_lambda': 1., 93 | 'pos_lambda': 15., 94 | } 95 | }, 96 | 'generator': { 97 | 'class': 'GramGenerator', 98 | 'kwargs': { 99 | 'z_dim': 256, 100 | 'img_size': 256, 101 | 'h_stddev': 0.3, 102 | 'v_stddev': 0.155, 103 | 'h_mean': math.pi*0.5, 104 | 'v_mean': math.pi*0.5, 105 | 'sample_dist': 'gaussian', 106 | }, 107 | 'representation': { 108 | 'class': 'gram', 109 | 'kwargs': { 110 | 'hidden_dim': 256, 111 | 'sigma_clamp_mode': 'softplus', 112 | 'rgb_clamp_mode': 'widen_sigmoid', 113 | 'hidden_dim_sample': 64, 114 | 'layer_num_sample': 3, 115 | 'center': (0, 0, -1.5), 116 | 'init_radius': 0, 117 | }, 118 | }, 119 | 'renderer': { 120 | 'class': 'manifold_renderer', 121 | 'kwargs': { 122 | 'num_samples': 64, 123 | 'num_manifolds': 24, 124 | 'levels_start': 23, 125 | 'levels_end': 8, 126 | 'last_back': False, 127 | 'white_back': False, 128 | 'background': True, 129 | } 130 | } 131 | }, 132 | 'discriminator': { 133 | 'class': 'GramDiscriminator', 134 | 'kwargs': { 135 | 'img_size': 256, 136 | } 137 | }, 138 | 'dataset': { 139 | 'class': 'CATS_finetune', 140 | 'kwargs': { 141 | 'img_size': 256, 142 | 'real_pose': True, 143 | } 144 | }, 145 | 'camera': { 146 | 'fov': 12, 147 | 'ray_start': 0.88, 148 | 'ray_end': 1.12, 149 | } 150 | } 151 | 152 | 153 | FACES_default = { 154 | 'global': { 155 | 'img_size': 256, 156 | 'batch_size': 1, # batchsize per GPU. We use 8 GPUs by default so that the effective batchsize for an iteration is 4*8=32 157 | 'z_dist': 'gaussian', 158 | }, 159 | 'optimizer': { 160 | 'gen_lr': 2e-5, 161 | 'disc_lr': 2e-4, 162 | 'betas': (0, 0.9), 163 | 'grad_clip': 1., 164 | }, 165 | 'process': { 166 | 'class': 'Gan3DProcess', 167 | 'kwargs': { 168 | 'batch_split': 4, 169 | 'real_pos_lambda': 15., 170 | 'r1_lambda': 1., 171 | 'pos_lambda': 15., 172 | } 173 | }, 174 | 'generator': { 175 | 'class': 'GramGenerator', 176 | 'kwargs': { 177 | 'z_dim': 256, 178 | 'img_size': 256, 179 | 'h_stddev': 0.3, 180 | 'v_stddev': 0.155, 181 | 'h_mean': math.pi*0.5, 182 | 'v_mean': math.pi*0.5, 183 | 'sample_dist': 'gaussian', 184 | }, 185 | 'representation': { 186 | 'class': 'gram', 187 | 'kwargs': { 188 | 'hidden_dim': 256, 189 | 'sigma_clamp_mode': 'softplus', 190 | 'rgb_clamp_mode': 'widen_sigmoid', 191 | 'hidden_dim_sample': 128, 192 | 'layer_num_sample': 3, 193 | 'center': (0, 0, -1.5), 194 | 'init_radius': 0, 195 | }, 196 | }, 197 | 'renderer': { 198 | 'class': 'manifold_renderer', 199 | 'kwargs': { 200 | 'num_samples': 64, 201 | 'num_manifolds': 24, 202 | 'levels_start': 23, 203 | 'levels_end': 8, 204 | 'last_back': False, 205 | 'white_back': False, 206 | 'background': True, 207 | } 208 | } 209 | }, 210 | 'dataset': { 211 | 'class': 'FFHQ', 212 | 'kwargs': { 213 | 'img_size': 256, 214 | 'real_pose': True, 215 | } 216 | }, 217 | 'camera': { 218 | 'fov': 12, 219 | 'ray_start': 0.88, 220 | 'ray_end': 1.12, 221 | } 222 | } 223 | 224 | CATS_default = { 225 | 'global': { 226 | 'img_size': 256, 227 | 'batch_size': 1, 228 | 'z_dist': 'gaussian', 229 | }, 230 | 'optimizer': { 231 | 'gen_lr': 2e-5, 232 | 'disc_lr': 2e-4, 233 | 'betas': (0, 0.9), 234 | 'grad_clip': 1., 235 | }, 236 | 'process': { 237 | 'class': 'Gan3DProcess', 238 | 'kwargs': { 239 | 'batch_split': 2, 240 | 'real_pos_lambda': 30., 241 | 'r1_lambda': 1., 242 | 'pos_lambda': 15., 243 | } 244 | }, 245 | 'generator': { 246 | 'class': 'GramGenerator', 247 | 'kwargs': { 248 | 'z_dim': 256, 249 | 'img_size': 256, 250 | 'h_stddev': 0.3, 251 | 'v_stddev': 0.155, 252 | 'h_mean': math.pi*0.5, 253 | 'v_mean': math.pi*0.5, 254 | 'sample_dist': 'gaussian', 255 | }, 256 | 'representation': { 257 | 'class': 'gram', 258 | 'kwargs': { 259 | 'hidden_dim': 256, 260 | 'sigma_clamp_mode': 'softplus', 261 | 'rgb_clamp_mode': 'widen_sigmoid', 262 | 'hidden_dim_sample': 64, 263 | 'layer_num_sample': 3, 264 | 'center': (0, 0, -1.5), 265 | 'init_radius': 0, 266 | }, 267 | }, 268 | 'renderer': { 269 | 'class': 'manifold_renderer', 270 | 'kwargs': { 271 | 'num_samples': 64, 272 | 'num_manifolds': 24, 273 | 'levels_start': 23, 274 | 'levels_end': 8, 275 | 'last_back': False, 276 | 'white_back': False, 277 | 'background': True, 278 | } 279 | } 280 | }, 281 | 'discriminator': { 282 | 'class': 'GramDiscriminator', 283 | 'kwargs': { 284 | 'img_size': 256, 285 | } 286 | }, 287 | 'dataset': { 288 | 'class': 'CATS', 289 | 'kwargs': { 290 | 'img_size': 256, 291 | 'real_pose': True, 292 | } 293 | }, 294 | 'camera': { 295 | 'fov': 12, 296 | 'ray_start': 0.88, 297 | 'ray_end': 1.12, 298 | } 299 | } 300 | 301 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import DataLoader, Dataset 5 | from torchvision import datasets 6 | import torchvision.transforms as transforms 7 | from torchvision.transforms import functional as F 8 | import torchvision 9 | import glob 10 | import PIL 11 | import math 12 | import numpy as np 13 | import zipfile 14 | import time 15 | from scipy.io import loadmat 16 | 17 | def read_pose(name,flip=False): 18 | P = loadmat(name)['angle'] 19 | P_x = -(P[0,0] - 0.1) + math.pi/2 20 | if not flip: 21 | P_y = P[0,1] + math.pi/2 22 | else: 23 | P_y = -P[0,1] + math.pi/2 24 | 25 | P = torch.tensor([P_x,P_y],dtype=torch.float32) 26 | 27 | return P 28 | 29 | def read_pose_npy(name,flip=False): 30 | P = np.load(name) 31 | P_x = P[0] + 0.14 32 | if not flip: 33 | P_y = P[1] 34 | else: 35 | P_y = -P[1] + math.pi 36 | 37 | P = torch.tensor([P_x,P_y],dtype=torch.float32) 38 | 39 | return P 40 | 41 | def read_latents_txt(name, device="cpu"): 42 | # load the latent codes for id, expression and so on. 43 | 44 | ''' 45 | the data structure of ffhq_pose 46 | latents: noise 47 | ''' 48 | latents = np.loadtxt(name) 49 | latents = torch.from_numpy(latents).float() #.unsqueeze(0).to(device) 50 | 51 | return latents 52 | 53 | def transform_matrix_to_camera_pos(c2w,flip=False): 54 | """ 55 | Get camera position with transform matrix 56 | 57 | :param c2w: camera to world transform matrix 58 | :return: camera position on spherical coord 59 | """ 60 | 61 | c2w[[0,1,2]] = c2w[[1,2,0]] 62 | pos = c2w[:, -1].squeeze() 63 | radius = float(np.linalg.norm(pos)) 64 | theta = float(np.arctan2(-pos[0], pos[2])) 65 | phi = float(np.arctan(-pos[1] / np.linalg.norm(pos[::2]))) 66 | theta = theta + np.pi * 0.5 67 | phi = phi + np.pi * 0.5 68 | if flip: 69 | theta = -theta + math.pi 70 | P = torch.tensor([phi,theta],dtype=torch.float32) 71 | return P 72 | 73 | class CATS_finetune(Dataset): 74 | def __init__(self, opt, img_size, **kwargs): 75 | super().__init__() 76 | imgname = opt.target_name 77 | 78 | self.img_size = img_size 79 | self.real_pose = False 80 | if 'real_pose' in kwargs and kwargs['real_pose'] == True: 81 | self.real_pose = True 82 | 83 | for i in range(10): 84 | try: 85 | self.data = glob.glob(os.path.join(opt.data_img_dir, f'{imgname}.png')) 86 | assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset" 87 | if self.real_pose: 88 | self.pose = [os.path.join(opt.data_pose_dir, f.split('/')[-1].replace('.png','_pose.npy')) for f in self.data] 89 | break 90 | except: 91 | print('failed to load dataset, try %02d times'%i) 92 | time.sleep(0.5) 93 | self.transform = transforms.Compose( 94 | [transforms.Resize((img_size, img_size), interpolation=1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 95 | 96 | self.opt_pose = False 97 | if opt.data_emd_dir.find('pose') > 0: 98 | self.opt_pose = True 99 | self.pose = [os.path.join(opt.data_emd_dir, f'{imgname}/{opt.target_inv_epoch}_pose_.txt')] 100 | 101 | self.emd = [os.path.join(opt.data_emd_dir, f'{imgname}/{opt.target_inv_epoch}_.txt')] 102 | self.green_bg = opt.green_bg 103 | self.load_mat = opt.load_mat 104 | if self.green_bg or self.load_mat: 105 | self.mat = [] 106 | for img in self.data: 107 | split = img.split("/") 108 | self.mat.append(img.replace(split[-1], f"mat256/{split[-1]}")) 109 | self.transform_mat = transforms.Compose([transforms.Resize((img_size, img_size), interpolation=1), transforms.ToTensor()]) 110 | 111 | def __len__(self): 112 | return len(self.data) 113 | 114 | def __getitem__(self, index): 115 | X = PIL.Image.open(self.data[index]) 116 | if self.green_bg: 117 | mat = PIL.Image.open(self.mat[index]) 118 | # mat.save("mat.png") 119 | mat_np = np.expand_dims(np.array(mat), axis=2) 120 | mat_np = mat_np / 255 121 | X_np = np.array(X) 122 | 123 | # green: [0, 177, 64] 124 | X_np = (X_np * mat_np + [0, 177, 64] * (1-mat_np)).astype('uint8') 125 | X = PIL.Image.fromarray(X_np) 126 | # X.save("rgb_mat.png") 127 | 128 | X = self.transform(X) 129 | if self.opt_pose: 130 | P = read_pose_txt(self.pose[index]) # optimized pose 131 | else: 132 | P = read_pose_npy(self.pose[index]) # ori pose 133 | 134 | Z = read_latents_txt(self.emd[index]) 135 | if self.load_mat: 136 | mat = PIL.Image.open(self.mat[index]) 137 | mat = self.transform_mat(mat) 138 | 139 | X = torch.cat((X,mat), 0) 140 | 141 | return X, P, Z 142 | 143 | 144 | class FACES_finetune(Dataset): 145 | def __init__(self, opt, img_size, **kwargs): 146 | super().__init__() 147 | imgname = opt.target_name 148 | 149 | self.img_size = img_size 150 | self.real_pose = False 151 | if 'real_pose' in kwargs and kwargs['real_pose'] == True: 152 | self.real_pose = True 153 | 154 | for i in range(10): 155 | try: 156 | self.data = glob.glob(os.path.join(opt.data_img_dir, f'{imgname}')) 157 | assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset" 158 | if self.real_pose: 159 | self.pose = [os.path.join(opt.data_pose_dir, f.split('/')[-1].replace('png','mat')) for f in self.data] 160 | break 161 | except: 162 | print('failed to load dataset, try %02d times'%i) 163 | print(os.path.join(opt.data_img_dir, f'{imgname}')) 164 | time.sleep(0.5) 165 | self.transform = transforms.Compose( 166 | [transforms.Resize((img_size, img_size), interpolation=1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 167 | 168 | self.opt_pose = False 169 | if opt.data_emd_dir.find('pose') > 0: 170 | self.opt_pose = True 171 | self.pose = [os.path.join(opt.data_emd_dir, f'{imgname.split(".")[0]}/{opt.target_inv_epoch}_pose_.txt')] 172 | 173 | self.emd = [os.path.join(opt.data_emd_dir, f'{imgname.split(".")[0]}/{opt.target_inv_epoch}_.txt')] 174 | self.load_mask = opt.load_mask 175 | if self.load_mask: 176 | self.mat = [] 177 | for img in self.data: 178 | split = img.split("/") 179 | self.mat.append(img.replace(split[-1], f"mask256/{split[-1]}")) 180 | self.transform_mat = transforms.Compose([transforms.Resize((img_size, img_size), interpolation=1), transforms.ToTensor()]) 181 | 182 | def __len__(self): 183 | return len(self.data) 184 | 185 | def __getitem__(self, index): 186 | X = PIL.Image.open(self.data[index]) 187 | 188 | X = self.transform(X) 189 | if self.opt_pose: 190 | P = read_pose_txt(self.pose[index]) # optimized pose 191 | else: 192 | P = read_pose(self.pose[index]) # ori pose 193 | 194 | Z = read_latents_txt(self.emd[index]) 195 | if self.load_mask: 196 | mat = PIL.Image.open(self.mat[index]) 197 | mat = self.transform_mat(mat) 198 | 199 | X = torch.cat((X,mat), 0) 200 | 201 | return X, P, Z 202 | 203 | 204 | def get_dataset(name, subsample=None, batch_size=1, **kwargs): 205 | dataset = globals()[name](**kwargs) 206 | 207 | dataloader = torch.utils.data.DataLoader( 208 | dataset, 209 | batch_size=batch_size, 210 | shuffle=True, 211 | drop_last=True, 212 | pin_memory=False, 213 | num_workers=8 214 | ) 215 | return dataloader, 3 216 | 217 | def get_dataset_(dataset, subsample=None, batch_size=1, **kwargs): 218 | 219 | dataloader = torch.utils.data.DataLoader( 220 | dataset, 221 | batch_size=batch_size, 222 | shuffle=True, 223 | drop_last=True, 224 | pin_memory=False, 225 | num_workers=8 226 | ) 227 | return dataloader, 3 228 | 229 | def get_dataset_distributed(name, world_size, rank, batch_size, **kwargs): 230 | 231 | dataset = globals()[name](**kwargs) 232 | 233 | sampler = torch.utils.data.distributed.DistributedSampler( 234 | dataset, 235 | num_replicas=world_size, 236 | rank=rank, 237 | ) 238 | dataloader = torch.utils.data.DataLoader( 239 | dataset, 240 | sampler=sampler, 241 | batch_size=batch_size, 242 | shuffle=False, 243 | drop_last=True, 244 | pin_memory=False, 245 | num_workers=1, 246 | persistent_workers=True, 247 | ) 248 | 249 | return dataloader, 3 250 | 251 | def get_dataset_distributed_(_dataset, world_size, rank, batch_size, **kwargs): 252 | 253 | sampler = torch.utils.data.distributed.DistributedSampler( 254 | _dataset, 255 | num_replicas=world_size, 256 | rank=rank, 257 | ) 258 | dataloader = torch.utils.data.DataLoader( 259 | _dataset, 260 | sampler=sampler, 261 | batch_size=batch_size, 262 | shuffle=False, 263 | drop_last=True, 264 | pin_memory=False, 265 | num_workers=1, 266 | persistent_workers=True, 267 | ) 268 | 269 | return dataloader, 3 270 | 271 | 272 | if __name__ == '__main__': 273 | import imageio 274 | from tqdm import tqdm 275 | dataset = FACES_finetune(64, **{'real_pose': True}) 276 | dataset, _ = get_dataset_(dataset) 277 | for i, (image, pose) in tqdm(enumerate(dataset)): 278 | print(pose * 180 / np.pi) 279 | imageio.imwrite('test.png', ((image.squeeze().permute(1, 2, 0)*0.5+0.5)*255).type(torch.uint8)) 280 | break 281 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nerfinvertor 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.8.0 8 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0 9 | - torchaudio=0.11.0=py38_cu113 10 | - torchvision=0.12.0=py38_cu113 11 | # - pytorch3d=0.7.2=py38_cu113_pyt1110 12 | - pip: 13 | - matplotlib==3.2.1 14 | - mrcfile==1.3.0 15 | - numpy==1.24.4 16 | - open3d==0.11.2 17 | - opencv-python==4.4.0.44 18 | - pandas==1.1.1 19 | - Pillow==6.0.0 20 | - plotly==4.9.0 21 | - python-dateutil==2.8.1 22 | - pytorch-fid==0.1.1 23 | - PyYAML==5.3.1 24 | - scikit-image==0.21.0 25 | - scikit-learn==1.3.0 26 | - scikit-video==1.1.11 27 | - scipy==1.10.1 28 | - tensorboard==1.15.0 #2.13.0 29 | - torch-fidelity==0.2.0 30 | - tensorboardX==2.2 31 | - sk-video==1.1.10 32 | - lpips==0.1.3 33 | - torch-ema 34 | - https://github.com/podgorskiy/dnnlib/releases/download/0.0.1/dnnlib-0.0.1-py3-none-any.whl 35 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import numpy as np 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | import time 10 | from tqdm import tqdm 11 | from finetune_loop import training_process 12 | 13 | torch.backends.cudnn.benchmark = True 14 | 15 | def synchronize(): 16 | if not dist.is_available(): 17 | return 18 | 19 | if not dist.is_initialized(): 20 | return 21 | 22 | world_size = dist.get_world_size() 23 | 24 | if world_size == 1: 25 | return 26 | 27 | dist.barrier() 28 | 29 | def setup(rank, world_size, port): 30 | os.environ['MASTER_ADDR'] = 'localhost' 31 | os.environ['MASTER_PORT'] = port 32 | 33 | # initialize the process group 34 | torch.cuda.set_device(rank) 35 | # dist.init_process_group("nccl", rank=rank, world_size=world_size) 36 | 37 | dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) 38 | synchronize() 39 | 40 | def cleanup(): 41 | dist.destroy_process_group() 42 | 43 | 44 | def train(rank, world_size, opt): 45 | torch.manual_seed(0) 46 | 47 | setup(rank, world_size, opt.port) # multi_process initialization 48 | device = torch.device(rank) 49 | training_process(rank, world_size, opt, device) # main training loop 50 | cleanup() 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--n_epochs", type=int, default=1000, help="number of epochs of training") # maximum training epochs 56 | parser.add_argument('--output_dir', type=str, default='experiments/gram/rendering_results/') 57 | parser.add_argument('--experiment_name', type=str, default='') 58 | parser.add_argument('--load_dir', type=str, default='../pretrained_models') 59 | parser.add_argument('--data_img_dir', type=str, default='samples/faces/') 60 | parser.add_argument('--data_pose_dir', type=str, default='samples/faces/poses/') 61 | parser.add_argument('--data_emd_dir', type=str, default='experiments/gram/inversion') 62 | parser.add_argument('--pretrain_model', type=str, default='pretrained_models/gram/FFHQ_default/generator.pth') 63 | parser.add_argument('--config', type=str, default='FACES_default') 64 | parser.add_argument('--eval_freq', type=int, default=5000) 65 | parser.add_argument('--save_mesh', type=int, default=1000000) 66 | parser.add_argument('--port', type=str, default='12356') 67 | parser.add_argument('--set_step', type=int, default=None) # set to None if train from scratch 68 | parser.add_argument('--model_save_interval', type=int, default=200) 69 | parser.add_argument('--print_freq', type=int, default=20) 70 | parser.add_argument('--log_freq', type=int, default=20) 71 | parser.add_argument("--sample_interval", type=int, default=50, help="interval between image sampling") # evaluation interval 72 | parser.add_argument('--target_inv_epoch', type=str, default='00999', help='epoch num of inversion') 73 | parser.add_argument('--target_names', type=str, default='') 74 | parser.add_argument('--load_mask', action='store_true', default=False, help='if specificed, ') 75 | 76 | # loss lambda 77 | parser.add_argument('--psi', type=float, default=0.7, help='truncation') 78 | parser.add_argument('--regulizer_alpha', type=float, default=5) 79 | 80 | parser.add_argument('--lambda_loc_reg_l2', type=float, default=1.0) 81 | parser.add_argument('--lambda_loc_reg_perceptual', type=float, default=1.0) 82 | parser.add_argument('--lambda_reg_volumeDensity', type=float, default=0) 83 | parser.add_argument('--lambda_reg_rgbBefAggregation', type=float, default=10) 84 | parser.add_argument('--lambda_reg_sigmaBefAggregation', type=float, default=0) 85 | parser.add_argument('--lambda_bg_sigma', type=float, default=10) 86 | parser.add_argument('--lambda_l2', type=float, default=1) 87 | parser.add_argument('--lambda_perceptual', type=float, default=1.0) 88 | parser.add_argument('--lambda_id', type=float, default=0.1) 89 | 90 | parser.add_argument('--warm_up_deform', type=int, default=2000, help='the warm up iterations for training DIF solely') 91 | parser.add_argument('--switch_interval', type=int, default=3, help='switch inverval between deform net and GRAM, 3 means training deformnet twice and train GRAM once') 92 | parser.add_argument('--gen_gt', action='store_true', help='gen_gt means for BFM, samples points on the rays; otherwise directly use points from BFM for training') 93 | parser.add_argument('--with_smoothness', action='store_true', help='whether use smoothness, need a high memory demand') 94 | parser.add_argument('--debug_mode', action='store_true', help='if specificed, use the debug mode') 95 | parser.add_argument('--real_latents', action='store_true', help='if specificed, use the real latents') 96 | parser.add_argument('--gen_points_threshold', type=float, default=0.00001) 97 | parser.add_argument('--sample_rays', action='store_true', help='whether sample rays during the training of DIFNET') 98 | parser.add_argument('--train_rignerf', action='store_true', help='whether use rignerf methods to train 3dmm guidance') 99 | parser.add_argument('--sample_3dmm', type=float, default=0.5, help='sample how much points on 3DMM face') 100 | parser.add_argument('--generator_model', type=str, default='GRAM', help='the generative model, choose from GRAM or pi-gan') 101 | parser.add_argument('--neutral_ratio', type=float, default=0.1, help='the ratio of input to simulate canonic process') 102 | parser.add_argument('--n_workers', type=int, default=1, help='the workers for dataloader') 103 | parser.add_argument('--deform_backbone', type=str, default='siren', help='the backbone of siren') 104 | 105 | parser.add_argument('--to_gram', type=str, default='v1', help='the backbone of siren') 106 | 107 | # parser = facerecon_params(parser) 108 | opt = parser.parse_args() 109 | 110 | # opt.checkpoints_dir = os.path.join(opt.load_dir, 'FaceRecon_Pytorch/checkpoints') 111 | # opt.bfm_folder = os.path.join(opt.load_dir, 'FaceRecon_Pytorch/BFM') 112 | # opt.init_path = os.path.join(opt.load_dir, 'FaceRecon_Pytorch/checkpoints/init_model/resnet50-0676ba61.pth') 113 | 114 | 115 | # print(opt) 116 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in list(range(torch.cuda.device_count()))) 117 | num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 118 | 119 | print("utilizing %02d gpus"%num_gpus) 120 | opt.target_names = opt.target_names.split('+') 121 | output_dir = opt.output_dir 122 | for target_name in opt.target_names: 123 | if target_name.find("start_from") >= 0: 124 | ## start from # in the dataset 125 | start_ind = int(target_name.split("_")[-1]) 126 | img_paths_all = sorted(os.listdir(opt.data_emd_dir)) 127 | for i, file in enumerate(img_paths_all): 128 | if i < start_ind: 129 | continue 130 | # -------------- modify the output dir 131 | opt.target_name = file 132 | timestr = time.strftime("%Y%m%d-%H%M%S") 133 | opt.output_dir = os.path.join(output_dir, '%s_%s_%s' % (timestr, opt.experiment_name, file)) 134 | os.makedirs(opt.output_dir, exist_ok=True) 135 | 136 | print("*" * 60) 137 | print(f"subject: {opt.target_name} (idx{i})") 138 | print("*" * 60) 139 | mp.spawn(train, args=(num_gpus, opt), nprocs=num_gpus, join=True) 140 | else: 141 | ## use specific target_names 142 | # -------------- modify the output dir 143 | opt.target_name = target_name 144 | timestr = time.strftime("%Y%m%d-%H%M%S") 145 | opt.output_dir = os.path.join(output_dir, '%s_%s_%s' % (timestr, opt.experiment_name, target_name.split(".")[0])) 146 | os.makedirs(opt.output_dir, exist_ok=True) 147 | 148 | print("*" * 60) 149 | print(f"subject: {opt.target_name}") 150 | print("*" * 60) 151 | mp.spawn(train, args=(num_gpus, opt), nprocs=num_gpus, join=True) 152 | # try: 153 | # mp.spawn(train, args=(num_gpus, opt), nprocs=num_gpus, join=True) 154 | # except: 155 | # continue 156 | -------------------------------------------------------------------------------- /finetune_loop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from curses import meta 3 | from dis import dis 4 | from itertools import cycle 5 | from locale import normalize 6 | import os 7 | import sys 8 | from random import triangular 9 | from sqlite3 import PARSE_DECLTYPES 10 | from textwrap import indent 11 | from turtle import pos 12 | 13 | from sklearn.datasets import load_diabetes 14 | from sklearn.metrics import zero_one_loss 15 | from grpc import metadata_call_credentials 16 | import numpy as np 17 | import math 18 | from collections import deque 19 | import torch 20 | import torch.distributed as dist 21 | import torch.multiprocessing as mp 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from torch.nn.parallel import DistributedDataParallel as DDP 25 | from torchvision.utils import save_image 26 | import torchvision.transforms as transforms 27 | import importlib 28 | import time 29 | # import trimesh 30 | # from discriminators import discriminators 31 | # from siren import siren 32 | from generators import generators 33 | import configs 34 | # import fid_evaluation 35 | import datasets 36 | from tqdm import tqdm 37 | from datetime import datetime 38 | import copy 39 | from torch_ema import ExponentialMovingAverage 40 | # import pytorch3d 41 | # from loss import * 42 | from torch.utils.tensorboard import SummaryWriter 43 | import pickle, PIL 44 | from PIL import Image 45 | # import utils 46 | import dnnlib 47 | from utils.arcface import get_model 48 | 49 | torch.backends.cudnn.benchmark = True 50 | 51 | 52 | # sample noises 53 | def z_sampler(shape, device, dist): 54 | if dist == 'gaussian': 55 | z = torch.randn(shape, device=device) 56 | # torch.randn - sample random numbers from a normal distribution with mean 0 and varaiance 1 57 | elif dist == 'uniform': 58 | z = torch.rand(shape, device=device) * 2 - 1 59 | # torch.rand - sample random numbers froma uniform distribution 60 | return z 61 | 62 | ##### --------------------------------------- set the networks --------------------------------------------------- 63 | 64 | 65 | def load_models_for_loss(device, opt): 66 | #for LPIPS loss 67 | if opt.config.find('FACES') >= 0: 68 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 69 | with dnnlib.util.open_url(url) as f: 70 | vgg16 = torch.jit.load(f).eval().to(device) 71 | elif opt.config.find('CATS') >= 0: # CATS, CARLA 72 | import lpips 73 | vgg16 = lpips.LPIPS(net='vgg').eval().to(device) # closer to "traditional" perceptual loss, when used for optimization 74 | print("load vgg for LPIPS loss") 75 | 76 | face_recog = get_model('r50', fp16=False) 77 | face_recog.load_state_dict(torch.load('pretrained_models/arcface.pth')) 78 | print("load face_recog model for ID loss") 79 | id_loss = IDLoss(face_recog.eval()).to(device) 80 | 81 | return vgg16, id_loss 82 | 83 | 84 | # define generator 85 | def set_generator(config, device, opt): 86 | generator_args = {} 87 | if 'representation' in config['generator']: 88 | generator_args['representation_kwargs'] = config['generator']['representation']['kwargs'] 89 | if 'renderer' in config['generator']: 90 | generator_args['renderer_kwargs'] = config['generator']['renderer']['kwargs'] 91 | generator = getattr(generators, config['generator']['class'])( 92 | **generator_args, 93 | **config['generator']['kwargs'] 94 | ) 95 | 96 | print(f"Loaded pretrained network: {opt.pretrain_model}") 97 | if opt.pretrain_model != '': 98 | generator.load_state_dict(torch.load(opt.pretrain_model, map_location='cpu')) 99 | 100 | generator = generator.to(device) 101 | 102 | if opt.pretrain_model != '': 103 | print(f"loaded ema network!") 104 | ema = ExponentialMovingAverage(generator.parameters(), decay=0.999) 105 | ema2 = ExponentialMovingAverage(generator.parameters(), decay=0.9999) 106 | 107 | ema = torch.load(opt.pretrain_model.replace('generator.pth','ema.pth'), map_location=device) 108 | parameters = [p for p in generator.parameters() if p.requires_grad] 109 | ema.copy_to(parameters) 110 | else: 111 | # exponential moving avaerage is to place a greater weight on the most recent data points 112 | ema = ExponentialMovingAverage(generator.parameters(), decay=0.999) 113 | ema2 = ExponentialMovingAverage(generator.parameters(), decay=0.9999) 114 | 115 | return generator, ema, ema2 116 | 117 | 118 | # set IDLoss 119 | class IDLoss(nn.Module): 120 | def __init__(self, facenet): 121 | super(IDLoss, self).__init__() 122 | self.facenet = facenet 123 | 124 | def forward(self, x, y): 125 | x = F.interpolate(x, size=[112, 112], mode='bilinear') 126 | y = F.interpolate(y, size=[112, 112], mode='bilinear') 127 | 128 | # x = 2*(x-0.5) 129 | # y = 2*(y-0.5) 130 | feat_x = self.facenet(x) 131 | feat_y = self.facenet(y.detach()) 132 | 133 | loss = 1 - F.cosine_similarity(feat_x, feat_y, dim=-1) 134 | 135 | return loss 136 | 137 | ##### ------------------------------------------- set the optimizers --------------------------------------------------- 138 | 139 | # define optimizer 140 | def set_optimizer_G(generator_ddp, config, opt): 141 | param_groups = [] 142 | if 'mapping_network_lr' in config['optimizer']: 143 | mapping_network_parameters = [p for n, p in generator_ddp.named_parameters() if 'module.representation.rf_network.mapping_network' in n] 144 | param_groups.append({'params': mapping_network_parameters, 'name': 'mapping_network', 'lr':config['optimizer']['mapping_network_lr']}) 145 | if 'sampling_network_lr' in config['optimizer']: 146 | sampling_network_parameters = [p for n, p in generator_ddp.named_parameters() if 'module.representation.sample_network' in n] 147 | param_groups.append({'params': sampling_network_parameters, 'name': 'sampling_network', 'lr':config['optimizer']['sampling_network_lr']}) 148 | generator_parameters = [p for n, p in generator_ddp.named_parameters() if 149 | ('mapping_network_lr' not in config['optimizer'] or 'module.representation.rf_network.mapping_network' not in n) and 150 | ('sampling_network_lr' not in config['optimizer'] or 'module.representation.sample_network' not in n)] 151 | param_groups.append({'params': generator_parameters, 'name': 'generator'}) 152 | 153 | optimizer_G = torch.optim.Adam(param_groups, lr=config['optimizer']['gen_lr'], betas=config['optimizer']['betas']) 154 | 155 | return optimizer_G 156 | 157 | 158 | def training_step_G(sample_z, sample_pose, input_imgs, zs, real_poses, generator_ddp, ema, ema2, 159 | generator_ori_ddp, vgg16, id_loss, optimizer_G, scaler, config, opt, device): 160 | batch_split = 1 161 | if opt.load_mask: 162 | real_imgs = input_imgs[:, :3, :, :] 163 | mat_imgs = input_imgs[:, 3:, :, :] 164 | else: 165 | real_imgs = input_imgs 166 | bs = zs.size()[0] 167 | split_batch_size = zs.shape[0] // batch_split # minibatch split for memory reduction 168 | img_size = input_imgs.size(-1) 169 | 170 | # --------------------------- interpolate zs and sampled z --------------------------------- 171 | interpolation_direction = sample_z - zs 172 | interpolation_direction_norm = torch.norm(interpolation_direction, p=2) 173 | result_zs = zs + opt.regulizer_alpha * interpolation_direction / interpolation_direction_norm 174 | 175 | gen_imgs_list = [] 176 | losses_dict = {} 177 | for split in range(batch_split): 178 | g_loss = 0 179 | with torch.cuda.amp.autocast(): 180 | subset_z = zs[split * split_batch_size:(split+1) * split_batch_size] 181 | generator_ddp.module.get_avg_w() 182 | gen_imgs, gen_bef_aggr, _ = generator_ddp(subset_z, **config['camera'], detailed_output=True, truncation_psi=opt.psi) 183 | 184 | # --------------------------- loss constraint----------------------- 185 | if opt.lambda_id > 0: 186 | id_l = id_loss(gen_imgs, real_imgs).mean() * opt.lambda_id 187 | g_loss += id_l 188 | losses_dict['id_l'] = id_l 189 | if opt.load_mask and opt.lambda_bg_sigma > 0: 190 | ## force bg sigma to be 0 191 | rgb_sigma = gen_bef_aggr['outputs'] 192 | N_steps = rgb_sigma.size(-2) 193 | mat_imgs = mat_imgs.permute(0, 2, 3, 1).expand(-1, -1, -1, N_steps).reshape(bs, -1, N_steps) 194 | 195 | weights = gen_bef_aggr['weights'].reshape(rgb_sigma.size(0), img_size*img_size, N_steps, 1) 196 | bg_sigma = (1 - mat_imgs[:, :, -1]) * weights[:, :, -1, 0] 197 | l2_bg_sigma = torch.mean((bg_sigma - 1) ** 2) * opt.lambda_bg_sigma 198 | 199 | # # error 4: 200 | # bg_sigma = (1 - mat_imgs) * weights[:, :, :, 0] 201 | # l2_bg_sigma = torch.mean(bg_sigma ** 2) * opt.lambda_bg_sigma 202 | # # error 3: 203 | # bg_sigma = (1 - mat_imgs) * weights[:, :, :, 0] 204 | # l2_bg_sigma = torch.mean((bg_sigma - 1) ** 2) * opt.lambda_bg_sigma 205 | 206 | # bg_sigma = (1 - mat_imgs[:, :, -1]) * weights[:, :, -1, 0] 207 | # l2_bg_sigma = torch.mean((bg_sigma - 1) ** 2) * opt.lambda_bg_sigma 208 | g_loss += l2_bg_sigma 209 | losses_dict['l2_bg_sigma'] = l2_bg_sigma 210 | 211 | # gen_imgs = mat_imgs * gen_imgs 212 | # real_imgs = mat_imgs * real_imgs 213 | 214 | # gen_imgs_bg = (1 - mat_imgs) * gen_imgs 215 | # real_imgs_bg = (1 - mat_imgs) * real_imgs 216 | # if opt.lambda_l2 > 0: 217 | # l2_bg = torch.mean((gen_imgs_bg - real_imgs_bg) ** 2) * opt.lambda_l2 * 0.1 218 | # g_loss += l2_bg 219 | # losses_dict['l2_bg'] = l2_bg 220 | # if opt.lambda_perceptual > 0: 221 | # gen_features_bg = vgg16(127.5 * (gen_imgs_bg + 1), resize_images=False, return_lpips=True) 222 | # real_features_bg = vgg16(127.5 * (real_imgs_bg + 1), resize_images=False, return_lpips=True) 223 | # perceptual_loss_bg = ((1000 * gen_features_bg - 1000 * real_features_bg) ** 2).mean() * opt.lambda_perceptual * 0.1 224 | # g_loss += perceptual_loss_bg 225 | # losses_dict['perceptual_loss_bg'] = perceptual_loss_bg 226 | if opt.lambda_l2 > 0: 227 | l2 = torch.mean((gen_imgs - real_imgs) ** 2) * opt.lambda_l2 228 | ## l2 = nn.MSELoss()(gen_imgs, real_imgs) * opt.lambda_l2 229 | 230 | # img_size = real_imgs.size(-1) 231 | # gen_imgs_d2 = F.upsample(gen_imgs, size=(img_size//2,img_size//2), mode='bilinear') 232 | # real_imgs_d2 = F.upsample(real_imgs, size=(img_size//2,img_size//2), mode='bilinear') 233 | # l2 += torch.mean((gen_imgs_d2 - real_imgs_d2)**2) * opt.lambda_l2 234 | 235 | # gen_imgs_d4 = F.upsample(gen_imgs, size=(img_size//4,img_size//4), mode='bilinear') 236 | # real_imgs_d4 = F.upsample(real_imgs, size=(img_size//4,img_size//4), mode='bilinear') 237 | # l2 += torch.mean((gen_imgs_d4-real_imgs_d4)**2) * opt.lambda_l2 238 | # l2 = l2 / 3.0 239 | 240 | g_loss += l2 241 | losses_dict['l2'] = l2 242 | if opt.lambda_perceptual > 0: 243 | if opt.config.find('FACES') >= 0: 244 | gen_features = vgg16(127.5 * (gen_imgs + 1), resize_images=False, return_lpips=True) 245 | real_features = vgg16(127.5 * (real_imgs + 1), resize_images=False, return_lpips=True) 246 | perceptual_loss = ((1000 * gen_features - 1000 * real_features) ** 2).mean() * opt.lambda_perceptual 247 | 248 | # gen_features_d2 = vgg16(127.5*(gen_imgs_d2+1), resize_images=False, return_lpips=True) 249 | # real_features_d2 = vgg16(127.5*(real_imgs_d2+1), resize_images=False, return_lpips=True) 250 | # perceptual_loss += ((1000*gen_features_d2-1000*real_features_d2)**2).mean() * opt.lambda_perceptual 251 | 252 | # gen_features_d4 = vgg16(127.5*(gen_imgs_d4+1), resize_images=False, return_lpips=True) 253 | # real_features_d4 = vgg16(127.5*(real_imgs_d4+1), resize_images=False, return_lpips=True) 254 | # perceptual_loss += ((1000*gen_features_d4-1000*real_features_d4)**2).mean() * opt.lambda_perceptual 255 | 256 | # perceptual_loss = perceptual_loss / 3.0 257 | 258 | elif opt.config.find('CATS') >= 0: # CATS, CARLA 259 | perceptual_loss = vgg16(gen_imgs, real_imgs).mean() * opt.lambda_perceptual 260 | # perceptual_loss += vgg16(gen_imgs_d2, real_imgs_d2).mean() * opt.lambda_perceptual 261 | # perceptual_loss += vgg16(gen_imgs_d4, real_imgs_d4).mean() * opt.lambda_perceptual 262 | # perceptual_loss = perceptual_loss / 3.0 263 | 264 | g_loss += perceptual_loss 265 | losses_dict['perceptual_loss'] = perceptual_loss 266 | 267 | # --------------------------- loc_regularization ----------------------- 268 | ## ori G 269 | subset_sample_z = result_zs[split * split_batch_size:(split + 1) * split_batch_size] 270 | with torch.no_grad(): 271 | generator_ori_ddp.module.get_avg_w() 272 | sampled_img, sampled_details, gen_positions = generator_ori_ddp(subset_sample_z, **config['camera'], img_size=128, detailed_output=True, truncation_psi=opt.psi) 273 | ## finetuned G 274 | output_updated, details_updated, _ = generator_ddp(subset_sample_z, **config['camera'], img_size=128, camera_pos=gen_positions, detailed_output=True, truncation_psi=opt.psi) 275 | 276 | 277 | if opt.lambda_reg_rgbBefAggregation > 0: 278 | # [0]: pixels, [1]: depth, [2]: weights, [3]: T, [4]: rgb_sigma, [5]: z_vals, [6]: is_valid 279 | sampled_rgb_bef_aggregation = sampled_details['weights'] * sampled_details['outputs'][..., :3] 280 | output_rgb_bef_aggregation = details_updated['weights'] * details_updated['outputs'][..., :3] 281 | reg_rgbBefAggregation = torch.nn.L1Loss()(sampled_rgb_bef_aggregation, output_rgb_bef_aggregation) \ 282 | * opt.lambda_reg_rgbBefAggregation 283 | g_loss += reg_rgbBefAggregation 284 | losses_dict['reg_rgbBefAggregation'] = reg_rgbBefAggregation 285 | if opt.lambda_reg_sigmaBefAggregation > 0: 286 | sampled_sigma_bef_aggregation = sampled_details['outputs'][..., 3:] 287 | output_sigma_bef_aggregation = details_updated['outputs'][..., 3:] 288 | reg_sigmaBefAggregation = torch.nn.L1Loss()(sampled_sigma_bef_aggregation, output_sigma_bef_aggregation) \ 289 | * opt.lambda_reg_sigmaBefAggregation 290 | g_loss += reg_sigmaBefAggregation 291 | losses_dict['reg_sigmaBefAggregation'] = reg_sigmaBefAggregation 292 | if opt.lambda_reg_volumeDensity > 0: 293 | reg_volumeDensity = torch.nn.L1Loss()(details_updated['depth'], sampled_details['depth']) * opt.lambda_reg_volumeDensity 294 | g_loss += reg_volumeDensity 295 | losses_dict['reg_volumeDensity'] = reg_volumeDensity 296 | if opt.lambda_loc_reg_l2 > 0: 297 | reg_l2 = torch.mean((output_updated - sampled_img) ** 2) * opt.lambda_loc_reg_l2 298 | g_loss += reg_l2 299 | losses_dict['reg_l2'] = reg_l2 300 | if opt.lambda_loc_reg_perceptual > 0: 301 | if opt.config.find('FACES') >= 0: 302 | gen_features = vgg16(127.5 * (output_updated + 1), resize_images=False, return_lpips=True) 303 | real_features = vgg16(127.5 * (sampled_img + 1), resize_images=False, return_lpips=True) 304 | reg_perceptual_loss = ((1000 * gen_features - 1000 * real_features) ** 2).mean() * opt.lambda_perceptual 305 | elif opt.config.find('CATS') >= 0: # CATS, CARLA 306 | reg_perceptual_loss = vgg16(output_updated, sampled_img).mean() * opt.lambda_perceptual 307 | g_loss += reg_perceptual_loss 308 | losses_dict['reg_perceptual_loss'] = reg_perceptual_loss 309 | gen_imgs_list.append(output_updated) 310 | gen_imgs_list.append(sampled_img) 311 | gen_imgs_list.append(gen_imgs) 312 | scaler.scale(g_loss).backward() 313 | 314 | scaler.unscale_(optimizer_G) 315 | torch.nn.utils.clip_grad_norm_(generator_ddp.parameters(), config['optimizer'].get('grad_clip', 0.3)) 316 | scaler.step(optimizer_G) 317 | scaler.update() 318 | optimizer_G.zero_grad() 319 | ema.update(generator_ddp.parameters()) 320 | ema2.update(generator_ddp.parameters()) 321 | 322 | loss_list = [ 323 | l2.detach() if opt.lambda_l2 else 0, 324 | perceptual_loss.detach() if opt.lambda_perceptual else 0, 325 | id_l.detach() if opt.lambda_id else 0, 326 | reg_rgbBefAggregation.detach() if opt.lambda_reg_rgbBefAggregation else 0, 327 | reg_sigmaBefAggregation.detach() if opt.lambda_reg_sigmaBefAggregation else 0, 328 | reg_volumeDensity.detach() if opt.lambda_reg_volumeDensity else 0, 329 | reg_l2.detach() if opt.lambda_loc_reg_l2 else 0, 330 | reg_perceptual_loss.detach() if opt.lambda_loc_reg_perceptual else 0, 331 | ] 332 | 333 | return g_loss.detach(), losses_dict, gen_imgs_list 334 | 335 | 336 | def training_process(rank, world_size, opt, device): 337 | # -------------------------------------------------------------------------------------- 338 | # extract training config 339 | config = getattr(configs, opt.config) 340 | if rank == 0: 341 | # print(metadata) 342 | log_dir = opt.output_dir + '/tensorboard/' 343 | os.makedirs(log_dir, exist_ok=True) 344 | writer = SummaryWriter(log_dir, 0) 345 | 346 | # -------------------------------------------------------------------------------------- 347 | # set amp gradient scaler 348 | scaler = torch.cuda.amp.GradScaler() 349 | if config['global'].get('disable_scaler', False): 350 | scaler = torch.cuda.amp.GradScaler(enabled=False) 351 | 352 | 353 | # -------------------------------------------------------------------------------------- 354 | # set LPIPS loss and id loss 355 | vgg16, id_loss = load_models_for_loss(device, opt) 356 | 357 | # -------------------------------------------------------------------------------------- 358 | # set the GRAM generator 359 | generator, ema, ema2 = set_generator(config, device, opt) 360 | generator_ddp = DDP(generator, device_ids=[rank], find_unused_parameters=True) 361 | generator = generator_ddp.module 362 | generator.renderer.lock_view_dependence = True 363 | 364 | if rank == 0: 365 | total_num = sum(p.numel() for p in generator_ddp.parameters()) 366 | trainable_num = sum(p.numel() for p in generator_ddp.parameters() if p.requires_grad) 367 | print('G: Total ', total_num, ' Trainable ', trainable_num) 368 | 369 | generator_ori, _, _ = set_generator(config, device, opt) 370 | generator_ori_ddp = DDP(generator_ori, device_ids=[rank], find_unused_parameters=True) 371 | generator_ori = generator_ori_ddp.module 372 | generator_ori.eval() 373 | 374 | # -------------------------------------------------------------------------------------- 375 | # set optimizers 376 | optimizer_G = set_optimizer_G(generator_ddp, config, opt) 377 | torch.cuda.empty_cache() 378 | generator_losses = [] 379 | 380 | # ---------- 381 | # Training 382 | # ---------- 383 | if rank == 0: 384 | log_file = os.path.join(opt.output_dir, 'logs.txt') 385 | with open(log_file, 'w') as f: 386 | f.write(str(opt)) 387 | f.write('\n\n') 388 | f.write(str(config)) 389 | f.write('\n\n') 390 | f.write(str(generator)) 391 | f.write('\n\n') 392 | 393 | 394 | total_progress_bar = tqdm(total=opt.n_epochs, desc="Total progress", dynamic_ncols=True, disable=True) 395 | torch.manual_seed(3) 396 | 397 | #-------------------------------------------------------------------------------------- 398 | # get dataset 399 | dataset = getattr(datasets, config['dataset']['class'])(opt, **config['dataset']['kwargs']) 400 | dataloader, CHANNELS = datasets.get_dataset_distributed_( 401 | dataset, 402 | world_size, 403 | rank, 404 | config['global']['batch_size'] 405 | ) 406 | 407 | # -------------------------------------------------------------------------------------- 408 | # main training loop 409 | generator_ddp.train() 410 | print("Total num epochs = ", opt.n_epochs) 411 | start_time = time.time() 412 | for epoch in range(opt.n_epochs): 413 | total_progress_bar.update(1) 414 | generator.epoch += 1 415 | # -------------------------------------------------------------------------------------- 416 | # trainging iterations 417 | for i, (imgs, poses, zs) in enumerate(dataloader): 418 | generator.step += 1 419 | zs = zs.to(device) 420 | fixed_z = zs 421 | 422 | real_imgs = imgs.to(device, non_blocking=True) 423 | real_poses = poses.to(device, non_blocking=True) 424 | generator.v_mean = poses[0, 0] 425 | generator.h_mean = poses[0, 1] 426 | generator.h_stddev = generator.v_stddev = 0 427 | 428 | if scaler.get_scale() < 1: 429 | scaler.update(1.) 430 | # TRAIN GENERATOR 431 | ## ------------------------ sample latend codes for regularization ------------------------------ 432 | # sample z 433 | sample_z = z_sampler((1, 256), device=device, dist='gaussian') 434 | # sample pose 435 | yaw = torch.randn((1, 1), device=device) * 0.3 + math.pi * 0.5 436 | pitch = torch.randn((1, 1), device=device) * 0.155 + math.pi * 0.5 437 | yaw = torch.clamp(yaw, math.pi * 0.5 - 1.3, math.pi * 0.5 + 1.3) 438 | pitch = torch.clamp(pitch, math.pi * 0.5 - 1.3, math.pi * 0.5 + 1.3) 439 | sample_pose = torch.cat((pitch, yaw), dim=1) 440 | # sample_pose = poses.deepcopy() 441 | generator_ori.v_mean = sample_pose[0, 0] 442 | generator_ori.h_mean = sample_pose[0, 1] 443 | generator_ori.h_stddev = generator_ori.v_stddev = 0 444 | 445 | g_loss, losses_dict, gen_imgs_list = training_step_G(sample_z, sample_pose, real_imgs, 446 | zs, real_poses, generator_ddp, ema, ema2, generator_ori_ddp, vgg16, id_loss, 447 | optimizer_G, scaler, config, opt, device) 448 | 449 | generator_losses.append(g_loss) 450 | if rank == 0: 451 | # interior_step_bar.update(1) 452 | if (epoch+1) % opt.print_freq == 0: 453 | elapsed = time.time() - start_time 454 | rate = elapsed / (epoch + 1.0) 455 | remaining = (opt.n_epochs - epoch) * rate if rate else 0 456 | out_str = f"[Experiment: {opt.output_dir}]\n[Epoch: {epoch}/{opt.n_epochs}] [Time: " \ 457 | f"{total_progress_bar.format_interval(elapsed)} < "\ 458 | f"{total_progress_bar.format_interval(remaining)}] " 459 | with open(log_file, 'a') as f: 460 | f.write(out_str) 461 | f.write("\n") 462 | print(out_str) 463 | 464 | for loss_key, loss_value in losses_dict.items(): 465 | with open(log_file, 'a') as f: 466 | f.write(f"\t{loss_key}: {loss_value:.4f}\n") 467 | print(f"\t{loss_key}: {loss_value:.4f}") 468 | 469 | if (epoch+1) % opt.log_freq == 0: 470 | for loss_key, loss_value in losses_dict.items(): 471 | writer.add_scalar(f'G/{loss_key}', loss_value, global_step=epoch) 472 | 473 | 474 | # save fixed angle generated images 475 | if (epoch+1) % opt.sample_interval == 0: 476 | save_image(gen_imgs_list[2], os.path.join(opt.output_dir, "%06d_debug.png" % epoch), 477 | nrow=1, normalize=True, value_range=(-1, 1)) 478 | save_image(gen_imgs_list[0], os.path.join(opt.output_dir, "%06d_debug_reg0.png" % epoch), 479 | nrow=1, normalize=True, value_range=(-1, 1)) 480 | save_image(gen_imgs_list[1], os.path.join(opt.output_dir, "%06d_debug_reg1.png" % epoch), 481 | nrow=1, normalize=True, value_range=(-1, 1)) 482 | 483 | ## save model 484 | if (epoch+1) % opt.model_save_interval == 0: 485 | torch.save(ema.state_dict(), os.path.join(opt.output_dir, 'step%06d_ema.pth' % epoch)) 486 | # torch.save(ema2.state_dict(), os.path.join(opt.output_dir, 'step%06d_ema2.pth' % dif_net.step)) 487 | torch.save(generator_ddp.module.state_dict(), 488 | os.path.join(opt.output_dir, 'step%06d_generator.pth' % epoch)) 489 | # save_model 490 | if rank == 0: 491 | torch.save(ema.state_dict(), os.path.join(opt.output_dir, 'ema.pth')) 492 | torch.save(ema2.state_dict(), os.path.join(opt.output_dir, 'ema2.pth')) 493 | torch.save(generator_ddp.module.state_dict(), os.path.join(opt.output_dir, 'generator.pth')) 494 | torch.save(optimizer_G.state_dict(), os.path.join(opt.output_dir, 'optimizer_G.pth')) 495 | torch.save(scaler.state_dict(), os.path.join(opt.output_dir, 'scaler.pth')) 496 | torch.save(generator_losses, os.path.join(opt.output_dir, 'generator.losses')) 497 | -------------------------------------------------------------------------------- /generators/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/generators/.DS_Store -------------------------------------------------------------------------------- /generators/generators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | 6 | from .representations.gram import * 7 | 8 | from .renderers.manifold_renderer import * 9 | 10 | 11 | def sample_camera_positions(device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1, horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'): 12 | """Samples n random locations along a sphere of radius r. Uses a gaussian distribution for pitch and yaw""" 13 | if mode == 'uniform': 14 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean 15 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev + vertical_mean 16 | elif mode == 'normal' or mode == 'gaussian': 17 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean 18 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean 19 | elif mode == 'spherical_uniform': 20 | theta = (torch.rand((n, 1), device=device) - .5) * 2 * horizontal_stddev + horizontal_mean 21 | v_stddev, v_mean = vertical_stddev / math.pi, vertical_mean / math.pi # convert from radians to [0,1] 22 | v = ((torch.rand((n,1), device=device) - .5) * 2 * v_stddev + v_mean) 23 | v = torch.clamp(v, 1e-5, 1 - 1e-5) 24 | phi = torch.arccos(1 - 2 * v) 25 | else: 26 | theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean 27 | phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean 28 | 29 | phi = torch.clamp(phi, 1e-5, math.pi - 1e-5) 30 | 31 | camera_origin = torch.zeros((n, 3), device=device)# torch.cuda.FloatTensor(n, 3).fill_(0)#torch.zeros((n, 3)) 32 | 33 | camera_origin[:, 0:1] = r*torch.sin(phi) * torch.cos(theta) 34 | camera_origin[:, 2:3] = r*torch.sin(phi) * torch.sin(theta) 35 | camera_origin[:, 1:2] = r*torch.cos(phi) 36 | 37 | return camera_origin, torch.cat([phi, theta], dim=-1) 38 | 39 | 40 | def get_camera_origins(camera_pos, r=1): 41 | n = camera_pos.shape[0] 42 | device = camera_pos.device 43 | phi = camera_pos[:, 0] 44 | theta = camera_pos[:, 1] 45 | camera_origin = torch.zeros((n, 3), device=device)# torch.cuda.FloatTensor(n, 3).fill_(0)#torch.zeros((n, 3)) 46 | 47 | camera_origin[:, 0:1] = r*torch.sin(phi) * torch.cos(theta) 48 | camera_origin[:, 2:3] = r*torch.sin(phi) * torch.sin(theta) 49 | camera_origin[:, 1:2] = r*torch.cos(phi) 50 | 51 | return camera_origin 52 | 53 | 54 | class Generator(torch.nn.Module): 55 | def __init__(self) -> None: 56 | super().__init__() 57 | self.epoch = 0 58 | self.step = 0 59 | 60 | 61 | class GramGenerator(Generator): 62 | def __init__(self, z_dim, img_size, h_stddev ,v_stddev, h_mean, v_mean, sample_dist, representation_kwargs, renderer_kwargs, partial_grad=False): 63 | super().__init__() 64 | self.z_dim = z_dim 65 | self.img_size = img_size 66 | self.h_stddev = h_stddev 67 | self.v_stddev = v_stddev 68 | self.h_mean = h_mean 69 | self.v_mean = v_mean 70 | self.sample_dist = sample_dist 71 | self.partial_grad = partial_grad 72 | self.representation = Gram(z_dim, **representation_kwargs) 73 | self.renderer = ManifoldRenderer(**renderer_kwargs) 74 | 75 | def _volume(self, z, truncation_psi=1): 76 | return lambda points, ray_directions: self.representation.get_radiance(z, points, ray_directions, truncation_psi) 77 | 78 | def _volume_with_frequencies_phase_shifts(self, freq, phase): 79 | return lambda points, ray_directions: self.representation.get_radiance_with_frequencies_phase_shifts(freq, phase, points, ray_directions) 80 | 81 | def _intersections(self, points, levels): 82 | return self.representation.get_intersections(points, levels) 83 | 84 | def get_avg_w(self): 85 | self.representation.get_avg_w() 86 | 87 | def forward_with_frequencies_phase_shifts(self, freq, phase, fov, ray_start, ray_end, img_size=None, camera_origin=None, camera_pos=None, patch=None): 88 | if camera_origin is None and camera_pos is None: 89 | camera_origin, camera_pos = sample_camera_positions(freq.device, freq.shape[0], 1, self.h_stddev, self.v_stddev, self.h_mean, self.v_mean, self.sample_dist) 90 | elif camera_origin is not None: 91 | camera_origin = torch.tensor(camera_origin, dtype=torch.float32, device=freq.device).reshape(1, 3).expand(freq.shape[0], 3) 92 | else: 93 | camera_origin = get_camera_origins(camera_pos) 94 | if img_size is None: 95 | img_size = self.img_size 96 | if patch is None: 97 | img, _ = self.renderer.render(self._intersections, self._volume_with_frequencies_phase_shifts(freq, phase), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, freq.device, partial_grad = self.partial_grad) 98 | else: 99 | img = self.renderer.render_patch(self._intersections, self._volume_with_frequencies_phase_shifts(freq, phase), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, freq.device, patch, partial_grad = self.partial_grad) 100 | return img, camera_pos 101 | 102 | def forward(self, z, fov, ray_start, ray_end, img_size=None, camera_origin=None, camera_pos=None, truncation_psi=1, patch=None, detailed_output=False): 103 | if camera_origin is None and camera_pos is None: 104 | camera_origin, camera_pos = sample_camera_positions(z.device, z.shape[0], 1, self.h_stddev, self.v_stddev, self.h_mean, self.v_mean, self.sample_dist) 105 | elif camera_origin is not None: 106 | camera_origin = torch.tensor(camera_origin, dtype=torch.float32, device=z.device).reshape(1, 3).expand(z.shape[0], 3) 107 | else: 108 | camera_origin = get_camera_origins(camera_pos) 109 | if img_size is None: 110 | img_size = self.img_size 111 | if patch is None: 112 | img, details = self.renderer.render(self._intersections, self._volume(z, truncation_psi), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, z.device, detailed_output=detailed_output) 113 | else: 114 | img = self.renderer.render_patch(self._intersections, self._volume(z, truncation_psi), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, z.device, patch) 115 | if detailed_output: 116 | return img, details, camera_pos 117 | else: 118 | return img, camera_pos 119 | 120 | 121 | 122 | @torch.no_grad() 123 | def experiment(self, z, fov, ray_start, ray_end, img_size=None, camera_origin=None, camera_pos=None, truncation_psi=1): 124 | if camera_origin is None and camera_pos is None: 125 | camera_origin, camera_pos = sample_camera_positions(z.device, z.shape[0], 1, self.h_stddev, self.v_stddev, self.h_mean, self.v_mean, self.sample_dist) 126 | elif camera_origin is not None: 127 | camera_origin = torch.tensor(camera_origin, dtype=torch.float32, device=z.device).reshape(1, 3).expand(z.shape[0], 3) 128 | else: 129 | camera_origin = get_camera_origins(camera_pos) 130 | if img_size is None: 131 | img_size = self.img_size 132 | img, exp_data = self.renderer.render(self._intersections, self._volume(z, truncation_psi), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, z.device, detailed_output=True) 133 | return img, camera_pos, exp_data 134 | -------------------------------------------------------------------------------- /generators/renderers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/generators/renderers/.DS_Store -------------------------------------------------------------------------------- /generators/renderers/manifold_renderer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .math_utils_torch import * 7 | 8 | 9 | def fancy_integration(rgb_sigma, z_vals, device, is_valid, bg_pos, last_back=False, white_back=False, delta_alpha=0.04, delta_final=1e10): 10 | rgbs = rgb_sigma[..., :3] 11 | sigmas = rgb_sigma[..., 3:] 12 | 13 | deltas = torch.ones_like(z_vals[:, :, 1:] - z_vals[:, :, :-1])*delta_alpha 14 | delta_inf = delta_final * torch.ones_like(deltas[:, :, :1]) 15 | deltas = torch.cat([deltas, delta_inf], -2) # [batch,N_rays,num_manifolds,1] 16 | 17 | bg_pos = F.one_hot(bg_pos.squeeze(-1),num_classes=deltas.shape[-2]).to(torch.bool) # [batch,N_rays,num_manifolds] 18 | bg_pos = bg_pos.unsqueeze(-1) # [batch,N_rays,num_manifolds,1] 19 | deltas[bg_pos] = delta_final 20 | 21 | alphas = 1-torch.exp(-deltas * sigmas) 22 | alphas = alphas*is_valid 23 | alphas_shifted = torch.cat([torch.ones_like(alphas[:, :, :1]), 1-alphas + 1e-10], -2) 24 | T = torch.cumprod(alphas_shifted, -2)[:, :, :-1] 25 | weights = alphas * T 26 | weights_sum = weights.sum(2) 27 | 28 | if last_back: 29 | weights[:, :, -1] += (1 - weights_sum) 30 | 31 | rgb_final = torch.sum(weights * rgbs, -2) 32 | depth_final = torch.sum(weights * z_vals, -2)/weights_sum 33 | 34 | if white_back: 35 | rgb_final = rgb_final + 1-weights_sum 36 | 37 | return rgb_final, depth_final, weights, T 38 | 39 | 40 | def get_initial_rays_trig(n, num_samples, device, fov, resolution, ray_start, ray_end, randomize=True, patch_range=None): 41 | """Returns sample points, z_vals, ray directions in camera space.""" 42 | W, H = resolution 43 | # Create full screen NDC (-1 to +1) coords [x, y, 0, 1]. 44 | # Y is flipped to follow image memory layouts. 45 | x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device), 46 | torch.linspace(1, -1, H, device=device)) 47 | x = x.T.flatten() 48 | y = y.T.flatten() 49 | if patch_range is not None: 50 | x = x[patch_range[0]:patch_range[1]] 51 | y = y[patch_range[0]:patch_range[1]] 52 | z = -torch.ones_like(x, device=device) / np.tan((2 * math.pi * fov / 360)/2) 53 | 54 | rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1)) 55 | 56 | z_vals = torch.linspace(ray_start, ray_end, num_samples, device=device).reshape(1, num_samples, 1).repeat(rays_d_cam.shape[0], 1, 1) 57 | points = rays_d_cam.unsqueeze(1).repeat(1, num_samples, 1) * z_vals 58 | 59 | points = torch.stack(n*[points]) 60 | z_vals = torch.stack(n*[z_vals]) 61 | rays_d_cam = torch.stack(n*[rays_d_cam]).to(device) 62 | 63 | if randomize: 64 | perturb_points(points, z_vals, rays_d_cam, device) 65 | 66 | return points, z_vals, rays_d_cam 67 | 68 | 69 | def perturb_points(points, z_vals, ray_directions, device): 70 | distance_between_points = z_vals[:,:,1:2,:] - z_vals[:,:,0:1,:] 71 | offset = (torch.rand(z_vals.shape, device=device)-0.5) * distance_between_points 72 | z_vals = z_vals + offset 73 | 74 | points = points + offset * ray_directions.unsqueeze(2) 75 | return points, z_vals 76 | 77 | 78 | def get_intersection_with_MPI(transformed_ray_directions,transformed_ray_origins,device, mpi_start=0.12,mpi_end=-0.12,mpi_num=24): 79 | mpi_z_vals = torch.linspace(mpi_start, mpi_end, mpi_num, device=device) 80 | z_vals = mpi_z_vals.view(1,1,mpi_num) - transformed_ray_origins[...,-1:] #[batch,N,mpi_num] 81 | z_vals = z_vals/transformed_ray_directions[...,-1:] #[batch,N,mpi_num] 82 | z_vals = z_vals.unsqueeze(-1) 83 | points = transformed_ray_origins.unsqueeze(2) + transformed_ray_directions.unsqueeze(2)*z_vals 84 | 85 | return points, z_vals 86 | 87 | 88 | def transform_sampled_points(points, ray_directions, camera_origin, camera_pos, device): 89 | n, num_rays, num_samples, channels = points.shape 90 | forward_vector = normalize_vecs(-camera_origin) 91 | 92 | cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device) 93 | 94 | points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1), device=device) 95 | points_homogeneous[:, :, :, :3] = points 96 | 97 | # should be n x 4 x 4 , n x r^2 x num_samples x 4 98 | transformed_points = torch.bmm(cam2world_matrix, points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, num_samples, 4) 99 | transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, 3) 100 | 101 | homogeneous_origins = torch.zeros((n, 4, num_rays), device=device) 102 | homogeneous_origins[:, 3, :] = 1 103 | transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1).reshape(n, num_rays, 4)[..., :3] 104 | 105 | return transformed_points[..., :3], transformed_ray_directions, transformed_ray_origins, camera_pos 106 | 107 | 108 | def create_cam2world_matrix(forward_vector, origin, device=None): 109 | """Takes in the direction the camera is pointing and the camera origin and returns a world2cam matrix.""" 110 | forward_vector = normalize_vecs(forward_vector) 111 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector) 112 | 113 | left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) 114 | 115 | up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1)) 116 | 117 | rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 118 | rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1) 119 | 120 | translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 121 | translation_matrix[:, :3, 3] = origin 122 | 123 | cam2world = translation_matrix @ rotation_matrix 124 | 125 | return cam2world 126 | 127 | 128 | class ManifoldRenderer: 129 | def __init__(self, num_manifolds, levels_start, levels_end, num_samples, last_back=False, white_back=False, background=True, delta_alpha=0.04, delta_final=1e10, lock_view_dependence=False) -> None: 130 | self.num_manifolds = num_manifolds 131 | self.levels_start = levels_start 132 | self.levels_end = levels_end 133 | self.num_samples = num_samples 134 | self.last_back = last_back 135 | self.white_back = white_back 136 | self.background = background 137 | self.delta_alpha = delta_alpha 138 | self.delta_final = delta_final if background else delta_alpha 139 | self.lock_view_dependence = lock_view_dependence 140 | 141 | def render(self, intersection, volume, img_size, camera_origin, camera_pos, fov, ray_start, ray_end, device, detailed_output=False, partial_grad=False): 142 | batchsize = camera_origin.shape[0] 143 | 144 | with torch.no_grad(): 145 | points_cam, z_vals, rays_d_cam = get_initial_rays_trig(batchsize, self.num_samples, resolution=(img_size, img_size), device=device, fov=fov, ray_start=ray_start, ray_end=ray_end, randomize=False) # batchsize, pixels, num_manifolds, 1 146 | transformed_points_sample, transformed_ray_directions, transformed_ray_origins, _ = transform_sampled_points(points_cam, rays_d_cam, camera_origin, camera_pos, device=device) 147 | transformed_points_sample = transformed_points_sample.reshape(batchsize, img_size*img_size, -1, 3) 148 | if self.background: 149 | levels = torch.linspace(self.levels_start, self.levels_end, self.num_manifolds-1).to(device) 150 | transformed_points_back, _ = get_intersection_with_MPI(transformed_ray_directions,transformed_ray_origins,device=device,mpi_start=-0.12,mpi_end=-0.12,mpi_num=1) 151 | else: 152 | levels = torch.linspace(self.levels_start, self.levels_end, self.num_manifolds).to(device) 153 | 154 | if not partial_grad: 155 | transformed_points,_,is_valid = intersection(transformed_points_sample, levels) # [batch,H*W,num_manifolds,3] 156 | else: 157 | with torch.no_grad(): 158 | transformed_points,_,is_valid = intersection(transformed_points_sample, levels) # [batch,H*W,num_manifolds,3] 159 | 160 | if self.background: 161 | transformed_points = torch.cat([transformed_points,transformed_points_back],dim=-2) 162 | is_valid = torch.cat([is_valid,torch.ones(is_valid.shape[0],is_valid.shape[1],1,is_valid.shape[-1]).to(is_valid.device)],dim=-2) 163 | 164 | with torch.no_grad(): 165 | z_vals = torch.sqrt(torch.sum((transformed_points - transformed_ray_origins.unsqueeze(2))**2,dim=-1,keepdim=True)) # [batch,H*W,num_manifolds,1] 166 | 167 | transformed_ray_directions_expanded = torch.unsqueeze(transformed_ray_directions, -2) 168 | transformed_ray_directions_expanded = transformed_ray_directions_expanded.expand(-1, -1, self.num_manifolds, -1) 169 | 170 | if self.lock_view_dependence: 171 | transformed_ray_directions_expanded = torch.zeros_like(transformed_ray_directions_expanded) 172 | transformed_ray_directions_expanded[..., -1] = -1 173 | 174 | 175 | if not partial_grad: 176 | transformed_ray_directions_expanded = transformed_ray_directions_expanded.reshape(batchsize, img_size*img_size*self.num_manifolds, 3) 177 | transformed_points = transformed_points.reshape(batchsize, img_size*img_size*self.num_manifolds, 3) 178 | coarse_output = volume(transformed_points, transformed_ray_directions_expanded).reshape(batchsize, img_size * img_size, self.num_manifolds, 4) 179 | else: 180 | indices_pts = torch.randperm(img_size*img_size) 181 | indices_with_grad = indices_pts[:img_size*img_size//4] 182 | indices_wo_grad = indices_pts[img_size*img_size//4:] 183 | 184 | transformed_ray_directions_expanded_with_grad = transformed_ray_directions_expanded[:,indices_with_grad] 185 | transformed_ray_directions_expanded_wo_grad = transformed_ray_directions_expanded[:,indices_wo_grad] 186 | 187 | transformed_points_with_grad = transformed_points[:,indices_with_grad] 188 | transformed_points_wo_grad = transformed_points[:,indices_wo_grad] 189 | transformed_points_with_grad = transformed_points_with_grad.reshape(batchsize, -1, 3) 190 | transformed_ray_directions_expanded_with_grad = transformed_ray_directions_expanded_with_grad.reshape(batchsize, -1, 3) 191 | 192 | coarse_output_with_grad = volume(transformed_points_with_grad, transformed_ray_directions_expanded_with_grad).reshape(batchsize, -1, self.num_manifolds, 4) 193 | with torch.no_grad(): 194 | transformed_points_wo_grad = transformed_points_wo_grad.reshape(batchsize, -1, 3) 195 | transformed_ray_directions_expanded_wo_grad = transformed_ray_directions_expanded_wo_grad.reshape(batchsize, -1, 3) 196 | coarse_output_wo_grad = volume(transformed_points_wo_grad, transformed_ray_directions_expanded_wo_grad).reshape(batchsize, -1, self.num_manifolds, 4) 197 | 198 | coarse_output = torch.zeros(batchsize, img_size*img_size, self.num_manifolds, 4).to(coarse_output_with_grad.device).to(coarse_output_with_grad.dtype) 199 | coarse_output[:,indices_with_grad] = coarse_output_with_grad 200 | coarse_output[:,indices_wo_grad] = coarse_output_wo_grad 201 | 202 | all_points = transformed_points.reshape(batchsize, img_size * img_size, self.num_manifolds, 3) 203 | all_outputs = coarse_output 204 | all_z_vals = z_vals 205 | 206 | _, indices = torch.sort(all_z_vals, dim=-2) 207 | all_z_vals = torch.gather(all_z_vals, -2, indices) 208 | all_outputs = torch.gather(all_outputs, -2, indices.expand(-1, -1, -1, 4)) 209 | all_points = torch.gather(all_points, -2, indices.expand(-1, -1, -1, 3)) 210 | is_valid = torch.gather(is_valid,-2,indices) 211 | 212 | bg_pos = torch.argmax(indices,dim=-2) 213 | 214 | pixels, depth, weights, T = fancy_integration(all_outputs, all_z_vals, is_valid=is_valid, bg_pos=bg_pos, device=device, white_back=self.white_back, last_back=self.last_back, delta_final=self.delta_final, delta_alpha=self.delta_alpha) 215 | 216 | pixels = pixels.reshape((batchsize, img_size, img_size, 3)) 217 | pixels = pixels.permute(0, 3, 1, 2).contiguous() * 2 - 1 218 | 219 | if detailed_output: 220 | detail = { 221 | 'points': all_points.reshape(batchsize, img_size, img_size, -1, 3), 222 | 'outputs': all_outputs.reshape(batchsize, img_size, img_size, -1, 4), 223 | 'z_vals': all_z_vals.reshape(batchsize, img_size, img_size, -1, 1), 224 | 'depth': depth.reshape(batchsize, img_size, img_size, -1, 1), 225 | 'weights': weights.reshape(batchsize, img_size, img_size, -1, 1), 226 | } 227 | else: 228 | detail = None 229 | 230 | return pixels, detail 231 | 232 | def render_patch(self, intersection, volume, img_size, camera_origin, camera_pos, fov, ray_start, ray_end, device, patch, partial_grad=False): 233 | batchsize = camera_origin.shape[0] 234 | patch_idx = patch[0] 235 | patch_num = patch[1] 236 | patch_range = (img_size*img_size*patch_idx//patch_num, img_size*img_size*(patch_idx+1)//patch_num) 237 | patch_len = patch_range[1] - patch_range[0] 238 | 239 | with torch.no_grad(): 240 | points_cam, z_vals, rays_d_cam = get_initial_rays_trig(batchsize, self.num_samples, resolution=(img_size, img_size), device=device, fov=fov, ray_start=ray_start, ray_end=ray_end, randomize=False, patch_range=patch_range) # batchsize, pixels, num_manifolds, 1 241 | transformed_points_sample, transformed_ray_directions, transformed_ray_origins, _ = transform_sampled_points(points_cam, rays_d_cam, camera_origin, camera_pos, device=device) 242 | transformed_points_sample = transformed_points_sample.reshape(batchsize, patch_len, -1, 3) 243 | if self.background: 244 | levels = torch.linspace(self.levels_start, self.levels_end, self.num_manifolds-1).to(device) 245 | transformed_points_back, _ = get_intersection_with_MPI(transformed_ray_directions,transformed_ray_origins,device=device,mpi_start=-0.12,mpi_end=-0.12,mpi_num=1) 246 | else: 247 | levels = torch.linspace(self.levels_start, self.levels_end, self.num_manifolds).to(device) 248 | 249 | if not partial_grad: 250 | transformed_points,_,is_valid = intersection(transformed_points_sample, levels) # [batch,H*W,num_manifolds,3] 251 | else: 252 | with torch.no_grad(): 253 | transformed_points,_,is_valid = intersection(transformed_points_sample, levels) # [batch,H*W,num_manifolds,3] 254 | 255 | if self.background: 256 | transformed_points = torch.cat([transformed_points,transformed_points_back],dim=-2) 257 | is_valid = torch.cat([is_valid,torch.ones(is_valid.shape[0],is_valid.shape[1],1,is_valid.shape[-1]).to(is_valid.device)],dim=-2) 258 | 259 | with torch.no_grad(): 260 | z_vals = torch.sqrt(torch.sum((transformed_points - transformed_ray_origins.unsqueeze(2))**2,dim=-1,keepdim=True)) # [batch,patch_len,num_manifolds,1] 261 | 262 | transformed_ray_directions_expanded = torch.unsqueeze(transformed_ray_directions, -2) 263 | transformed_ray_directions_expanded = transformed_ray_directions_expanded.expand(-1, -1, self.num_manifolds, -1) 264 | 265 | if self.lock_view_dependence: 266 | transformed_ray_directions_expanded = torch.zeros_like(transformed_ray_directions_expanded) 267 | transformed_ray_directions_expanded[..., -1] = -1 268 | 269 | if not partial_grad: 270 | transformed_ray_directions_expanded = transformed_ray_directions_expanded.reshape(batchsize, patch_len*self.num_manifolds, 3) 271 | transformed_points = transformed_points.reshape(batchsize, patch_len*self.num_manifolds, 3) 272 | coarse_output = volume(transformed_points, transformed_ray_directions_expanded).reshape(batchsize, patch_len, self.num_manifolds, 4) 273 | else: 274 | indices_pts = torch.randperm(patch_len) 275 | indices_with_grad = indices_pts[:patch_len//4] 276 | indices_wo_grad = indices_pts[patch_len//4:] 277 | 278 | transformed_ray_directions_expanded_with_grad = transformed_ray_directions_expanded[:,indices_with_grad] 279 | transformed_ray_directions_expanded_wo_grad = transformed_ray_directions_expanded[:,indices_wo_grad] 280 | 281 | transformed_points_with_grad = transformed_points[:,indices_with_grad] 282 | transformed_points_wo_grad = transformed_points[:,indices_wo_grad] 283 | transformed_points_with_grad = transformed_points_with_grad.reshape(batchsize, -1, 3) 284 | transformed_ray_directions_expanded_with_grad = transformed_ray_directions_expanded_with_grad.reshape(batchsize, -1, 3) 285 | 286 | coarse_output_with_grad = volume(transformed_points_with_grad, transformed_ray_directions_expanded_with_grad).reshape(batchsize, -1, self.num_manifolds, 4) 287 | with torch.no_grad(): 288 | transformed_points_wo_grad = transformed_points_wo_grad.reshape(batchsize, -1, 3) 289 | transformed_ray_directions_expanded_wo_grad = transformed_ray_directions_expanded_wo_grad.reshape(batchsize, -1, 3) 290 | coarse_output_wo_grad = volume(transformed_points_wo_grad, transformed_ray_directions_expanded_wo_grad).reshape(batchsize, -1, self.num_manifolds, 4) 291 | 292 | coarse_output = torch.zeros(batchsize, patch_len, self.num_manifolds, 4).to(coarse_output_with_grad.device).to(coarse_output_with_grad.dtype) 293 | coarse_output[:,indices_with_grad] = coarse_output_with_grad 294 | coarse_output[:,indices_wo_grad] = coarse_output_wo_grad 295 | 296 | all_points = transformed_points.reshape(batchsize, patch_len, self.num_manifolds, 3) 297 | all_outputs = coarse_output 298 | all_z_vals = z_vals 299 | 300 | _, indices = torch.sort(all_z_vals, dim=-2) 301 | all_z_vals = torch.gather(all_z_vals, -2, indices) 302 | all_outputs = torch.gather(all_outputs, -2, indices.expand(-1, -1, -1, 4)) 303 | all_points = torch.gather(all_points, -2, indices.expand(-1, -1, -1, 3)) 304 | is_valid = torch.gather(is_valid,-2,indices) 305 | 306 | bg_pos = torch.argmax(indices,dim=-2) 307 | 308 | pixels, depth, weights, T = fancy_integration(all_outputs, all_z_vals, is_valid=is_valid, bg_pos=bg_pos, device=device, white_back=self.white_back, last_back=self.last_back, delta_final=self.delta_final, delta_alpha=self.delta_alpha) 309 | 310 | pixels = pixels.reshape((batchsize, patch_len, 3)).permute(0, 2, 1) * 2 - 1 311 | 312 | return pixels 313 | -------------------------------------------------------------------------------- /generators/renderers/math_utils_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for geometry etc. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 9 | """ 10 | Left-multiplies MxM @ NxM. Returns NxM. 11 | """ 12 | res = torch.matmul(vectors4, matrix.T) 13 | return res 14 | 15 | 16 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 17 | """ 18 | Normalize vector lengths. 19 | """ 20 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 21 | 22 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 23 | """ 24 | Dot product of two tensors. 25 | """ 26 | return (x * y).sum(-1) 27 | -------------------------------------------------------------------------------- /generators/representations/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/generators/representations/.DS_Store -------------------------------------------------------------------------------- /generators/representations/gram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def frequency_init(freq): 8 | def init(m): 9 | with torch.no_grad(): 10 | # if hasattr(m, 'weight'): 11 | if isinstance(m, nn.Linear): 12 | num_input = m.weight.size(-1) 13 | m.weight.uniform_(-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq) 14 | return init 15 | 16 | 17 | def first_layer_film_sine_init(m): 18 | with torch.no_grad(): 19 | # if hasattr(m, 'weight'): 20 | if isinstance(m, nn.Linear): 21 | num_input = m.weight.size(-1) 22 | m.weight.uniform_(-1 / num_input, 1 / num_input) 23 | 24 | 25 | def kaiming_leaky_init(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Linear') != -1: 28 | torch.nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') 29 | 30 | 31 | def geometry_init(m): 32 | with torch.no_grad(): 33 | # if hasattr(m, 'weight'): 34 | if isinstance(m, nn.Linear): 35 | num_output = m.weight.size(0) 36 | m.weight.normal_(0,np.sqrt(2/num_output)) 37 | nn.init.constant_(m.bias,0) 38 | 39 | 40 | def geometry_init_last_layer(radius): 41 | def init(m): 42 | with torch.no_grad(): 43 | # if hasattr(m, 'weight'): 44 | if isinstance(m, nn.Linear): 45 | num_input = m.weight.size(-1) 46 | nn.init.constant_(m.weight,10*np.sqrt(np.pi/num_input)) 47 | nn.init.constant_(m.bias,-radius) 48 | return init 49 | 50 | 51 | class FiLMLayer(nn.Module): 52 | def __init__(self, input_dim, hidden_dim, activation=torch.sin): 53 | super().__init__() 54 | self.layer = nn.Linear(input_dim, hidden_dim) 55 | self.activation = activation 56 | 57 | def forward(self, x, freq, phase_shift): 58 | x = self.layer(x) 59 | freq = freq.unsqueeze(1).expand_as(x) 60 | phase_shift = phase_shift.unsqueeze(1).expand_as(x) 61 | return self.activation(freq * x + phase_shift) 62 | 63 | def statistic(self, x, freq, phase_shift): 64 | x = self.layer(x) 65 | freq = freq.unsqueeze(1).expand_as(x) 66 | phase_shift = phase_shift.unsqueeze(1).expand_as(x) 67 | return self.activation(freq * x + phase_shift) 68 | 69 | 70 | class CustomMappingNetwork(nn.Module): 71 | def __init__(self, z_dim, map_hidden_dim, map_output_dim): 72 | super().__init__() 73 | 74 | self.network = nn.Sequential( 75 | nn.Linear(z_dim, map_hidden_dim), 76 | nn.LeakyReLU(0.2, inplace=True), 77 | 78 | nn.Linear(map_hidden_dim, map_hidden_dim), 79 | nn.LeakyReLU(0.2, inplace=True), 80 | 81 | nn.Linear(map_hidden_dim, map_hidden_dim), 82 | nn.LeakyReLU(0.2, inplace=True), 83 | 84 | nn.Linear(map_hidden_dim, map_output_dim) 85 | ) 86 | 87 | self.network.apply(kaiming_leaky_init) 88 | with torch.no_grad(): 89 | self.network[-1].weight *= 0.25 90 | 91 | def forward(self, z): 92 | frequencies_offsets = self.network(z) 93 | frequencies = frequencies_offsets[..., :frequencies_offsets.shape[-1]//2] 94 | phase_shifts = frequencies_offsets[..., frequencies_offsets.shape[-1]//2:] 95 | 96 | return frequencies, phase_shifts 97 | 98 | 99 | class UniformBoxWarp(nn.Module): 100 | def __init__(self, sidelength): 101 | super().__init__() 102 | self.scale_factor = 2/sidelength 103 | 104 | def forward(self, coordinates): 105 | return coordinates * self.scale_factor 106 | 107 | 108 | class GramSample(nn.Module): 109 | def __init__(self, hidden_dim_sample=64, layer_num_sample=3, center=(0,0,0), init_radius=0): 110 | super().__init__() 111 | self.hidden_dim = hidden_dim_sample 112 | self.layer_num = layer_num_sample 113 | 114 | self.network = [nn.Linear(3, self.hidden_dim), nn.ReLU(inplace=True)] 115 | for _ in range(self.layer_num - 1): 116 | self.network += [nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True)] 117 | 118 | self.network = nn.Sequential(*self.network) 119 | 120 | self.output_layer = nn.Linear(self.hidden_dim, 1) 121 | 122 | self.network.apply(geometry_init) 123 | self.output_layer.apply(geometry_init_last_layer(init_radius)) 124 | self.center = torch.tensor(center) 125 | 126 | self.gridwarper = UniformBoxWarp(0.24) # Don't worry about this, it was added to ensure compatibility with another model. Shouldn't affect performance. 127 | 128 | def calculate_intersection(self,intervals,vals,levels): 129 | intersections = [] 130 | is_valid = [] 131 | for interval,val,l in zip(intervals,vals,levels): 132 | x_l = interval[:,:,0] 133 | x_h = interval[:,:,1] 134 | s_l = val[:,:,0] 135 | s_h = val[:,:,1] 136 | scale = torch.where(torch.abs(s_h-s_l) > 0.05,s_h-s_l,torch.ones_like(s_h)*0.05) 137 | intersect = torch.where(((s_h-l<=0)*(l-s_l<=0)) & (torch.abs(s_h-s_l) > 0.05),((s_h-l)*x_l + (l-s_l)*x_h)/scale,x_h) 138 | intersections.append(intersect) 139 | is_valid.append(((s_h-l<=0)*(l-s_l<=0)).to(intersect.dtype)) 140 | 141 | return torch.stack(intersections,dim=-2),torch.stack(is_valid,dim=-2) #[batch,N_rays,level,3] 142 | 143 | def forward(self,input): 144 | x = input 145 | x = self.gridwarper(x) 146 | x = x - self.center.to(x.device) 147 | x = self.network(x) 148 | s = self.output_layer(x) 149 | 150 | return s 151 | 152 | def get_intersections(self, input, levels, **kwargs): 153 | # levels num_l 154 | batch,N_rays,N_points,_ = input.shape 155 | 156 | x = input.reshape(batch,-1,3) 157 | x = self.gridwarper(x) 158 | 159 | x = x - self.center.to(x.device) 160 | 161 | x = self.network(x) 162 | s = self.output_layer(x) 163 | 164 | s = s.reshape(batch,N_rays,N_points,1) 165 | s_l = s[:,:,:-1] 166 | s_h = s[:,:,1:] 167 | 168 | cost = torch.linspace(N_points-1,0,N_points-1).float().to(input.device).reshape(1,1,-1,1) 169 | x_interval = [] 170 | s_interval = [] 171 | for l in levels: 172 | r = (s_h-l <= 0) * (l-s_l <= 0) * 2 - 1 173 | r = r*cost 174 | _, indices = torch.max(r,dim=-2,keepdim=True) 175 | x_l_select = torch.gather(input,-2,indices.expand(-1, -1, -1, 3)) # [batch,N_rays,1] 176 | x_h_select = torch.gather(input,-2,indices.expand(-1, -1, -1, 3)+1) # [batch,N_rays,1] 177 | s_l_select = torch.gather(s_l,-2,indices) 178 | s_h_select = torch.gather(s_h,-2,indices) 179 | x_interval.append(torch.cat([x_l_select,x_h_select],dim=-2)) 180 | s_interval.append(torch.cat([s_l_select,s_h_select],dim=-2)) 181 | 182 | intersections,is_valid = self.calculate_intersection(x_interval,s_interval,levels) 183 | 184 | return intersections,s,is_valid 185 | 186 | 187 | class GramRF(nn.Module): 188 | def __init__(self, z_dim=100, hidden_dim=256, normalize=0.24, sigma_clamp_mode='softplus', rgb_clamp_mode='widen_sigmoid'): 189 | super().__init__() 190 | self.z_dim = z_dim 191 | self.hidden_dim = hidden_dim 192 | self.sigma_clamp_mode = sigma_clamp_mode 193 | self.rgb_clamp_mode = rgb_clamp_mode 194 | self.avg_frequencies = None 195 | self.avg_phase_shifts = None 196 | 197 | self.network = nn.ModuleList([ 198 | FiLMLayer(3, hidden_dim), 199 | FiLMLayer(hidden_dim, hidden_dim), 200 | FiLMLayer(hidden_dim, hidden_dim), 201 | FiLMLayer(hidden_dim, hidden_dim), 202 | FiLMLayer(hidden_dim, hidden_dim), 203 | FiLMLayer(hidden_dim, hidden_dim), 204 | FiLMLayer(hidden_dim, hidden_dim), 205 | FiLMLayer(hidden_dim, hidden_dim), 206 | ]) 207 | 208 | self.color_layer = nn.ModuleList([FiLMLayer(hidden_dim + 3, hidden_dim)]) 209 | 210 | self.output_sigma = nn.ModuleList([ 211 | nn.Linear(hidden_dim, 1), 212 | nn.Linear(hidden_dim, 1), 213 | nn.Linear(hidden_dim, 1), 214 | nn.Linear(hidden_dim, 1), 215 | nn.Linear(hidden_dim, 1), 216 | nn.Linear(hidden_dim, 1), 217 | nn.Linear(hidden_dim, 1), 218 | ]) 219 | 220 | self.output_color = nn.ModuleList([ 221 | nn.Linear(hidden_dim, 3), 222 | nn.Linear(hidden_dim, 3), 223 | nn.Linear(hidden_dim, 3), 224 | nn.Linear(hidden_dim, 3), 225 | nn.Linear(hidden_dim, 3), 226 | nn.Linear(hidden_dim, 3), 227 | nn.Linear(hidden_dim, 3), 228 | ]) 229 | 230 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + len(self.color_layer))*hidden_dim*2) 231 | 232 | self.network.apply(frequency_init(25)) 233 | self.output_sigma.apply(frequency_init(25)) 234 | self.color_layer.apply(frequency_init(25)) 235 | self.output_color.apply(frequency_init(25)) 236 | self.network[0].apply(first_layer_film_sine_init) 237 | 238 | self.gridwarper = UniformBoxWarp(normalize) # Don't worry about this, it was added to ensure compatibility with another model. Shouldn't affect performance. 239 | 240 | def get_avg_w(self): 241 | z = torch.randn((10000, self.z_dim), device=next(self.parameters()).device) 242 | with torch.no_grad(): 243 | frequencies, phase_shifts = self.mapping_network(z) 244 | self.avg_frequencies = frequencies.mean(0, keepdim=True) 245 | self.avg_phase_shifts = phase_shifts.mean(0, keepdim=True) 246 | return self.avg_frequencies, self.avg_phase_shifts 247 | 248 | def forward(self, input, z, ray_directions, truncation_psi=1): 249 | frequencies, phase_shifts = self.mapping_network(z) 250 | if truncation_psi < 1: 251 | frequencies = self.avg_frequencies.lerp(frequencies, truncation_psi) 252 | phase_shifts = self.avg_phase_shifts.lerp(phase_shifts, truncation_psi) 253 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions) 254 | 255 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, eps=1e-3): 256 | frequencies = frequencies*15 + 30 257 | 258 | input = self.gridwarper(input) 259 | x = input 260 | sigma = 0 261 | rgb = 0 262 | 263 | for index, layer in enumerate(self.network): 264 | start = index * self.hidden_dim 265 | end = (index+1) * self.hidden_dim 266 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) 267 | if index > 0: 268 | layer_sigma = self.output_sigma[index-1](x) 269 | if not index == 7: 270 | layer_rgb_feature = x 271 | else: 272 | layer_rgb_feature = self.color_layer[0](torch.cat([ray_directions, x], dim=-1),\ 273 | frequencies[..., len(self.network)*self.hidden_dim:(len(self.network)+1)*self.hidden_dim], phase_shifts[..., len(self.network)*self.hidden_dim:(len(self.network)+1)*self.hidden_dim]) 274 | layer_rgb = self.output_color[index-1](layer_rgb_feature) 275 | 276 | sigma += layer_sigma 277 | rgb += layer_rgb 278 | 279 | if self.rgb_clamp_mode == 'sigmoid': 280 | rgb = torch.sigmoid(rgb) 281 | elif self.rgb_clamp_mode == 'widen_sigmoid': 282 | rgb = torch.sigmoid(rgb)*(1+2*eps) - eps 283 | 284 | if self.sigma_clamp_mode == 'relu': 285 | sigma = F.relu(sigma) 286 | elif self.sigma_clamp_mode == 'softplus': 287 | sigma = F.softplus(sigma) 288 | 289 | return torch.cat([rgb, sigma], dim=-1) 290 | 291 | 292 | class Gram(nn.Module): 293 | def __init__(self, z_dim=256, hidden_dim=256, normalize=0.24, sigma_clamp_mode='softplus', rgb_clamp_mode='widen_sigmoid', **sample_network_kwargs): 294 | super().__init__() 295 | self.sample_network = GramSample(**sample_network_kwargs) 296 | self.rf_network = GramRF(z_dim, hidden_dim, normalize, sigma_clamp_mode, rgb_clamp_mode) 297 | 298 | def get_avg_w(self): 299 | self.rf_network.get_avg_w() 300 | 301 | def get_intersections(self, points, levels): 302 | return self.sample_network.get_intersections(points, levels) 303 | 304 | def get_radiance(self, z, x, ray_directions, truncation_psi=1): 305 | return self.rf_network(x, z, ray_directions, truncation_psi) 306 | 307 | def get_radiance_with_frequencies_phase_shifts(self, frequencies, phase_shifts, x, ray_directions): 308 | return self.rf_network.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts, ray_directions) 309 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob, shutil 3 | import torch 4 | 5 | from torchvision.utils import save_image 6 | from tqdm import tqdm 7 | import copy 8 | import argparse 9 | 10 | # import pytorch3d 11 | 12 | from generators import generators 13 | import configs 14 | import math 15 | 16 | import time 17 | from PIL import Image 18 | import torchvision.transforms as transforms 19 | import dnnlib 20 | import numpy as np 21 | from scipy.io import loadmat 22 | import torch.nn.functional as F 23 | import torch.nn as nn 24 | import importlib 25 | from torch_ema import ExponentialMovingAverage 26 | from utils.arcface import get_model 27 | 28 | 29 | class IDLoss(nn.Module): 30 | def __init__(self, facenet): 31 | super(IDLoss, self).__init__() 32 | self.facenet = facenet 33 | 34 | def forward(self,x,y): 35 | x = F.interpolate(x,size=[112,112],mode='bilinear') 36 | y = F.interpolate(y,size=[112,112],mode='bilinear') 37 | 38 | # x = 2*(x-0.5) 39 | # y = 2*(y-0.5) 40 | feat_x = self.facenet(x) 41 | feat_y = self.facenet(y.detach()) 42 | 43 | loss = 1 - F.cosine_similarity(feat_x,feat_y,dim=-1) 44 | 45 | return loss 46 | 47 | def read_pose(name,flip=False): 48 | P = loadmat(name)['angle'] 49 | P_x = -(P[0,0] - 0.1) + math.pi/2 50 | if not flip: 51 | P_y = P[0,1] + math.pi/2 52 | else: 53 | P_y = -P[0,1] + math.pi/2 54 | 55 | P = torch.tensor([P_x,P_y],dtype=torch.float32) 56 | 57 | return P 58 | 59 | def read_pose_npy(name,flip=False): 60 | P = np.load(name) 61 | P_x = P[0] + 0.14 62 | if not flip: 63 | P_y = P[1] 64 | else: 65 | P_y = -P[1] + math.pi 66 | 67 | P = torch.tensor([P_x,P_y],dtype=torch.float32) 68 | 69 | return P 70 | 71 | 72 | def transform_matrix_to_camera_pos(c2w,flip=False): 73 | """ 74 | Get camera position with transform matrix 75 | 76 | :param c2w: camera to world transform matrix 77 | :return: camera position on spherical coord 78 | """ 79 | 80 | c2w[[0,1,2]] = c2w[[1,2,0]] 81 | pos = c2w[:, -1].squeeze() 82 | radius = float(np.linalg.norm(pos)) 83 | theta = float(np.arctan2(-pos[0], pos[2])) 84 | phi = float(np.arctan(-pos[1] / np.linalg.norm(pos[::2]))) 85 | theta = theta + np.pi * 0.5 86 | phi = phi + np.pi * 0.5 87 | if flip: 88 | theta = -theta + math.pi 89 | P = torch.tensor([phi,theta],dtype=torch.float32) 90 | return P 91 | 92 | 93 | def load_models(opt, config, device): 94 | print("loading models...") 95 | generator_args = {} 96 | if 'representation' in config['generator']: 97 | generator_args['representation_kwargs'] = config['generator']['representation']['kwargs'] 98 | if 'renderer' in config['generator']: 99 | generator_args['renderer_kwargs'] = config['generator']['renderer']['kwargs'] 100 | generator = getattr(generators, config['generator']['class'])( 101 | **generator_args, 102 | **config['generator']['kwargs'] 103 | ) 104 | 105 | generator.load_state_dict(torch.load(os.path.join(opt.generator_file), map_location='cpu'),strict=False) 106 | generator = generator.to('cuda') 107 | generator.eval() 108 | 109 | ema = torch.load(os.path.join(opt.generator_file.replace('generator', 'ema')), map_location='cuda') 110 | parameters = [p for p in generator.parameters() if p.requires_grad] 111 | ema.copy_to(parameters) 112 | 113 | #for LPIPS loss 114 | if opt.config == 'FACES_default': 115 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 116 | with dnnlib.util.open_url(url) as f: 117 | vgg16 = torch.jit.load(f).eval().to(device) 118 | elif opt.config == 'CATS_default': # CATS, CARLA 119 | import lpips 120 | vgg16 = lpips.LPIPS(net='vgg').eval().to(device) # closer to "traditional" perceptual loss, when used for optimization 121 | else: 122 | raise 123 | 124 | face_recog = get_model('r50', fp16=False) 125 | face_recog.load_state_dict(torch.load('pretrained_models/arcface.pth')) 126 | face_recog.eval() 127 | 128 | return generator, vgg16, face_recog 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--generator_file', type=str, default='pretrained_models/gram/FACES_default/generator.pth') 133 | parser.add_argument('--output_dir', type=str, default='experiments/gram/inversion') 134 | parser.add_argument('--data_img_dir', type=str, default='samples/faces/') 135 | parser.add_argument('--data_pose_dir', type=str, default='samples/faces/poses/') 136 | parser.add_argument('--name', type=str, default=None, help="specifc image name (e.g. '28606.png'), or None (will invert all images)") 137 | parser.add_argument('--config', type=str, default='FACES_default') 138 | parser.add_argument('--ema', action='store_true') 139 | parser.add_argument('--max_batch_size', type=int, default=None) 140 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 141 | parser.add_argument('--psi', type=str, default=0.7) 142 | parser.add_argument('--lambda_perceptual', type=float, default=1) 143 | parser.add_argument('--lambda_l2', type=float, default=0.01) 144 | parser.add_argument('--lambda_id', type=float, default=0.01) 145 | parser.add_argument('--lambda_reg', type=float, default=0.04) 146 | 147 | parser.add_argument('--start_iter', type=int, default=2000) 148 | parser.add_argument('--max_iter', type=int, default=1000) 149 | parser.add_argument('--sv_interval', type=int, default=50) 150 | parser.add_argument('--vis_loss', type=bool, default=False) 151 | 152 | opt = parser.parse_args() 153 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 154 | config = getattr(configs, opt.config) 155 | 156 | ## load models 157 | generator, vgg16, face_recog = load_models(opt, config, device) 158 | generator.renderer.lock_view_dependence = True 159 | 160 | ## load data 161 | img_size = config['global']['img_size'] 162 | transform = transforms.Compose([transforms.Resize((img_size, img_size), interpolation=1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 163 | # search all images 164 | img_fullpaths_all = [] 165 | if opt.name: 166 | name = opt.name 167 | img_fullpath = os.path.join(opt.data_img_dir, f"{opt.name}.png") 168 | img_fullpaths_all.append(img_fullpath) 169 | else: 170 | img_fullpaths_all = sorted(glob.glob(os.path.join(opt.data_img_dir, f"*.png"))) 171 | img_fullpaths = [] 172 | for imgpath in img_fullpaths_all: 173 | subject = imgpath.split('/')[-1].split('.')[0] 174 | inv_path = os.path.join(opt.output_dir, subject, f"{(opt.max_iter-1):05d}_.txt") 175 | if not os.path.exists(inv_path): 176 | img_fullpaths.append(imgpath) 177 | print 178 | else: 179 | print(f"Ignoring {subject}...") 180 | 181 | ## start optimization 182 | for img_fullpath in img_fullpaths: 183 | # load image and mat file 184 | print(f"Processing {img_fullpath}...") 185 | img = Image.open(img_fullpath) 186 | img = transform(img).cuda() 187 | img = img.unsqueeze(0) 188 | name = img_fullpath.split("/")[-1][:-4] 189 | if opt.config.find('FACES') >= 0: 190 | mat_fullpath = os.path.join(opt.data_pose_dir, f"{name.split('.')[0]}.mat") 191 | pose = read_pose(mat_fullpath) 192 | elif opt.config.find('CATS') >= 0: # CATS 193 | mat_fullpath = os.path.join(opt.data_pose_dir, f"{name.split('.')[0]}_pose.npy") 194 | pose = read_pose_npy(mat_fullpath) 195 | else: 196 | raise 197 | # load camera pose 198 | generator.h_mean = pose[1] 199 | generator.v_mean = pose[0] 200 | generator.h_stddev = generator.v_stddev = 0 201 | 202 | # set output_dir 203 | output_dir = os.path.join(opt.output_dir, f"{name}") 204 | os.makedirs(output_dir, exist_ok=True) 205 | f = open(os.path.join(output_dir, 'logs.txt'), "w") 206 | f.write(str(opt)) 207 | f.write('\n\n') 208 | f.write(str(config)) 209 | f.write('\n\n') 210 | 211 | load_prev_file = os.path.join(output_dir, '%05d_%s.txt' % (opt.start_iter-1, opt.suffix)) 212 | patch_split = None 213 | with torch.cuda.amp.autocast(): 214 | generator.get_avg_w() 215 | if not os.path.exists(load_prev_file): 216 | start_iter = 0 217 | # initialize z 218 | init_z_noise = torch.randn((1, 256), device=device) 219 | latent_code = init_z_noise.detach().clone() 220 | latent_code.requires_grad = True 221 | latent_code = latent_code.to(device) 222 | else: 223 | start_iter = opt.start_iter 224 | latents = np.loadtxt(load_prev_file) 225 | latent_code = torch.from_numpy(latents).float().unsqueeze(0).to(device) 226 | latent_code.requires_grad = True 227 | optimizer = torch.optim.Adam([latent_code], lr=1e-1) # z 228 | 229 | 230 | scaler = torch.cuda.amp.GradScaler() 231 | scaler._init_scale = 32 232 | 233 | id_loss = IDLoss(face_recog.eval()).cuda() 234 | save_image(img.detach().cpu(), os.path.join(output_dir, 'input.png'), normalize=True, range=(-1, 1)) 235 | for i in tqdm(range(start_iter, opt.max_iter)): 236 | loss = 0 237 | if patch_split is None: 238 | with torch.cuda.amp.autocast(): 239 | gen_img = generator(latent_code, **config['camera'], truncation_psi=opt.psi)[0] 240 | 241 | img_size = img.size(-1) 242 | if opt.lambda_l2 > 0: 243 | l2 = torch.mean((gen_img-img)**2) * opt.lambda_l2 244 | 245 | gen_img_d2 = F.upsample(gen_img, size=(img_size//2,img_size//2), mode='bilinear') 246 | img_d2 = F.upsample(img, size=(img_size//2,img_size//2), mode='bilinear') 247 | l2 += torch.mean((gen_img_d2-img_d2)**2) * opt.lambda_l2 248 | 249 | gen_img_d4 = F.upsample(gen_img, size=(img_size//4,img_size//4), mode='bilinear') 250 | img_d4 = F.upsample(img, size=(img_size//4,img_size//4), mode='bilinear') 251 | l2 += torch.mean((gen_img_d4-img_d4)**2) * opt.lambda_l2 252 | l2 = l2 / 3.0 253 | 254 | loss += l2 255 | if opt.lambda_perceptual > 0: 256 | if opt.config == 'FACES_default': 257 | gen_features = vgg16(127.5*(gen_img+1), resize_images=False, return_lpips=True) 258 | real_features = vgg16(127.5*(img+1), resize_images=False, return_lpips=True) 259 | perceptual_loss = ((1000*gen_features-1000*real_features)**2).mean() * opt.lambda_perceptual 260 | 261 | gen_features_d2 = vgg16(127.5*(gen_img_d2+1), resize_images=False, return_lpips=True) 262 | real_features_d2 = vgg16(127.5*(img_d2+1), resize_images=False, return_lpips=True) 263 | perceptual_loss += ((1000*gen_features_d2-1000*real_features_d2)**2).mean() * opt.lambda_perceptual 264 | 265 | gen_features_d4 = vgg16(127.5*(gen_img_d4+1), resize_images=False, return_lpips=True) 266 | real_features_d4 = vgg16(127.5*(img_d4+1), resize_images=False, return_lpips=True) 267 | perceptual_loss += ((1000*gen_features_d4-1000*real_features_d4)**2).mean() * opt.lambda_perceptual 268 | 269 | perceptual_loss = perceptual_loss / 3.0 270 | elif opt.config == 'CATS_default': 271 | perceptual_loss = vgg16(gen_img, img).mean() * opt.lambda_perceptual 272 | perceptual_loss += vgg16(gen_img_d2, img_d2).mean() * opt.lambda_perceptual 273 | perceptual_loss += vgg16(gen_img_d4, img_d4).mean() * opt.lambda_perceptual 274 | perceptual_loss = perceptual_loss / 3.0 275 | loss += perceptual_loss 276 | if opt.lambda_id > 0: 277 | id_l = id_loss(gen_img,img).mean() * opt.lambda_id 278 | loss += id_l 279 | scaler.scale(loss).backward() 280 | else: 281 | with torch.cuda.amp.autocast(): 282 | gen_img = [] 283 | with torch.no_grad(): 284 | for patch_idx in range(patch_split): 285 | gen_imgs_patch = generator(latent_code, **config['camera'], truncation_psi=opt.psi, patch=(patch_idx, patch_split))[0] 286 | gen_img.append(gen_imgs_patch) 287 | gen_img = torch.cat(gen_img,-1).reshape(1,3,generator.img_size,generator.img_size) 288 | gen_img.requires_grad = True 289 | 290 | if opt.lambda_l2 > 0: 291 | l2 = torch.mean((gen_img-img)**2) * opt.lambda_l2 292 | 293 | gen_img_d2 = F.upsample(gen_img, size=(img_size//2,img_size//2), mode='bilinear') 294 | img_d2 = F.upsample(img, size=(img_size//2,img_size//2), mode='bilinear') 295 | l2 += torch.mean((gen_img_d2-img_d2)**2) * opt.lambda_l2 296 | 297 | gen_img_d4 = F.upsample(gen_img, size=(img_size//4,img_size//4), mode='bilinear') 298 | img_d4 = F.upsample(img, size=(img_size//4,img_size//4), mode='bilinear') 299 | l2 += torch.mean((gen_img_d4-img_d4)**2) * opt.lambda_l2 300 | l2 = l2 / 3.0 301 | 302 | loss += l2 303 | if opt.lambda_perceptual > 0: 304 | if opt.config == 'FACES_default': 305 | gen_features = vgg16(127.5*(gen_img+1), resize_images=False, return_lpips=True) 306 | real_features = vgg16(127.5*(img+1), resize_images=False, return_lpips=True) 307 | perceptual_loss = ((1000*gen_features-1000*real_features)**2).mean() * opt.lambda_perceptual 308 | 309 | gen_features_d2 = vgg16(127.5*(gen_img_d2+1), resize_images=False, return_lpips=True) 310 | real_features_d2 = vgg16(127.5*(img_d2+1), resize_images=False, return_lpips=True) 311 | perceptual_loss += ((1000*gen_features_d2-1000*real_features_d2)**2).mean() * opt.lambda_perceptual 312 | 313 | gen_features_d4 = vgg16(127.5*(gen_img_d4+1), resize_images=False, return_lpips=True) 314 | real_features_d4 = vgg16(127.5*(img_d4+1), resize_images=False, return_lpips=True) 315 | perceptual_loss += ((1000*gen_features_d4-1000*real_features_d4)**2).mean() * opt.lambda_perceptual 316 | 317 | perceptual_loss = perceptual_loss / 3.0 318 | elif opt.config == 'CATS_default': 319 | perceptual_loss = vgg16(gen_img, img).mean() * opt.lambda_perceptual 320 | perceptual_loss += vgg16(gen_img_d2, img_d2).mean() * opt.lambda_perceptual 321 | perceptual_loss += vgg16(gen_img_d4, img_d4).mean() * opt.lambda_perceptual 322 | perceptual_loss = perceptual_loss / 3.0 323 | loss += perceptual_loss 324 | if opt.lambda_id > 0: 325 | id_l = id_loss(gen_img,img).mean() * opt.lambda_id 326 | loss += id_l 327 | 328 | grad_gen_imgs = torch.autograd.grad(outputs=scaler.scale(loss), inputs=gen_img, create_graph=False)[0] 329 | grad_gen_imgs = grad_gen_imgs.reshape(1,3,-1) 330 | grad_gen_imgs = grad_gen_imgs.detach() 331 | 332 | for patch_idx in range(patch_split): 333 | with torch.cuda.amp.autocast(): 334 | gen_imgs_patch = generator(latent_code, **config['camera'], truncation_psi=opt.psi, patch=(patch_idx, patch_split))[0] 335 | 336 | start = generator.img_size*generator.img_size*patch_idx//patch_split 337 | end = generator.img_size*generator.img_size*(patch_idx+1)//patch_split 338 | gen_imgs_patch.backward(grad_gen_imgs[...,start:end]) 339 | 340 | 341 | scaler.unscale_(optimizer) 342 | nn.utils.clip_grad_norm_(latent_code, config['optimizer'].get('grad_clip', 0.3)) 343 | scaler.step(optimizer) 344 | scaler.update() 345 | optimizer.zero_grad() 346 | 347 | out_img = gen_img.clone().detach().cpu() 348 | if i ==0: 349 | save_image(out_img, os.path.join(output_dir, 'init.png'), normalize=True, range=(-1, 1)) 350 | 351 | l_2 = l2.detach().cpu().numpy() if opt.lambda_l2 else 0 352 | lpips = perceptual_loss.detach().cpu().numpy() if opt.lambda_perceptual else 0 353 | l_id = id_l.detach().cpu().numpy() if opt.lambda_id else 0 354 | 355 | if opt.vis_loss: 356 | print(f"LPIPS: {lpips}; id_loss: {l_id}; l2: {l_2};") 357 | 358 | f.write(f"Iter {i}: ") 359 | f.write(f"LPIPS: {lpips}; id_loss: {l_id}; l2: {l_2};") 360 | f.write('\n\n') 361 | 362 | # debug 363 | # if i == 0: 364 | # save_image(out_img, os.path.join(output_dir, '%05d_%s.png'%(i, opt.suffix)), normalize=True, range=(-1, 1)) 365 | # import ipdb; ipdb.set_trace() 366 | 367 | if i % opt.sv_interval == 0 and i > 0: 368 | save_image(out_img, os.path.join(output_dir, '%05d_%s.png'%(i, opt.suffix)), normalize=True, range=(-1, 1)) 369 | lat = latent_code.detach().cpu().numpy() 370 | np.savetxt(os.path.join(output_dir, '%05d_%s.txt' % (i, opt.suffix)), lat) 371 | 372 | f.write(f"Save output to {os.path.join(output_dir, '%05d_%s.png' % (i, opt.suffix))}") 373 | f.close() 374 | save_image(out_img, os.path.join(output_dir, '%05d_%s.png' %(i, opt.suffix)), normalize=True, range=(-1, 1)) 375 | lat = latent_code.detach().cpu().numpy() 376 | np.savetxt(os.path.join(output_dir, '%05d_%s.txt' % (i, opt.suffix)), lat) 377 | 378 | -------------------------------------------------------------------------------- /preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import glob 5 | from matplotlib import pyplot as plt 6 | from scipy.spatial.transform import Rotation 7 | from tqdm import tqdm 8 | import argparse 9 | import shutil 10 | 11 | 12 | # 3D landmarks of a template cat head 13 | cat_lm3D = np.array([ 14 | [-4.893227, 0.255504, 3.936153], 15 | [4.893227, 0.255504, 3.936153], 16 | [0.000000, -5.859148, 8.948051], 17 | [-11.579516, 3.353250, -6.676847], 18 | [-12.895623, 15.929962, -4.881758], 19 | [-5.203006, 9.292290, -3.132928], 20 | [5.203006, 9.292290, -3.132928], 21 | [12.895623, 15.929962, -4.881758], 22 | [11.579516, 3.353250, -6.676847], 23 | ]) 24 | 25 | # 3D landmarks of a template human face 26 | face_lm3D = np.array([ 27 | [-0.31148657, 0.09036078, 0.13377953], 28 | [ 0.30979887, 0.08972035, 0.13179526], 29 | [ 0.0032535 , -0.24617933, 0.55244243], 30 | [-0.25216928, -0.5813392 , 0.22405732], 31 | [ 0.2484662 , -0.5812824 , 0.22235769], 32 | ]) 33 | 34 | 35 | # calculating least squres problem between 3D landmarks and 2D landmarks for image alignment 36 | def POS(xp,x,cate=None): 37 | npts = xp.shape[0] 38 | 39 | A = np.zeros([2*npts,8]) 40 | A[0:2*npts-1:2,0:3] = x 41 | A[0:2*npts-1:2,3] = 1 42 | A[1:2*npts:2,4:7] = x 43 | A[1:2*npts:2,7] = 1 44 | b = np.reshape(xp,[2*npts,1]) 45 | 46 | if cate=='cats': 47 | weight = np.array([[4]] * 4 + [[2]] * 2 + [[0.5]] * 2 + [[0.2]] * 2 + [[0.2]] * 2 + [[0.2]] * 2 + [[0.2]] * 2 + [[0.5]] * 2) # set different importances for different landmarks 48 | else: 49 | weight = 1 50 | 51 | A = A * weight 52 | b = b * weight 53 | 54 | k,_,_,_ = np.linalg.lstsq(A,b) 55 | 56 | R1 = k[0:3].squeeze() 57 | R2 = k[4:7].squeeze() 58 | sTx = k[3] 59 | sTy = k[7] 60 | 61 | cz = np.cross(R1, R2) 62 | y = np.array([0, 1, 0]) 63 | cx = np.cross(y, cz) 64 | cy = np.cross(cz, cx) 65 | cx = cx / np.linalg.norm(cx) 66 | cy = cy / np.linalg.norm(cy) 67 | cz = cz / np.linalg.norm(cz) 68 | 69 | yaw = np.arctan2(-cz[0], cz[2]) + 0.5 * np.pi 70 | pitch = np.arctan(-cz[1] / np.linalg.norm(cz[::2])) + 0.5 * np.pi 71 | roll1 = (np.sign(np.dot(cz, np.cross(cx, R1))) * np.arccos(np.dot(R1, cx) / np.linalg.norm(R1)) + np.sign(np.dot(cz, np.cross(cy, R2))) * np.arccos(np.dot(R2, cy) / np.linalg.norm(R2))) / 2 72 | roll2 = np.arctan2(-xp[1, 1] + xp[0, 1], xp[1, 0] - xp[0, 0]) 73 | roll = roll2 + np.sign(roll1 - roll2) * np.log(np.abs(roll1 - roll2)/np.pi*180)*np.pi/180 74 | 75 | if cate=='cats': 76 | scale = 0.75 * np.linalg.norm(R1) + 0.25 * np.linalg.norm(R2) # for cats, we try to ensure the head scales along x-axis are similar for different subjects 77 | else: 78 | scale = 0.5 * np.linalg.norm(R1) + 0.5 * np.linalg.norm(R2) 79 | 80 | translate = np.stack([sTx, sTy],axis = 0) 81 | 82 | return yaw, pitch, roll, translate, scale 83 | 84 | 85 | def align_img_ffhq(img,pos,target_size=256): 86 | _, _, _, translate, scale = pos 87 | w0,h0 = img.size 88 | scale = scale/target_size*224 89 | 90 | w = (w0/scale*95).astype(np.int32) 91 | h = (h0/scale*95).astype(np.int32) 92 | img = img.resize((w,h),resample = Image.LANCZOS) 93 | 94 | left = (w/2 - target_size/2 + float((translate[0] - w0/2)*95/scale)).astype(np.int32) 95 | right = left + target_size 96 | up = (h/2 - target_size/2 + float((h0/2 - translate[1])*95/scale)).astype(np.int32) 97 | below = up + target_size 98 | 99 | padding_len = max([abs(min(0,left)),abs(min(0,up)),max(right-w,0),max(below-h,0)]) 100 | if padding_len > 0: 101 | img = np.array(img) 102 | img = np.pad(img,pad_width=((padding_len,padding_len),(padding_len,padding_len),(0,0)),mode='reflect') 103 | img = Image.fromarray(img) 104 | 105 | crop_img = img.crop((left+padding_len,up+padding_len,right+padding_len,below+padding_len)) 106 | 107 | return crop_img 108 | 109 | 110 | def align_img_cats(img, pos, target_size=256): 111 | _, _, roll, translate, scale = pos 112 | img = np.array(img) 113 | translate[1] = img.shape[0] - translate[1] 114 | cos_ = np.cos(roll) 115 | sin_ = np.sin(roll) 116 | rotate = np.array([[cos_, -sin_], [sin_, cos_]]) 117 | crop = 15 * scale * np.array([[1, 1, -1, -1, 1], [1, -1, -1, 1, 1]]) 118 | crop = rotate @ crop + translate.reshape((2, 1)) 119 | padding = int(15 * scale) 120 | translate = translate + padding 121 | img = np.pad(img, ((padding, padding), (padding, padding), (0, 0)), 'constant') 122 | crop_img = Image.fromarray(img) 123 | 124 | # we eliminate roll angles for cat heads 125 | crop_img = crop_img\ 126 | .rotate(roll/np.pi*180, resample=Image.BICUBIC, center=(translate[0], translate[1]))\ 127 | .resize((target_size, target_size), Image.LANCZOS, box=(translate[0] - 15 * scale, translate[1] - 15 * scale, translate[0] + 15 * scale, translate[1] + 15 * scale)) 128 | 129 | return crop_img 130 | 131 | def preprocess_ffhq(img_path,lm_path,save_path,target_size=256,cate=None): 132 | img_name = img_path.split('/')[-1] 133 | print(os.path.join(img_path.replace(img_name,''),'poses',img_name.replace('png','mat'))) 134 | print(os.path.join(save_path,'poses',img_name.replace('png','mat'))) 135 | # shutil.copy(os.path.join(img_path.replace(img_name,''),'poses',img_name.replace('png','mat')),os.path.join(save_path,'poses',img_name.replace('png','mat'))) 136 | 137 | img, lm = load_data_ffhq(img_path,lm_path) 138 | pos = POS(lm,face_lm3D,cate=cate) 139 | crop_img = align_img_ffhq(img,pos,target_size=target_size) 140 | crop_img.save(os.path.join(save_path,img_name)) 141 | 142 | def preprocess_cats(img_path,lm_path,save_path,target_size=256,cate=None): 143 | img_name = img_path.split('/')[-1] 144 | 145 | img, lm = load_data_cats(img_path,lm_path) 146 | pos = POS(lm,cat_lm3D,cate=cate) 147 | crop_img = align_img_cats(img,pos,target_size=target_size) 148 | crop_img.save(os.path.join(save_path,img_name.replace('jpg','png'))) 149 | yaw, pitch, _, _, _ = pos 150 | np.save(os.path.join(save_path,'poses',img_name.replace('.jpg','_pose.npy')), np.array([float(pitch), float(yaw)])) 151 | 152 | def preprocess_carla(img_path,lm_path,save_path,target_size=128,cate=None): 153 | img_name = img_path.split('/')[-1] 154 | 155 | shutil.copy(os.path.join(img_path.replace(img_name,''),'poses',img_name.replace('.png','_extrinsics.npy')),os.path.join(save_path,'poses',img_name.replace('.png','_extrinsics.npy'))) 156 | img = Image.open(img_path) 157 | img = img.resize((target_size,target_size),resample = Image.LANCZOS) 158 | img.save(os.path.join(save_path,img_name)) 159 | 160 | def load_data_ffhq(img_path,lm_path): 161 | img = Image.open(img_path) 162 | 163 | lm = np.loadtxt(lm_path) 164 | lm[:,1] = img.size[1] - 1 - lm[:,1] #flip y-axis for detected landmarks 165 | 166 | return img, lm 167 | 168 | def load_data_cats(img_path,lm_path): 169 | img = Image.open(img_path) 170 | 171 | with open(lm_path) as lm_file: 172 | lm = lm_file.read() 173 | lm = lm.split()[1:] 174 | lm = np.array([float(i) for i in lm]) 175 | lm = lm.reshape((-1, 2)) 176 | lm[:,1] = img.size[1] - lm[:,1] #flip y-axis for provided landmarks 177 | 178 | return img, lm 179 | 180 | def preprocess_data(raw_dataset_path, cate='ffhq'): 181 | 182 | if cate == 'ffhq': 183 | all_img_path = sorted(glob.glob(os.path.join(raw_dataset_path,'*.png'))) 184 | all_lm_path = [os.path.join(raw_dataset_path,'detections',f.split('/')[-1].replace('png','txt')) for f in all_img_path] 185 | preprocess_func = preprocess_ffhq 186 | elif cate == 'cats': 187 | all_img_path = sorted(glob.glob(os.path.join(raw_dataset_path,'*.jpg'))) 188 | all_lm_path = [f+'.cat' for f in all_img_path] 189 | preprocess_func = preprocess_cats 190 | elif cate == 'carla': 191 | all_img_path = sorted(glob.glob(os.path.join(raw_dataset_path,'*.png'))) 192 | all_lm_path = all_img_path 193 | preprocess_func = preprocess_carla 194 | else: 195 | raise Exception("Invalid dataset type") 196 | 197 | print('Number of images: %d'%len(all_img_path)) 198 | 199 | save_path = os.path.join('datasets',cate) 200 | os.makedirs(save_path, exist_ok=True) 201 | os.makedirs(os.path.join(save_path,'poses'), exist_ok=True) 202 | 203 | print("all_img_path", all_img_path) 204 | print("all_lm_path", all_lm_path) 205 | print("save_path", save_path) 206 | for img_path, lm_path in tqdm(zip(all_img_path,all_lm_path)): 207 | try: 208 | preprocess_func(img_path,lm_path,save_path,cate=cate) # skip a raw image if it does not have corresponding landmarks or poses 209 | except: 210 | print('skip invalid data...') 211 | continue 212 | 213 | 214 | if __name__ == '__main__': 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument('--raw_dataset_path', type=str, default='./raw_data/ffhq', help='raw dataset path') 217 | parser.add_argument('--cate', type=str, default='ffhq', help='dataset type [ffhq | cats | carla]') 218 | opt = parser.parse_args() 219 | 220 | preprocess_data(opt.raw_dataset_path,cate=opt.cate) 221 | -------------------------------------------------------------------------------- /rendering_using_finetuned_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from matplotlib.pyplot import prism 3 | import numpy as np 4 | import math 5 | from collections import deque 6 | 7 | from yaml import parse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision.utils import save_image 12 | import torchvision.transforms as transforms 13 | import importlib 14 | import time 15 | import glob, shutil 16 | from scipy.io import loadmat 17 | import copy 18 | from generators import generators 19 | import configs 20 | 21 | from tqdm import tqdm 22 | from torch_ema import ExponentialMovingAverage 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import argparse 26 | from PIL import Image 27 | import skvideo 28 | skvideo.setFFmpegPath("/usr/bin/") 29 | # import skvideo.io 30 | from skvideo.io import FFmpegWriter 31 | import PIL.ImageDraw as ImageDraw 32 | 33 | import plyfile 34 | import mrcfile 35 | import skimage.measure 36 | 37 | 38 | def load_models(opt, config, device): 39 | generator_args = {} 40 | if 'representation' in config['generator']: 41 | generator_args['representation_kwargs'] = config['generator']['representation']['kwargs'] 42 | if 'renderer' in config['generator']: 43 | generator_args['renderer_kwargs'] = config['generator']['renderer']['kwargs'] 44 | generator = getattr(generators, config['generator']['class'])( 45 | **generator_args, 46 | **config['generator']['kwargs'] 47 | ) 48 | print(opt.generator_file) 49 | generator.load_state_dict(torch.load(os.path.join(opt.generator_file), map_location='cpu'),strict=False) 50 | generator = generator.to('cuda') 51 | generator.eval() 52 | 53 | try: 54 | ema = torch.load(os.path.join(opt.generator_file.replace('generator', 'ema')), map_location='cpu') 55 | parameters = [p for p in generator.parameters() if p.requires_grad] 56 | ema.copy_to(parameters) 57 | except: 58 | pass 59 | 60 | 61 | return generator 62 | 63 | def read_pose_ori(name,flip=False): 64 | P = loadmat(name)['angle'] 65 | P_x = -(P[0,0] - 0.1) + math.pi/2 66 | if not flip: 67 | P_y = P[0,1] + math.pi/2 68 | else: 69 | P_y = -P[0,1] + math.pi/2 70 | 71 | 72 | P = torch.tensor([P_x,P_y],dtype=torch.float32) 73 | 74 | return P 75 | 76 | def read_pose_npy(name,flip=False): 77 | P = np.load(name) 78 | P_x = P[0] + 0.14 79 | if not flip: 80 | P_y = P[1] 81 | else: 82 | P_y = -P[1] + math.pi 83 | 84 | P = torch.tensor([P_x,P_y],dtype=torch.float32) 85 | 86 | return P 87 | 88 | def read_latents_txt_z(name, device="cpu"): 89 | ''' 90 | the data structure of z inversion 91 | ''' 92 | latents = np.loadtxt(name) 93 | latents = torch.from_numpy(latents).float().unsqueeze(0).to(device) 94 | 95 | return latents 96 | 97 | def get_trajectory(type, num_frames, latent_code1, latent_code2=None): 98 | latent_codes = [] 99 | if type == 'still': 100 | for pp in range(num_frames): 101 | latent_codes.append((latent_code1)) 102 | elif type == 'gradual': 103 | ratio = np.linspace(0, 1.0, num_frames) 104 | for pp in range(num_frames): 105 | latent_codes.append((ratio[pp] * latent_code1)) 106 | elif type == 'interpolate': 107 | # interpolate between two inverted images 108 | ratio = np.linspace(0, 1.0, num_frames) 109 | for pp in range(num_frames): 110 | latent_code_interpolate = ratio[pp] * latent_code2 + (1 - ratio[pp]) * latent_code1 111 | latent_codes.append(latent_code_interpolate) 112 | return latent_codes 113 | 114 | def tensor_to_PIL(img): 115 | img = img.squeeze() * 0.5 + 0.5 116 | return Image.fromarray(img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()) 117 | 118 | 119 | def convert_sdf_samples_to_ply( 120 | pytorch_3d_sdf_tensor, 121 | voxel_grid_origin, 122 | voxel_size, 123 | ply_filename_out, 124 | offset=None, 125 | scale=None, 126 | level=0.0, 127 | ): 128 | """ 129 | Convert sdf samples to .ply 130 | 131 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 132 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 133 | :voxel_size: float, the size of the voxels 134 | :ply_filename_out: string, path of the filename to save to 135 | 136 | This function adapted from: https://github.com/RobotLocomotion/spartan 137 | """ 138 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor 139 | 140 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0) 141 | verts, faces, normals, values = skimage.measure.marching_cubes( 142 | numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3 143 | ) 144 | 145 | faces = faces[:,::-1] 146 | 147 | # transform from voxel coordinates to camera coordinates 148 | # note x and y are flipped in the output of marching_cubes 149 | mesh_points = np.zeros_like(verts) 150 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] 151 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] 152 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] 153 | 154 | # apply additional offset and scale 155 | if scale is not None: 156 | mesh_points = mesh_points / scale 157 | if offset is not None: 158 | mesh_points = mesh_points - offset 159 | 160 | # try writing to the ply file 161 | 162 | num_verts = verts.shape[0] 163 | num_faces = faces.shape[0] 164 | 165 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 166 | 167 | for i in range(0, num_verts): 168 | verts_tuple[i] = tuple(mesh_points[i, :]) 169 | 170 | faces_building = [] 171 | for i in range(0, num_faces): 172 | faces_building.append(((faces[i, :].tolist(),))) 173 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 174 | 175 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 176 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 177 | 178 | ply_data = plyfile.PlyData([el_verts, el_faces]) 179 | ply_data.write(ply_filename_out) 180 | 181 | 182 | def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): 183 | # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle 184 | voxel_origin = np.array(voxel_origin) - cube_length/2 185 | voxel_size = cube_length / (N - 1) 186 | 187 | overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) 188 | samples = torch.zeros(N ** 3, 3) 189 | 190 | # transform first 3 columns 191 | # to be the x, y, z index 192 | samples[:, 2] = overall_index % N 193 | samples[:, 1] = (overall_index.float() / N) % N 194 | samples[:, 0] = ((overall_index.float() / N) / N) % N 195 | 196 | # transform first 3 columns 197 | # to be the x, y, z coordinate 198 | samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] 199 | samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] 200 | samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] 201 | 202 | num_samples = N ** 3 203 | 204 | return samples.unsqueeze(0), voxel_origin, voxel_size 205 | 206 | 207 | def sample_generator(generator, z, max_batch=100000, voxel_resolution=256, voxel_origin=[0,0,0], cube_length=2.0, psi=0.7): 208 | head = 0 209 | samples, voxel_origin, voxel_size = create_samples(voxel_resolution, voxel_origin, cube_length) 210 | samples = samples.to(z.device) 211 | sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=z.device) 212 | 213 | transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=z.device) 214 | transformed_ray_directions_expanded[..., -1] = -1 215 | 216 | generator.get_avg_w() 217 | with torch.no_grad(): 218 | while head < samples.shape[1]: 219 | coarse_output = generator._volume(z, truncation_psi=psi)(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head]) 220 | 221 | sigmas[:, head:head+max_batch] = coarse_output[:, :, -1:] 222 | head += max_batch 223 | 224 | sigmas = sigmas.reshape((voxel_resolution, voxel_resolution, voxel_resolution)).cpu().numpy() 225 | 226 | return sigmas, voxel_origin, voxel_size 227 | 228 | if __name__ == '__main__': 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument('--generator_file', type=str, default='experiments/gram/finetuned_model/subject_name/generator.pth') 231 | parser.add_argument('--target_name', type=str, default=None) 232 | parser.add_argument('--output_dir', type=str, default='experiments/gram/rendering_results/') 233 | parser.add_argument('--data_img_dir', type=str, default='samples/faces/') 234 | parser.add_argument('--data_pose_dir', type=str, default='samples/faces/poses/') 235 | parser.add_argument('--data_emd_dir', type=str, default='experiments/gram/inversion') 236 | parser.add_argument('--config', type=str, default='FACES_default') 237 | parser.add_argument('--max_batch_size', type=int, default=1200000) 238 | parser.add_argument('--lock_view_dependence', action='store_true') 239 | parser.add_argument('--image_size', type=int, default=256) 240 | parser.add_argument('--name', type=str, default='render', help='name of the experiment. It decides where to store samples and models') 241 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 242 | parser.add_argument('--psi', type=float, default=0.7) 243 | 244 | parser.add_argument('--seed', type=int, default=0) 245 | parser.add_argument('--trajectory', type=str, default='front', help='still, front, orbit') 246 | parser.add_argument('--z_trajectory', type=str, default='still', help='still, gradual, interpolate') 247 | parser.add_argument('--freq_trajectory', type=str, default='still', help='still, gradual, interpolate') 248 | parser.add_argument('--phase_trajectory', type=str, default='still', help='still, gradual, interpolate') 249 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 250 | parser.add_argument('--gen_video', action='store_true', help='whether generate video') 251 | parser.add_argument('--cube_size', type=float, default=0.3) 252 | parser.add_argument('--voxel_resolution', type=int, default=256) 253 | parser.add_argument('--use_depth', action='store_true', help='whether use depth loss for geomotry generation') 254 | parser.add_argument('--white_bg', action='store_true', help='whether use white background') 255 | opt = parser.parse_args() 256 | os.makedirs(opt.output_dir, exist_ok=True) 257 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 258 | config = getattr(configs, opt.config) 259 | if opt.white_bg: 260 | config['generator']['renderer']['kwargs']['white_back'] = True 261 | config['generator']['renderer']['kwargs']['background'] = False 262 | ## load models 263 | generator = load_models(opt, config, device) 264 | 265 | ## load data 266 | generator.renderer.lock_view_dependence = True 267 | img_size = config['global']['img_size'] 268 | target_emb_name = f"{opt.target_name}/00999_.txt" 269 | optimized_latents = sorted(glob.glob(os.path.join(opt.data_emd_dir, target_emb_name))) 270 | 271 | for optimized_latent in optimized_latents: 272 | print(f"Rendering for {optimized_latent}") 273 | if not os.path.exists(optimized_latent): 274 | print(f"The file '{optimized_latent}' does not exist.") 275 | raise 276 | 277 | extract_shape = False 278 | ## load pose, inverted latent code 279 | target_name = optimized_latent.split("/")[-2] 280 | 281 | if opt.config.find('FACES') >= 0: 282 | mat_target = os.path.join(opt.data_pose_dir, f"{target_name}.mat") 283 | pose = read_pose_ori(mat_target, flip=False) 284 | elif opt.config.find('CATS') >= 0: # CATS 285 | mat_target = os.path.join(opt.data_pose_dir, f"{target_name}_pose.npy") 286 | pose = read_pose_npy(mat_target, flip=False) 287 | else: 288 | raise 289 | 290 | if opt.trajectory == 'still_pose': 291 | num_frames = 1 292 | extract_shape = True 293 | else: 294 | num_frames = 100 295 | ## set latent code z 296 | z = read_latents_txt_z(optimized_latent, device=device) 297 | zs = get_trajectory(opt.z_trajectory, num_frames, z) 298 | 299 | ## set trajectory 300 | if opt.trajectory == 'still_front': 301 | trajectory = [] 302 | pose_ratio = np.linspace(0, 1, num_frames) 303 | for t in np.linspace(0, 1, num_frames): 304 | ## frontal face 305 | fixed_t = pose_ratio[0] # t=pose_ratio[19] 306 | pitch = math.pi / 2 # 0.2 * np.cos(t * 2 * math.pi) + math.pi / 2 307 | yaw = 0.4 * np.sin(fixed_t * 2 * math.pi) + math.pi / 2 308 | fov = 12 309 | trajectory.append((pitch, yaw, fov)) 310 | elif opt.trajectory == 'still_pose': 311 | ## pose of ori image 312 | trajectory = [] 313 | for t in np.linspace(0, 1, num_frames): 314 | pitch = pose[0] 315 | yaw = pose[1] 316 | fov = 12 317 | trajectory.append((pitch, yaw, fov)) 318 | elif opt.trajectory == 'front': 319 | trajectory = [] 320 | for t in np.linspace(0, 1, num_frames): 321 | pitch = 0.2 * np.cos(t * 2 * math.pi) + math.pi / 2 322 | yaw = 0.4 * np.sin(t * 2 * math.pi) + math.pi / 2 323 | fov = 12 324 | trajectory.append((pitch, yaw, fov)) 325 | elif opt.trajectory == 'orbit': 326 | trajectory = [] 327 | for t in np.linspace(0, 1, num_frames): 328 | pitch = math.pi / 4 329 | yaw = t * 2 * math.pi 330 | fov = curriculum['fov'] 331 | 332 | trajectory.append((pitch, yaw, fov)) 333 | 334 | ## generate images 335 | with torch.no_grad(): 336 | flag = True 337 | images = [] 338 | depths = [] 339 | 340 | generator.get_avg_w() 341 | output_name = os.path.join(opt.output_dir, f"{target_name}_{opt.suffix}.mp4") 342 | if os.path.exists(output_name): 343 | continue 344 | writer = FFmpegWriter(output_name, outputdict={'-pix_fmt': 'yuv420p', '-crf': '21'}, verbosity=10) 345 | frames = [] 346 | 347 | cnt_output_dir = os.path.join(opt.output_dir, '%s_%s/'%(target_name, opt.suffix)) 348 | os.makedirs(cnt_output_dir, exist_ok=True) 349 | 350 | for frame_idx in range(num_frames): 351 | pitch, yaw, fov = trajectory[frame_idx] 352 | 353 | generator.h_mean = yaw 354 | generator.v_mean = pitch 355 | generator.h_stddev = generator.v_stddev = 0 356 | 357 | # generate img 358 | z = zs[frame_idx] 359 | tensor_img = generator(z, **config['camera'], truncation_psi=opt.psi)[0] 360 | 361 | if extract_shape: 362 | voxel_grid, voxel_origin, voxel_size = sample_generator( 363 | generator, z, cube_length=opt.cube_size, voxel_resolution=opt.voxel_resolution) 364 | 365 | save_image(tensor_img, os.path.join(cnt_output_dir, f"{target_name}_{frame_idx}_.png"), normalize=True,range=(-1,1)) 366 | frames.append(tensor_to_PIL(tensor_img)) 367 | ## save shape 368 | if extract_shape: 369 | l = 5 370 | try: 371 | convert_sdf_samples_to_ply(voxel_grid, voxel_origin, voxel_size, 372 | os.path.join(opt.output_dir,f'{target_name}.ply'), level=l) 373 | # with mrcfile.new_mmap(os.path.join(opt.output_dir, f'{target_name}.mrc'), overwrite=True, shape=voxel_grid.shape, mrc_mode=2) as mrc: 374 | # mrc.data[:] = voxel_grid 375 | except: 376 | continue 377 | 378 | for frame in frames: 379 | writer.writeFrame(np.array(frame)) 380 | 381 | writer.close() 382 | -------------------------------------------------------------------------------- /samples/cats/00000005_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/cats/00000005_001.png -------------------------------------------------------------------------------- /samples/cats/00000009_013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/cats/00000009_013.png -------------------------------------------------------------------------------- /samples/cats/poses/00000005_001_pose.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/cats/poses/00000005_001_pose.npy -------------------------------------------------------------------------------- /samples/cats/poses/00000009_013_pose.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/cats/poses/00000009_013_pose.npy -------------------------------------------------------------------------------- /samples/faces/000656.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/000656.png -------------------------------------------------------------------------------- /samples/faces/000990.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/000990.png -------------------------------------------------------------------------------- /samples/faces/097665.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/097665.png -------------------------------------------------------------------------------- /samples/faces/R5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/R5.png -------------------------------------------------------------------------------- /samples/faces/mask256/000656.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/mask256/000656.png -------------------------------------------------------------------------------- /samples/faces/mask256/000990.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/mask256/000990.png -------------------------------------------------------------------------------- /samples/faces/mask256/097665.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/mask256/097665.png -------------------------------------------------------------------------------- /samples/faces/mask256/R5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/mask256/R5.png -------------------------------------------------------------------------------- /samples/faces/poses/000656.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/poses/000656.mat -------------------------------------------------------------------------------- /samples/faces/poses/000990.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/poses/000990.mat -------------------------------------------------------------------------------- /samples/faces/poses/097665.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/poses/097665.mat -------------------------------------------------------------------------------- /samples/faces/poses/R5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuYin1/NeRFInvertor/1006dce92f1749373bdc26052e0b5d2662663315/samples/faces/poses/R5.mat -------------------------------------------------------------------------------- /utils/arcface/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import MobileFaceNet 3 | 4 | 5 | def get_model(name, **kwargs): 6 | if name == "r18": 7 | return iresnet18(False, **kwargs) 8 | elif name == "r34": 9 | return iresnet34(False, **kwargs) 10 | elif name == "r50": 11 | return iresnet50(False, **kwargs) 12 | elif name == "r100": 13 | return iresnet100(False, **kwargs) 14 | elif name == "r200": 15 | return iresnet200(False, **kwargs) 16 | elif name == "mbf": 17 | return MobileFaceNet((112, 112), **kwargs) 18 | else: 19 | raise ValueError() 20 | -------------------------------------------------------------------------------- /utils/arcface/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=stride, 25 | bias=False) 26 | 27 | 28 | class IBasicBlock(nn.Module): 29 | expansion = 1 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, 31 | groups=1, base_width=64, dilation=1): 32 | super(IBasicBlock, self).__init__() 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 38 | self.conv1 = conv3x3(inplanes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 40 | self.prelu = nn.PReLU(planes) 41 | self.conv2 = conv3x3(planes, planes, stride) 42 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | out = self.bn1(x) 49 | out = self.conv1(out) 50 | out = self.bn2(out) 51 | out = self.prelu(out) 52 | out = self.conv2(out) 53 | out = self.bn3(out) 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | out += identity 57 | return out 58 | 59 | 60 | class IResNet(nn.Module): 61 | fc_scale = 7 * 7 62 | def __init__(self, 63 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 64 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 65 | super(IResNet, self).__init__() 66 | self.fp16 = fp16 67 | self.inplanes = 64 68 | self.dilation = 1 69 | if replace_stride_with_dilation is None: 70 | replace_stride_with_dilation = [False, False, False] 71 | if len(replace_stride_with_dilation) != 3: 72 | raise ValueError("replace_stride_with_dilation should be None " 73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 74 | self.groups = groups 75 | self.base_width = width_per_group 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 78 | self.prelu = nn.PReLU(self.inplanes) 79 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 80 | self.layer2 = self._make_layer(block, 81 | 128, 82 | layers[1], 83 | stride=2, 84 | dilate=replace_stride_with_dilation[0]) 85 | self.layer3 = self._make_layer(block, 86 | 256, 87 | layers[2], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[1]) 90 | self.layer4 = self._make_layer(block, 91 | 512, 92 | layers[3], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[2]) 95 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 96 | self.dropout = nn.Dropout(p=dropout, inplace=True) 97 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 98 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 99 | nn.init.constant_(self.features.weight, 1.0) 100 | self.features.weight.requires_grad = False 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.normal_(m.weight, 0, 0.1) 105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | if zero_init_residual: 110 | for m in self.modules(): 111 | if isinstance(m, IBasicBlock): 112 | nn.init.constant_(m.bn2.weight, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 115 | downsample = None 116 | previous_dilation = self.dilation 117 | if dilate: 118 | self.dilation *= stride 119 | stride = 1 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | conv1x1(self.inplanes, planes * block.expansion, stride), 123 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 124 | ) 125 | layers = [] 126 | layers.append( 127 | block(self.inplanes, planes, stride, downsample, self.groups, 128 | self.base_width, previous_dilation)) 129 | self.inplanes = planes * block.expansion 130 | for _ in range(1, blocks): 131 | layers.append( 132 | block(self.inplanes, 133 | planes, 134 | groups=self.groups, 135 | base_width=self.base_width, 136 | dilation=self.dilation)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | with torch.cuda.amp.autocast(self.fp16): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.prelu(x) 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | x = self.bn2(x) 150 | x = torch.flatten(x, 1) 151 | x = self.dropout(x) 152 | x = self.fc(x.float() if self.fp16 else x) 153 | x = self.features(x) 154 | return x 155 | 156 | 157 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 158 | model = IResNet(block, layers, **kwargs) 159 | if pretrained: 160 | raise ValueError() 161 | return model 162 | 163 | 164 | def iresnet18(pretrained=False, progress=True, **kwargs): 165 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 166 | progress, **kwargs) 167 | 168 | 169 | def iresnet34(pretrained=False, progress=True, **kwargs): 170 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 171 | progress, **kwargs) 172 | 173 | 174 | def iresnet50(pretrained=False, progress=True, **kwargs): 175 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 176 | progress, **kwargs) 177 | 178 | 179 | def iresnet100(pretrained=False, progress=True, **kwargs): 180 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 181 | progress, **kwargs) 182 | 183 | 184 | def iresnet200(pretrained=False, progress=True, **kwargs): 185 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 186 | progress, **kwargs) 187 | 188 | -------------------------------------------------------------------------------- /utils/arcface/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter 7 | import torch.nn.functional as F 8 | import torch 9 | import torch.nn as nn 10 | from collections import namedtuple 11 | import math 12 | 13 | 14 | ################################## Common ############################################################# 15 | def round_channels(channels, divisor=8): 16 | """ 17 | Round weighted channel number (make divisible operation). 18 | Parameters: 19 | ---------- 20 | channels : int or float 21 | Original number of channels. 22 | divisor : int, default 8 23 | Alignment value. 24 | Returns 25 | ------- 26 | int 27 | Weighted number of channels. 28 | """ 29 | rounded_channels = max( 30 | int(channels + divisor / 2.0) // divisor * divisor, divisor) 31 | if float(rounded_channels) < 0.9 * channels: 32 | rounded_channels += divisor 33 | return rounded_channels 34 | 35 | 36 | class ECA_Layer(nn.Module): 37 | def __init__(self, channels, gamma=2, b=1): 38 | super(ECA_Layer, self).__init__() 39 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 40 | 41 | t = int(abs((math.log(channels, 2) + b) / gamma)) 42 | k_size = t if t % 2 else t + 1 43 | self.conv = nn.Conv1d(1, 44 | 1, 45 | kernel_size=k_size, 46 | padding=(k_size - 1) // 2, 47 | bias=False) 48 | 49 | self.sigmoid = nn.Sigmoid() 50 | 51 | def forward(self, x): 52 | # feature descriptor on the global spatial information 53 | y = self.avg_pool(x) 54 | 55 | # Two different branches of ECA module 56 | y = self.conv(y.squeeze(-1).transpose(-1, 57 | -2)).transpose(-1, 58 | -2).unsqueeze(-1) 59 | 60 | # Multi-scale information fusion 61 | y = self.sigmoid(y) 62 | 63 | return x * y.expand_as(x) 64 | 65 | 66 | class SEBlock(nn.Module): 67 | """ 68 | Squeeze-and-Excitation block from 'Squeeze-and-Excitation Networks,' https://arxiv.org/abs/1709.01507. 69 | """ 70 | def __init__(self, 71 | channels, 72 | reduction=16, 73 | round_mid=False, 74 | use_conv=True, 75 | mid_activation=(lambda: nn.ReLU(inplace=True)), 76 | out_activation=(lambda: nn.Sigmoid())): 77 | super(SEBlock, self).__init__() 78 | self.use_conv = use_conv 79 | mid_channels = channels // reduction if not round_mid else round_channels( 80 | float(channels) / reduction) 81 | 82 | self.pool = nn.AdaptiveAvgPool2d(output_size=1) 83 | if use_conv: 84 | self.conv1 = nn.Conv2d(in_channels=channels, 85 | out_channels=mid_channels, 86 | kernel_size=1, 87 | stride=1, 88 | groups=1, 89 | bias=True) 90 | else: 91 | self.fc1 = nn.Linear(in_features=channels, 92 | out_features=mid_channels) 93 | self.activ = nn.ReLU(inplace=True) 94 | if use_conv: 95 | self.conv2 = nn.Conv2d(in_channels=mid_channels, 96 | out_channels=channels, 97 | kernel_size=1, 98 | stride=1, 99 | groups=1, 100 | bias=True) 101 | else: 102 | self.fc2 = nn.Linear(in_features=mid_channels, 103 | out_features=channels) 104 | self.sigmoid = nn.Sigmoid() 105 | 106 | def forward(self, x): 107 | w = self.pool(x) 108 | if not self.use_conv: 109 | w = w.view(x.size(0), -1) 110 | w = self.conv1(w) if self.use_conv else self.fc1(w) 111 | w = self.activ(w) 112 | w = self.conv2(w) if self.use_conv else self.fc2(w) 113 | w = self.sigmoid(w) 114 | if not self.use_conv: 115 | w = w.unsqueeze(2).unsqueeze(3) 116 | x = x * w 117 | return x 118 | 119 | 120 | ################################## Original Arcface Model ############################################################# 121 | class Flatten(Module): 122 | def forward(self, input): 123 | return input.view(input.size(0), -1) 124 | 125 | 126 | ################################## MobileFaceNet ############################################################# 127 | class Conv_block(Module): 128 | def __init__(self, 129 | in_c, 130 | out_c, 131 | kernel=(1, 1), 132 | stride=(1, 1), 133 | padding=(0, 0), 134 | groups=1): 135 | super(Conv_block, self).__init__() 136 | self.conv = Conv2d(in_c, 137 | out_channels=out_c, 138 | kernel_size=kernel, 139 | groups=groups, 140 | stride=stride, 141 | padding=padding, 142 | bias=False) 143 | self.bn = BatchNorm2d(out_c) 144 | self.prelu = PReLU(out_c) 145 | 146 | def forward(self, x): 147 | x = self.conv(x) 148 | x = self.bn(x) 149 | x = self.prelu(x) 150 | return x 151 | 152 | 153 | class Linear_block(Module): 154 | def __init__(self, 155 | in_c, 156 | out_c, 157 | kernel=(1, 1), 158 | stride=(1, 1), 159 | padding=(0, 0), 160 | groups=1): 161 | super(Linear_block, self).__init__() 162 | self.conv = Conv2d(in_c, 163 | out_channels=out_c, 164 | kernel_size=kernel, 165 | groups=groups, 166 | stride=stride, 167 | padding=padding, 168 | bias=False) 169 | self.bn = BatchNorm2d(out_c) 170 | 171 | def forward(self, x): 172 | x = self.conv(x) 173 | x = self.bn(x) 174 | return x 175 | 176 | 177 | class Depth_Wise(Module): 178 | def __init__(self, 179 | in_c, 180 | out_c, 181 | attention, 182 | residual=False, 183 | kernel=(3, 3), 184 | stride=(2, 2), 185 | padding=(1, 1), 186 | groups=1): 187 | super(Depth_Wise, self).__init__() 188 | self.conv = Conv_block(in_c, 189 | out_c=groups, 190 | kernel=(1, 1), 191 | padding=(0, 0), 192 | stride=(1, 1)) 193 | self.conv_dw = Conv_block(groups, 194 | groups, 195 | groups=groups, 196 | kernel=kernel, 197 | padding=padding, 198 | stride=stride) 199 | self.project = Linear_block(groups, 200 | out_c, 201 | kernel=(1, 1), 202 | padding=(0, 0), 203 | stride=(1, 1)) 204 | self.attention = attention 205 | if self.attention == 'eca': 206 | self.attention_layer = ECA_Layer(out_c) 207 | elif self.attention == 'se': 208 | self.attention_layer = SEBlock(out_c) 209 | # elif self.attention == 'cbam': 210 | # self.attention_layer = CbamBlock(out_c) 211 | # elif self.attention == 'gct': 212 | # self.attention_layer = GCT(out_c) 213 | 214 | self.residual = residual 215 | 216 | self.attention = attention #se, eca, cbam 217 | 218 | def forward(self, x): 219 | if self.residual: 220 | short_cut = x 221 | x = self.conv(x) 222 | x = self.conv_dw(x) 223 | x = self.project(x) 224 | if self.attention != 'none': 225 | x = self.attention_layer(x) 226 | if self.residual: 227 | output = short_cut + x 228 | else: 229 | output = x 230 | return output 231 | 232 | 233 | class Residual(Module): 234 | def __init__(self, 235 | c, 236 | attention, 237 | num_block, 238 | groups, 239 | kernel=(3, 3), 240 | stride=(1, 1), 241 | padding=(1, 1)): 242 | super(Residual, self).__init__() 243 | modules = [] 244 | for _ in range(num_block): 245 | modules.append( 246 | Depth_Wise(c, 247 | c, 248 | attention, 249 | residual=True, 250 | kernel=kernel, 251 | padding=padding, 252 | stride=stride, 253 | groups=groups)) 254 | self.model = Sequential(*modules) 255 | 256 | def forward(self, x): 257 | return self.model(x) 258 | 259 | 260 | class GNAP(Module): 261 | def __init__(self, embedding_size): 262 | super(GNAP, self).__init__() 263 | assert embedding_size == 512 264 | self.bn1 = BatchNorm2d(512, affine=False) 265 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 266 | 267 | self.bn2 = BatchNorm1d(512, affine=False) 268 | 269 | def forward(self, x): 270 | x = self.bn1(x) 271 | x_norm = torch.norm(x, 2, 1, True) 272 | x_norm_mean = torch.mean(x_norm) 273 | weight = x_norm_mean / x_norm 274 | x = x * weight 275 | x = self.pool(x) 276 | x = x.view(x.shape[0], -1) 277 | feature = self.bn2(x) 278 | return feature 279 | 280 | 281 | class GDC(Module): 282 | def __init__(self, embedding_size): 283 | super(GDC, self).__init__() 284 | self.conv_6_dw = Linear_block(512, 285 | 512, 286 | groups=512, 287 | kernel=(7, 7), 288 | stride=(1, 1), 289 | padding=(0, 0)) 290 | self.conv_6_flatten = Flatten() 291 | self.linear = Linear(512, embedding_size, bias=False) 292 | #self.bn = BatchNorm1d(embedding_size, affine=False) 293 | self.bn = BatchNorm1d(embedding_size) 294 | 295 | def forward(self, x): 296 | x = self.conv_6_dw(x) 297 | x = self.conv_6_flatten(x) 298 | x = self.linear(x) 299 | x = self.bn(x) 300 | return x 301 | 302 | 303 | class MobileFaceNet(Module): 304 | def __init__(self, 305 | input_size, 306 | dropout=0, 307 | fp16=False, 308 | num_features=512, 309 | output_name="GDC", 310 | attention='none'): 311 | super(MobileFaceNet, self).__init__() 312 | assert output_name in ['GNAP', 'GDC'] 313 | assert input_size[0] in [112] 314 | assert fp16 is False, "MobileFaceNet not support fp16 mode;)" 315 | 316 | self.conv1 = Conv_block(3, 317 | 64, 318 | kernel=(3, 3), 319 | stride=(2, 2), 320 | padding=(1, 1)) 321 | self.conv2_dw = Conv_block(64, 322 | 64, 323 | kernel=(3, 3), 324 | stride=(1, 1), 325 | padding=(1, 1), 326 | groups=64) 327 | self.conv_23 = Depth_Wise(64, 328 | 64, 329 | attention, 330 | kernel=(3, 3), 331 | stride=(2, 2), 332 | padding=(1, 1), 333 | groups=128) 334 | self.conv_3 = Residual(64, 335 | attention, 336 | num_block=4, 337 | groups=128, 338 | kernel=(3, 3), 339 | stride=(1, 1), 340 | padding=(1, 1)) 341 | self.conv_34 = Depth_Wise(64, 342 | 128, 343 | attention, 344 | kernel=(3, 3), 345 | stride=(2, 2), 346 | padding=(1, 1), 347 | groups=256) 348 | self.conv_4 = Residual(128, 349 | attention, 350 | num_block=6, 351 | groups=256, 352 | kernel=(3, 3), 353 | stride=(1, 1), 354 | padding=(1, 1)) 355 | self.conv_45 = Depth_Wise(128, 356 | 128, 357 | attention, 358 | kernel=(3, 3), 359 | stride=(2, 2), 360 | padding=(1, 1), 361 | groups=512) 362 | self.conv_5 = Residual(128, 363 | attention, 364 | num_block=2, 365 | groups=256, 366 | kernel=(3, 3), 367 | stride=(1, 1), 368 | padding=(1, 1)) 369 | self.conv_6_sep = Conv_block(128, 370 | 512, 371 | kernel=(1, 1), 372 | stride=(1, 1), 373 | padding=(0, 0)) 374 | if output_name == "GNAP": 375 | self.output_layer = GNAP(512) 376 | else: 377 | self.output_layer = GDC(num_features) 378 | 379 | self._initialize_weights() 380 | 381 | def _initialize_weights(self): 382 | for m in self.modules(): 383 | if isinstance(m, nn.Conv2d): 384 | nn.init.kaiming_normal_(m.weight, 385 | mode='fan_out', 386 | nonlinearity='relu') 387 | if m.bias is not None: 388 | m.bias.data.zero_() 389 | elif isinstance(m, nn.BatchNorm2d): 390 | m.weight.data.fill_(1) 391 | m.bias.data.zero_() 392 | elif isinstance(m, nn.Linear): 393 | nn.init.kaiming_normal_(m.weight, 394 | mode='fan_out', 395 | nonlinearity='relu') 396 | if m.bias is not None: 397 | m.bias.data.zero_() 398 | 399 | def forward(self, x): 400 | out = self.conv1(x) 401 | out = self.conv2_dw(out) 402 | out = self.conv_23(out) 403 | out = self.conv_3(out) 404 | out = self.conv_34(out) 405 | out = self.conv_4(out) 406 | out = self.conv_45(out) 407 | out = self.conv_5(out) 408 | conv_features = self.conv_6_sep(out) 409 | out = self.output_layer(conv_features) 410 | return out 411 | --------------------------------------------------------------------------------