├── LICENSE
├── README.md
├── exp3D
├── cfgs
│ ├── nvs_eval
│ │ ├── eval_nvs_shapenet_cars.yaml
│ │ ├── eval_nvs_shapenet_chairs.yaml
│ │ └── eval_nvs_shapenet_lamps.yaml
│ ├── nvs_shapenet_cars.yaml
│ ├── nvs_shapenet_chairs.yaml
│ └── nvs_shapenet_lamps.yaml
├── datasets
│ ├── __init__.py
│ ├── celeba.py
│ ├── datasets.py
│ ├── imagenette.py
│ ├── imgrec_dataset.py
│ └── learnit_shapenet.py
├── models
│ ├── __init__.py
│ ├── blocks.py
│ ├── hier_model.py
│ ├── models.py
│ ├── op
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── fused_act.cpython-37.pyc
│ │ │ ├── fused_act.cpython-38.pyc
│ │ │ ├── upfirdn2d.cpython-37.pyc
│ │ │ └── upfirdn2d.cpython-38.pyc
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── tokenizers
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── imgrec_tokenizer.cpython-38.pyc
│ │ │ └── nvs_tokenizer.cpython-38.pyc
│ │ └── nvs_tokenizer.py
│ ├── transformer.py
│ └── versatile_np.py
├── run_trainer.py
├── trainers
│ ├── __init__.py
│ ├── base_trainer.py
│ ├── nvs_trainer.py
│ └── trainers.py
└── utils
│ ├── __init__.py
│ ├── common.py
│ └── geometry.py
└── imgs
└── framework.png
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Zongyu Guo
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Versatile-NP
2 |
3 | This repository contains the official implementation for the following paper:
4 |
5 | [**Versatile Neural Processes for Learning Implicit Neural Representations**](https://arxiv.org/abs/2301.08883), ICLR 2023
6 |
7 |
8 |
9 | ## Reproducing 3D Experiments
10 |
11 | The code for 3D experiments follows the logistics of [Trans-INR](https://github.com/yinboc/trans-inr).
12 |
13 | ### Environment
14 | - Python 3
15 | - Pytorch 1.7.1
16 | - pyyaml numpy tqdm imageio TensorboardX einops
17 |
18 | ### Data
19 |
20 | `mkdir data` and put different dataset folders in it.
21 |
22 | - **View synthesis**: download from [google drive](https://drive.google.com/drive/folders/1lRfg-Ov1dd3ldke9Gv9dyzGGTxiFOhIs) (provided by [learnit](https://www.matthewtancik.com/learnit)) and put them in a folder named `learnit_shapenet`, unzip the category folders and rename them as `chairs`, `cars`, `lamps` correspondingly.
23 |
24 | ### Training
25 |
26 | `cd exp3D`
27 | `CUDA_VISIBLE_DEVICES=[GPU] python run_trainer.py --cfg [CONFIG] --load-root [DATADIR]`
28 |
29 | Configs are in `cfgs/`. Four 3090Ti or four 32GB V100 GPUs are suggested for training.
30 |
31 | ### Evaluation
32 |
33 | For view synthesis, run in a single GPU with configs in `cfgs/nvs_eval`.
34 |
35 | ### Checkpoint models
36 |
37 | The pretrained checkpoint models can be found in [Google Drive](https://drive.google.com/drive/folders/16_ZrgYLH2oiV0uC6OBwnI3op-24nS0RK?usp=share_link).
38 |
39 | | | Cars | Lamps | Chairs |
40 | | :---- | :---- | :---- | :---- |
41 | | PSNR (dB) | 24.21 | 24.10 | 19.54 |
42 |
43 | Since the code is reorganized with unified training settings, the performances of lamps and chairs are slightly better than our initial submission in openreview, and the performance of cars is slightly lower than our initial submission. By adjusting the annealing strategy of beta coefficient, the performance of cars could be further improved.
44 |
45 | ## Reproducing 1D Experiments
46 |
47 | The code for 1D toy examples can be found in the [supplementary material of our Openreview submission](https://openreview.net/forum?id=2nLeOOfAjK).
--------------------------------------------------------------------------------
/exp3D/cfgs/nvs_eval/eval_nvs_shapenet_cars.yaml:
--------------------------------------------------------------------------------
1 | trainer: nvs_evaluator
2 |
3 | test_dataset:
4 | name: learnit_shapenet
5 | args:
6 | root_path: $load_root$/learnit_shapenet
7 | category: cars
8 | split: test
9 | n_support: 1
10 | n_query: 1
11 | repeat: 100
12 | loader:
13 | batch_size: 1
14 | num_workers: 1
15 |
16 | model:
17 | name: versatile_np
18 | args:
19 | tokenizer:
20 | name: nvs_tokenizer
21 | args: {input_size: 128, patch_size: 8}
22 | self_attender:
23 | name: self_attender
24 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024}
25 | cross_attender:
26 | name: cross_attender
27 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512}
28 | hierarchical_model:
29 | name: hier_model
30 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64}
31 |
32 | eval_model: save/nvs_shapenet_cars/epoch-last.pth
33 |
34 | train_points_per_ray: 128
35 | render_ray_batch: 1024
36 |
--------------------------------------------------------------------------------
/exp3D/cfgs/nvs_eval/eval_nvs_shapenet_chairs.yaml:
--------------------------------------------------------------------------------
1 | trainer: nvs_evaluator
2 |
3 | test_dataset:
4 | name: learnit_shapenet
5 | args:
6 | root_path: $load_root$/learnit_shapenet
7 | category: chairs
8 | split: test
9 | n_support: 1
10 | n_query: 1
11 | repeat: 100
12 | loader:
13 | batch_size: 1
14 | num_workers: 1
15 |
16 | model:
17 | name: versatile_np
18 | args:
19 | tokenizer:
20 | name: nvs_tokenizer
21 | args: {input_size: 128, patch_size: 8}
22 | self_attender:
23 | name: self_attender
24 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024}
25 | cross_attender:
26 | name: cross_attender
27 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512}
28 | hierarchical_model:
29 | name: hier_model
30 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64}
31 |
32 | eval_model: save/nvs_shapenet_chairs/epoch-last.pth
33 |
34 | train_points_per_ray: 128
35 | render_ray_batch: 1024
36 |
--------------------------------------------------------------------------------
/exp3D/cfgs/nvs_eval/eval_nvs_shapenet_lamps.yaml:
--------------------------------------------------------------------------------
1 | trainer: nvs_evaluator
2 |
3 | test_dataset:
4 | name: learnit_shapenet
5 | args:
6 | root_path: $load_root$/learnit_shapenet
7 | category: lamps
8 | split: test
9 | n_support: 1
10 | n_query: 1
11 | repeat: 100
12 | loader:
13 | batch_size: 1
14 | num_workers: 1
15 |
16 | model:
17 | name: versatile_np
18 | args:
19 | tokenizer:
20 | name: nvs_tokenizer
21 | args: {input_size: 128, patch_size: 8}
22 | self_attender:
23 | name: self_attender
24 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024}
25 | cross_attender:
26 | name: cross_attender
27 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512}
28 | hierarchical_model:
29 | name: hier_model
30 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64}
31 |
32 | eval_model: save/nvs_shapenet_lamps/epoch-last.pth
33 |
34 | train_points_per_ray: 128
35 | render_ray_batch: 4096
36 |
--------------------------------------------------------------------------------
/exp3D/cfgs/nvs_shapenet_cars.yaml:
--------------------------------------------------------------------------------
1 | trainer: nvs_trainer
2 |
3 | train_dataset:
4 | name: learnit_shapenet
5 | args:
6 | root_path: $load_root$/learnit_shapenet
7 | category: cars
8 | split: train
9 | views_rng: [0, 25]
10 | n_support: 1
11 | n_query: 1
12 | repeat: 2
13 | loader:
14 | batch_size: 32
15 | num_workers: 8
16 |
17 | test_dataset:
18 | name: learnit_shapenet
19 | args:
20 | root_path: $load_root$/learnit_shapenet
21 | category: cars
22 | split: test
23 | n_support: 1
24 | n_query: 1
25 | repeat: 100
26 | loader:
27 | batch_size: 32
28 | num_workers: 8
29 |
30 | model:
31 | name: versatile_np
32 | args:
33 | tokenizer:
34 | name: nvs_tokenizer
35 | args: {input_size: 128, patch_size: 8}
36 | self_attender:
37 | name: self_attender
38 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024}
39 | cross_attender:
40 | name: cross_attender
41 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512}
42 | hierarchical_model:
43 | name: hier_model
44 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64}
45 |
46 | train_points_per_ray: 128
47 | train_n_rays: 128
48 | render_ray_batch: 1024
49 |
50 | # resume_model: ./save/nvs_shapenet_chairs/epoch-last.pth
51 |
52 | optimizer:
53 | name: adam
54 | args: {lr: 1.e-4}
55 | max_epoch: 1000
56 | save_epoch: 100
57 | adaptive_sample_epoch: 1
58 | eval_epoch: 10
59 |
60 | Lambda: 0.001
--------------------------------------------------------------------------------
/exp3D/cfgs/nvs_shapenet_chairs.yaml:
--------------------------------------------------------------------------------
1 | trainer: nvs_trainer
2 |
3 | train_dataset:
4 | name: learnit_shapenet
5 | args:
6 | root_path: $load_root$/learnit_shapenet
7 | category: chairs
8 | split: train
9 | views_rng: [0, 25]
10 | n_support: 1
11 | n_query: 1
12 | repeat: 1
13 | loader:
14 | batch_size: 32
15 | num_workers: 8
16 |
17 | test_dataset:
18 | name: learnit_shapenet
19 | args:
20 | root_path: $load_root$/learnit_shapenet
21 | category: chairs
22 | split: test
23 | n_support: 1
24 | n_query: 1
25 | repeat: 100
26 | loader:
27 | batch_size: 32
28 | num_workers: 8
29 |
30 | model:
31 | name: versatile_np
32 | args:
33 | tokenizer:
34 | name: nvs_tokenizer
35 | args: {input_size: 128, patch_size: 8}
36 | self_attender:
37 | name: self_attender
38 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024}
39 | cross_attender:
40 | name: cross_attender
41 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512}
42 | hierarchical_model:
43 | name: hier_model
44 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64}
45 |
46 | train_points_per_ray: 128
47 | train_n_rays: 128
48 | render_ray_batch: 1024
49 |
50 | # resume_model: ./save/nvs_shapenet_chairs/epoch-last.pth
51 |
52 | optimizer:
53 | name: adam
54 | args: {lr: 1.e-4}
55 | max_epoch: 1000
56 | save_epoch: 100
57 | adaptive_sample_epoch: 1
58 | eval_epoch: 10
59 |
60 | Lambda: 0.001
--------------------------------------------------------------------------------
/exp3D/cfgs/nvs_shapenet_lamps.yaml:
--------------------------------------------------------------------------------
1 | trainer: nvs_trainer
2 |
3 | train_dataset:
4 | name: learnit_shapenet
5 | args:
6 | root_path: $load_root$/learnit_shapenet
7 | category: lamps
8 | split: train
9 | views_rng: [0, 25]
10 | n_support: 1
11 | n_query: 1
12 | repeat: 3
13 | loader:
14 | batch_size: 32
15 | num_workers: 8
16 |
17 | test_dataset:
18 | name: learnit_shapenet
19 | args:
20 | root_path: $load_root$/learnit_shapenet
21 | category: lamps
22 | split: test
23 | n_support: 1
24 | n_query: 1
25 | repeat: 100
26 | loader:
27 | batch_size: 32
28 | num_workers: 8
29 |
30 | model:
31 | name: versatile_np
32 | args:
33 | tokenizer:
34 | name: nvs_tokenizer
35 | args: {input_size: 128, patch_size: 8}
36 | self_attender:
37 | name: self_attender
38 | args: {dim: 512, depth: 6, n_head: 8, head_dim: 64, ff_dim: 1024}
39 | cross_attender:
40 | name: cross_attender
41 | args: {dim: 512, depth: 3, n_head: 4, head_dim: 128, ff_dim: 512}
42 | hierarchical_model:
43 | name: hier_model
44 | args: {depth: 4, dim_y: 512, dim_hid: 512, dim_lat: 64}
45 |
46 | train_points_per_ray: 128
47 | train_n_rays: 128
48 | render_ray_batch: 1024
49 |
50 | # resume_model: ./save/nvs_shapenet_lamps/epoch-last.pth
51 |
52 | optimizer:
53 | name: adam
54 | args: {lr: 1.e-4}
55 | max_epoch: 1000
56 | save_epoch: 100
57 | adaptive_sample_epoch: 1
58 | eval_epoch: 10
59 |
60 | Lambda: 0.001
--------------------------------------------------------------------------------
/exp3D/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import register, make
2 | from . import imgrec_dataset, celeba, imagenette
3 | from . import learnit_shapenet
4 |
--------------------------------------------------------------------------------
/exp3D/datasets/celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 | from torch.utils.data import Dataset
5 |
6 | from datasets import register
7 |
8 |
9 | @register('celeba')
10 | class Celeba(Dataset):
11 |
12 | def __init__(self, root_path, split):
13 | if split == 'train':
14 | s, t = 1, 162770
15 | elif split == 'val':
16 | s, t = 162771, 182637
17 | elif split == 'test':
18 | s, t = 182638, 202599
19 | self.data = []
20 | for i in range(s, t + 1):
21 | path = os.path.join(root_path, 'img_align_celeba', 'img_align_celeba', f'{i:06}.jpg')
22 | self.data.append(path)
23 |
24 | def __len__(self):
25 | return len(self.data)
26 |
27 | def __getitem__(self, idx):
28 | return Image.open(self.data[idx])
29 |
--------------------------------------------------------------------------------
/exp3D/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | datasets = {}
5 |
6 |
7 | def register(name):
8 | def decorator(cls):
9 | datasets[name] = cls
10 | return cls
11 | return decorator
12 |
13 |
14 | def make(dataset_spec, args=None):
15 | if args is not None:
16 | dataset_args = copy.deepcopy(dataset_spec['args'])
17 | dataset_args.update(args)
18 | else:
19 | dataset_args = dataset_spec['args']
20 | dataset = datasets[dataset_spec['name']](**dataset_args)
21 | return dataset
22 |
--------------------------------------------------------------------------------
/exp3D/datasets/imagenette.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 |
7 | from datasets import register
8 |
9 |
10 | @register('imagenette')
11 | class Imagenette(Dataset):
12 |
13 | def __init__(self, root_path, split, augment):
14 | root_path = os.path.join(root_path, split)
15 | classes = sorted(os.listdir(root_path))
16 | self.data = []
17 | for c in classes:
18 | filenames = sorted(os.listdir(os.path.join(root_path, c)))
19 | for f in filenames:
20 | self.data.append(os.path.join(root_path, c, f))
21 | if augment == 'none':
22 | self.transform = transforms.Compose([])
23 | elif augment == 'random_crop_178':
24 | self.transform = transforms.Compose([
25 | transforms.RandomCrop(178),
26 | transforms.RandomHorizontalFlip(),
27 | ])
28 |
29 | def __len__(self):
30 | return len(self.data)
31 |
32 | def __getitem__(self, idx):
33 | return self.transform(Image.open(self.data[idx]).convert('RGB'))
34 |
--------------------------------------------------------------------------------
/exp3D/datasets/imgrec_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from torchvision import transforms
3 |
4 | import datasets
5 | from datasets import register
6 |
7 |
8 | @register('imgrec_dataset')
9 | class ImgrecDataset(Dataset):
10 |
11 | def __init__(self, imageset, resize):
12 | self.imageset = datasets.make(imageset)
13 | self.transform = transforms.Compose([
14 | transforms.Resize(resize),
15 | transforms.CenterCrop(resize),
16 | transforms.ToTensor(),
17 | ])
18 |
19 | def __len__(self):
20 | return len(self.imageset)
21 |
22 | def __getitem__(self, idx):
23 | x = self.transform(self.imageset[idx])
24 | return {'inp': x, 'gt': x}
25 |
--------------------------------------------------------------------------------
/exp3D/datasets/learnit_shapenet.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/tancik/learnit/blob/main/Experiments/shapenet.ipynb
2 | import os
3 | import json
4 |
5 | import imageio
6 | import numpy as np
7 | import torch
8 | import einops
9 | from torch.utils.data import Dataset
10 |
11 | from datasets import register
12 |
13 |
14 | @register('learnit_shapenet')
15 | class LearnitShapenet(Dataset):
16 |
17 | def __init__(self, root_path, category, split, n_support, n_query, views_rng=None, repeat=1):
18 | with open(os.path.join(root_path, category[:-len('s')] + '_splits.json'), 'r') as f:
19 | obj_ids = json.load(f)[split]
20 | _data = [os.path.join(root_path, category, _) for _ in obj_ids]
21 | self.data = []
22 | for x in _data:
23 | if os.path.exists(os.path.join(x, 'transforms.json')):
24 | self.data.append(x)
25 | else:
26 | print(f'Missing obj at {x}, skipped.')
27 | self.n_support = n_support
28 | self.n_query = n_query
29 | self.views_rng = views_rng
30 | self.repeat = repeat
31 |
32 | def __len__(self):
33 | return len(self.data) * self.repeat
34 |
35 | def __getitem__(self, idx):
36 | idx %= len(self.data)
37 |
38 | train_ex_dir = self.data[idx]
39 | with open(os.path.join(train_ex_dir, 'transforms.json'), 'r') as fp:
40 | meta = json.load(fp)
41 | camera_angle_x = float(meta['camera_angle_x'])
42 | frames = meta['frames']
43 | if self.views_rng is not None:
44 | frames = frames[self.views_rng[0]: self.views_rng[1]]
45 |
46 | frames = np.random.choice(frames, self.n_support + self.n_query, replace=False)
47 |
48 | imgs = []
49 | poses = []
50 | for frame in frames:
51 | fname = os.path.join(train_ex_dir, os.path.basename(frame['file_path']) + '.png')
52 | imgs.append(imageio.imread(fname))
53 | poses.append(np.array(frame['transform_matrix']))
54 | H, W = imgs[0].shape[:2]
55 | assert H == W
56 | focal = .5 * W / np.tan(.5 * camera_angle_x)
57 | imgs = (np.array(imgs) / 255.).astype(np.float32)
58 | imgs = imgs[..., :3] * imgs[..., -1:] + (1 - imgs[..., -1:])
59 | poses = np.array(poses).astype(np.float32)
60 |
61 | imgs = einops.rearrange(torch.from_numpy(imgs), 'n h w c -> n c h w')
62 | poses = torch.from_numpy(poses)[:, :3, :4]
63 | focal = torch.ones(len(poses), 2) * float(focal)
64 | t = self.n_support
65 | return {
66 | 'support_imgs': imgs[:t],
67 | 'support_poses': poses[:t],
68 | 'support_focals': focal[:t],
69 | 'query_imgs': imgs[t:],
70 | 'query_poses': poses[t:],
71 | 'query_focals': focal[t:],
72 | 'near': 2,
73 | 'far': 6,
74 | }
75 |
--------------------------------------------------------------------------------
/exp3D/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import register, make
2 | from . import versatile_np
3 | from . import hier_model
4 | from . import tokenizers
5 | from . import transformer
6 |
--------------------------------------------------------------------------------
/exp3D/models/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | import models
9 | from models.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
10 | # from op import FusedLeakyReLU, fused_leaky_relu
11 |
12 | class PixelNorm(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def forward(self, input):
17 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
18 |
19 |
20 | def make_kernel(k):
21 | k = torch.tensor(k, dtype=torch.float32)
22 |
23 | if k.ndim == 1:
24 | k = k[None, :] * k[:, None]
25 |
26 | k /= k.sum()
27 |
28 | return k
29 |
30 |
31 | class Upsample(nn.Module):
32 | def __init__(self, kernel, factor=2):
33 | super().__init__()
34 |
35 | self.factor = factor
36 | kernel = make_kernel(kernel) * (factor ** 2)
37 | self.register_buffer('kernel', kernel)
38 |
39 | p = kernel.shape[0] - factor
40 |
41 | pad0 = (p + 1) // 2 + factor - 1
42 | pad1 = p // 2
43 |
44 | self.pad = (pad0, pad1)
45 |
46 | def forward(self, input):
47 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
48 | return out
49 |
50 |
51 | class Downsample(nn.Module):
52 | def __init__(self, kernel, factor=2):
53 | super().__init__()
54 |
55 | self.factor = factor
56 | kernel = make_kernel(kernel)
57 | self.register_buffer('kernel', kernel)
58 |
59 | p = kernel.shape[0] - factor
60 |
61 | pad0 = (p + 1) // 2
62 | pad1 = p // 2
63 |
64 | self.pad = (pad0, pad1)
65 |
66 | def forward(self, input):
67 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
68 | return out
69 |
70 |
71 | class Blur(nn.Module):
72 | def __init__(self, kernel, pad, upsample_factor=1):
73 | super().__init__()
74 |
75 | kernel = make_kernel(kernel)
76 |
77 | if upsample_factor > 1:
78 | kernel = kernel * (upsample_factor ** 2)
79 |
80 | self.register_buffer('kernel', kernel)
81 |
82 | self.pad = pad
83 |
84 | def forward(self, input):
85 | out = upfirdn2d(input, self.kernel, pad=self.pad)
86 | return out
87 |
88 |
89 | class EqualConv2d(nn.Module):
90 | def __init__(
91 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
92 | ):
93 | super().__init__()
94 |
95 | self.weight = nn.Parameter(
96 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
97 | )
98 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
99 |
100 | self.stride = stride
101 | self.padding = padding
102 |
103 | if bias:
104 | self.bias = nn.Parameter(torch.zeros(out_channel))
105 |
106 | else:
107 | self.bias = None
108 |
109 | def forward(self, input):
110 | out = F.conv2d(
111 | input,
112 | self.weight * self.scale,
113 | bias=self.bias,
114 | stride=self.stride,
115 | padding=self.padding,
116 | )
117 | return out
118 |
119 | def __repr__(self):
120 | return (
121 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
122 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
123 | )
124 |
125 |
126 | class EqualLinear(nn.Module):
127 | def __init__(
128 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
129 | ):
130 | super().__init__()
131 |
132 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
133 |
134 | if bias:
135 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
136 |
137 | else:
138 | self.bias = None
139 |
140 | self.activation = activation
141 |
142 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
143 | self.lr_mul = lr_mul
144 |
145 | def forward(self, input):
146 | if self.activation:
147 | out = F.linear(input, self.weight * self.scale)
148 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
149 |
150 | else:
151 | out = F.linear(
152 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
153 | )
154 |
155 | return out
156 |
157 | def __repr__(self):
158 | return (
159 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
160 | )
161 |
162 |
163 | class ScaledLeakyReLU(nn.Module):
164 | def __init__(self, negative_slope=0.2):
165 | super().__init__()
166 |
167 | self.negative_slope = negative_slope
168 |
169 | def forward(self, input):
170 | out = F.leaky_relu(input, negative_slope=self.negative_slope)
171 | return out * math.sqrt(2)
172 |
173 |
174 | class ModulatedConv2d(nn.Module):
175 | def __init__(
176 | self,
177 | in_channel,
178 | out_channel,
179 | kernel_size,
180 | style_dim,
181 | demodulate=True,
182 | upsample=False,
183 | downsample=False,
184 | blur_kernel=[1, 3, 3, 1],
185 | ):
186 | super().__init__()
187 |
188 | self.eps = 1e-8
189 | self.kernel_size = kernel_size
190 | self.in_channel = in_channel
191 | self.out_channel = out_channel
192 | self.upsample = upsample
193 | self.downsample = downsample
194 |
195 | if upsample:
196 | factor = 2
197 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
198 | pad0 = (p + 1) // 2 + factor - 1
199 | pad1 = p // 2 + 1
200 |
201 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
202 |
203 | if downsample:
204 | factor = 2
205 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
206 | pad0 = (p + 1) // 2
207 | pad1 = p // 2
208 |
209 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
210 |
211 | fan_in = in_channel * kernel_size ** 2
212 | self.scale = 1 / math.sqrt(fan_in)
213 | self.padding = kernel_size // 2
214 |
215 | self.weight = nn.Parameter(
216 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
217 | )
218 |
219 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
220 |
221 | self.demodulate = demodulate
222 |
223 | def __repr__(self):
224 | return (
225 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
226 | f'upsample={self.upsample}, downsample={self.downsample})'
227 | )
228 |
229 | def forward(self, input, style):
230 | batch, in_channel, height, width = input.shape
231 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
232 | weight = self.scale * self.weight * style
233 | if self.demodulate:
234 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
235 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
236 |
237 | weight = weight.view(
238 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
239 | )
240 |
241 | if self.upsample:
242 | input = input.view(1, batch * in_channel, height, width)
243 | weight = weight.view(
244 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
245 | )
246 | weight = weight.transpose(1, 2).reshape(
247 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
248 | )
249 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
250 | _, _, height, width = out.shape
251 | out = out.view(batch, self.out_channel, height, width)
252 | out = self.blur(out)
253 |
254 | elif self.downsample:
255 | input = self.blur(input)
256 | _, _, height, width = input.shape
257 | input = input.view(1, batch * in_channel, height, width)
258 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
259 | _, _, height, width = out.shape
260 | out = out.view(batch, self.out_channel, height, width)
261 |
262 | else:
263 | input = input.view(1, batch * in_channel, height, width)
264 | out = F.conv2d(input, weight, padding=self.padding, groups=batch)
265 | _, _, height, width = out.shape
266 | out = out.view(batch, self.out_channel, height, width)
267 |
268 | return out
269 |
270 |
271 | class NoiseInjection(nn.Module):
272 | def __init__(self):
273 | super().__init__()
274 |
275 | self.weight = nn.Parameter(torch.zeros(1))
276 |
277 | def forward(self, image, noise=None):
278 | if noise is None:
279 | batch, _, height, width = image.shape
280 | noise = image.new_empty(batch, 1, height, width).normal_()
281 |
282 | return image + self.weight * noise
283 |
284 |
285 | class ConstantInput(nn.Module):
286 | def __init__(self, channel, size=4):
287 | super().__init__()
288 |
289 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
290 |
291 | def forward(self, batch_size):
292 | out = self.input.repeat(batch_size, 1, 1, 1)
293 | return out
294 |
295 |
296 | class StyledConv(nn.Module):
297 | def __init__(
298 | self,
299 | in_channel,
300 | out_channel,
301 | kernel_size,
302 | style_dim,
303 | upsample=False,
304 | blur_kernel=[1, 3, 3, 1],
305 | demodulate=True,
306 | activation=None,
307 | downsample=False,
308 | ):
309 | super().__init__()
310 |
311 | self.conv = ModulatedConv2d(
312 | in_channel,
313 | out_channel,
314 | kernel_size,
315 | style_dim,
316 | upsample=upsample,
317 | blur_kernel=blur_kernel,
318 | demodulate=demodulate,
319 | downsample=downsample,
320 | )
321 |
322 | self.activation = activation
323 | self.noise = NoiseInjection()
324 | if activation == 'sinrelu':
325 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
326 | self.activate = ScaledLeakyReLUSin()
327 | elif activation == 'sin':
328 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
329 | self.activate = SinActivation()
330 | else:
331 | self.activate = FusedLeakyReLU(out_channel)
332 |
333 | def forward(self, input, style, noise=None):
334 | out = self.conv(input, style)
335 | # out = self.noise(out, noise=noise)
336 | if self.activation == 'sinrelu' or self.activation == 'sin':
337 | out = out + self.bias
338 | out = self.activate(out)
339 |
340 | return out
341 |
342 |
343 | class ToRGB(nn.Module):
344 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
345 | super().__init__()
346 |
347 | self.upsample = upsample
348 | if upsample:
349 | self.upsample = Upsample(blur_kernel)
350 |
351 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
352 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
353 |
354 | def forward(self, input, style, skip=None):
355 | out = self.conv(input, style)
356 | out = out + self.bias
357 |
358 | if skip is not None:
359 | if self.upsample:
360 | skip = self.upsample(skip)
361 | out = out + skip
362 | return out
363 |
364 |
365 | class ToMixLogistic(nn.Module):
366 | def __init__(self, in_channel, out_channel, style_dim):
367 | # def __init__(self, in_channel, out_channel, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1]):
368 | super().__init__()
369 | # self.upsample = upsample
370 | # if upsample:
371 | # self.upsample = Upsample(blur_kernel)
372 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False)
373 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
374 |
375 | def forward(self, input, style, skip=None):
376 | out = self.conv(input, style)
377 | out = out + self.bias
378 | if skip is not None:
379 | # if self.upsample:
380 | # skip = self.upsample(skip)
381 | out = out + skip
382 | return out
383 |
384 | class ToMixLogisticNoCond(nn.Module):
385 | def __init__(self, in_channel, out_channel):
386 | # def __init__(self, in_channel, out_channel, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1]):
387 | super().__init__()
388 | # self.upsample = upsample
389 | # if upsample:
390 | # self.upsample = Upsample(blur_kernel)
391 | self.conv = nn.Conv2d(in_channel, out_channel, 1)
392 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
393 |
394 | def forward(self, input, skip=None):
395 | out = self.conv(input)
396 | out = out + self.bias
397 | if skip is not None:
398 | # if self.upsample:
399 | # skip = self.upsample(skip)
400 | out = out + skip
401 | return out
402 |
403 |
404 | class EqualConvTranspose2d(nn.Module):
405 | def __init__(
406 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
407 | ):
408 | super().__init__()
409 |
410 | self.weight = nn.Parameter(
411 | torch.randn(in_channel, out_channel, kernel_size, kernel_size)
412 | )
413 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
414 |
415 | self.stride = stride
416 | self.padding = padding
417 |
418 | if bias:
419 | self.bias = nn.Parameter(torch.zeros(out_channel))
420 |
421 | else:
422 | self.bias = None
423 |
424 | def forward(self, input):
425 | out = F.conv_transpose2d(
426 | input,
427 | self.weight * self.scale,
428 | bias=self.bias,
429 | stride=self.stride,
430 | padding=self.padding,
431 | )
432 |
433 | return out
434 |
435 | def __repr__(self):
436 | return (
437 | f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},"
438 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
439 | )
440 |
441 |
442 | class ConvLayer(nn.Sequential):
443 | def __init__(
444 | self,
445 | in_channel,
446 | out_channel,
447 | kernel_size,
448 | downsample=False,
449 | blur_kernel=[1, 3, 3, 1],
450 | bias=True,
451 | activate=True,
452 | upsample=False,
453 | padding="zero",
454 | ):
455 | layers = []
456 |
457 | self.padding = 0
458 | stride = 1
459 |
460 | if downsample:
461 | factor = 2
462 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
463 | pad0 = (p + 1) // 2
464 | pad1 = p // 2
465 |
466 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
467 |
468 | stride = 2
469 |
470 | if upsample:
471 | layers.append(
472 | EqualConvTranspose2d(
473 | in_channel,
474 | out_channel,
475 | kernel_size,
476 | padding=0,
477 | stride=2,
478 | bias=bias and not activate,
479 | )
480 | )
481 |
482 | factor = 2
483 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
484 | pad0 = (p + 1) // 2 + factor - 1
485 | pad1 = p // 2 + 1
486 |
487 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
488 |
489 | else:
490 | if not downsample:
491 | if padding == "zero":
492 | self.padding = (kernel_size - 1) // 2
493 |
494 | elif padding == "reflect":
495 | padding = (kernel_size - 1) // 2
496 |
497 | if padding > 0:
498 | layers.append(nn.ReflectionPad2d(padding))
499 |
500 | self.padding = 0
501 |
502 | elif padding != "valid":
503 | raise ValueError('Padding should be "zero", "reflect", or "valid"')
504 |
505 | layers.append(
506 | EqualConv2d(
507 | in_channel,
508 | out_channel,
509 | kernel_size,
510 | padding=self.padding,
511 | stride=stride,
512 | bias=bias and not activate,
513 | )
514 | )
515 |
516 | if activate:
517 | if bias:
518 | layers.append(FusedLeakyReLU(out_channel))
519 |
520 | else:
521 | layers.append(ScaledLeakyReLU(0.2))
522 |
523 | super().__init__(*layers)
524 |
525 |
526 | class ResBlock(nn.Module):
527 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], kernel_size=3, downsample=True):
528 | super().__init__()
529 |
530 | self.conv1 = ConvLayer(in_channel, in_channel, kernel_size)
531 | self.conv2 = ConvLayer(in_channel, out_channel, kernel_size, downsample=downsample)
532 |
533 | self.skip = ConvLayer(
534 | in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
535 | )
536 |
537 | def forward(self, input):
538 | out = self.conv1(input)
539 | out = self.conv2(out)
540 |
541 | skip = self.skip(input)
542 | out = (out + skip) / math.sqrt(2)
543 |
544 | return out
545 |
546 |
547 | class ConLinear(nn.Module):
548 | def __init__(self, ch_in, ch_out, is_first=False, bias=True):
549 | super(ConLinear, self).__init__()
550 | self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0, bias=bias)
551 | if is_first:
552 | nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))
553 | else:
554 | nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in))
555 |
556 | def forward(self, x):
557 | return self.conv(x)
558 |
559 |
560 | class SinActivation(nn.Module):
561 | def __init__(self,):
562 | super(SinActivation, self).__init__()
563 |
564 | def forward(self, x):
565 | return torch.sin(x)
566 |
567 | class SinActivationBias(nn.Module):
568 | def __init__(self, out_channel):
569 | super(SinActivationBias, self).__init__()
570 | self.bias = nn.Parameter(torch.zeros(1, 1, out_channel))
571 |
572 | def forward(self, x):
573 | return torch.sin(x + self.bias)
574 |
575 |
576 | class LFF(nn.Module):
577 | def __init__(self, hidden_size, ):
578 | super(LFF, self).__init__()
579 | self.ffm = ConLinear(1, hidden_size, is_first=True)
580 | self.activation = SinActivation()
581 |
582 | def forward(self, x):
583 | x = self.ffm(x)
584 | x = self.activation(x)
585 | return x
586 |
587 |
588 | class ScaledLeakyReLUSin(nn.Module):
589 | def __init__(self, negative_slope=0.2):
590 | super().__init__()
591 |
592 | self.negative_slope = negative_slope
593 |
594 | def forward(self, input):
595 | out_lr = F.leaky_relu(input[:, ::2], negative_slope=self.negative_slope)
596 | out_sin = torch.sin(input[:, 1::2])
597 | out = torch.cat([out_lr, out_sin], 1)
598 | return out * math.sqrt(2)
599 |
600 |
601 | class StyledResBlock(nn.Module):
602 | def __init__(self, in_channel, out_channel, kernel_size, style_dim, blur_kernel=[1, 3, 3, 1], demodulate=True,
603 | activation=None, upsample=False, downsample=False):
604 | super().__init__()
605 |
606 | self.conv1 = StyledConv(in_channel, out_channel, kernel_size, style_dim,
607 | demodulate=demodulate, activation=activation)
608 | self.conv2 = StyledConv(out_channel, out_channel, kernel_size, style_dim,
609 | demodulate=demodulate, activation=activation,
610 | upsample=upsample, downsample=downsample)
611 |
612 | if downsample or in_channel != out_channel or upsample:
613 | self.skip = ConvLayer(
614 | in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False, upsample=upsample,
615 | )
616 | else:
617 | self.skip = None
618 |
619 | def forward(self, input, latent):
620 | out = self.conv1(input, latent)
621 | out = self.conv2(out, latent)
622 |
623 | if self.skip is not None:
624 | skip = self.skip(input)
625 | else:
626 | skip = input
627 |
628 | out = (out + skip) / math.sqrt(2)
629 |
630 | return out
631 |
--------------------------------------------------------------------------------
/exp3D/models/hier_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.distributions import kl_divergence, Normal
7 | import einops
8 |
9 | from models import register
10 | from utils import rendering
11 |
12 | from .blocks import StyledConv, ModulatedConv2d
13 |
14 | def build_mlp(dim_in, dim_hid, dim_out, depth):
15 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
16 | for _ in range(depth-2):
17 | modules.append(nn.Linear(dim_hid, dim_hid))
18 | modules.append(nn.ReLU(True))
19 | modules.append(nn.Linear(dim_hid, dim_out))
20 | return nn.Sequential(*modules)
21 |
22 |
23 | @register('hier_model')
24 | class Hierarchical_Model(nn.Module):
25 | def __init__(self, depth, dim_y, dim_hid, dim_lat):
26 | super().__init__()
27 | self.layers = nn.ModuleList()
28 | for _ in range(depth):
29 | self.layers.append(ModBlock(dim_y=dim_y, dim_hid=dim_hid, dim_lat=dim_lat))
30 |
31 | def forward(self, y, x_tgt, y_tgt, rays_o, z_vals, training=True):
32 | kls = 0
33 | for layer in self.layers:
34 | y, kld = layer(y, x_tgt, y_tgt, rays_o, z_vals, training=training)
35 | kls += kld
36 | return y, kls
37 |
38 |
39 | class ModBlock(nn.Module):
40 | def __init__(self, dim_y, dim_hid, dim_lat):
41 | super().__init__()
42 | self.input_mlp = build_mlp(dim_y, dim_hid, 4, depth=3)
43 |
44 | self.merge_p = build_mlp(3, dim_hid, dim_hid, depth=2)
45 | self.merge_q = build_mlp(6, dim_hid, dim_hid, depth=2)
46 | self.latent_encoder_p = build_mlp(dim_hid, dim_hid, dim_lat * 2, depth=2)
47 | self.latent_encoder_q = build_mlp(dim_hid, dim_hid, dim_lat * 2, depth=2)
48 | self.latent_decoder = build_mlp(dim_lat, dim_hid, dim_hid, depth=2)
49 |
50 | self.unmod_mlp = build_mlp(dim_y * 2, dim_hid, dim_y, depth=2)
51 | self.mod_conv1 = StyledConv(dim_y, dim_hid, 1, dim_hid, demodulate=True)# with built-in activation function
52 | self.mod_conv2 = ModulatedConv2d(dim_hid, dim_y, 1, dim_hid, demodulate=True) # without built-in activation function
53 | self.mod_conv2.weight.data *= np.sqrt(1 / 6)
54 |
55 | def forward(self, y, x_tgt, y_tgt, rays_o, z_vals, training=True):
56 | # 3D points -> 2D image pixels
57 | y_rgb = rendering(self.input_mlp(y), rays_o=rays_o, z_vals=z_vals)
58 |
59 | # variational inference in Neural Process (with an average pooling across all points)
60 | # if not training, y_tgt is not used
61 | z, kld = self.forward_latent(y_rgb, y_tgt, training=training)
62 |
63 | # unmodulated and modulated layers leveraging latent variables
64 | y = self.forward_mlps(y, x_tgt, latent=z)
65 | return y, kld
66 |
67 | def forward_latent(self, y_rgb, y_tgt, training):
68 | z_prior = self.merge_p(y_rgb).mean(dim=-2, keepdim=True)
69 | dist_prior = self.normal_distribution(self.latent_encoder_p(z_prior))
70 |
71 | if training:
72 | z_posterior = self.merge_q(torch.cat([y_rgb, y_tgt], dim=-1)).mean(dim=-2, keepdim=True)
73 | dist_posterior = self.normal_distribution(self.latent_encoder_q(z_posterior))
74 | z = dist_posterior.rsample()
75 | kld = kl_divergence(dist_posterior, dist_prior).sum(-1)
76 | else:
77 | z = dist_prior.rsample()
78 | kld = torch.zeros_like(z)
79 |
80 | z = self.latent_decoder(z)
81 | return z, kld
82 |
83 | def forward_mlps(self, y, x_tgt, latent):
84 | b, num_tgt, c = y.shape
85 |
86 | # Two unmodulated MLPs (residual structure)
87 | # We enhance feature y by combining x_tgt, which can actually be omitted.
88 | y = y + self.unmod_mlp(torch.cat([x_tgt, y], dim=-1))
89 |
90 | # Two modulated MLPs (residual structure)
91 | # Modulated MLP is implemented in the 2D form.
92 | y_res = einops.rearrange(y, 'b n c -> b c n 1').contiguous()
93 | y_res = self.mod_conv1(y_res, latent)
94 | y_res = self.mod_conv2(y_res, latent)
95 | y_res = einops.rearrange(y_res, 'b c n 1 -> b n c').contiguous()
96 |
97 | y = y + y_res
98 | return y
99 |
100 | @staticmethod
101 | def normal_distribution(input):
102 | mean, var = input.chunk(2, dim=-1)
103 | var = 0.1 + 0.9 * torch.sigmoid(var)
104 | dist = Normal(mean, var)
105 | return dist
106 |
107 |
--------------------------------------------------------------------------------
/exp3D/models/models.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | models = {}
5 |
6 |
7 | def register(name):
8 | def decorator(cls):
9 | models[name] = cls
10 | return cls
11 | return decorator
12 |
13 |
14 | def make(model_spec, args=None, load_sd=False):
15 | if args is not None:
16 | model_args = copy.deepcopy(model_spec['args'])
17 | model_args.update(args)
18 | else:
19 | model_args = model_spec['args']
20 | model = models[model_spec['name']](**model_args)
21 | if load_sd:
22 | model.load_state_dict(model_spec['sd'])
23 | return model
24 |
--------------------------------------------------------------------------------
/exp3D/models/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
--------------------------------------------------------------------------------
/exp3D/models/op/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/exp3D/models/op/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/exp3D/models/op/__pycache__/fused_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/fused_act.cpython-37.pyc
--------------------------------------------------------------------------------
/exp3D/models/op/__pycache__/fused_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/fused_act.cpython-38.pyc
--------------------------------------------------------------------------------
/exp3D/models/op/__pycache__/upfirdn2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/upfirdn2d.cpython-37.pyc
--------------------------------------------------------------------------------
/exp3D/models/op/__pycache__/upfirdn2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/op/__pycache__/upfirdn2d.cpython-38.pyc
--------------------------------------------------------------------------------
/exp3D/models/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | fused = load(
11 | 'fused',
12 | sources=[
13 | os.path.join(module_path, 'fused_bias_act.cpp'),
14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15 | ],
16 | )
17 |
18 |
19 | class FusedLeakyReLUFunctionBackward(Function):
20 | @staticmethod
21 | def forward(ctx, grad_output, out, negative_slope, scale):
22 | ctx.save_for_backward(out)
23 | ctx.negative_slope = negative_slope
24 | ctx.scale = scale
25 |
26 | empty = grad_output.new_empty(0)
27 |
28 | grad_input = fused.fused_bias_act(
29 | grad_output, empty, out, 3, 1, negative_slope, scale
30 | )
31 |
32 | dim = [0]
33 |
34 | if grad_input.ndim > 2:
35 | dim += list(range(2, grad_input.ndim))
36 |
37 | grad_bias = grad_input.sum(dim).detach()
38 |
39 | return grad_input, grad_bias
40 |
41 | @staticmethod
42 | def backward(ctx, gradgrad_input, gradgrad_bias):
43 | out, = ctx.saved_tensors
44 | gradgrad_out = fused.fused_bias_act(
45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46 | )
47 |
48 | return gradgrad_out, None, None, None
49 |
50 |
51 | class FusedLeakyReLUFunction(Function):
52 | @staticmethod
53 | def forward(ctx, input, bias, negative_slope, scale):
54 | empty = input.new_empty(0)
55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56 | ctx.save_for_backward(out)
57 | ctx.negative_slope = negative_slope
58 | ctx.scale = scale
59 |
60 | return out
61 |
62 | @staticmethod
63 | def backward(ctx, grad_output):
64 | out, = ctx.saved_tensors
65 |
66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67 | grad_output, out, ctx.negative_slope, ctx.scale
68 | )
69 |
70 | return grad_input, grad_bias, None, None
71 |
72 |
73 | class FusedLeakyReLU(nn.Module):
74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
75 | super().__init__()
76 |
77 | self.bias = nn.Parameter(torch.zeros(channel))
78 | self.negative_slope = negative_slope
79 | self.scale = scale
80 |
81 | def forward(self, input):
82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
83 |
84 |
85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
87 |
--------------------------------------------------------------------------------
/exp3D/models/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/exp3D/models/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/exp3D/models/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/exp3D/models/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.autograd import Function
5 | from torch.utils.cpp_extension import load
6 |
7 |
8 | module_path = os.path.dirname(__file__)
9 |
10 | upfirdn2d_op = load(
11 | 'upfirdn2d',
12 | sources=[
13 | os.path.join(module_path, 'upfirdn2d.cpp'),
14 | os.path.join(module_path, 'upfirdn2d_kernel.cu'),
15 | ],
16 | )
17 |
18 | class UpFirDn2dBackward(Function):
19 | @staticmethod
20 | def forward(
21 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22 | ):
23 |
24 | up_x, up_y = up
25 | down_x, down_y = down
26 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27 |
28 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29 |
30 | grad_input = upfirdn2d_op.upfirdn2d(
31 | grad_output,
32 | grad_kernel,
33 | down_x,
34 | down_y,
35 | up_x,
36 | up_y,
37 | g_pad_x0,
38 | g_pad_x1,
39 | g_pad_y0,
40 | g_pad_y1,
41 | )
42 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43 |
44 | ctx.save_for_backward(kernel)
45 |
46 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
47 |
48 | ctx.up_x = up_x
49 | ctx.up_y = up_y
50 | ctx.down_x = down_x
51 | ctx.down_y = down_y
52 | ctx.pad_x0 = pad_x0
53 | ctx.pad_x1 = pad_x1
54 | ctx.pad_y0 = pad_y0
55 | ctx.pad_y1 = pad_y1
56 | ctx.in_size = in_size
57 | ctx.out_size = out_size
58 |
59 | return grad_input
60 |
61 | @staticmethod
62 | def backward(ctx, gradgrad_input):
63 | kernel, = ctx.saved_tensors
64 |
65 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66 |
67 | gradgrad_out = upfirdn2d_op.upfirdn2d(
68 | gradgrad_input,
69 | kernel,
70 | ctx.up_x,
71 | ctx.up_y,
72 | ctx.down_x,
73 | ctx.down_y,
74 | ctx.pad_x0,
75 | ctx.pad_x1,
76 | ctx.pad_y0,
77 | ctx.pad_y1,
78 | )
79 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80 | gradgrad_out = gradgrad_out.view(
81 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82 | )
83 |
84 | return gradgrad_out, None, None, None, None, None, None, None, None
85 |
86 |
87 | class UpFirDn2d(Function):
88 | @staticmethod
89 | def forward(ctx, input, kernel, up, down, pad):
90 | up_x, up_y = up
91 | down_x, down_y = down
92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
93 |
94 | kernel_h, kernel_w = kernel.shape
95 | batch, channel, in_h, in_w = input.shape
96 | ctx.in_size = input.shape
97 |
98 | input = input.reshape(-1, in_h, in_w, 1)
99 |
100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101 |
102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104 | ctx.out_size = (out_h, out_w)
105 |
106 | ctx.up = (up_x, up_y)
107 | ctx.down = (down_x, down_y)
108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109 |
110 | g_pad_x0 = kernel_w - pad_x0 - 1
111 | g_pad_y0 = kernel_h - pad_y0 - 1
112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114 |
115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116 |
117 | out = upfirdn2d_op.upfirdn2d(
118 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119 | )
120 | # out = out.view(major, out_h, out_w, minor)
121 | out = out.view(-1, channel, out_h, out_w)
122 |
123 | return out
124 |
125 | @staticmethod
126 | def backward(ctx, grad_output):
127 | kernel, grad_kernel = ctx.saved_tensors
128 |
129 | grad_input = UpFirDn2dBackward.apply(
130 | grad_output,
131 | kernel,
132 | grad_kernel,
133 | ctx.up,
134 | ctx.down,
135 | ctx.pad,
136 | ctx.g_pad,
137 | ctx.in_size,
138 | ctx.out_size,
139 | )
140 |
141 | return grad_input, None, None, None, None
142 |
143 |
144 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145 | out = UpFirDn2d.apply(
146 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147 | )
148 |
149 | return out
150 |
151 |
152 | def upfirdn2d_native(
153 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154 | ):
155 | _, in_h, in_w, minor = input.shape
156 | kernel_h, kernel_w = kernel.shape
157 |
158 | out = input.view(-1, in_h, 1, in_w, 1, minor)
159 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
160 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
161 |
162 | out = F.pad(
163 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
164 | )
165 | out = out[
166 | :,
167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
169 | :,
170 | ]
171 |
172 | out = out.permute(0, 3, 1, 2)
173 | out = out.reshape(
174 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
175 | )
176 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
177 | out = F.conv2d(out, w)
178 | out = out.reshape(
179 | -1,
180 | minor,
181 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
182 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183 | )
184 | out = out.permute(0, 2, 3, 1)
185 |
186 | return out[:, ::down_y, ::down_x, :]
187 |
188 |
--------------------------------------------------------------------------------
/exp3D/models/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19 | int c = a / b;
20 |
21 | if (c * b > a) {
22 | c--;
23 | }
24 |
25 | return c;
26 | }
27 |
28 |
29 | struct UpFirDn2DKernelParams {
30 | int up_x;
31 | int up_y;
32 | int down_x;
33 | int down_y;
34 | int pad_x0;
35 | int pad_x1;
36 | int pad_y0;
37 | int pad_y1;
38 |
39 | int major_dim;
40 | int in_h;
41 | int in_w;
42 | int minor_dim;
43 | int kernel_h;
44 | int kernel_w;
45 | int out_h;
46 | int out_w;
47 | int loop_major;
48 | int loop_x;
49 | };
50 |
51 |
52 | template
53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56 |
57 | __shared__ volatile float sk[kernel_h][kernel_w];
58 | __shared__ volatile float sx[tile_in_h][tile_in_w];
59 |
60 | int minor_idx = blockIdx.x;
61 | int tile_out_y = minor_idx / p.minor_dim;
62 | minor_idx -= tile_out_y * p.minor_dim;
63 | tile_out_y *= tile_out_h;
64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65 | int major_idx_base = blockIdx.z * p.loop_major;
66 |
67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68 | return;
69 | }
70 |
71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72 | int ky = tap_idx / kernel_w;
73 | int kx = tap_idx - ky * kernel_w;
74 | scalar_t v = 0.0;
75 |
76 | if (kx < p.kernel_w & ky < p.kernel_h) {
77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78 | }
79 |
80 | sk[ky][kx] = v;
81 | }
82 |
83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87 | int tile_in_x = floor_div(tile_mid_x, up_x);
88 | int tile_in_y = floor_div(tile_mid_y, up_y);
89 |
90 | __syncthreads();
91 |
92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93 | int rel_in_y = in_idx / tile_in_w;
94 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
95 | int in_x = rel_in_x + tile_in_x;
96 | int in_y = rel_in_y + tile_in_y;
97 |
98 | scalar_t v = 0.0;
99 |
100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102 | }
103 |
104 | sx[rel_in_y][rel_in_x] = v;
105 | }
106 |
107 | __syncthreads();
108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109 | int rel_out_y = out_idx / tile_out_w;
110 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
111 | int out_x = rel_out_x + tile_out_x;
112 | int out_y = rel_out_y + tile_out_y;
113 |
114 | int mid_x = tile_mid_x + rel_out_x * down_x;
115 | int mid_y = tile_mid_y + rel_out_y * down_y;
116 | int in_x = floor_div(mid_x, up_x);
117 | int in_y = floor_div(mid_y, up_y);
118 | int rel_in_x = in_x - tile_in_x;
119 | int rel_in_y = in_y - tile_in_y;
120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122 |
123 | scalar_t v = 0.0;
124 |
125 | #pragma unroll
126 | for (int y = 0; y < kernel_h / up_y; y++)
127 | #pragma unroll
128 | for (int x = 0; x < kernel_w / up_x; x++)
129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130 |
131 | if (out_x < p.out_w & out_y < p.out_h) {
132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133 | }
134 | }
135 | }
136 | }
137 | }
138 |
139 |
140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141 | int up_x, int up_y, int down_x, int down_y,
142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143 | int curDevice = -1;
144 | cudaGetDevice(&curDevice);
145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146 |
147 | UpFirDn2DKernelParams p;
148 |
149 | auto x = input.contiguous();
150 | auto k = kernel.contiguous();
151 |
152 | p.major_dim = x.size(0);
153 | p.in_h = x.size(1);
154 | p.in_w = x.size(2);
155 | p.minor_dim = x.size(3);
156 | p.kernel_h = k.size(0);
157 | p.kernel_w = k.size(1);
158 | p.up_x = up_x;
159 | p.up_y = up_y;
160 | p.down_x = down_x;
161 | p.down_y = down_y;
162 | p.pad_x0 = pad_x0;
163 | p.pad_x1 = pad_x1;
164 | p.pad_y0 = pad_y0;
165 | p.pad_y1 = pad_y1;
166 |
167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169 |
170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171 |
172 | int mode = -1;
173 |
174 | int tile_out_h;
175 | int tile_out_w;
176 |
177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178 | mode = 1;
179 | tile_out_h = 16;
180 | tile_out_w = 64;
181 | }
182 |
183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184 | mode = 2;
185 | tile_out_h = 16;
186 | tile_out_w = 64;
187 | }
188 |
189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190 | mode = 3;
191 | tile_out_h = 16;
192 | tile_out_w = 64;
193 | }
194 |
195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196 | mode = 4;
197 | tile_out_h = 16;
198 | tile_out_w = 64;
199 | }
200 |
201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202 | mode = 5;
203 | tile_out_h = 8;
204 | tile_out_w = 32;
205 | }
206 |
207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208 | mode = 6;
209 | tile_out_h = 8;
210 | tile_out_w = 32;
211 | }
212 |
213 | dim3 block_size;
214 | dim3 grid_size;
215 |
216 | if (tile_out_h > 0 && tile_out_w) {
217 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
218 | p.loop_x = 1;
219 | block_size = dim3(32 * 8, 1, 1);
220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222 | (p.major_dim - 1) / p.loop_major + 1);
223 | }
224 |
225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226 | switch (mode) {
227 | case 1:
228 | upfirdn2d_kernel<<>>(
229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
230 | );
231 |
232 | break;
233 |
234 | case 2:
235 | upfirdn2d_kernel<<>>(
236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
237 | );
238 |
239 | break;
240 |
241 | case 3:
242 | upfirdn2d_kernel<<>>(
243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
244 | );
245 |
246 | break;
247 |
248 | case 4:
249 | upfirdn2d_kernel<<>>(
250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
251 | );
252 |
253 | break;
254 |
255 | case 5:
256 | upfirdn2d_kernel<<>>(
257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
258 | );
259 |
260 | break;
261 |
262 | case 6:
263 | upfirdn2d_kernel<<>>(
264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
265 | );
266 |
267 | break;
268 | }
269 | });
270 |
271 | return out;
272 | }
--------------------------------------------------------------------------------
/exp3D/models/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | from . import nvs_tokenizer
2 |
--------------------------------------------------------------------------------
/exp3D/models/tokenizers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/tokenizers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/exp3D/models/tokenizers/__pycache__/imgrec_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/tokenizers/__pycache__/imgrec_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/exp3D/models/tokenizers/__pycache__/nvs_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/exp3D/models/tokenizers/__pycache__/nvs_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/exp3D/models/tokenizers/nvs_tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import einops
5 |
6 | from models import register
7 | from utils import poses_to_rays
8 |
9 | @register('nvs_tokenizer')
10 | class NvsTokenizer(nn.Module):
11 |
12 | def __init__(self, input_size, patch_size, dim, padding=0, img_channels=3):
13 | super().__init__()
14 | if isinstance(input_size, int):
15 | input_size = (input_size, input_size)
16 | if isinstance(patch_size, int):
17 | patch_size = (patch_size, patch_size)
18 | if isinstance(padding, int):
19 | padding = (padding, padding)
20 | self.patch_size = patch_size
21 | self.padding = padding
22 | self.prefc = nn.Linear(patch_size[0] * patch_size[1] * (img_channels + 3 + 3), dim)
23 |
24 | def forward(self, data):
25 | imgs = data['support_imgs']
26 | B = imgs.shape[0]
27 | H, W = imgs.shape[-2:]
28 | rays_o, rays_d = poses_to_rays(data['support_poses'], H, W, data['support_focals'])
29 | rays_o = einops.rearrange(rays_o, 'b n h w c -> b n c h w')
30 | rays_d = einops.rearrange(rays_d, 'b n h w c -> b n c h w')
31 |
32 | x = torch.cat([imgs, rays_o, rays_d], dim=2)
33 | x = einops.rearrange(x, 'b n d h w -> (b n) d h w')
34 | p = self.patch_size
35 | x = F.unfold(x, p, stride=p, padding=self.padding)
36 | x = einops.rearrange(x, '(b n) ppd l -> b (n l) ppd', b=B)
37 |
38 | x = self.prefc(x)
39 | return x
40 |
--------------------------------------------------------------------------------
/exp3D/models/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import einops
5 |
6 | from models import register
7 |
8 |
9 | class Attention(nn.Module):
10 |
11 | def __init__(self, dim, n_head, head_dim, dropout=0.):
12 | super().__init__()
13 | self.n_head = n_head
14 | inner_dim = n_head * head_dim
15 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
16 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
17 | self.scale = head_dim ** -0.5
18 | self.to_out = nn.Sequential(
19 | nn.Linear(inner_dim, dim),
20 | nn.Dropout(dropout),
21 | )
22 |
23 | def forward(self, fr, to=None):
24 | if to is None:
25 | to = fr
26 | q = self.to_q(fr)
27 | k, v = self.to_kv(to).chunk(2, dim=-1)
28 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v])
29 |
30 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
31 | attn = F.softmax(dots, dim=-1) # b h n n
32 | out = torch.matmul(attn, v)
33 | out = einops.rearrange(out, 'b h n d -> b n (h d)')
34 | return self.to_out(out)
35 |
36 |
37 | class FeedForward(nn.Module):
38 |
39 | def __init__(self, dim, ff_dim, dropout=0.):
40 | super().__init__()
41 | self.net = nn.Sequential(
42 | nn.Linear(dim, ff_dim),
43 | nn.GELU(),
44 | nn.Dropout(dropout),
45 | nn.Linear(ff_dim, dim),
46 | nn.Dropout(dropout),
47 | )
48 |
49 | def forward(self, x):
50 | return self.net(x)
51 |
52 |
53 | class FeedForwardReLU(nn.Module):
54 |
55 | def __init__(self, dim, ff_dim, dropout=0.):
56 | super().__init__()
57 | self.net = nn.Sequential(
58 | nn.Linear(dim, ff_dim),
59 | nn.ReLU(),
60 | nn.Dropout(dropout),
61 | nn.Linear(ff_dim, dim),
62 | nn.Dropout(dropout),
63 | )
64 |
65 | def forward(self, x):
66 | return self.net(x)
67 |
68 | class PreNorm(nn.Module):
69 |
70 | def __init__(self, dim, fn):
71 | super().__init__()
72 | self.norm = nn.LayerNorm(dim)
73 | self.fn = fn
74 |
75 | def forward(self, x):
76 | return self.fn(self.norm(x))
77 |
78 |
79 | @register('self_attender')
80 | class Self_Attender(nn.Module):
81 |
82 | def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.0):
83 | super().__init__()
84 | self.layers = nn.ModuleList()
85 | for _ in range(depth):
86 | self.layers.append(nn.ModuleList([
87 | PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)),
88 | PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)),
89 | ]))
90 |
91 | def forward(self, x):
92 | for norm_attn, norm_ff in self.layers:
93 | x = x + norm_attn(x)
94 | x = x + norm_ff(x)
95 | return x
96 |
97 |
98 | @register('cross_attender')
99 | class Cross_Attender(nn.Module):
100 |
101 | def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.0):
102 | super().__init__()
103 | self.layers = nn.ModuleList()
104 | for _ in range(depth):
105 | self.layers.append(nn.ModuleList([
106 | # Attention(dim, n_head, head_dim, dropout=dropout),
107 | nn.MultiheadAttention(dim, n_head, dropout=dropout),
108 | nn.LayerNorm(dim),
109 | FeedForwardReLU(dim, ff_dim, dropout=dropout),
110 | nn.LayerNorm(dim)
111 | ]))
112 | self.mlp_enhance = build_mlp(dim, dim, dim, 2)
113 |
114 | def with_pos_embed(self, tensor, pos=None):
115 | return tensor if pos is None else tensor + pos
116 |
117 | def forward(self, x_tgt, context):
118 | x = 0
119 | for attender, norm1, ff, norm2 in self.layers:
120 | x_attn = attender(query=self.with_pos_embed(x, x_tgt).transpose(0,1),
121 | key=context.transpose(0,1),
122 | value=context.transpose(0,1))[0]
123 | x = norm1(x + x_attn.transpose(0,1))
124 | x_ff = ff(x)
125 | x = norm2(x + x_ff)
126 | x = self.mlp_enhance(x)
127 | return x
128 |
129 |
130 | def build_mlp(dim_in, dim_hid, dim_out, depth):
131 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
132 | for _ in range(depth-2):
133 | modules.append(nn.Linear(dim_hid, dim_hid))
134 | modules.append(nn.ReLU(True))
135 | modules.append(nn.Linear(dim_hid, dim_out))
136 | return nn.Sequential(*modules)
--------------------------------------------------------------------------------
/exp3D/models/versatile_np.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import einops
7 |
8 | import models
9 | from models import register
10 | from utils import rendering
11 |
12 |
13 | def build_mlp(dim_in, dim_hid, dim_out, depth):
14 | modules = [nn.Linear(dim_in, dim_hid), nn.ReLU(True)]
15 | for _ in range(depth-2):
16 | modules.append(nn.Linear(dim_hid, dim_hid))
17 | modules.append(nn.ReLU(True))
18 | modules.append(nn.Linear(dim_hid, dim_out))
19 | return nn.Sequential(*modules)
20 |
21 |
22 | @register('versatile_np')
23 | class Versatile_NP(nn.Module):
24 |
25 | def __init__(self, tokenizer, self_attender, cross_attender, hierarchical_model, pe_dim=128):
26 | super().__init__()
27 | dim_y = hierarchical_model['args']['dim_y']
28 | self.pe_dim = pe_dim # dimension of positional embedding
29 |
30 | self.tokenizer = models.make(tokenizer, args={'dim': dim_y})
31 | self.self_attn = models.make(self_attender)
32 | self.cross_attn = models.make(cross_attender)
33 | self.hier_model = models.make(hierarchical_model)
34 |
35 | self.embed_input = build_mlp(pe_dim * 3, dim_y, dim_y, 2)
36 | self.embed_last = build_mlp(dim_y, dim_y, 4, 2)
37 |
38 | def forward(self, data, x_tgt, rays_o, z_vals, y_tgt, is_train=True):
39 | data_tokens = self.tokenizer(data)
40 | context_tokens = self.self_attn(data_tokens)
41 | x_tgt = self.coord_embedding(x_tgt)
42 | y_middle = self.cross_attn(x_tgt, context_tokens)
43 |
44 | y_middle, kls = self.hier_model(y_middle, x_tgt, y_tgt, rays_o, z_vals, training=is_train)
45 | y_pred = rendering(self.embed_last(y_middle), rays_o = rays_o, z_vals = z_vals)
46 | return y_pred, kls
47 |
48 | def coord_embedding(self, x):
49 | w = 2 ** torch.linspace(0, 8, self.pe_dim // 2, device=x.device)
50 | x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1)
51 | x = torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
52 | x = self.embed_input(x)
53 | return x
--------------------------------------------------------------------------------
/exp3D/run_trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a cfg object according to a cfg file and args, then spawn Trainer(rank, cfg).
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import yaml
9 | import numpy as np
10 | import torch
11 | import torch.multiprocessing as mp
12 |
13 | import utils
14 | import trainers
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--cfg')
20 | parser.add_argument('--seed', type=int, default=42)
21 | parser.add_argument('--load-root', default='data')
22 | parser.add_argument('--save-root', default='save')
23 | parser.add_argument('--Lambda', type=float, default=0.001)
24 | parser.add_argument('--name', '-n', default=None)
25 | parser.add_argument('--tag', default=None)
26 | parser.add_argument('--cudnn', action='store_true')
27 | parser.add_argument('--port-offset', '-p', type=int, default=0)
28 | parser.add_argument('--wandb-upload', '-w', action='store_true')
29 | args = parser.parse_args()
30 |
31 | return args
32 |
33 |
34 | def make_cfg(args):
35 | with open(args.cfg, 'r') as f:
36 | cfg = yaml.load(f, Loader=yaml.FullLoader)
37 |
38 | def translate_cfg_(d):
39 | for k, v in d.items():
40 | if isinstance(v, dict):
41 | translate_cfg_(v)
42 | elif isinstance(v, str):
43 | d[k] = v.replace('$load_root$', args.load_root)
44 |
45 | if cfg['Lambda'] != args.Lambda:
46 | cfg['Lambda'] = args.Lambda
47 |
48 | translate_cfg_(cfg)
49 |
50 | if args.name is None:
51 | exp_name = os.path.basename(args.cfg).split('.')[0]
52 | else:
53 | exp_name = args.name
54 | if args.tag is not None:
55 | exp_name += '_' + args.tag
56 |
57 | env = dict()
58 | env['exp_name'] = exp_name
59 | env['save_dir'] = os.path.join(args.save_root, exp_name)
60 | env['tot_gpus'] = torch.cuda.device_count()
61 | env['cudnn'] = args.cudnn
62 | env['port'] = str(29600 + args.port_offset)
63 | env['wandb_upload'] = args.wandb_upload
64 | cfg['env'] = env
65 |
66 | return cfg
67 |
68 |
69 | def main_worker(rank, cfg):
70 | trainer = trainers.trainers_dict[cfg['trainer']](rank, cfg)
71 | trainer.run()
72 |
73 |
74 | def main():
75 | args = parse_args()
76 |
77 | torch.manual_seed(args.seed)
78 | np.random.seed(args.seed)
79 |
80 | cfg = make_cfg(args)
81 | utils.ensure_path(cfg['env']['save_dir'])
82 |
83 | if cfg['env']['tot_gpus'] > 1:
84 | mp.spawn(main_worker, args=(cfg,), nprocs=cfg['env']['tot_gpus'])
85 | else:
86 | main_worker(0, cfg)
87 |
88 |
89 | if __name__ == '__main__':
90 | main()
91 |
--------------------------------------------------------------------------------
/exp3D/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainers import register, trainers_dict
2 | from . import base_trainer
3 | from . import imgrec_trainer
4 | from . import nvs_trainer
5 |
--------------------------------------------------------------------------------
/exp3D/trainers/base_trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | A basic trainer.
3 |
4 | The general procedure in run() is:
5 | make_datasets()
6 | create . train_loader, test_loader, dist_samplers
7 | make_model()
8 | create . model_ddp, model
9 | train()
10 | create . optimizer, epoch, log_buffer
11 | for epoch = 1 ... max_epoch:
12 | adjust_learning_rate()
13 | train_epoch()
14 | train_step()
15 | evaluate_epoch()
16 | evaluate_step()
17 | save_checkpoint()
18 | """
19 |
20 | import os
21 | import os.path as osp
22 | import time
23 |
24 | import yaml
25 | import torch
26 | import torch.nn as nn
27 | import torch.backends.cudnn as cudnn
28 | import torch.distributed as dist
29 | import wandb
30 | from tqdm import tqdm
31 | from torch.utils.data import DataLoader
32 | from torch.utils.data.distributed import DistributedSampler
33 | from torch.nn.parallel import DistributedDataParallel
34 |
35 | import datasets
36 | import models
37 | import utils
38 | from trainers import register
39 |
40 |
41 | @register('base_trainer')
42 | class BaseTrainer():
43 |
44 | def __init__(self, rank, cfg):
45 | self.rank = rank
46 | self.cfg = cfg
47 | self.is_master = (rank == 0)
48 |
49 | env = cfg['env']
50 | self.tot_gpus = env['tot_gpus']
51 | self.distributed = (env['tot_gpus'] > 1)
52 |
53 | # Setup log, tensorboard, wandb
54 | if self.is_master:
55 | logger, writer = utils.set_save_dir(env['save_dir'], replace=False)
56 | with open(osp.join(env['save_dir'], 'cfg.yaml'), 'w') as f:
57 | yaml.dump(cfg, f, sort_keys=False)
58 |
59 | self.log = logger.info
60 |
61 | self.enable_tb = True
62 | self.writer = writer
63 |
64 | if env['wandb_upload']:
65 | self.enable_wandb = True
66 | with open('wandb.yaml', 'r') as f:
67 | wandb_cfg = yaml.load(f, Loader=yaml.FullLoader)
68 | os.environ['WANDB_DIR'] = env['save_dir']
69 | os.environ['WANDB_NAME'] = env['exp_name']
70 | os.environ['WANDB_API_KEY'] = wandb_cfg['api_key']
71 | wandb.init(project=wandb_cfg['project'], entity=wandb_cfg['entity'], config=cfg)
72 | else:
73 | self.enable_wandb = False
74 | else:
75 | self.log = lambda *args, **kwargs: None
76 | self.enable_tb = False
77 | self.enable_wandb = False
78 |
79 | # Setup distributed devices
80 | torch.cuda.set_device(rank)
81 | self.device = torch.device('cuda', torch.cuda.current_device())
82 |
83 | if self.distributed:
84 | dist_url = f"tcp://localhost:{env['port']}"
85 | dist.init_process_group(backend='nccl', init_method=dist_url,
86 | world_size=self.tot_gpus, rank=rank)
87 | self.log(f'Distributed training enabled.')
88 |
89 | cudnn.benchmark = env['cudnn']
90 |
91 | self.log(f'Environment setup done.')
92 |
93 | def run(self):
94 | self.make_datasets()
95 |
96 | if self.cfg.get('eval_model') is not None:
97 | print('Load model:', self.cfg['eval_model'])
98 | model_spec = torch.load(self.cfg['eval_model'])['model']
99 | self.make_model(model_spec, load_sd=True)
100 | self.epoch = 0
101 | self.log_buffer = []
102 | self.t_data, self.t_model = 0, 0
103 | self.evaluate_epoch()
104 | self.log(', '.join(self.log_buffer))
105 | elif self.cfg.get('resume_model') is not None:
106 | print('Load model:', self.cfg['resume_model'])
107 | model_spec = torch.load(self.cfg['resume_model'], map_location='cpu')['model']
108 | self.make_model(model_spec, load_sd=True)
109 | self.train()
110 | else:
111 | self.make_model()
112 | self.train()
113 |
114 | if self.enable_tb:
115 | self.writer.close()
116 | if self.enable_wandb:
117 | wandb.finish()
118 |
119 | def make_datasets(self):
120 | """
121 | By default, train dataset performs shuffle and drop_last.
122 | Distributed sampler will extend the dataset with a prefix to make the length divisible by tot_gpus, samplers should be stored in .dist_samplers.
123 |
124 | Cfg example:
125 |
126 | train/test_dataset:
127 | name:
128 | args:
129 | loader: {batch_size: , num_workers: }
130 | """
131 | cfg = self.cfg
132 | self.dist_samplers = []
133 |
134 | def make_distributed_loader(dataset, batch_size, num_workers, shuffle=False, drop_last=False):
135 | sampler = DistributedSampler(dataset, shuffle=shuffle) if self.distributed else None
136 | loader = DataLoader(
137 | dataset,
138 | batch_size // self.tot_gpus,
139 | drop_last=drop_last,
140 | sampler=sampler,
141 | shuffle=(shuffle and (sampler is None)),
142 | num_workers=num_workers // self.tot_gpus,
143 | pin_memory=True)
144 | return loader, sampler
145 |
146 | if cfg.get('train_dataset') is not None:
147 | train_dataset = datasets.make(cfg['train_dataset'])
148 | self.log(f'Train dataset: len={len(train_dataset)}')
149 | l = cfg['train_dataset']['loader']
150 | self.train_loader, train_sampler = make_distributed_loader(
151 | train_dataset, l['batch_size'], l['num_workers'], shuffle=True, drop_last=True)
152 | self.dist_samplers.append(train_sampler)
153 |
154 | if cfg.get('test_dataset') is not None:
155 | test_dataset = datasets.make(cfg['test_dataset'])
156 | self.log(f'Test dataset: len={len(test_dataset)}')
157 | l = cfg['test_dataset']['loader']
158 | self.test_loader, test_sampler = make_distributed_loader(
159 | test_dataset, l['batch_size'], l['num_workers'], shuffle=False, drop_last=False)
160 | self.dist_samplers.append(test_sampler)
161 |
162 | def make_model(self, model_spec=None, load_sd=False):
163 | if model_spec is None:
164 | model_spec = self.cfg['model']
165 | model = models.make(model_spec, load_sd=load_sd)
166 | self.log(f'Model: #params={utils.compute_num_params(model)}')
167 |
168 | if self.distributed:
169 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
170 | model.cuda()
171 | model_ddp = DistributedDataParallel(model, device_ids=[self.rank], find_unused_parameters=True)
172 | else:
173 | model.cuda()
174 | model_ddp = model
175 | self.model = model
176 | self.model_ddp = model_ddp
177 |
178 | def train(self):
179 | """
180 | For epochs perform training, evaluation, and visualization.
181 | Note that ave_scalars update ignores the actual current batch_size.
182 | """
183 | cfg = self.cfg
184 |
185 | self.optimizer = utils.make_optimizer(self.model_ddp.parameters(), cfg['optimizer'])
186 |
187 | max_epoch = cfg['max_epoch']
188 | eval_epoch = cfg.get('eval_epoch', max_epoch + 1)
189 | vis_epoch = cfg.get('vis_epoch', max_epoch + 1)
190 | save_epoch = cfg.get('save_epoch', max_epoch + 1)
191 | epoch_timer = utils.EpochTimer(max_epoch)
192 |
193 | for epoch in range(1, max_epoch + 1):
194 | self.epoch = epoch
195 | self.log_buffer = [f'Epoch {epoch}']
196 |
197 | if self.distributed:
198 | for sampler in self.dist_samplers:
199 | sampler.set_epoch(epoch)
200 |
201 | self.adjust_learning_rate()
202 |
203 | self.t_data, self.t_model = 0, 0
204 | self.train_epoch()
205 |
206 | if epoch % eval_epoch == 0:
207 | self.evaluate_epoch()
208 |
209 | if epoch % save_epoch == 0:
210 | self.save_checkpoint(f'epoch-{epoch}.pth')
211 | self.save_checkpoint('epoch-last.pth')
212 |
213 | epoch_time, tot_time, est_time = epoch_timer.epoch_done()
214 | t_data_ratio = self.t_data / (self.t_data + self.t_model)
215 | self.log_buffer.append(f'{epoch_time} (d {t_data_ratio:.2f}) {tot_time}/{est_time}')
216 | self.log(', '.join(self.log_buffer))
217 |
218 | def adjust_learning_rate(self):
219 | base_lr = self.cfg['optimizer']['args']['lr']
220 | for param_group in self.optimizer.param_groups:
221 | param_group['lr'] = base_lr
222 | self.log_temp_scalar('lr', self.optimizer.param_groups[0]['lr'])
223 |
224 | def log_temp_scalar(self, k, v, t=None):
225 | if t is None:
226 | t = self.epoch
227 | if self.enable_tb:
228 | self.writer.add_scalar(k, v, global_step=t)
229 | if self.enable_wandb:
230 | wandb.log({k: v}, step=t)
231 |
232 | def dist_all_reduce_mean_(self, x):
233 | dist.all_reduce(x, op=dist.ReduceOp.SUM)
234 | x.div_(self.tot_gpus)
235 |
236 | def sync_ave_scalars_(self, ave_scalars):
237 | for k in ave_scalars.keys():
238 | x = torch.tensor(ave_scalars[k].item(), dtype=torch.float32, device=self.device)
239 | self.dist_all_reduce_mean_(x)
240 | ave_scalars[k].v = x.item()
241 | ave_scalars[k].n *= self.tot_gpus
242 |
243 | def train_step(self, data):
244 | data = {k: v.cuda() for k, v in data.items()}
245 | loss = self.model_ddp(data)
246 | self.optimizer.zero_grad()
247 | loss.backward()
248 | self.optimizer.step()
249 | return {'loss': loss.item()}
250 |
251 | def train_epoch(self):
252 | self.model_ddp.train()
253 | ave_scalars = dict()
254 |
255 | pbar = self.train_loader
256 | if self.is_master:
257 | pbar = tqdm(pbar, desc='train', leave=False)
258 |
259 | t1 = time.time()
260 | for data in pbar:
261 | t0 = time.time()
262 | self.t_data += t0 - t1
263 | ret = self.train_step(data)
264 | self.t_model += time.time() - t0
265 |
266 | B = len(next(iter(data.values())))
267 | for k, v in ret.items():
268 | if ave_scalars.get(k) is None:
269 | ave_scalars[k] = utils.Averager()
270 | ave_scalars[k].add(v, n=B)
271 |
272 | if self.is_master:
273 | pbar.set_description(desc=f'train: loss={ret["loss"]:.4f}')
274 | t1 = time.time()
275 |
276 | if self.distributed:
277 | self.sync_ave_scalars_(ave_scalars)
278 |
279 | logtext = 'train:'
280 | for k, v in ave_scalars.items():
281 | logtext += f' {k}={v.item():.4f}'
282 | self.log_temp_scalar('train/' + k, v.item())
283 | self.log_buffer.append(logtext)
284 |
285 | def evaluate_step(self, data):
286 | data = {k: v.cuda() for k, v in data.items()}
287 | with torch.no_grad():
288 | loss = self.model_ddp(data)
289 | return {'loss': loss.item()}
290 |
291 | def evaluate_epoch(self):
292 | self.model_ddp.eval()
293 | ave_scalars = dict()
294 |
295 | pbar = self.test_loader
296 | if self.is_master:
297 | pbar = tqdm(pbar, desc='eval', leave=False)
298 |
299 | t1 = time.time()
300 | for data in pbar:
301 | t0 = time.time()
302 | self.t_data += t0 - t1
303 | ret = self.evaluate_step(data)
304 | self.t_model += time.time() - t0
305 |
306 | B = len(next(iter(data.values())))
307 | for k, v in ret.items():
308 | if ave_scalars.get(k) is None:
309 | ave_scalars[k] = utils.Averager()
310 | ave_scalars[k].add(v, n=B)
311 |
312 | t1 = time.time()
313 |
314 | if self.distributed:
315 | self.sync_ave_scalars_(ave_scalars)
316 |
317 | logtext = 'eval:'
318 | for k, v in ave_scalars.items():
319 | logtext += f' {k}={v.item():.4f}'
320 | self.log_temp_scalar('test/' + k, v.item())
321 | self.log_buffer.append(logtext)
322 |
323 |
324 | def save_checkpoint(self, filename):
325 | if not self.is_master:
326 | return
327 | model_spec = self.cfg['model']
328 | model_spec['sd'] = self.model.state_dict()
329 | optimizer_spec = self.cfg['optimizer']
330 | optimizer_spec['sd'] = self.optimizer.state_dict()
331 | checkpoint = {
332 | 'model': model_spec,
333 | 'optimizer': optimizer_spec,
334 | 'epoch': self.epoch,
335 | 'cfg': self.cfg,
336 | }
337 | torch.save(checkpoint, osp.join(self.cfg['env']['save_dir'], filename))
338 |
--------------------------------------------------------------------------------
/exp3D/trainers/nvs_trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torchvision
5 | import numpy as np
6 | import einops
7 | import wandb
8 | from tqdm import tqdm
9 |
10 | import utils
11 | from .base_trainer import BaseTrainer
12 | from trainers import register
13 | from utils import poses_to_rays, get_coord, volume_rendering, batched_volume_rendering
14 |
15 |
16 | @register('nvs_trainer')
17 | class NvsTrainer(BaseTrainer):
18 |
19 | def make_datasets(self):
20 | super().make_datasets()
21 |
22 | def get_vislist(dataset, n_vis=8):
23 | ids = torch.arange(n_vis) * (len(dataset) // n_vis)
24 | return [dataset[i] for i in ids]
25 |
26 | if hasattr(self, 'train_loader'):
27 | np.random.seed(0)
28 | self.vislist_train = get_vislist(self.train_loader.dataset)
29 | if hasattr(self, 'test_loader'):
30 | np.random.seed(0)
31 | self.vislist_test = get_vislist(self.test_loader.dataset)
32 |
33 | def adjust_learning_rate(self):
34 | base_lr = self.cfg['optimizer']['args']['lr']
35 | if self.epoch <= round(self.cfg['max_epoch'] * 0.8):
36 | lr = base_lr
37 | else:
38 | lr = base_lr * 0.1
39 | for param_group in self.optimizer.param_groups:
40 | param_group['lr'] = lr
41 | self.log_temp_scalar('lr', lr)
42 |
43 | def _adaptive_sample_rays(self, rays_o, rays_d, gt, n_sample):
44 | B = rays_o.shape[0]
45 | inds = []
46 | fg_n_sample = n_sample // 2
47 | for i in range(B):
48 | fg = ((gt[i].min(dim=-1).values < 1).nonzero().view(-1)).cpu().numpy()
49 | if fg_n_sample <= len(fg):
50 | fg = np.random.choice(fg, fg_n_sample, replace=False)
51 | else:
52 | fg = np.concatenate([fg, np.random.choice(fg, fg_n_sample - len(fg), replace=True)], axis=0)
53 | rd = np.random.choice(rays_o.shape[1], n_sample - fg_n_sample, replace=False)
54 | inds.append(np.concatenate([fg, rd], axis=0))
55 |
56 | def subselect(x, inds):
57 | t = torch.empty(B, len(inds[0]), 3, dtype=x.dtype, device=x.device)
58 | for i in range(B):
59 | t[i] = x[i][inds[i], :]
60 | return t
61 |
62 | return subselect(rays_o, inds), subselect(rays_d, inds), subselect(gt, inds)
63 |
64 | def _iter_step(self, data, is_train):
65 | data = {k: v.cuda() for k, v in data.items()}
66 | query_imgs = data.pop('query_imgs')
67 |
68 | B, _, _, H, W = query_imgs.shape
69 | rays_o, rays_d = poses_to_rays(data['query_poses'], H, W, data['query_focals'])
70 |
71 | gt = einops.rearrange(query_imgs, 'b n c h w -> b (n h w) c')
72 | rays_o = einops.rearrange(rays_o, 'b n h w c -> b (n h w) c')
73 | rays_d = einops.rearrange(rays_d, 'b n h w c -> b (n h w) c')
74 |
75 | n_sample = self.cfg['train_n_rays']
76 | if is_train and self.epoch <= self.cfg.get('adaptive_sample_epoch', 0):
77 | rays_o, rays_d, gt = self._adaptive_sample_rays(rays_o, rays_d, gt, n_sample)
78 | else:
79 | ray_ids = np.random.choice(rays_o.shape[1], n_sample, replace=False)
80 | rays_o, rays_d, gt = map((lambda _: _[:, ray_ids, :]), [rays_o, rays_d, gt])
81 |
82 | # x_tgt is the world coordinate of target points
83 | x_tgt, z_vals = get_coord(
84 | rays_o, rays_d,
85 | near=data['near'][0],
86 | far=data['far'][0],
87 | points_per_ray=self.cfg['train_points_per_ray'],
88 | rand=is_train,
89 | )
90 |
91 | pred, loss_kl = self.model_ddp(data, x_tgt, rays_o, z_vals, gt, is_train=is_train)
92 |
93 | loss_mse = ((pred - gt) ** 2).view(B, -1).mean(dim=-1)
94 | loss_kl = loss_kl.view(B, -1).mean()
95 |
96 | # annealing beta coefficient
97 | if self.cfg.get('resume_model') is not None:
98 | beta = self.cfg['Lambda']
99 | else:
100 | beta = self.cfg['Lambda'] * min(1.0, self.epoch / 50)
101 |
102 | loss = loss_mse.mean() + loss_kl * beta
103 | psnr = (-10 * torch.log10(loss_mse)).mean()
104 |
105 | if is_train:
106 | self.optimizer.zero_grad()
107 | loss.backward()
108 | self.optimizer.step()
109 |
110 | return {'loss': loss.item(), 'psnr': psnr.item(), 'loss_kl': loss_kl.item()}
111 |
112 | def train_step(self, data):
113 | return self._iter_step(data, is_train=True)
114 |
115 | def evaluate_step(self, data):
116 | with torch.no_grad():
117 | return self._iter_step(data, is_train=False)
118 |
119 |
120 | @register('nvs_evaluator')
121 | class NvsEvaluator(NvsTrainer):
122 |
123 | def _iter_step(self, data, is_train, step=0):
124 | assert not is_train
125 | data = {k: v.cuda() for k, v in data.items()}
126 |
127 | query_imgs = data.pop('query_imgs')
128 | B, N, _, H, W = query_imgs.shape
129 |
130 | with torch.no_grad():
131 | B, _, _, H, W = query_imgs.shape
132 | rays_o, rays_d = poses_to_rays(data['query_poses'], H, W, data['query_focals'])
133 |
134 | gt = einops.rearrange(query_imgs, 'b n c h w -> b (n h w) c')
135 | rays_o = einops.rearrange(rays_o, 'b n h w c -> b (n h w) c')
136 | rays_d = einops.rearrange(rays_d, 'b n h w c -> b (n h w) c')
137 | n_sample = self.cfg['render_ray_batch']
138 | pred = []
139 | for i in range(0, rays_o.shape[1], n_sample):
140 | rays_o_, rays_d_, gt_ = map((lambda _: _[:, i: i + n_sample, :]), [rays_o, rays_d, gt])
141 | x_tgt, z_vals = get_coord(
142 | rays_o_, rays_d_,
143 | near=data['near'][0],
144 | far=data['far'][0],
145 | points_per_ray=self.cfg['train_points_per_ray'],
146 | rand=is_train,
147 | )
148 | pred_, loss_kl = self.model_ddp(data, x_tgt, rays_o_, z_vals, gt_, is_train=False)
149 | pred.append(pred_)
150 |
151 |
152 | pred = torch.cat(pred, dim=1)
153 | pred = torch.clamp(pred, min=0.0, max=1.0)
154 | gt = torch.clamp(gt, min=0.0, max=1.0)
155 |
156 | ref = data['support_imgs'][:,0].permute(0, 2, 3, 1)
157 | # save_img = torch.cat([ref.view(B, 128, 128, 3), pred.view(B, 128, 128, 3), gt.view(B, 128, 128, 3)], dim=2).view(-1, 384, 3)
158 | # imsave('save_imgs/save_img' + str(step) + '.png', np.squeeze(save_img.cpu().numpy() * 255).astype(np.uint8))
159 |
160 | loss_kl = loss_kl.view(B, -1).mean(-1)
161 | mses = ((pred - gt)**2).view(B * N, -1).mean(dim=-1)
162 | psnr = (-10 * torch.log10(mses)).mean()
163 |
164 | return {'psnr': psnr.item(), 'loss_kl': loss_kl.mean().item()}
165 |
166 | def evaluate_epoch(self):
167 | self.model_ddp.eval()
168 | ave_scalars = dict()
169 |
170 | pbar = self.test_loader
171 | if self.is_master:
172 | pbar = tqdm(pbar, desc='eval', leave=False)
173 |
174 | t1 = time.time()
175 | for data in pbar:
176 | t0 = time.time()
177 | self.t_data += t0 - t1
178 | ret = self.evaluate_step(data)
179 | self.t_model += time.time() - t0
180 |
181 | B = len(next(iter(data.values())))
182 | for k, v in ret.items():
183 | if ave_scalars.get(k) is None:
184 | ave_scalars[k] = utils.Averager()
185 | ave_scalars[k].add(v, n=B)
186 |
187 | # --The only thing added-- #
188 | if self.is_master:
189 | pbar.set_description(desc=f'eval: psnr={ave_scalars["psnr"].item():.2f}')
190 | # ------------------------ #
191 | t1 = time.time()
192 |
193 | if self.distributed:
194 | self.sync_ave_scalars_(ave_scalars)
195 |
196 | logtext = 'eval:'
197 | for k, v in ave_scalars.items():
198 | logtext += f' {k}={v.item():.4f}'
199 | self.log_temp_scalar('test/' + k, v.item())
200 | self.log_buffer.append(logtext)
--------------------------------------------------------------------------------
/exp3D/trainers/trainers.py:
--------------------------------------------------------------------------------
1 | trainers_dict = {}
2 |
3 |
4 | def register(name):
5 | def decorator(cls):
6 | trainers_dict[name] = cls
7 | return cls
8 | return decorator
9 |
--------------------------------------------------------------------------------
/exp3D/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 | from .geometry import *
3 |
--------------------------------------------------------------------------------
/exp3D/utils/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | import logging
5 |
6 | import numpy as np
7 | from torch.optim import SGD, Adam
8 | from tensorboardX import SummaryWriter
9 |
10 |
11 | def ensure_path(path, replace=True):
12 | basename = os.path.basename(path.rstrip('/'))
13 | if os.path.exists(path):
14 | if replace and (basename.startswith('_') or input('{} exists, replace? (y/[n]): '.format(path)) == 'y'):
15 | shutil.rmtree(path)
16 | os.makedirs(path)
17 | else:
18 | os.makedirs(path)
19 |
20 |
21 | def set_logger(file_path):
22 | logger = logging.getLogger()
23 | logger.setLevel('INFO')
24 | stream_handler = logging.StreamHandler()
25 | file_handler = logging.FileHandler(file_path, 'w')
26 | formatter = logging.Formatter('[%(asctime)s] %(message)s', '%m-%d %H:%M:%S')
27 | for handler in [stream_handler, file_handler]:
28 | handler.setFormatter(formatter)
29 | handler.setLevel('INFO')
30 | logger.addHandler(handler)
31 | return logger
32 |
33 |
34 | def set_save_dir(save_dir, replace=True):
35 | ensure_path(save_dir, replace=replace)
36 | logger = set_logger(os.path.join(save_dir, 'log.txt'))
37 | writer = SummaryWriter(os.path.join(save_dir, 'tensorboard'))
38 | return logger, writer
39 |
40 |
41 | def compute_num_params(model, text=True):
42 | tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
43 | if text:
44 | if tot >= 1e6:
45 | return '{:.1f}M'.format(tot / 1e6)
46 | elif tot >= 1e3:
47 | return '{:.1f}K'.format(tot / 1e3)
48 | else:
49 | return str(tot)
50 | else:
51 | return tot
52 |
53 |
54 | def make_optimizer(params, optimizer_spec, load_sd=False):
55 | optimizer = {
56 | 'sgd': SGD,
57 | 'adam': Adam
58 | }[optimizer_spec['name']](params, **optimizer_spec['args'])
59 | if load_sd:
60 | optimizer.load_state_dict(optimizer_spec['sd'])
61 | return optimizer
62 |
63 |
64 | class Averager():
65 |
66 | def __init__(self):
67 | self.n = 0.0
68 | self.v = 0.0
69 |
70 | def add(self, v, n=1.0):
71 | self.v = (self.v * self.n + v * n) / (self.n + n)
72 | self.n += n
73 |
74 | def item(self):
75 | return self.v
76 |
77 |
78 | class EpochTimer():
79 |
80 | def __init__(self, max_epoch):
81 | self.max_epoch = max_epoch
82 | self.epoch = 0
83 | self.t_start = time.time()
84 | self.t_last = self.t_start
85 |
86 | def epoch_done(self):
87 | t_cur = time.time()
88 | self.epoch += 1
89 | epoch_time = t_cur - self.t_last
90 | tot_time = t_cur - self.t_start
91 | est_time = tot_time / self.epoch * self.max_epoch
92 | self.t_last = t_cur
93 | return time_text(epoch_time), time_text(tot_time), time_text(est_time)
94 |
95 |
96 | def time_text(secs):
97 | if secs >= 3600:
98 | return f'{secs / 3600:.1f}h'
99 | elif secs >= 60:
100 | return f'{secs / 60:.1f}m'
101 | else:
102 | return f'{secs:.1f}s'
103 |
--------------------------------------------------------------------------------
/exp3D/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import einops
4 |
5 |
6 | def make_coord_grid(shape, range, device=None):
7 | """
8 | Args:
9 | shape: tuple
10 | range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim
11 | Returns:
12 | grid: shape (*shape, )
13 | """
14 | l_lst = []
15 | for i, s in enumerate(shape):
16 | l = (0.5 + torch.arange(s, device=device)) / s
17 | if isinstance(range[0], list) or isinstance(range[0], tuple):
18 | minv, maxv = range[i]
19 | else:
20 | minv, maxv = range
21 | l = minv + (maxv - minv) * l
22 | l_lst.append(l)
23 | grid = torch.meshgrid(*l_lst)
24 | grid = torch.stack(grid, dim=-1)
25 | return grid
26 |
27 |
28 | def poses_to_rays(poses, image_h, image_w, focal):
29 | """
30 | Pose columns are: 3 camera axes specified in world coordinate + 1 camera position.
31 | Camera: x-axis right, y-axis up, z-axis inward.
32 | Focal is in pixel-scale.
33 |
34 | Args:
35 | poses: (... 3 4)
36 | focal: (... 2)
37 | Returns:
38 | rays_o, rays_d: shape (... image_h image_w 3)
39 | """
40 | device = poses.device
41 | bshape = poses.shape[:-2]
42 | poses = poses.view(-1, 3, 4)
43 | focal = focal.view(-1, 2)
44 | bsize = poses.shape[0]
45 |
46 | y, x = torch.meshgrid(torch.arange(image_w, device=device), torch.arange(image_h, device=device)) # h w
47 | x, y = x + 0.5, y + 0.5 # modified to + 0.5
48 | x, y = x.unsqueeze(0), y.unsqueeze(0) # h w -> 1 h w
49 | focal = focal.unsqueeze(1).unsqueeze(1) # b 2 -> b 1 1 2
50 | dirs = torch.stack([
51 | (x - image_w / 2) / focal[..., 0],
52 | -(y - image_h / 2) / focal[..., 1],
53 | -torch.ones(bsize, image_h, image_w, device=device)
54 | ], dim=-1) # b h w 3
55 |
56 | poses = poses.unsqueeze(1).unsqueeze(1) # b 3 4 -> b 1 1 3 4
57 | rays_o = poses[..., -1].repeat(1, image_h, image_w, 1) # b h w 3
58 | rays_d = (dirs.unsqueeze(-2) * poses[..., :3]).sum(dim=-1) # b h w 3
59 |
60 | rays_o = rays_o.view(*bshape, *rays_o.shape[1:])
61 | rays_d = rays_d.view(*bshape, *rays_d.shape[1:])
62 | return rays_o, rays_d
63 |
64 |
65 | def get_coord(rays_o, rays_d, near, far, points_per_ray, rand):
66 | """
67 | Args:
68 | rays_o, rays_d: shape (b ... 3)
69 | Returns:
70 | pts_flat, z_vals
71 | """
72 | B = rays_o.shape[0]
73 | rays_o = rays_o.view(B, -1, 3)
74 | rays_d = rays_d.view(B, -1, 3)
75 | n_rays = rays_o.shape[1]
76 | device = rays_o.device
77 |
78 | # Compute 3D query points
79 | z_vals = torch.linspace(near, far, points_per_ray, device=device)
80 | z_vals = einops.repeat(z_vals, 'p -> n p', n=n_rays)
81 | if rand:
82 | d = (far - near) / (points_per_ray - 1) # modified as points_per_ray - 1
83 | z_vals = z_vals + torch.rand(n_rays, points_per_ray, device=device) * d
84 |
85 | pts = rays_o.view(B, n_rays, 1, 3) + rays_d.view(B, n_rays, 1, 3) * z_vals.view(1, n_rays, points_per_ray, 1)
86 | pts_flat = einops.rearrange(pts, 'b n p d -> b (n p) d')
87 | return pts_flat, z_vals
88 |
89 |
90 | def rendering(raw, rays_o, z_vals):
91 | B = raw.shape[0]
92 | query_shape = rays_o.shape[1: -1]
93 | n_rays = rays_o.shape[1]
94 | raw = einops.rearrange(raw, 'b (n p) c -> b n p c', n=n_rays)
95 |
96 | # Compute opacities and colors
97 | rgb, sigma_a = raw[..., :3], raw[..., 3]
98 | rgb = torch.sigmoid(rgb) # b n p 3
99 | sigma_a = F.relu(sigma_a) # b n p
100 |
101 | # Do volume rendering
102 | dists = torch.cat([z_vals[:, 1:] - z_vals[:, :-1], torch.ones_like(z_vals[:, -1:]) * 1e-3], dim=-1) # n p
103 | alpha = 1. - torch.exp(-sigma_a * dists) # b n p
104 | trans = torch.clamp(1. - alpha + 1e-10, max=1.) # b n p
105 | trans = torch.cat([torch.ones_like(trans[..., :1]), trans[..., :-1]], dim=-1)
106 | weights = alpha * torch.cumprod(trans, dim=-1) # b n p
107 |
108 | rgb_map = torch.sum(weights.unsqueeze(-1) * rgb, dim=-2)
109 | acc_map = torch.sum(weights, dim=-1)
110 | rgb_map = rgb_map + (1. - acc_map).unsqueeze(-1) # white background
111 | # depth_map = torch.sum(weights * z_vals, dim=-1)
112 |
113 | rgb_map = rgb_map.view(B, *query_shape, 3)
114 | return rgb_map
115 |
116 |
117 | def volume_rendering(nerf, rays_o, rays_d, near, far, points_per_ray, use_viewdirs, rand):
118 | """
119 | Args:
120 | rays_o, rays_d: shape (b ... 3)
121 | Returns:
122 | pred: (b ... 3)
123 | """
124 | pts_flat, z_vals = get_coord(rays_o, rays_d, near, far, points_per_ray, rand)
125 |
126 | if not use_viewdirs:
127 | raw = nerf(pts_flat)
128 | else:
129 | viewdirs = einops.repeat(rays_d, 'b n d -> b n p d', p=points_per_ray)
130 | raw = nerf(pts_flat, viewdirs=viewdirs)
131 |
132 | rgb_map = rendering(aw, rays_o, z_vals)
133 | return rgb_map
134 |
135 |
136 | def batched_volume_rendering(nerf, rays_o, rays_d, *args, batch_size=None, **kwargs):
137 | """
138 | Args:
139 | rays_o, rays_d: (b ... 3)
140 | Returns:
141 | pred: (b ... 3)
142 | """
143 | B = rays_o.shape[0]
144 | query_shape = rays_o.shape[1: -1]
145 | rays_o = rays_o.view(B, -1, 3)
146 | rays_d = rays_d.view(B, -1, 3)
147 |
148 | ret = []
149 | ll = 0
150 | while ll < rays_o.shape[1]:
151 | rr = min(ll + batch_size, rays_o.shape[1])
152 | rays_o_ = rays_o[:, ll: rr, :]
153 | rays_d_ = rays_d[:, ll: rr, :]
154 | ret.append(volume_rendering(nerf, rays_o_, rays_d_, *args, **kwargs))
155 | ll = rr
156 | ret = torch.cat(ret, dim=1)
157 |
158 | ret = ret.view(B, *query_shape, 3)
159 | return ret
160 |
161 |
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZongyuGuo/Versatile-NP/ead3b421b06f1edebfdb2d125026e9e053306e2d/imgs/framework.png
--------------------------------------------------------------------------------