├── LICENSE ├── README.md ├── exp3D ├── cfgs │ ├── nvs_eval │ │ ├── eval_nvs_shapenet_cars.yaml │ │ ├── eval_nvs_shapenet_chairs.yaml │ │ └── eval_nvs_shapenet_lamps.yaml │ ├── nvs_shapenet_cars.yaml │ ├── nvs_shapenet_chairs.yaml │ └── nvs_shapenet_lamps.yaml ├── datasets │ ├── __init__.py │ ├── celeba.py │ ├── datasets.py │ ├── imagenette.py │ ├── imgrec_dataset.py │ └── learnit_shapenet.py ├── models │ ├── __init__.py │ ├── blocks.py │ ├── hier_model.py │ ├── models.py │ ├── op │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── fused_act.cpython-37.pyc │ │ │ ├── fused_act.cpython-38.pyc │ │ │ ├── upfirdn2d.cpython-37.pyc │ │ │ └── upfirdn2d.cpython-38.pyc │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ ├── tokenizers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── imgrec_tokenizer.cpython-38.pyc │ │ │ └── nvs_tokenizer.cpython-38.pyc │ │ └── nvs_tokenizer.py │ ├── transformer.py │ └── versatile_np.py ├── run_trainer.py ├── trainers │ ├── __init__.py │ ├── base_trainer.py │ ├── nvs_trainer.py │ └── trainers.py └── utils │ ├── __init__.py │ ├── common.py │ └── geometry.py └── imgs └── framework.png /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Zongyu Guo 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Versatile-NP 2 | 3 | This repository contains the official implementation for the following paper: 4 | 5 | [**Versatile Neural Processes for Learning Implicit Neural Representations**](https://arxiv.org/abs/2301.08883), ICLR 2023 6 | 7 | 8 | 9 | ## Reproducing 3D Experiments 10 | 11 | The code for 3D experiments follows the logistics of [Trans-INR](https://github.com/yinboc/trans-inr). 12 | 13 | ### Environment 14 | - Python 3 15 | - Pytorch 1.7.1 16 | - pyyaml numpy tqdm imageio TensorboardX einops 17 | 18 | ### Data 19 | 20 | `mkdir data` and put different dataset folders in it. 21 | 22 | - **View synthesis**: download from [google drive](https://drive.google.com/drive/folders/1lRfg-Ov1dd3ldke9Gv9dyzGGTxiFOhIs) (provided by [learnit](https://www.matthewtancik.com/learnit)) and put them in a folder named `learnit_shapenet`, unzip the category folders and rename them as `chairs`, `cars`, `lamps` correspondingly. 23 | 24 | ### Training 25 | 26 | `cd exp3D` 27 | `CUDA_VISIBLE_DEVICES=[GPU] python run_trainer.py --cfg [CONFIG] --load-root [DATADIR]` 28 | 29 | Configs are in `cfgs/`. Four 3090Ti or four 32GB V100 GPUs are suggested for training. 30 | 31 | ### Evaluation 32 | 33 | For view synthesis, run in a single GPU with configs in `cfgs/nvs_eval`. 34 | 35 | ### Checkpoint models 36 | 37 | The pretrained checkpoint models can be found in [Google Drive](https://drive.google.com/drive/folders/16_ZrgYLH2oiV0uC6OBwnI3op-24nS0RK?usp=share_link). 38 | 39 | | | Cars | Lamps | Chairs | 40 | | :---- | :---- | :---- | :---- | 41 | | PSNR (dB) | 24.21 | 24.10 | 19.54 | 42 | 43 | Since the code is reorganized with unified training settings, the performances of lamps and chairs are slightly better than our initial submission in openreview, and the performance of cars is slightly lower than our initial submission. By adjusting the annealing strategy of beta coefficient, the performance of cars could be further improved. 44 | 45 | ## Reproducing 1D Experiments 46 | 47 | The code for 1D toy examples can be found in the [supplementary material of our Openreview submission](https://openreview.net/forum?id=2nLeOOfAjK). -------------------------------------------------------------------------------- /exp3D/cfgs/nvs_eval/eval_nvs_shapenet_cars.yaml: -------------------------------------------------------------------------------- 1 | trainer: nvs_evaluator 2 | 3 | test_dataset: 4 | name: learnit_shapenet 5 | args: 6 | root_path: $load_root$/learnit_shapenet 7 | category: cars 8 | split: test 9 | n_support: 1 10 | n_query: 1 11 | repeat: 100 12 | loader: 13 | batch_size: 1 14 | num_workers: 1 15 | 16 | model: 17 | name: versatile_np 18 | args: 19 | tokenizer: 20 | name: nvs_tokenizer 21 | args: {input_size: 128, patch_size: 8} 22 | self_attender: 23 | name: self_attender 24 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024} 25 | cross_attender: 26 | name: cross_attender 27 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512} 28 | hierarchical_model: 29 | name: hier_model 30 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64} 31 | 32 | eval_model: save/nvs_shapenet_cars/epoch-last.pth 33 | 34 | train_points_per_ray: 128 35 | render_ray_batch: 1024 36 | -------------------------------------------------------------------------------- /exp3D/cfgs/nvs_eval/eval_nvs_shapenet_chairs.yaml: -------------------------------------------------------------------------------- 1 | trainer: nvs_evaluator 2 | 3 | test_dataset: 4 | name: learnit_shapenet 5 | args: 6 | root_path: $load_root$/learnit_shapenet 7 | category: chairs 8 | split: test 9 | n_support: 1 10 | n_query: 1 11 | repeat: 100 12 | loader: 13 | batch_size: 1 14 | num_workers: 1 15 | 16 | model: 17 | name: versatile_np 18 | args: 19 | tokenizer: 20 | name: nvs_tokenizer 21 | args: {input_size: 128, patch_size: 8} 22 | self_attender: 23 | name: self_attender 24 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024} 25 | cross_attender: 26 | name: cross_attender 27 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512} 28 | hierarchical_model: 29 | name: hier_model 30 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64} 31 | 32 | eval_model: save/nvs_shapenet_chairs/epoch-last.pth 33 | 34 | train_points_per_ray: 128 35 | render_ray_batch: 1024 36 | -------------------------------------------------------------------------------- /exp3D/cfgs/nvs_eval/eval_nvs_shapenet_lamps.yaml: -------------------------------------------------------------------------------- 1 | trainer: nvs_evaluator 2 | 3 | test_dataset: 4 | name: learnit_shapenet 5 | args: 6 | root_path: $load_root$/learnit_shapenet 7 | category: lamps 8 | split: test 9 | n_support: 1 10 | n_query: 1 11 | repeat: 100 12 | loader: 13 | batch_size: 1 14 | num_workers: 1 15 | 16 | model: 17 | name: versatile_np 18 | args: 19 | tokenizer: 20 | name: nvs_tokenizer 21 | args: {input_size: 128, patch_size: 8} 22 | self_attender: 23 | name: self_attender 24 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024} 25 | cross_attender: 26 | name: cross_attender 27 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512} 28 | hierarchical_model: 29 | name: hier_model 30 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64} 31 | 32 | eval_model: save/nvs_shapenet_lamps/epoch-last.pth 33 | 34 | train_points_per_ray: 128 35 | render_ray_batch: 4096 36 | -------------------------------------------------------------------------------- /exp3D/cfgs/nvs_shapenet_cars.yaml: -------------------------------------------------------------------------------- 1 | trainer: nvs_trainer 2 | 3 | train_dataset: 4 | name: learnit_shapenet 5 | args: 6 | root_path: $load_root$/learnit_shapenet 7 | category: cars 8 | split: train 9 | views_rng: [0, 25] 10 | n_support: 1 11 | n_query: 1 12 | repeat: 2 13 | loader: 14 | batch_size: 32 15 | num_workers: 8 16 | 17 | test_dataset: 18 | name: learnit_shapenet 19 | args: 20 | root_path: $load_root$/learnit_shapenet 21 | category: cars 22 | split: test 23 | n_support: 1 24 | n_query: 1 25 | repeat: 100 26 | loader: 27 | batch_size: 32 28 | num_workers: 8 29 | 30 | model: 31 | name: versatile_np 32 | args: 33 | tokenizer: 34 | name: nvs_tokenizer 35 | args: {input_size: 128, patch_size: 8} 36 | self_attender: 37 | name: self_attender 38 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024} 39 | cross_attender: 40 | name: cross_attender 41 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512} 42 | hierarchical_model: 43 | name: hier_model 44 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64} 45 | 46 | train_points_per_ray: 128 47 | train_n_rays: 128 48 | render_ray_batch: 1024 49 | 50 | # resume_model: ./save/nvs_shapenet_chairs/epoch-last.pth 51 | 52 | optimizer: 53 | name: adam 54 | args: {lr: 1.e-4} 55 | max_epoch: 1000 56 | save_epoch: 100 57 | adaptive_sample_epoch: 1 58 | eval_epoch: 10 59 | 60 | Lambda: 0.001 -------------------------------------------------------------------------------- /exp3D/cfgs/nvs_shapenet_chairs.yaml: -------------------------------------------------------------------------------- 1 | trainer: nvs_trainer 2 | 3 | train_dataset: 4 | name: learnit_shapenet 5 | args: 6 | root_path: $load_root$/learnit_shapenet 7 | category: chairs 8 | split: train 9 | views_rng: [0, 25] 10 | n_support: 1 11 | n_query: 1 12 | repeat: 1 13 | loader: 14 | batch_size: 32 15 | num_workers: 8 16 | 17 | test_dataset: 18 | name: learnit_shapenet 19 | args: 20 | root_path: $load_root$/learnit_shapenet 21 | category: chairs 22 | split: test 23 | n_support: 1 24 | n_query: 1 25 | repeat: 100 26 | loader: 27 | batch_size: 32 28 | num_workers: 8 29 | 30 | model: 31 | name: versatile_np 32 | args: 33 | tokenizer: 34 | name: nvs_tokenizer 35 | args: {input_size: 128, patch_size: 8} 36 | self_attender: 37 | name: self_attender 38 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024} 39 | cross_attender: 40 | name: cross_attender 41 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512} 42 | hierarchical_model: 43 | name: hier_model 44 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64} 45 | 46 | train_points_per_ray: 128 47 | train_n_rays: 128 48 | render_ray_batch: 1024 49 | 50 | # resume_model: ./save/nvs_shapenet_chairs/epoch-last.pth 51 | 52 | optimizer: 53 | name: adam 54 | args: {lr: 1.e-4} 55 | max_epoch: 1000 56 | save_epoch: 100 57 | adaptive_sample_epoch: 1 58 | eval_epoch: 10 59 | 60 | Lambda: 0.001 -------------------------------------------------------------------------------- /exp3D/cfgs/nvs_shapenet_lamps.yaml: -------------------------------------------------------------------------------- 1 | trainer: nvs_trainer 2 | 3 | train_dataset: 4 | name: learnit_shapenet 5 | args: 6 | root_path: $load_root$/learnit_shapenet 7 | category: lamps 8 | split: train 9 | views_rng: [0, 25] 10 | n_support: 1 11 | n_query: 1 12 | repeat: 3 13 | loader: 14 | batch_size: 32 15 | num_workers: 8 16 | 17 | test_dataset: 18 | name: learnit_shapenet 19 | args: 20 | root_path: $load_root$/learnit_shapenet 21 | category: lamps 22 | split: test 23 | n_support: 1 24 | n_query: 1 25 | repeat: 100 26 | loader: 27 | batch_size: 32 28 | num_workers: 8 29 | 30 | model: 31 | name: versatile_np 32 | args: 33 | tokenizer: 34 | name: nvs_tokenizer 35 | args: {input_size: 128, patch_size: 8} 36 | self_attender: 37 | name: self_attender 38 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024} 39 | cross_attender: 40 | name: cross_attender 41 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512} 42 | hierarchical_model: 43 | name: hier_model 44 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64} 45 | 46 | train_points_per_ray: 128 47 | train_n_rays: 128 48 | render_ray_batch: 1024 49 | 50 | # resume_model: ./save/nvs_shapenet_lamps/epoch-last.pth 51 | 52 | optimizer: 53 | name: adam 54 | args: {lr: 1.e-4} 55 | max_epoch: 1000 56 | save_epoch: 100 57 | adaptive_sample_epoch: 1 58 | eval_epoch: 10 59 | 60 | Lambda: 0.001 -------------------------------------------------------------------------------- /exp3D/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import register, make 2 | from . import imgrec_dataset, celeba, imagenette 3 | from . import learnit_shapenet 4 | -------------------------------------------------------------------------------- /exp3D/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | 6 | from datasets import register 7 | 8 | 9 | @register('celeba') 10 | class Celeba(Dataset): 11 | 12 | def __init__(self, root_path, split): 13 | if split == 'train': 14 | s, t = 1, 162770 15 | elif split == 'val': 16 | s, t = 162771, 182637 17 | elif split == 'test': 18 | s, t = 182638, 202599 19 | self.data = [] 20 | for i in range(s, t + 1): 21 | path = os.path.join(root_path, 'img_align_celeba', 'img_align_celeba', f'{i:06}.jpg') 22 | self.data.append(path) 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, idx): 28 | return Image.open(self.data[idx]) 29 | -------------------------------------------------------------------------------- /exp3D/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | datasets = {} 5 | 6 | 7 | def register(name): 8 | def decorator(cls): 9 | datasets[name] = cls 10 | return cls 11 | return decorator 12 | 13 | 14 | def make(dataset_spec, args=None): 15 | if args is not None: 16 | dataset_args = copy.deepcopy(dataset_spec['args']) 17 | dataset_args.update(args) 18 | else: 19 | dataset_args = dataset_spec['args'] 20 | dataset = datasets[dataset_spec['name']](**dataset_args) 21 | return dataset 22 | -------------------------------------------------------------------------------- /exp3D/datasets/imagenette.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | from datasets import register 8 | 9 | 10 | @register('imagenette') 11 | class Imagenette(Dataset): 12 | 13 | def __init__(self, root_path, split, augment): 14 | root_path = os.path.join(root_path, split) 15 | classes = sorted(os.listdir(root_path)) 16 | self.data = [] 17 | for c in classes: 18 | filenames = sorted(os.listdir(os.path.join(root_path, c))) 19 | for f in filenames: 20 | self.data.append(os.path.join(root_path, c, f)) 21 | if augment == 'none': 22 | self.transform = transforms.Compose([]) 23 | elif augment == 'random_crop_178': 24 | self.transform = transforms.Compose([ 25 | transforms.RandomCrop(178), 26 | transforms.RandomHorizontalFlip(), 27 | ]) 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, idx): 33 | return self.transform(Image.open(self.data[idx]).convert('RGB')) 34 | -------------------------------------------------------------------------------- /exp3D/datasets/imgrec_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | 4 | import datasets 5 | from datasets import register 6 | 7 | 8 | @register('imgrec_dataset') 9 | class ImgrecDataset(Dataset): 10 | 11 | def __init__(self, imageset, resize): 12 | self.imageset = datasets.make(imageset) 13 | self.transform = transforms.Compose([ 14 | transforms.Resize(resize), 15 | transforms.CenterCrop(resize), 16 | transforms.ToTensor(), 17 | ]) 18 | 19 | def __len__(self): 20 | return len(self.imageset) 21 | 22 | def __getitem__(self, idx): 23 | x = self.transform(self.imageset[idx]) 24 | return {'inp': x, 'gt': x} 25 | -------------------------------------------------------------------------------- /exp3D/datasets/learnit_shapenet.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/tancik/learnit/blob/main/Experiments/shapenet.ipynb 2 | import os 3 | import json 4 | 5 | import imageio 6 | import numpy as np 7 | import torch 8 | import einops 9 | from torch.utils.data import Dataset 10 | 11 | from datasets import register 12 | 13 | 14 | @register('learnit_shapenet') 15 | class LearnitShapenet(Dataset): 16 | 17 | def __init__(self, root_path, category, split, n_support, n_query, views_rng=None, repeat=1): 18 | with open(os.path.join(root_path, category[:-len('s')] + '_splits.json'), 'r') as f: 19 | obj_ids = json.load(f)[split] 20 | _data = [os.path.join(root_path, category, _) for _ in obj_ids] 21 | self.data = [] 22 | for x in _data: 23 | if os.path.exists(os.path.join(x, 'transforms.json')): 24 | self.data.append(x) 25 | else: 26 | print(f'Missing obj at {x}, skipped.') 27 | self.n_support = n_support 28 | self.n_query = n_query 29 | self.views_rng = views_rng 30 | self.repeat = repeat 31 | 32 | def __len__(self): 33 | return len(self.data) * self.repeat 34 | 35 | def __getitem__(self, idx): 36 | idx %= len(self.data) 37 | 38 | train_ex_dir = self.data[idx] 39 | with open(os.path.join(train_ex_dir, 'transforms.json'), 'r') as fp: 40 | meta = json.load(fp) 41 | camera_angle_x = float(meta['camera_angle_x']) 42 | frames = meta['frames'] 43 | if self.views_rng is not None: 44 | frames = frames[self.views_rng[0]: self.views_rng[1]] 45 | 46 | frames = np.random.choice(frames, self.n_support + self.n_query, replace=False) 47 | 48 | imgs = [] 49 | poses = [] 50 | for frame in frames: 51 | fname = os.path.join(train_ex_dir, os.path.basename(frame['file_path']) + '.png') 52 | imgs.append(imageio.imread(fname)) 53 | poses.append(np.array(frame['transform_matrix'])) 54 | H, W = imgs[0].shape[:2] 55 | assert H == W 56 | focal = .5 * W / np.tan(.5 * camera_angle_x) 57 | imgs = (np.array(imgs) / 255.).astype(np.float32) 58 | imgs = imgs[..., :3] * imgs[..., -1:] + (1 - imgs[..., -1:]) 59 | poses = np.array(poses).astype(np.float32) 60 | 61 | imgs = einops.rearrange(torch.from_numpy(imgs), 'n h w c -> n c h w') 62 | poses = torch.from_numpy(poses)[:, :3, :4] 63 | focal = torch.ones(len(poses), 2) * float(focal) 64 | t = self.n_support 65 | return { 66 | 'support_imgs': imgs[:t], 67 | 'support_poses': poses[:t], 68 | 'support_focals': focal[:t], 69 | 'query_imgs': imgs[t:], 70 | 'query_poses': poses[t:], 71 | 'query_focals': focal[t:], 72 | 'near': 2, 73 | 'far': 6, 74 | } 75 | -------------------------------------------------------------------------------- /exp3D/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import register, make 2 | from . import versatile_np 3 | from . import hier_model 4 | from . import tokenizers 5 | from . import transformer 6 | -------------------------------------------------------------------------------- /exp3D/models/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import models 9 | from models.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 10 | # from op import FusedLeakyReLU, fused_leaky_relu 11 | 12 | class PixelNorm(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, input): 17 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 18 | 19 | 20 | def make_kernel(k): 21 | k = torch.tensor(k, dtype=torch.float32) 22 | 23 | if k.ndim == 1: 24 | k = k[None, :] * k[:, None] 25 | 26 | k /= k.sum() 27 | 28 | return k 29 | 30 | 31 | class Upsample(nn.Module): 32 | def __init__(self, kernel, factor=2): 33 | super().__init__() 34 | 35 | self.factor = factor 36 | kernel = make_kernel(kernel) * (factor ** 2) 37 | self.register_buffer('kernel', kernel) 38 | 39 | p = kernel.shape[0] - factor 40 | 41 | pad0 = (p + 1) // 2 + factor - 1 42 | pad1 = p // 2 43 | 44 | self.pad = (pad0, pad1) 45 | 46 | def forward(self, input): 47 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 48 | return out 49 | 50 | 51 | class Downsample(nn.Module): 52 | def __init__(self, kernel, factor=2): 53 | super().__init__() 54 | 55 | self.factor = factor 56 | kernel = make_kernel(kernel) 57 | self.register_buffer('kernel', kernel) 58 | 59 | p = kernel.shape[0] - factor 60 | 61 | pad0 = (p + 1) // 2 62 | pad1 = p // 2 63 | 64 | self.pad = (pad0, pad1) 65 | 66 | def forward(self, input): 67 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 68 | return out 69 | 70 | 71 | class Blur(nn.Module): 72 | def __init__(self, kernel, pad, upsample_factor=1): 73 | super().__init__() 74 | 75 | kernel = make_kernel(kernel) 76 | 77 | if upsample_factor > 1: 78 | kernel = kernel * (upsample_factor ** 2) 79 | 80 | self.register_buffer('kernel', kernel) 81 | 82 | self.pad = pad 83 | 84 | def forward(self, input): 85 | out = upfirdn2d(input, self.kernel, pad=self.pad) 86 | return out 87 | 88 | 89 | class EqualConv2d(nn.Module): 90 | def __init__( 91 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 92 | ): 93 | super().__init__() 94 | 95 | self.weight = nn.Parameter( 96 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 97 | ) 98 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 99 | 100 | self.stride = stride 101 | self.padding = padding 102 | 103 | if bias: 104 | self.bias = nn.Parameter(torch.zeros(out_channel)) 105 | 106 | else: 107 | self.bias = None 108 | 109 | def forward(self, input): 110 | out = F.conv2d( 111 | input, 112 | self.weight * self.scale, 113 | bias=self.bias, 114 | stride=self.stride, 115 | padding=self.padding, 116 | ) 117 | return out 118 | 119 | def __repr__(self): 120 | return ( 121 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 122 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 123 | ) 124 | 125 | 126 | class EqualLinear(nn.Module): 127 | def __init__( 128 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 129 | ): 130 | super().__init__() 131 | 132 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 133 | 134 | if bias: 135 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 136 | 137 | else: 138 | self.bias = None 139 | 140 | self.activation = activation 141 | 142 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 143 | self.lr_mul = lr_mul 144 | 145 | def forward(self, input): 146 | if self.activation: 147 | out = F.linear(input, self.weight * self.scale) 148 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 149 | 150 | else: 151 | out = F.linear( 152 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 153 | ) 154 | 155 | return out 156 | 157 | def __repr__(self): 158 | return ( 159 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 160 | ) 161 | 162 | 163 | class ScaledLeakyReLU(nn.Module): 164 | def __init__(self, negative_slope=0.2): 165 | super().__init__() 166 | 167 | self.negative_slope = negative_slope 168 | 169 | def forward(self, input): 170 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 171 | return out * math.sqrt(2) 172 | 173 | 174 | class ModulatedConv2d(nn.Module): 175 | def __init__( 176 | self, 177 | in_channel, 178 | out_channel, 179 | kernel_size, 180 | style_dim, 181 | demodulate=True, 182 | upsample=False, 183 | downsample=False, 184 | blur_kernel=[1, 3, 3, 1], 185 | ): 186 | super().__init__() 187 | 188 | self.eps = 1e-8 189 | self.kernel_size = kernel_size 190 | self.in_channel = in_channel 191 | self.out_channel = out_channel 192 | self.upsample = upsample 193 | self.downsample = downsample 194 | 195 | if upsample: 196 | factor = 2 197 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 198 | pad0 = (p + 1) // 2 + factor - 1 199 | pad1 = p // 2 + 1 200 | 201 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 202 | 203 | if downsample: 204 | factor = 2 205 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 206 | pad0 = (p + 1) // 2 207 | pad1 = p // 2 208 | 209 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 210 | 211 | fan_in = in_channel * kernel_size ** 2 212 | self.scale = 1 / math.sqrt(fan_in) 213 | self.padding = kernel_size // 2 214 | 215 | self.weight = nn.Parameter( 216 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 217 | ) 218 | 219 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 220 | 221 | self.demodulate = demodulate 222 | 223 | def __repr__(self): 224 | return ( 225 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 226 | f'upsample={self.upsample}, downsample={self.downsample})' 227 | ) 228 | 229 | def forward(self, input, style): 230 | batch, in_channel, height, width = input.shape 231 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 232 | weight = self.scale * self.weight * style 233 | if self.demodulate: 234 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 235 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 236 | 237 | weight = weight.view( 238 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 239 | ) 240 | 241 | if self.upsample: 242 | input = input.view(1, batch * in_channel, height, width) 243 | weight = weight.view( 244 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 245 | ) 246 | weight = weight.transpose(1, 2).reshape( 247 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 248 | ) 249 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 250 | _, _, height, width = out.shape 251 | out = out.view(batch, self.out_channel, height, width) 252 | out = self.blur(out) 253 | 254 | elif self.downsample: 255 | input = self.blur(input) 256 | _, _, height, width = input.shape 257 | input = input.view(1, batch * in_channel, height, width) 258 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 259 | _, _, height, width = out.shape 260 | out = out.view(batch, self.out_channel, height, width) 261 | 262 | else: 263 | input = input.view(1, batch * in_channel, height, width) 264 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 265 | _, _, height, width = out.shape 266 | out = out.view(batch, self.out_channel, height, width) 267 | 268 | return out 269 | 270 | 271 | class NoiseInjection(nn.Module): 272 | def __init__(self): 273 | super().__init__() 274 | 275 | self.weight = nn.Parameter(torch.zeros(1)) 276 | 277 | def forward(self, image, noise=None): 278 | if noise is None: 279 | batch, _, height, width = image.shape 280 | noise = image.new_empty(batch, 1, height, width).normal_() 281 | 282 | return image + self.weight * noise 283 | 284 | 285 | class ConstantInput(nn.Module): 286 | def __init__(self, channel, size=4): 287 | super().__init__() 288 | 289 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 290 | 291 | def forward(self, batch_size): 292 | out = self.input.repeat(batch_size, 1, 1, 1) 293 | return out 294 | 295 | 296 | class StyledConv(nn.Module): 297 | def __init__( 298 | self, 299 | in_channel, 300 | out_channel, 301 | kernel_size, 302 | style_dim, 303 | upsample=False, 304 | blur_kernel=[1, 3, 3, 1], 305 | demodulate=True, 306 | activation=None, 307 | downsample=False, 308 | ): 309 | super().__init__() 310 | 311 | self.conv = ModulatedConv2d( 312 | in_channel, 313 | out_channel, 314 | kernel_size, 315 | style_dim, 316 | upsample=upsample, 317 | blur_kernel=blur_kernel, 318 | demodulate=demodulate, 319 | downsample=downsample, 320 | ) 321 | 322 | self.activation = activation 323 | self.noise = NoiseInjection() 324 | if activation == 'sinrelu': 325 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 326 | self.activate = ScaledLeakyReLUSin() 327 | elif activation == 'sin': 328 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 329 | self.activate = SinActivation() 330 | else: 331 | self.activate = FusedLeakyReLU(out_channel) 332 | 333 | def forward(self, input, style, noise=None): 334 | out = self.conv(input, style) 335 | # out = self.noise(out, noise=noise) 336 | if self.activation == 'sinrelu' or self.activation == 'sin': 337 | out = out + self.bias 338 | out = self.activate(out) 339 | 340 | return out 341 | 342 | 343 | class ToRGB(nn.Module): 344 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 345 | super().__init__() 346 | 347 | self.upsample = upsample 348 | if upsample: 349 | self.upsample = Upsample(blur_kernel) 350 | 351 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 352 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 353 | 354 | def forward(self, input, style, skip=None): 355 | out = self.conv(input, style) 356 | out = out + self.bias 357 | 358 | if skip is not None: 359 | if self.upsample: 360 | skip = self.upsample(skip) 361 | out = out + skip 362 | return out 363 | 364 | 365 | class ToMixLogistic(nn.Module): 366 | def __init__(self, in_channel, out_channel, style_dim): 367 | # def __init__(self, in_channel, out_channel, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1]): 368 | super().__init__() 369 | # self.upsample = upsample 370 | # if upsample: 371 | # self.upsample = Upsample(blur_kernel) 372 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False) 373 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 374 | 375 | def forward(self, input, style, skip=None): 376 | out = self.conv(input, style) 377 | out = out + self.bias 378 | if skip is not None: 379 | # if self.upsample: 380 | # skip = self.upsample(skip) 381 | out = out + skip 382 | return out 383 | 384 | class ToMixLogisticNoCond(nn.Module): 385 | def __init__(self, in_channel, out_channel): 386 | # def __init__(self, in_channel, out_channel, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1]): 387 | super().__init__() 388 | # self.upsample = upsample 389 | # if upsample: 390 | # self.upsample = Upsample(blur_kernel) 391 | self.conv = nn.Conv2d(in_channel, out_channel, 1) 392 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 393 | 394 | def forward(self, input, skip=None): 395 | out = self.conv(input) 396 | out = out + self.bias 397 | if skip is not None: 398 | # if self.upsample: 399 | # skip = self.upsample(skip) 400 | out = out + skip 401 | return out 402 | 403 | 404 | class EqualConvTranspose2d(nn.Module): 405 | def __init__( 406 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 407 | ): 408 | super().__init__() 409 | 410 | self.weight = nn.Parameter( 411 | torch.randn(in_channel, out_channel, kernel_size, kernel_size) 412 | ) 413 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 414 | 415 | self.stride = stride 416 | self.padding = padding 417 | 418 | if bias: 419 | self.bias = nn.Parameter(torch.zeros(out_channel)) 420 | 421 | else: 422 | self.bias = None 423 | 424 | def forward(self, input): 425 | out = F.conv_transpose2d( 426 | input, 427 | self.weight * self.scale, 428 | bias=self.bias, 429 | stride=self.stride, 430 | padding=self.padding, 431 | ) 432 | 433 | return out 434 | 435 | def __repr__(self): 436 | return ( 437 | f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]}," 438 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 439 | ) 440 | 441 | 442 | class ConvLayer(nn.Sequential): 443 | def __init__( 444 | self, 445 | in_channel, 446 | out_channel, 447 | kernel_size, 448 | downsample=False, 449 | blur_kernel=[1, 3, 3, 1], 450 | bias=True, 451 | activate=True, 452 | upsample=False, 453 | padding="zero", 454 | ): 455 | layers = [] 456 | 457 | self.padding = 0 458 | stride = 1 459 | 460 | if downsample: 461 | factor = 2 462 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 463 | pad0 = (p + 1) // 2 464 | pad1 = p // 2 465 | 466 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 467 | 468 | stride = 2 469 | 470 | if upsample: 471 | layers.append( 472 | EqualConvTranspose2d( 473 | in_channel, 474 | out_channel, 475 | kernel_size, 476 | padding=0, 477 | stride=2, 478 | bias=bias and not activate, 479 | ) 480 | ) 481 | 482 | factor = 2 483 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 484 | pad0 = (p + 1) // 2 + factor - 1 485 | pad1 = p // 2 + 1 486 | 487 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 488 | 489 | else: 490 | if not downsample: 491 | if padding == "zero": 492 | self.padding = (kernel_size - 1) // 2 493 | 494 | elif padding == "reflect": 495 | padding = (kernel_size - 1) // 2 496 | 497 | if padding > 0: 498 | layers.append(nn.ReflectionPad2d(padding)) 499 | 500 | self.padding = 0 501 | 502 | elif padding != "valid": 503 | raise ValueError('Padding should be "zero", "reflect", or "valid"') 504 | 505 | layers.append( 506 | EqualConv2d( 507 | in_channel, 508 | out_channel, 509 | kernel_size, 510 | padding=self.padding, 511 | stride=stride, 512 | bias=bias and not activate, 513 | ) 514 | ) 515 | 516 | if activate: 517 | if bias: 518 | layers.append(FusedLeakyReLU(out_channel)) 519 | 520 | else: 521 | layers.append(ScaledLeakyReLU(0.2)) 522 | 523 | super().__init__(*layers) 524 | 525 | 526 | class ResBlock(nn.Module): 527 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], kernel_size=3, downsample=True): 528 | super().__init__() 529 | 530 | self.conv1 = ConvLayer(in_channel, in_channel, kernel_size) 531 | self.conv2 = ConvLayer(in_channel, out_channel, kernel_size, downsample=downsample) 532 | 533 | self.skip = ConvLayer( 534 | in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False 535 | ) 536 | 537 | def forward(self, input): 538 | out = self.conv1(input) 539 | out = self.conv2(out) 540 | 541 | skip = self.skip(input) 542 | out = (out + skip) / math.sqrt(2) 543 | 544 | return out 545 | 546 | 547 | class ConLinear(nn.Module): 548 | def __init__(self, ch_in, ch_out, is_first=False, bias=True): 549 | super(ConLinear, self).__init__() 550 | self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0, bias=bias) 551 | if is_first: 552 | nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in)) 553 | else: 554 | nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in)) 555 | 556 | def forward(self, x): 557 | return self.conv(x) 558 | 559 | 560 | class SinActivation(nn.Module): 561 | def __init__(self,): 562 | super(SinActivation, self).__init__() 563 | 564 | def forward(self, x): 565 | return torch.sin(x) 566 | 567 | class SinActivationBias(nn.Module): 568 | def __init__(self, out_channel): 569 | super(SinActivationBias, self).__init__() 570 | self.bias = nn.Parameter(torch.zeros(1, 1, out_channel)) 571 | 572 | def forward(self, x): 573 | return torch.sin(x + self.bias) 574 | 575 | 576 | class LFF(nn.Module): 577 | def __init__(self, hidden_size, ): 578 | super(LFF, self).__init__() 579 | self.ffm = ConLinear(1, hidden_size, is_first=True) 580 | self.activation = SinActivation() 581 | 582 | def forward(self, x): 583 | x = self.ffm(x) 584 | x = self.activation(x) 585 | return x 586 | 587 | 588 | class ScaledLeakyReLUSin(nn.Module): 589 | def __init__(self, negative_slope=0.2): 590 | super().__init__() 591 | 592 | self.negative_slope = negative_slope 593 | 594 | def forward(self, input): 595 | out_lr = F.leaky_relu(input[:, ::2], negative_slope=self.negative_slope) 596 | out_sin = torch.sin(input[:, 1::2]) 597 | out = torch.cat([out_lr, out_sin], 1) 598 | return out * math.sqrt(2) 599 | 600 | 601 | class StyledResBlock(nn.Module): 602 | def __init__(self, in_channel, out_channel, kernel_size, style_dim, blur_kernel=[1, 3, 3, 1], demodulate=True, 603 | activation=None, upsample=False, downsample=False): 604 | super().__init__() 605 | 606 | self.conv1 = StyledConv(in_channel, out_channel, kernel_size, style_dim, 607 | demodulate=demodulate, activation=activation) 608 | self.conv2 = StyledConv(out_channel, out_channel, kernel_size, style_dim, 609 | demodulate=demodulate, activation=activation, 610 | upsample=upsample, downsample=downsample) 611 | 612 | if downsample or in_channel != out_channel or upsample: 613 | self.skip = ConvLayer( 614 | in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False, upsample=upsample, 615 | ) 616 | else: 617 | self.skip = None 618 | 619 | def forward(self, input, latent): 620 | out = self.conv1(input, latent) 621 | out = self.conv2(out, latent) 622 | 623 | if self.skip is not None: 624 | skip = self.skip(input) 625 | else: 626 | skip = input 627 | 628 | out = (out + skip) / math.sqrt(2) 629 | 630 | return out 631 | -------------------------------------------------------------------------------- /exp3D/models/hier_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import kl_divergence, Normal 7 | import einops 8 | 9 | from models import register 10 | from utils import rendering 11 | 12 | from .blocks import StyledConv, ModulatedConv2d 13 | 14 | def build_mlp(dim_in, dim_hid, dim_out, depth): 15 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)] 16 | for _ in range(depth-2): 17 | modules.append(nn.Linear(dim_hid, dim_hid)) 18 | modules.append(nn.ReLU(True)) 19 | modules.append(nn.Linear(dim_hid, dim_out)) 20 | return nn.Sequential(*modules) 21 | 22 | 23 | @register('hier_model') 24 | class Hierarchical_Model(nn.Module): 25 | def __init__(self, depth, dim_y, dim_hid, dim_lat): 26 | super().__init__() 27 | self.layers = nn.ModuleList() 28 | for _ in range(depth): 29 | self.layers.append(ModBlock(dim_y=dim_y, dim_hid=dim_hid, dim_lat=dim_lat)) 30 | 31 | def forward(self, y, x_tgt, y_tgt, rays_o, z_vals, training=True): 32 | kls = 0 33 | for layer in self.layers: 34 | y, kld = layer(y, x_tgt, y_tgt, rays_o, z_vals, training=training) 35 | kls += kld 36 | return y, kls 37 | 38 | 39 | class ModBlock(nn.Module): 40 | def __init__(self, dim_y, dim_hid, dim_lat): 41 | super().__init__() 42 | self.input_mlp = build_mlp(dim_y, dim_hid, 4, depth=3) 43 | 44 | self.merge_p = build_mlp(3, dim_hid, dim_hid, depth=2) 45 | self.merge_q = build_mlp(6, dim_hid, dim_hid, depth=2) 46 | self.latent_encoder_p = build_mlp(dim_hid, dim_hid, dim_lat * 2, depth=2) 47 | self.latent_encoder_q = build_mlp(dim_hid, dim_hid, dim_lat * 2, depth=2) 48 | self.latent_decoder = build_mlp(dim_lat, dim_hid, dim_hid, depth=2) 49 | 50 | self.unmod_mlp = build_mlp(dim_y * 2, dim_hid, dim_y, depth=2) 51 | self.mod_conv1 = StyledConv(dim_y, dim_hid, 1, dim_hid, demodulate=True)# with built-in activation function 52 | self.mod_conv2 = ModulatedConv2d(dim_hid, dim_y, 1, dim_hid, demodulate=True) # without built-in activation function 53 | self.mod_conv2.weight.data *= np.sqrt(1 / 6) 54 | 55 | def forward(self, y, x_tgt, y_tgt, rays_o, z_vals, training=True): 56 | # 3D points -> 2D image pixels 57 | y_rgb = rendering(self.input_mlp(y), rays_o=rays_o, z_vals=z_vals) 58 | 59 | # variational inference in Neural Process (with an average pooling across all points) 60 | # if not training, y_tgt is not used 61 | z, kld = self.forward_latent(y_rgb, y_tgt, training=training) 62 | 63 | # unmodulated and modulated layers leveraging latent variables 64 | y = self.forward_mlps(y, x_tgt, latent=z) 65 | return y, kld 66 | 67 | def forward_latent(self, y_rgb, y_tgt, training): 68 | z_prior = self.merge_p(y_rgb).mean(dim=-2, keepdim=True) 69 | dist_prior = self.normal_distribution(self.latent_encoder_p(z_prior)) 70 | 71 | if training: 72 | z_posterior = self.merge_q(torch.cat([y_rgb, y_tgt], dim=-1)).mean(dim=-2, keepdim=True) 73 | dist_posterior = self.normal_distribution(self.latent_encoder_q(z_posterior)) 74 | z = dist_posterior.rsample() 75 | kld = kl_divergence(dist_posterior, dist_prior).sum(-1) 76 | else: 77 | z = dist_prior.rsample() 78 | kld = torch.zeros_like(z) 79 | 80 | z = self.latent_decoder(z) 81 | return z, kld 82 | 83 | def forward_mlps(self, y, x_tgt, latent): 84 | b, num_tgt, c = y.shape 85 | 86 | # Two unmodulated MLPs (residual structure) 87 | # We enhance feature y by combining x_tgt, which can actually be omitted. 88 | y = y + self.unmod_mlp(torch.cat([x_tgt, y], dim=-1)) 89 | 90 | # Two modulated MLPs (residual structure) 91 | # Modulated MLP is implemented in the 2D form. 92 | y_res = einops.rearrange(y, 'b n c -> b c n 1').contiguous() 93 | y_res = self.mod_conv1(y_res, latent) 94 | y_res = self.mod_conv2(y_res, latent) 95 | y_res = einops.rearrange(y_res, 'b c n 1 -> b n c').contiguous() 96 | 97 | y = y + y_res 98 | return y 99 | 100 | @staticmethod 101 | def normal_distribution(input): 102 | mean, var = input.chunk(2, dim=-1) 103 | var = 0.1 + 0.9 * torch.sigmoid(var) 104 | dist = Normal(mean, var) 105 | return dist 106 | 107 | -------------------------------------------------------------------------------- /exp3D/models/models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | models = {} 5 | 6 | 7 | def register(name): 8 | def decorator(cls): 9 | models[name] = cls 10 | return cls 11 | return decorator 12 | 13 | 14 | def make(model_spec, args=None, load_sd=False): 15 | if args is not None: 16 | model_args = copy.deepcopy(model_spec['args']) 17 | model_args.update(args) 18 | else: 19 | model_args = model_spec['args'] 20 | model = models[model_spec['name']](**model_args) 21 | if load_sd: 22 | model.load_state_dict(model_spec['sd']) 23 | return model 24 | -------------------------------------------------------------------------------- /exp3D/models/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d -------------------------------------------------------------------------------- /exp3D/models/op/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /exp3D/models/op/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /exp3D/models/op/__pycache__/fused_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/fused_act.cpython-37.pyc -------------------------------------------------------------------------------- /exp3D/models/op/__pycache__/fused_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/fused_act.cpython-38.pyc -------------------------------------------------------------------------------- /exp3D/models/op/__pycache__/upfirdn2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/upfirdn2d.cpython-37.pyc -------------------------------------------------------------------------------- /exp3D/models/op/__pycache__/upfirdn2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/upfirdn2d.cpython-38.pyc -------------------------------------------------------------------------------- /exp3D/models/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /exp3D/models/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /exp3D/models/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /exp3D/models/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /exp3D/models/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | 8 | module_path = os.path.dirname(__file__) 9 | 10 | upfirdn2d_op = load( 11 | 'upfirdn2d', 12 | sources=[ 13 | os.path.join(module_path, 'upfirdn2d.cpp'), 14 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 15 | ], 16 | ) 17 | 18 | class UpFirDn2dBackward(Function): 19 | @staticmethod 20 | def forward( 21 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 22 | ): 23 | 24 | up_x, up_y = up 25 | down_x, down_y = down 26 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 27 | 28 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 29 | 30 | grad_input = upfirdn2d_op.upfirdn2d( 31 | grad_output, 32 | grad_kernel, 33 | down_x, 34 | down_y, 35 | up_x, 36 | up_y, 37 | g_pad_x0, 38 | g_pad_x1, 39 | g_pad_y0, 40 | g_pad_y1, 41 | ) 42 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 43 | 44 | ctx.save_for_backward(kernel) 45 | 46 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 47 | 48 | ctx.up_x = up_x 49 | ctx.up_y = up_y 50 | ctx.down_x = down_x 51 | ctx.down_y = down_y 52 | ctx.pad_x0 = pad_x0 53 | ctx.pad_x1 = pad_x1 54 | ctx.pad_y0 = pad_y0 55 | ctx.pad_y1 = pad_y1 56 | ctx.in_size = in_size 57 | ctx.out_size = out_size 58 | 59 | return grad_input 60 | 61 | @staticmethod 62 | def backward(ctx, gradgrad_input): 63 | kernel, = ctx.saved_tensors 64 | 65 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 66 | 67 | gradgrad_out = upfirdn2d_op.upfirdn2d( 68 | gradgrad_input, 69 | kernel, 70 | ctx.up_x, 71 | ctx.up_y, 72 | ctx.down_x, 73 | ctx.down_y, 74 | ctx.pad_x0, 75 | ctx.pad_x1, 76 | ctx.pad_y0, 77 | ctx.pad_y1, 78 | ) 79 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 80 | gradgrad_out = gradgrad_out.view( 81 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 82 | ) 83 | 84 | return gradgrad_out, None, None, None, None, None, None, None, None 85 | 86 | 87 | class UpFirDn2d(Function): 88 | @staticmethod 89 | def forward(ctx, input, kernel, up, down, pad): 90 | up_x, up_y = up 91 | down_x, down_y = down 92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 93 | 94 | kernel_h, kernel_w = kernel.shape 95 | batch, channel, in_h, in_w = input.shape 96 | ctx.in_size = input.shape 97 | 98 | input = input.reshape(-1, in_h, in_w, 1) 99 | 100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 101 | 102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 104 | ctx.out_size = (out_h, out_w) 105 | 106 | ctx.up = (up_x, up_y) 107 | ctx.down = (down_x, down_y) 108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 109 | 110 | g_pad_x0 = kernel_w - pad_x0 - 1 111 | g_pad_y0 = kernel_h - pad_y0 - 1 112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 114 | 115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 116 | 117 | out = upfirdn2d_op.upfirdn2d( 118 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 119 | ) 120 | # out = out.view(major, out_h, out_w, minor) 121 | out = out.view(-1, channel, out_h, out_w) 122 | 123 | return out 124 | 125 | @staticmethod 126 | def backward(ctx, grad_output): 127 | kernel, grad_kernel = ctx.saved_tensors 128 | 129 | grad_input = UpFirDn2dBackward.apply( 130 | grad_output, 131 | kernel, 132 | grad_kernel, 133 | ctx.up, 134 | ctx.down, 135 | ctx.pad, 136 | ctx.g_pad, 137 | ctx.in_size, 138 | ctx.out_size, 139 | ) 140 | 141 | return grad_input, None, None, None, None 142 | 143 | 144 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 145 | out = UpFirDn2d.apply( 146 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 147 | ) 148 | 149 | return out 150 | 151 | 152 | def upfirdn2d_native( 153 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 154 | ): 155 | _, in_h, in_w, minor = input.shape 156 | kernel_h, kernel_w = kernel.shape 157 | 158 | out = input.view(-1, in_h, 1, in_w, 1, minor) 159 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 160 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 161 | 162 | out = F.pad( 163 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 164 | ) 165 | out = out[ 166 | :, 167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 169 | :, 170 | ] 171 | 172 | out = out.permute(0, 3, 1, 2) 173 | out = out.reshape( 174 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 175 | ) 176 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 177 | out = F.conv2d(out, w) 178 | out = out.reshape( 179 | -1, 180 | minor, 181 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 182 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 183 | ) 184 | out = out.permute(0, 2, 3, 1) 185 | 186 | return out[:, ::down_y, ::down_x, :] 187 | 188 | -------------------------------------------------------------------------------- /exp3D/models/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /exp3D/models/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nvs_tokenizer 2 | -------------------------------------------------------------------------------- /exp3D/models/tokenizers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/tokenizers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /exp3D/models/tokenizers/__pycache__/imgrec_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/tokenizers/__pycache__/imgrec_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /exp3D/models/tokenizers/__pycache__/nvs_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/tokenizers/__pycache__/nvs_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /exp3D/models/tokenizers/nvs_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import einops 5 | 6 | from models import register 7 | from utils import poses_to_rays 8 | 9 | @register('nvs_tokenizer') 10 | class NvsTokenizer(nn.Module): 11 | 12 | def __init__(self, input_size, patch_size, dim, padding=0, img_channels=3): 13 | super().__init__() 14 | if isinstance(input_size, int): 15 | input_size = (input_size, input_size) 16 | if isinstance(patch_size, int): 17 | patch_size = (patch_size, patch_size) 18 | if isinstance(padding, int): 19 | padding = (padding, padding) 20 | self.patch_size = patch_size 21 | self.padding = padding 22 | self.prefc = nn.Linear(patch_size[0] * patch_size[1] * (img_channels + 3 + 3), dim) 23 | 24 | def forward(self, data): 25 | imgs = data['support_imgs'] 26 | B = imgs.shape[0] 27 | H, W = imgs.shape[-2:] 28 | rays_o, rays_d = poses_to_rays(data['support_poses'], H, W, data['support_focals']) 29 | rays_o = einops.rearrange(rays_o, 'b n h w c -> b n c h w') 30 | rays_d = einops.rearrange(rays_d, 'b n h w c -> b n c h w') 31 | 32 | x = torch.cat([imgs, rays_o, rays_d], dim=2) 33 | x = einops.rearrange(x, 'b n d h w -> (b n) d h w') 34 | p = self.patch_size 35 | x = F.unfold(x, p, stride=p, padding=self.padding) 36 | x = einops.rearrange(x, '(b n) ppd l -> b (n l) ppd', b=B) 37 | 38 | x = self.prefc(x) 39 | return x 40 | -------------------------------------------------------------------------------- /exp3D/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import einops 5 | 6 | from models import register 7 | 8 | 9 | class Attention(nn.Module): 10 | 11 | def __init__(self, dim, n_head, head_dim, dropout=0.): 12 | super().__init__() 13 | self.n_head = n_head 14 | inner_dim = n_head * head_dim 15 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 16 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 17 | self.scale = head_dim ** -0.5 18 | self.to_out = nn.Sequential( 19 | nn.Linear(inner_dim, dim), 20 | nn.Dropout(dropout), 21 | ) 22 | 23 | def forward(self, fr, to=None): 24 | if to is None: 25 | to = fr 26 | q = self.to_q(fr) 27 | k, v = self.to_kv(to).chunk(2, dim=-1) 28 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v]) 29 | 30 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 31 | attn = F.softmax(dots, dim=-1) # b h n n 32 | out = torch.matmul(attn, v) 33 | out = einops.rearrange(out, 'b h n d -> b n (h d)') 34 | return self.to_out(out) 35 | 36 | 37 | class FeedForward(nn.Module): 38 | 39 | def __init__(self, dim, ff_dim, dropout=0.): 40 | super().__init__() 41 | self.net = nn.Sequential( 42 | nn.Linear(dim, ff_dim), 43 | nn.GELU(), 44 | nn.Dropout(dropout), 45 | nn.Linear(ff_dim, dim), 46 | nn.Dropout(dropout), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.net(x) 51 | 52 | 53 | class FeedForwardReLU(nn.Module): 54 | 55 | def __init__(self, dim, ff_dim, dropout=0.): 56 | super().__init__() 57 | self.net = nn.Sequential( 58 | nn.Linear(dim, ff_dim), 59 | nn.ReLU(), 60 | nn.Dropout(dropout), 61 | nn.Linear(ff_dim, dim), 62 | nn.Dropout(dropout), 63 | ) 64 | 65 | def forward(self, x): 66 | return self.net(x) 67 | 68 | class PreNorm(nn.Module): 69 | 70 | def __init__(self, dim, fn): 71 | super().__init__() 72 | self.norm = nn.LayerNorm(dim) 73 | self.fn = fn 74 | 75 | def forward(self, x): 76 | return self.fn(self.norm(x)) 77 | 78 | 79 | @register('self_attender') 80 | class Self_Attender(nn.Module): 81 | 82 | def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.0): 83 | super().__init__() 84 | self.layers = nn.ModuleList() 85 | for _ in range(depth): 86 | self.layers.append(nn.ModuleList([ 87 | PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)), 88 | PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)), 89 | ])) 90 | 91 | def forward(self, x): 92 | for norm_attn, norm_ff in self.layers: 93 | x = x + norm_attn(x) 94 | x = x + norm_ff(x) 95 | return x 96 | 97 | 98 | @register('cross_attender') 99 | class Cross_Attender(nn.Module): 100 | 101 | def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.0): 102 | super().__init__() 103 | self.layers = nn.ModuleList() 104 | for _ in range(depth): 105 | self.layers.append(nn.ModuleList([ 106 | # Attention(dim, n_head, head_dim, dropout=dropout), 107 | nn.MultiheadAttention(dim, n_head, dropout=dropout), 108 | nn.LayerNorm(dim), 109 | FeedForwardReLU(dim, ff_dim, dropout=dropout), 110 | nn.LayerNorm(dim) 111 | ])) 112 | self.mlp_enhance = build_mlp(dim, dim, dim, 2) 113 | 114 | def with_pos_embed(self, tensor, pos=None): 115 | return tensor if pos is None else tensor + pos 116 | 117 | def forward(self, x_tgt, context): 118 | x = 0 119 | for attender, norm1, ff, norm2 in self.layers: 120 | x_attn = attender(query=self.with_pos_embed(x, x_tgt).transpose(0,1), 121 | key=context.transpose(0,1), 122 | value=context.transpose(0,1))[0] 123 | x = norm1(x + x_attn.transpose(0,1)) 124 | x_ff = ff(x) 125 | x = norm2(x + x_ff) 126 | x = self.mlp_enhance(x) 127 | return x 128 | 129 | 130 | def build_mlp(dim_in, dim_hid, dim_out, depth): 131 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)] 132 | for _ in range(depth-2): 133 | modules.append(nn.Linear(dim_hid, dim_hid)) 134 | modules.append(nn.ReLU(True)) 135 | modules.append(nn.Linear(dim_hid, dim_out)) 136 | return nn.Sequential(*modules) -------------------------------------------------------------------------------- /exp3D/models/versatile_np.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import einops 7 | 8 | import models 9 | from models import register 10 | from utils import rendering 11 | 12 | 13 | def build_mlp(dim_in, dim_hid, dim_out, depth): 14 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)] 15 | for _ in range(depth-2): 16 | modules.append(nn.Linear(dim_hid, dim_hid)) 17 | modules.append(nn.ReLU(True)) 18 | modules.append(nn.Linear(dim_hid, dim_out)) 19 | return nn.Sequential(*modules) 20 | 21 | 22 | @register('versatile_np') 23 | class Versatile_NP(nn.Module): 24 | 25 | def __init__(self, tokenizer, self_attender, cross_attender, hierarchical_model, pe_dim=128): 26 | super().__init__() 27 | dim_y = hierarchical_model['args']['dim_y'] 28 | self.pe_dim = pe_dim # dimension of positional embedding 29 | 30 | self.tokenizer = models.make(tokenizer, args={'dim': dim_y}) 31 | self.self_attn = models.make(self_attender) 32 | self.cross_attn = models.make(cross_attender) 33 | self.hier_model = models.make(hierarchical_model) 34 | 35 | self.embed_input = build_mlp(pe_dim * 3, dim_y, dim_y, 2) 36 | self.embed_last = build_mlp(dim_y, dim_y, 4, 2) 37 | 38 | def forward(self, data, x_tgt, rays_o, z_vals, y_tgt, is_train=True): 39 | data_tokens = self.tokenizer(data) 40 | context_tokens = self.self_attn(data_tokens) 41 | x_tgt = self.coord_embedding(x_tgt) 42 | y_middle = self.cross_attn(x_tgt, context_tokens) 43 | 44 | y_middle, kls = self.hier_model(y_middle, x_tgt, y_tgt, rays_o, z_vals, training=is_train) 45 | y_pred = rendering(self.embed_last(y_middle), rays_o = rays_o, z_vals = z_vals) 46 | return y_pred, kls 47 | 48 | def coord_embedding(self, x): 49 | w = 2 ** torch.linspace(0, 8, self.pe_dim // 2, device=x.device) 50 | x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1) 51 | x = torch.cat([torch.cos(x), torch.sin(x)], dim=-1) 52 | x = self.embed_input(x) 53 | return x -------------------------------------------------------------------------------- /exp3D/run_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a cfg object according to a cfg file and args, then spawn Trainer(rank, cfg). 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import yaml 9 | import numpy as np 10 | import torch 11 | import torch.multiprocessing as mp 12 | 13 | import utils 14 | import trainers 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--cfg') 20 | parser.add_argument('--seed', type=int, default=42) 21 | parser.add_argument('--load-root', default='data') 22 | parser.add_argument('--save-root', default='save') 23 | parser.add_argument('--Lambda', type=float, default=0.001) 24 | parser.add_argument('--name', '-n', default=None) 25 | parser.add_argument('--tag', default=None) 26 | parser.add_argument('--cudnn', action='store_true') 27 | parser.add_argument('--port-offset', '-p', type=int, default=0) 28 | parser.add_argument('--wandb-upload', '-w', action='store_true') 29 | args = parser.parse_args() 30 | 31 | return args 32 | 33 | 34 | def make_cfg(args): 35 | with open(args.cfg, 'r') as f: 36 | cfg = yaml.load(f, Loader=yaml.FullLoader) 37 | 38 | def translate_cfg_(d): 39 | for k, v in d.items(): 40 | if isinstance(v, dict): 41 | translate_cfg_(v) 42 | elif isinstance(v, str): 43 | d[k] = v.replace('$load_root$', args.load_root) 44 | 45 | if cfg['Lambda'] != args.Lambda: 46 | cfg['Lambda'] = args.Lambda 47 | 48 | translate_cfg_(cfg) 49 | 50 | if args.name is None: 51 | exp_name = os.path.basename(args.cfg).split('.')[0] 52 | else: 53 | exp_name = args.name 54 | if args.tag is not None: 55 | exp_name += '_' + args.tag 56 | 57 | env = dict() 58 | env['exp_name'] = exp_name 59 | env['save_dir'] = os.path.join(args.save_root, exp_name) 60 | env['tot_gpus'] = torch.cuda.device_count() 61 | env['cudnn'] = args.cudnn 62 | env['port'] = str(29600 + args.port_offset) 63 | env['wandb_upload'] = args.wandb_upload 64 | cfg['env'] = env 65 | 66 | return cfg 67 | 68 | 69 | def main_worker(rank, cfg): 70 | trainer = trainers.trainers_dict[cfg['trainer']](rank, cfg) 71 | trainer.run() 72 | 73 | 74 | def main(): 75 | args = parse_args() 76 | 77 | torch.manual_seed(args.seed) 78 | np.random.seed(args.seed) 79 | 80 | cfg = make_cfg(args) 81 | utils.ensure_path(cfg['env']['save_dir']) 82 | 83 | if cfg['env']['tot_gpus'] > 1: 84 | mp.spawn(main_worker, args=(cfg,), nprocs=cfg['env']['tot_gpus']) 85 | else: 86 | main_worker(0, cfg) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /exp3D/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import register, trainers_dict 2 | from . import base_trainer 3 | from . import imgrec_trainer 4 | from . import nvs_trainer 5 | -------------------------------------------------------------------------------- /exp3D/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | A basic trainer. 3 | 4 | The general procedure in run() is: 5 | make_datasets() 6 | create . train_loader, test_loader, dist_samplers 7 | make_model() 8 | create . model_ddp, model 9 | train() 10 | create . optimizer, epoch, log_buffer 11 | for epoch = 1 ... max_epoch: 12 | adjust_learning_rate() 13 | train_epoch() 14 | train_step() 15 | evaluate_epoch() 16 | evaluate_step() 17 | save_checkpoint() 18 | """ 19 | 20 | import os 21 | import os.path as osp 22 | import time 23 | 24 | import yaml 25 | import torch 26 | import torch.nn as nn 27 | import torch.backends.cudnn as cudnn 28 | import torch.distributed as dist 29 | import wandb 30 | from tqdm import tqdm 31 | from torch.utils.data import DataLoader 32 | from torch.utils.data.distributed import DistributedSampler 33 | from torch.nn.parallel import DistributedDataParallel 34 | 35 | import datasets 36 | import models 37 | import utils 38 | from trainers import register 39 | 40 | 41 | @register('base_trainer') 42 | class BaseTrainer(): 43 | 44 | def __init__(self, rank, cfg): 45 | self.rank = rank 46 | self.cfg = cfg 47 | self.is_master = (rank == 0) 48 | 49 | env = cfg['env'] 50 | self.tot_gpus = env['tot_gpus'] 51 | self.distributed = (env['tot_gpus'] > 1) 52 | 53 | # Setup log, tensorboard, wandb 54 | if self.is_master: 55 | logger, writer = utils.set_save_dir(env['save_dir'], replace=False) 56 | with open(osp.join(env['save_dir'], 'cfg.yaml'), 'w') as f: 57 | yaml.dump(cfg, f, sort_keys=False) 58 | 59 | self.log = logger.info 60 | 61 | self.enable_tb = True 62 | self.writer = writer 63 | 64 | if env['wandb_upload']: 65 | self.enable_wandb = True 66 | with open('wandb.yaml', 'r') as f: 67 | wandb_cfg = yaml.load(f, Loader=yaml.FullLoader) 68 | os.environ['WANDB_DIR'] = env['save_dir'] 69 | os.environ['WANDB_NAME'] = env['exp_name'] 70 | os.environ['WANDB_API_KEY'] = wandb_cfg['api_key'] 71 | wandb.init(project=wandb_cfg['project'], entity=wandb_cfg['entity'], config=cfg) 72 | else: 73 | self.enable_wandb = False 74 | else: 75 | self.log = lambda *args, **kwargs: None 76 | self.enable_tb = False 77 | self.enable_wandb = False 78 | 79 | # Setup distributed devices 80 | torch.cuda.set_device(rank) 81 | self.device = torch.device('cuda', torch.cuda.current_device()) 82 | 83 | if self.distributed: 84 | dist_url = f"tcp://localhost:{env['port']}" 85 | dist.init_process_group(backend='nccl', init_method=dist_url, 86 | world_size=self.tot_gpus, rank=rank) 87 | self.log(f'Distributed training enabled.') 88 | 89 | cudnn.benchmark = env['cudnn'] 90 | 91 | self.log(f'Environment setup done.') 92 | 93 | def run(self): 94 | self.make_datasets() 95 | 96 | if self.cfg.get('eval_model') is not None: 97 | print('Load model:', self.cfg['eval_model']) 98 | model_spec = torch.load(self.cfg['eval_model'])['model'] 99 | self.make_model(model_spec, load_sd=True) 100 | self.epoch = 0 101 | self.log_buffer = [] 102 | self.t_data, self.t_model = 0, 0 103 | self.evaluate_epoch() 104 | self.log(', '.join(self.log_buffer)) 105 | elif self.cfg.get('resume_model') is not None: 106 | print('Load model:', self.cfg['resume_model']) 107 | model_spec = torch.load(self.cfg['resume_model'], map_location='cpu')['model'] 108 | self.make_model(model_spec, load_sd=True) 109 | self.train() 110 | else: 111 | self.make_model() 112 | self.train() 113 | 114 | if self.enable_tb: 115 | self.writer.close() 116 | if self.enable_wandb: 117 | wandb.finish() 118 | 119 | def make_datasets(self): 120 | """ 121 | By default, train dataset performs shuffle and drop_last. 122 | Distributed sampler will extend the dataset with a prefix to make the length divisible by tot_gpus, samplers should be stored in .dist_samplers. 123 | 124 | Cfg example: 125 | 126 | train/test_dataset: 127 | name: 128 | args: 129 | loader: {batch_size: , num_workers: } 130 | """ 131 | cfg = self.cfg 132 | self.dist_samplers = [] 133 | 134 | def make_distributed_loader(dataset, batch_size, num_workers, shuffle=False, drop_last=False): 135 | sampler = DistributedSampler(dataset, shuffle=shuffle) if self.distributed else None 136 | loader = DataLoader( 137 | dataset, 138 | batch_size // self.tot_gpus, 139 | drop_last=drop_last, 140 | sampler=sampler, 141 | shuffle=(shuffle and (sampler is None)), 142 | num_workers=num_workers // self.tot_gpus, 143 | pin_memory=True) 144 | return loader, sampler 145 | 146 | if cfg.get('train_dataset') is not None: 147 | train_dataset = datasets.make(cfg['train_dataset']) 148 | self.log(f'Train dataset: len={len(train_dataset)}') 149 | l = cfg['train_dataset']['loader'] 150 | self.train_loader, train_sampler = make_distributed_loader( 151 | train_dataset, l['batch_size'], l['num_workers'], shuffle=True, drop_last=True) 152 | self.dist_samplers.append(train_sampler) 153 | 154 | if cfg.get('test_dataset') is not None: 155 | test_dataset = datasets.make(cfg['test_dataset']) 156 | self.log(f'Test dataset: len={len(test_dataset)}') 157 | l = cfg['test_dataset']['loader'] 158 | self.test_loader, test_sampler = make_distributed_loader( 159 | test_dataset, l['batch_size'], l['num_workers'], shuffle=False, drop_last=False) 160 | self.dist_samplers.append(test_sampler) 161 | 162 | def make_model(self, model_spec=None, load_sd=False): 163 | if model_spec is None: 164 | model_spec = self.cfg['model'] 165 | model = models.make(model_spec, load_sd=load_sd) 166 | self.log(f'Model: #params={utils.compute_num_params(model)}') 167 | 168 | if self.distributed: 169 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 170 | model.cuda() 171 | model_ddp = DistributedDataParallel(model, device_ids=[self.rank], find_unused_parameters=True) 172 | else: 173 | model.cuda() 174 | model_ddp = model 175 | self.model = model 176 | self.model_ddp = model_ddp 177 | 178 | def train(self): 179 | """ 180 | For epochs perform training, evaluation, and visualization. 181 | Note that ave_scalars update ignores the actual current batch_size. 182 | """ 183 | cfg = self.cfg 184 | 185 | self.optimizer = utils.make_optimizer(self.model_ddp.parameters(), cfg['optimizer']) 186 | 187 | max_epoch = cfg['max_epoch'] 188 | eval_epoch = cfg.get('eval_epoch', max_epoch + 1) 189 | vis_epoch = cfg.get('vis_epoch', max_epoch + 1) 190 | save_epoch = cfg.get('save_epoch', max_epoch + 1) 191 | epoch_timer = utils.EpochTimer(max_epoch) 192 | 193 | for epoch in range(1, max_epoch + 1): 194 | self.epoch = epoch 195 | self.log_buffer = [f'Epoch {epoch}'] 196 | 197 | if self.distributed: 198 | for sampler in self.dist_samplers: 199 | sampler.set_epoch(epoch) 200 | 201 | self.adjust_learning_rate() 202 | 203 | self.t_data, self.t_model = 0, 0 204 | self.train_epoch() 205 | 206 | if epoch % eval_epoch == 0: 207 | self.evaluate_epoch() 208 | 209 | if epoch % save_epoch == 0: 210 | self.save_checkpoint(f'epoch-{epoch}.pth') 211 | self.save_checkpoint('epoch-last.pth') 212 | 213 | epoch_time, tot_time, est_time = epoch_timer.epoch_done() 214 | t_data_ratio = self.t_data / (self.t_data + self.t_model) 215 | self.log_buffer.append(f'{epoch_time} (d {t_data_ratio:.2f}) {tot_time}/{est_time}') 216 | self.log(', '.join(self.log_buffer)) 217 | 218 | def adjust_learning_rate(self): 219 | base_lr = self.cfg['optimizer']['args']['lr'] 220 | for param_group in self.optimizer.param_groups: 221 | param_group['lr'] = base_lr 222 | self.log_temp_scalar('lr', self.optimizer.param_groups[0]['lr']) 223 | 224 | def log_temp_scalar(self, k, v, t=None): 225 | if t is None: 226 | t = self.epoch 227 | if self.enable_tb: 228 | self.writer.add_scalar(k, v, global_step=t) 229 | if self.enable_wandb: 230 | wandb.log({k: v}, step=t) 231 | 232 | def dist_all_reduce_mean_(self, x): 233 | dist.all_reduce(x, op=dist.ReduceOp.SUM) 234 | x.div_(self.tot_gpus) 235 | 236 | def sync_ave_scalars_(self, ave_scalars): 237 | for k in ave_scalars.keys(): 238 | x = torch.tensor(ave_scalars[k].item(), dtype=torch.float32, device=self.device) 239 | self.dist_all_reduce_mean_(x) 240 | ave_scalars[k].v = x.item() 241 | ave_scalars[k].n *= self.tot_gpus 242 | 243 | def train_step(self, data): 244 | data = {k: v.cuda() for k, v in data.items()} 245 | loss = self.model_ddp(data) 246 | self.optimizer.zero_grad() 247 | loss.backward() 248 | self.optimizer.step() 249 | return {'loss': loss.item()} 250 | 251 | def train_epoch(self): 252 | self.model_ddp.train() 253 | ave_scalars = dict() 254 | 255 | pbar = self.train_loader 256 | if self.is_master: 257 | pbar = tqdm(pbar, desc='train', leave=False) 258 | 259 | t1 = time.time() 260 | for data in pbar: 261 | t0 = time.time() 262 | self.t_data += t0 - t1 263 | ret = self.train_step(data) 264 | self.t_model += time.time() - t0 265 | 266 | B = len(next(iter(data.values()))) 267 | for k, v in ret.items(): 268 | if ave_scalars.get(k) is None: 269 | ave_scalars[k] = utils.Averager() 270 | ave_scalars[k].add(v, n=B) 271 | 272 | if self.is_master: 273 | pbar.set_description(desc=f'train: loss={ret["loss"]:.4f}') 274 | t1 = time.time() 275 | 276 | if self.distributed: 277 | self.sync_ave_scalars_(ave_scalars) 278 | 279 | logtext = 'train:' 280 | for k, v in ave_scalars.items(): 281 | logtext += f' {k}={v.item():.4f}' 282 | self.log_temp_scalar('train/' + k, v.item()) 283 | self.log_buffer.append(logtext) 284 | 285 | def evaluate_step(self, data): 286 | data = {k: v.cuda() for k, v in data.items()} 287 | with torch.no_grad(): 288 | loss = self.model_ddp(data) 289 | return {'loss': loss.item()} 290 | 291 | def evaluate_epoch(self): 292 | self.model_ddp.eval() 293 | ave_scalars = dict() 294 | 295 | pbar = self.test_loader 296 | if self.is_master: 297 | pbar = tqdm(pbar, desc='eval', leave=False) 298 | 299 | t1 = time.time() 300 | for data in pbar: 301 | t0 = time.time() 302 | self.t_data += t0 - t1 303 | ret = self.evaluate_step(data) 304 | self.t_model += time.time() - t0 305 | 306 | B = len(next(iter(data.values()))) 307 | for k, v in ret.items(): 308 | if ave_scalars.get(k) is None: 309 | ave_scalars[k] = utils.Averager() 310 | ave_scalars[k].add(v, n=B) 311 | 312 | t1 = time.time() 313 | 314 | if self.distributed: 315 | self.sync_ave_scalars_(ave_scalars) 316 | 317 | logtext = 'eval:' 318 | for k, v in ave_scalars.items(): 319 | logtext += f' {k}={v.item():.4f}' 320 | self.log_temp_scalar('test/' + k, v.item()) 321 | self.log_buffer.append(logtext) 322 | 323 | 324 | def save_checkpoint(self, filename): 325 | if not self.is_master: 326 | return 327 | model_spec = self.cfg['model'] 328 | model_spec['sd'] = self.model.state_dict() 329 | optimizer_spec = self.cfg['optimizer'] 330 | optimizer_spec['sd'] = self.optimizer.state_dict() 331 | checkpoint = { 332 | 'model': model_spec, 333 | 'optimizer': optimizer_spec, 334 | 'epoch': self.epoch, 335 | 'cfg': self.cfg, 336 | } 337 | torch.save(checkpoint, osp.join(self.cfg['env']['save_dir'], filename)) 338 | -------------------------------------------------------------------------------- /exp3D/trainers/nvs_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | import einops 7 | import wandb 8 | from tqdm import tqdm 9 | 10 | import utils 11 | from .base_trainer import BaseTrainer 12 | from trainers import register 13 | from utils import poses_to_rays, get_coord, volume_rendering, batched_volume_rendering 14 | 15 | 16 | @register('nvs_trainer') 17 | class NvsTrainer(BaseTrainer): 18 | 19 | def make_datasets(self): 20 | super().make_datasets() 21 | 22 | def get_vislist(dataset, n_vis=8): 23 | ids = torch.arange(n_vis) * (len(dataset) // n_vis) 24 | return [dataset[i] for i in ids] 25 | 26 | if hasattr(self, 'train_loader'): 27 | np.random.seed(0) 28 | self.vislist_train = get_vislist(self.train_loader.dataset) 29 | if hasattr(self, 'test_loader'): 30 | np.random.seed(0) 31 | self.vislist_test = get_vislist(self.test_loader.dataset) 32 | 33 | def adjust_learning_rate(self): 34 | base_lr = self.cfg['optimizer']['args']['lr'] 35 | if self.epoch <= round(self.cfg['max_epoch'] * 0.8): 36 | lr = base_lr 37 | else: 38 | lr = base_lr * 0.1 39 | for param_group in self.optimizer.param_groups: 40 | param_group['lr'] = lr 41 | self.log_temp_scalar('lr', lr) 42 | 43 | def _adaptive_sample_rays(self, rays_o, rays_d, gt, n_sample): 44 | B = rays_o.shape[0] 45 | inds = [] 46 | fg_n_sample = n_sample // 2 47 | for i in range(B): 48 | fg = ((gt[i].min(dim=-1).values < 1).nonzero().view(-1)).cpu().numpy() 49 | if fg_n_sample <= len(fg): 50 | fg = np.random.choice(fg, fg_n_sample, replace=False) 51 | else: 52 | fg = np.concatenate([fg, np.random.choice(fg, fg_n_sample - len(fg), replace=True)], axis=0) 53 | rd = np.random.choice(rays_o.shape[1], n_sample - fg_n_sample, replace=False) 54 | inds.append(np.concatenate([fg, rd], axis=0)) 55 | 56 | def subselect(x, inds): 57 | t = torch.empty(B, len(inds[0]), 3, dtype=x.dtype, device=x.device) 58 | for i in range(B): 59 | t[i] = x[i][inds[i], :] 60 | return t 61 | 62 | return subselect(rays_o, inds), subselect(rays_d, inds), subselect(gt, inds) 63 | 64 | def _iter_step(self, data, is_train): 65 | data = {k: v.cuda() for k, v in data.items()} 66 | query_imgs = data.pop('query_imgs') 67 | 68 | B, _, _, H, W = query_imgs.shape 69 | rays_o, rays_d = poses_to_rays(data['query_poses'], H, W, data['query_focals']) 70 | 71 | gt = einops.rearrange(query_imgs, 'b n c h w -> b (n h w) c') 72 | rays_o = einops.rearrange(rays_o, 'b n h w c -> b (n h w) c') 73 | rays_d = einops.rearrange(rays_d, 'b n h w c -> b (n h w) c') 74 | 75 | n_sample = self.cfg['train_n_rays'] 76 | if is_train and self.epoch <= self.cfg.get('adaptive_sample_epoch', 0): 77 | rays_o, rays_d, gt = self._adaptive_sample_rays(rays_o, rays_d, gt, n_sample) 78 | else: 79 | ray_ids = np.random.choice(rays_o.shape[1], n_sample, replace=False) 80 | rays_o, rays_d, gt = map((lambda _: _[:, ray_ids, :]), [rays_o, rays_d, gt]) 81 | 82 | # x_tgt is the world coordinate of target points 83 | x_tgt, z_vals = get_coord( 84 | rays_o, rays_d, 85 | near=data['near'][0], 86 | far=data['far'][0], 87 | points_per_ray=self.cfg['train_points_per_ray'], 88 | rand=is_train, 89 | ) 90 | 91 | pred, loss_kl = self.model_ddp(data, x_tgt, rays_o, z_vals, gt, is_train=is_train) 92 | 93 | loss_mse = ((pred - gt) ** 2).view(B, -1).mean(dim=-1) 94 | loss_kl = loss_kl.view(B, -1).mean() 95 | 96 | # annealing beta coefficient 97 | if self.cfg.get('resume_model') is not None: 98 | beta = self.cfg['Lambda'] 99 | else: 100 | beta = self.cfg['Lambda'] * min(1.0, self.epoch / 50) 101 | 102 | loss = loss_mse.mean() + loss_kl * beta 103 | psnr = (-10 * torch.log10(loss_mse)).mean() 104 | 105 | if is_train: 106 | self.optimizer.zero_grad() 107 | loss.backward() 108 | self.optimizer.step() 109 | 110 | return {'loss': loss.item(), 'psnr': psnr.item(), 'loss_kl': loss_kl.item()} 111 | 112 | def train_step(self, data): 113 | return self._iter_step(data, is_train=True) 114 | 115 | def evaluate_step(self, data): 116 | with torch.no_grad(): 117 | return self._iter_step(data, is_train=False) 118 | 119 | 120 | @register('nvs_evaluator') 121 | class NvsEvaluator(NvsTrainer): 122 | 123 | def _iter_step(self, data, is_train, step=0): 124 | assert not is_train 125 | data = {k: v.cuda() for k, v in data.items()} 126 | 127 | query_imgs = data.pop('query_imgs') 128 | B, N, _, H, W = query_imgs.shape 129 | 130 | with torch.no_grad(): 131 | B, _, _, H, W = query_imgs.shape 132 | rays_o, rays_d = poses_to_rays(data['query_poses'], H, W, data['query_focals']) 133 | 134 | gt = einops.rearrange(query_imgs, 'b n c h w -> b (n h w) c') 135 | rays_o = einops.rearrange(rays_o, 'b n h w c -> b (n h w) c') 136 | rays_d = einops.rearrange(rays_d, 'b n h w c -> b (n h w) c') 137 | n_sample = self.cfg['render_ray_batch'] 138 | pred = [] 139 | for i in range(0, rays_o.shape[1], n_sample): 140 | rays_o_, rays_d_, gt_ = map((lambda _: _[:, i: i + n_sample, :]), [rays_o, rays_d, gt]) 141 | x_tgt, z_vals = get_coord( 142 | rays_o_, rays_d_, 143 | near=data['near'][0], 144 | far=data['far'][0], 145 | points_per_ray=self.cfg['train_points_per_ray'], 146 | rand=is_train, 147 | ) 148 | pred_, loss_kl = self.model_ddp(data, x_tgt, rays_o_, z_vals, gt_, is_train=False) 149 | pred.append(pred_) 150 | 151 | 152 | pred = torch.cat(pred, dim=1) 153 | pred = torch.clamp(pred, min=0.0, max=1.0) 154 | gt = torch.clamp(gt, min=0.0, max=1.0) 155 | 156 | ref = data['support_imgs'][:,0].permute(0, 2, 3, 1) 157 | # save_img = torch.cat([ref.view(B, 128, 128, 3), pred.view(B, 128, 128, 3), gt.view(B, 128, 128, 3)], dim=2).view(-1, 384, 3) 158 | # imsave('save_imgs/save_img' + str(step) + '.png', np.squeeze(save_img.cpu().numpy() * 255).astype(np.uint8)) 159 | 160 | loss_kl = loss_kl.view(B, -1).mean(-1) 161 | mses = ((pred - gt)**2).view(B * N, -1).mean(dim=-1) 162 | psnr = (-10 * torch.log10(mses)).mean() 163 | 164 | return {'psnr': psnr.item(), 'loss_kl': loss_kl.mean().item()} 165 | 166 | def evaluate_epoch(self): 167 | self.model_ddp.eval() 168 | ave_scalars = dict() 169 | 170 | pbar = self.test_loader 171 | if self.is_master: 172 | pbar = tqdm(pbar, desc='eval', leave=False) 173 | 174 | t1 = time.time() 175 | for data in pbar: 176 | t0 = time.time() 177 | self.t_data += t0 - t1 178 | ret = self.evaluate_step(data) 179 | self.t_model += time.time() - t0 180 | 181 | B = len(next(iter(data.values()))) 182 | for k, v in ret.items(): 183 | if ave_scalars.get(k) is None: 184 | ave_scalars[k] = utils.Averager() 185 | ave_scalars[k].add(v, n=B) 186 | 187 | # --The only thing added-- # 188 | if self.is_master: 189 | pbar.set_description(desc=f'eval: psnr={ave_scalars["psnr"].item():.2f}') 190 | # ------------------------ # 191 | t1 = time.time() 192 | 193 | if self.distributed: 194 | self.sync_ave_scalars_(ave_scalars) 195 | 196 | logtext = 'eval:' 197 | for k, v in ave_scalars.items(): 198 | logtext += f' {k}={v.item():.4f}' 199 | self.log_temp_scalar('test/' + k, v.item()) 200 | self.log_buffer.append(logtext) -------------------------------------------------------------------------------- /exp3D/trainers/trainers.py: -------------------------------------------------------------------------------- 1 | trainers_dict = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | trainers_dict[name] = cls 7 | return cls 8 | return decorator 9 | -------------------------------------------------------------------------------- /exp3D/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .geometry import * 3 | -------------------------------------------------------------------------------- /exp3D/utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import logging 5 | 6 | import numpy as np 7 | from torch.optim import SGD, Adam 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | def ensure_path(path, replace=True): 12 | basename = os.path.basename(path.rstrip('/')) 13 | if os.path.exists(path): 14 | if replace and (basename.startswith('_') or input('{} exists, replace? (y/[n]): '.format(path)) == 'y'): 15 | shutil.rmtree(path) 16 | os.makedirs(path) 17 | else: 18 | os.makedirs(path) 19 | 20 | 21 | def set_logger(file_path): 22 | logger = logging.getLogger() 23 | logger.setLevel('INFO') 24 | stream_handler = logging.StreamHandler() 25 | file_handler = logging.FileHandler(file_path, 'w') 26 | formatter = logging.Formatter('[%(asctime)s] %(message)s', '%m-%d %H:%M:%S') 27 | for handler in [stream_handler, file_handler]: 28 | handler.setFormatter(formatter) 29 | handler.setLevel('INFO') 30 | logger.addHandler(handler) 31 | return logger 32 | 33 | 34 | def set_save_dir(save_dir, replace=True): 35 | ensure_path(save_dir, replace=replace) 36 | logger = set_logger(os.path.join(save_dir, 'log.txt')) 37 | writer = SummaryWriter(os.path.join(save_dir, 'tensorboard')) 38 | return logger, writer 39 | 40 | 41 | def compute_num_params(model, text=True): 42 | tot = int(sum([np.prod(p.shape) for p in model.parameters()])) 43 | if text: 44 | if tot >= 1e6: 45 | return '{:.1f}M'.format(tot / 1e6) 46 | elif tot >= 1e3: 47 | return '{:.1f}K'.format(tot / 1e3) 48 | else: 49 | return str(tot) 50 | else: 51 | return tot 52 | 53 | 54 | def make_optimizer(params, optimizer_spec, load_sd=False): 55 | optimizer = { 56 | 'sgd': SGD, 57 | 'adam': Adam 58 | }[optimizer_spec['name']](params, **optimizer_spec['args']) 59 | if load_sd: 60 | optimizer.load_state_dict(optimizer_spec['sd']) 61 | return optimizer 62 | 63 | 64 | class Averager(): 65 | 66 | def __init__(self): 67 | self.n = 0.0 68 | self.v = 0.0 69 | 70 | def add(self, v, n=1.0): 71 | self.v = (self.v * self.n + v * n) / (self.n + n) 72 | self.n += n 73 | 74 | def item(self): 75 | return self.v 76 | 77 | 78 | class EpochTimer(): 79 | 80 | def __init__(self, max_epoch): 81 | self.max_epoch = max_epoch 82 | self.epoch = 0 83 | self.t_start = time.time() 84 | self.t_last = self.t_start 85 | 86 | def epoch_done(self): 87 | t_cur = time.time() 88 | self.epoch += 1 89 | epoch_time = t_cur - self.t_last 90 | tot_time = t_cur - self.t_start 91 | est_time = tot_time / self.epoch * self.max_epoch 92 | self.t_last = t_cur 93 | return time_text(epoch_time), time_text(tot_time), time_text(est_time) 94 | 95 | 96 | def time_text(secs): 97 | if secs >= 3600: 98 | return f'{secs / 3600:.1f}h' 99 | elif secs >= 60: 100 | return f'{secs / 60:.1f}m' 101 | else: 102 | return f'{secs:.1f}s' 103 | -------------------------------------------------------------------------------- /exp3D/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import einops 4 | 5 | 6 | def make_coord_grid(shape, range, device=None): 7 | """ 8 | Args: 9 | shape: tuple 10 | range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim 11 | Returns: 12 | grid: shape (*shape, ) 13 | """ 14 | l_lst = [] 15 | for i, s in enumerate(shape): 16 | l = (0.5 + torch.arange(s, device=device)) / s 17 | if isinstance(range[0], list) or isinstance(range[0], tuple): 18 | minv, maxv = range[i] 19 | else: 20 | minv, maxv = range 21 | l = minv + (maxv - minv) * l 22 | l_lst.append(l) 23 | grid = torch.meshgrid(*l_lst) 24 | grid = torch.stack(grid, dim=-1) 25 | return grid 26 | 27 | 28 | def poses_to_rays(poses, image_h, image_w, focal): 29 | """ 30 | Pose columns are: 3 camera axes specified in world coordinate + 1 camera position. 31 | Camera: x-axis right, y-axis up, z-axis inward. 32 | Focal is in pixel-scale. 33 | 34 | Args: 35 | poses: (... 3 4) 36 | focal: (... 2) 37 | Returns: 38 | rays_o, rays_d: shape (... image_h image_w 3) 39 | """ 40 | device = poses.device 41 | bshape = poses.shape[:-2] 42 | poses = poses.view(-1, 3, 4) 43 | focal = focal.view(-1, 2) 44 | bsize = poses.shape[0] 45 | 46 | y, x = torch.meshgrid(torch.arange(image_w, device=device), torch.arange(image_h, device=device)) # h w 47 | x, y = x + 0.5, y + 0.5 # modified to + 0.5 48 | x, y = x.unsqueeze(0), y.unsqueeze(0) # h w -> 1 h w 49 | focal = focal.unsqueeze(1).unsqueeze(1) # b 2 -> b 1 1 2 50 | dirs = torch.stack([ 51 | (x - image_w / 2) / focal[..., 0], 52 | -(y - image_h / 2) / focal[..., 1], 53 | -torch.ones(bsize, image_h, image_w, device=device) 54 | ], dim=-1) # b h w 3 55 | 56 | poses = poses.unsqueeze(1).unsqueeze(1) # b 3 4 -> b 1 1 3 4 57 | rays_o = poses[..., -1].repeat(1, image_h, image_w, 1) # b h w 3 58 | rays_d = (dirs.unsqueeze(-2) * poses[..., :3]).sum(dim=-1) # b h w 3 59 | 60 | rays_o = rays_o.view(*bshape, *rays_o.shape[1:]) 61 | rays_d = rays_d.view(*bshape, *rays_d.shape[1:]) 62 | return rays_o, rays_d 63 | 64 | 65 | def get_coord(rays_o, rays_d, near, far, points_per_ray, rand): 66 | """ 67 | Args: 68 | rays_o, rays_d: shape (b ... 3) 69 | Returns: 70 | pts_flat, z_vals 71 | """ 72 | B = rays_o.shape[0] 73 | rays_o = rays_o.view(B, -1, 3) 74 | rays_d = rays_d.view(B, -1, 3) 75 | n_rays = rays_o.shape[1] 76 | device = rays_o.device 77 | 78 | # Compute 3D query points 79 | z_vals = torch.linspace(near, far, points_per_ray, device=device) 80 | z_vals = einops.repeat(z_vals, 'p -> n p', n=n_rays) 81 | if rand: 82 | d = (far - near) / (points_per_ray - 1) # modified as points_per_ray - 1 83 | z_vals = z_vals + torch.rand(n_rays, points_per_ray, device=device) * d 84 | 85 | pts = rays_o.view(B, n_rays, 1, 3) + rays_d.view(B, n_rays, 1, 3) * z_vals.view(1, n_rays, points_per_ray, 1) 86 | pts_flat = einops.rearrange(pts, 'b n p d -> b (n p) d') 87 | return pts_flat, z_vals 88 | 89 | 90 | def rendering(raw, rays_o, z_vals): 91 | B = raw.shape[0] 92 | query_shape = rays_o.shape[1: -1] 93 | n_rays = rays_o.shape[1] 94 | raw = einops.rearrange(raw, 'b (n p) c -> b n p c', n=n_rays) 95 | 96 | # Compute opacities and colors 97 | rgb, sigma_a = raw[..., :3], raw[..., 3] 98 | rgb = torch.sigmoid(rgb) # b n p 3 99 | sigma_a = F.relu(sigma_a) # b n p 100 | 101 | # Do volume rendering 102 | dists = torch.cat([z_vals[:, 1:] - z_vals[:, :-1], torch.ones_like(z_vals[:, -1:]) * 1e-3], dim=-1) # n p 103 | alpha = 1. - torch.exp(-sigma_a * dists) # b n p 104 | trans = torch.clamp(1. - alpha + 1e-10, max=1.) # b n p 105 | trans = torch.cat([torch.ones_like(trans[..., :1]), trans[..., :-1]], dim=-1) 106 | weights = alpha * torch.cumprod(trans, dim=-1) # b n p 107 | 108 | rgb_map = torch.sum(weights.unsqueeze(-1) * rgb, dim=-2) 109 | acc_map = torch.sum(weights, dim=-1) 110 | rgb_map = rgb_map + (1. - acc_map).unsqueeze(-1) # white background 111 | # depth_map = torch.sum(weights * z_vals, dim=-1) 112 | 113 | rgb_map = rgb_map.view(B, *query_shape, 3) 114 | return rgb_map 115 | 116 | 117 | def volume_rendering(nerf, rays_o, rays_d, near, far, points_per_ray, use_viewdirs, rand): 118 | """ 119 | Args: 120 | rays_o, rays_d: shape (b ... 3) 121 | Returns: 122 | pred: (b ... 3) 123 | """ 124 | pts_flat, z_vals = get_coord(rays_o, rays_d, near, far, points_per_ray, rand) 125 | 126 | if not use_viewdirs: 127 | raw = nerf(pts_flat) 128 | else: 129 | viewdirs = einops.repeat(rays_d, 'b n d -> b n p d', p=points_per_ray) 130 | raw = nerf(pts_flat, viewdirs=viewdirs) 131 | 132 | rgb_map = rendering(aw, rays_o, z_vals) 133 | return rgb_map 134 | 135 | 136 | def batched_volume_rendering(nerf, rays_o, rays_d, *args, batch_size=None, **kwargs): 137 | """ 138 | Args: 139 | rays_o, rays_d: (b ... 3) 140 | Returns: 141 | pred: (b ... 3) 142 | """ 143 | B = rays_o.shape[0] 144 | query_shape = rays_o.shape[1: -1] 145 | rays_o = rays_o.view(B, -1, 3) 146 | rays_d = rays_d.view(B, -1, 3) 147 | 148 | ret = [] 149 | ll = 0 150 | while ll < rays_o.shape[1]: 151 | rr = min(ll + batch_size, rays_o.shape[1]) 152 | rays_o_ = rays_o[:, ll: rr, :] 153 | rays_d_ = rays_d[:, ll: rr, :] 154 | ret.append(volume_rendering(nerf, rays_o_, rays_d_, *args, **kwargs)) 155 | ll = rr 156 | ret = torch.cat(ret, dim=1) 157 | 158 | ret = ret.view(B, *query_shape, 3) 159 | return ret 160 | 161 | -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/imgs/framework.png --------------------------------------------------------------------------------