├── vgg.pth
├── images
├── SFU_AI.png
├── APEX_lab.png
└── papr_video_cover.png
├── .gitignore
├── configs
├── nerfsyn
│ ├── mic.yml
│ ├── ficus.yml
│ ├── chair.yml
│ ├── materials.yml
│ ├── hotdog.yml
│ ├── ship.yml
│ ├── drums.yml
│ └── lego.yml
├── t2
│ ├── Truck.yml
│ ├── Barn.yml
│ ├── Family.yml
│ ├── Caterpillar.yml
│ ├── Ignatius.yml
│ └── Caterpillar_exposure_control.yml
└── default.yml
├── papr.yml
├── requirements.txt
├── dataset
├── __init__.py
├── load_nerfsyn.py
├── load_t2.py
├── dataset.py
└── utils.py
├── LICENSE
├── models
├── __init__.py
├── renderer.py
├── mlp.py
├── lpips.py
├── unet.py
├── attn.py
├── utils.py
└── model.py
├── README.md
├── demo.ipynb
├── exposure_control_finetune.py
├── train.py
├── utils.py
└── test.py
/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zvict/papr/HEAD/vgg.pth
--------------------------------------------------------------------------------
/images/SFU_AI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zvict/papr/HEAD/images/SFU_AI.png
--------------------------------------------------------------------------------
/images/APEX_lab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zvict/papr/HEAD/images/APEX_lab.png
--------------------------------------------------------------------------------
/images/papr_video_cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zvict/papr/HEAD/images/papr_video_cover.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | data
3 | .vscode/
4 | data/
5 | experiments/
6 | checkpoints/
7 | *.pth
8 | *.pt
9 | *.out
10 | *.mp4
11 | *.tar.gz
12 | !vgg.pth
--------------------------------------------------------------------------------
/configs/nerfsyn/mic.yml:
--------------------------------------------------------------------------------
1 | index: "mic"
2 | dataset:
3 | path: "./data/nerf_synthetic/mic/"
4 | training:
5 | add_start: 10000
6 | add_stop: 80000
7 | add_num: 500
8 | eval:
9 | dataset:
10 | path: "./data/nerf_synthetic/mic/"
11 | test:
12 | # load_path: "checkpoints/mic.pth"
13 | datasets:
14 | - name: "testset"
15 | path: "./data/nerf_synthetic/mic/"
--------------------------------------------------------------------------------
/configs/nerfsyn/ficus.yml:
--------------------------------------------------------------------------------
1 | index: "ficus"
2 | dataset:
3 | path: "./data/nerf_synthetic/ficus/"
4 | training:
5 | add_start: 10000
6 | add_stop: 90000
7 | add_num: 500
8 | eval:
9 | dataset:
10 | path: "./data/nerf_synthetic/ficus/"
11 | test:
12 | # load_path: "checkpoints/ficus.pth"
13 | datasets:
14 | - name: "testset"
15 | path: "./data/nerf_synthetic/ficus/"
--------------------------------------------------------------------------------
/configs/nerfsyn/chair.yml:
--------------------------------------------------------------------------------
1 | index: "chair"
2 | dataset:
3 | path: "./data/nerf_synthetic/chair/"
4 | geoms:
5 | points:
6 | init_num: 10000
7 | training:
8 | add_start: 10000
9 | add_stop: 50000
10 | eval:
11 | dataset:
12 | path: "./data/nerf_synthetic/chair/"
13 | test:
14 | # load_path: "checkpoints/chair.pth"
15 | datasets:
16 | - name: "testset"
17 | path: "./data/nerf_synthetic/chair/"
--------------------------------------------------------------------------------
/configs/nerfsyn/materials.yml:
--------------------------------------------------------------------------------
1 | index: "materials"
2 | dataset:
3 | path: "./data/nerf_synthetic/materials/"
4 | geoms:
5 | point_feats:
6 | dim: 128
7 | training:
8 | add_start: 10000
9 | add_stop: 80000
10 | add_num: 500
11 | eval:
12 | dataset:
13 | path: "./data/nerf_synthetic/materials/"
14 | test:
15 | # load_path: "checkpoints/materials.pth"
16 | datasets:
17 | - name: "testset"
18 | path: "./data/nerf_synthetic/materials/"
--------------------------------------------------------------------------------
/configs/nerfsyn/hotdog.yml:
--------------------------------------------------------------------------------
1 | index: "hotdog"
2 | dataset:
3 | path: "./data/nerf_synthetic/hotdog/"
4 | geoms:
5 | points:
6 | select_k: 30
7 | init_num: 10000
8 | training:
9 | add_start: 10000
10 | add_stop: 80000
11 | add_num: 500
12 | eval:
13 | dataset:
14 | path: "./data/nerf_synthetic/hotdog/"
15 | test:
16 | # load_path: "checkpoints/hotdog.pth"
17 | datasets:
18 | - name: "testset"
19 | path: "./data/nerf_synthetic/hotdog/"
--------------------------------------------------------------------------------
/papr.yml:
--------------------------------------------------------------------------------
1 | name: papr
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - python=3.9
9 | - pip=23.3.1
10 | - pytorch-cuda=11.8
11 | - pytorch==2.1.2
12 | - torchaudio==2.1.2
13 | - torchvision==0.16.2
14 | - matplotlib
15 | - imageio
16 | - scipy
17 | - pip:
18 | - fsspec==2024.3.1
19 | - lpips==0.1.4
20 | - tqdm==4.66.2
21 | - imageio-ffmpeg==0.4.9
22 | - scikit-image==0.22.0
23 |
--------------------------------------------------------------------------------
/configs/nerfsyn/ship.yml:
--------------------------------------------------------------------------------
1 | index: "ship"
2 | dataset:
3 | path: "./data/nerf_synthetic/ship/"
4 | geoms:
5 | points:
6 | init_num: 10000
7 | training:
8 | add_start: 10000
9 | add_stop: 80000
10 | add_num: 500
11 | lr:
12 | points:
13 | base_lr: 3.0e-3
14 | eval:
15 | dataset:
16 | path: "./data/nerf_synthetic/ship/"
17 | test:
18 | # load_path: "checkpoints/ship.pth"
19 | datasets:
20 | - name: "testset"
21 | path: "./data/nerf_synthetic/ship/"
--------------------------------------------------------------------------------
/configs/nerfsyn/drums.yml:
--------------------------------------------------------------------------------
1 | index: "drums"
2 | dataset:
3 | path: "./data/nerf_synthetic/drums/"
4 | models:
5 | attn:
6 | embed:
7 | value:
8 | skip_layers: [5]
9 | training:
10 | add_start: 10000
11 | add_stop: 40000
12 | lr:
13 | generator:
14 | base_lr: 2.0e-4
15 | eval:
16 | dataset:
17 | path: "./data/nerf_synthetic/drums/"
18 | test:
19 | # load_path: "checkpoints/drums.pth"
20 | datasets:
21 | - name: "testset"
22 | path: "./data/nerf_synthetic/drums/"
--------------------------------------------------------------------------------
/configs/t2/Truck.yml:
--------------------------------------------------------------------------------
1 | index: "Truck"
2 | dataset:
3 | coord_scale: 40.0
4 | type: "t2"
5 | path: "./data/tanks_temples/Truck/"
6 | factor: 2
7 | geoms:
8 | points:
9 | init_scale: [1.0, 1.0, 1.0]
10 | num: 5000
11 | constant: 4.0
12 | training:
13 | add_start: 20000
14 | add_stop: 60000
15 | lr:
16 | points:
17 | base_lr: 8.0e-3
18 | eval:
19 | dataset:
20 | type: "t2"
21 | path: "./data/tanks_temples/Truck/"
22 | factor: 2
23 | img_idx: 0
24 | test:
25 | # load_path: "checkpoints/Truck.pth"\
26 | datasets:
27 | - name: "testset"
28 | type: "t2"
29 | path: "./data/tanks_temples/Truck/"
30 | factor: 2
--------------------------------------------------------------------------------
/configs/t2/Barn.yml:
--------------------------------------------------------------------------------
1 | index: "Barn"
2 | dataset:
3 | coord_scale: 30.0
4 | type: "t2"
5 | path: "./data/tanks_temples/Barn/"
6 | factor: 2
7 | patches:
8 | height: 180
9 | width: 180
10 | geoms:
11 | points:
12 | init_scale: [1.8, 1.8, 1.8]
13 | init_num: 5000
14 | training:
15 | add_start: 10000
16 | add_stop: 35000
17 | lr:
18 | points:
19 | base_lr: 1.0e-2
20 | eval:
21 | dataset:
22 | type: "t2"
23 | path: "./data/tanks_temples/Barn/"
24 | factor: 2
25 | img_idx: 0
26 | test:
27 | # load_path: "checkpoints/Barn.pth"
28 | datasets:
29 | - name: "testset"
30 | type: "t2"
31 | path: "./data/tanks_temples/Barn/"
32 | factor: 2
--------------------------------------------------------------------------------
/configs/t2/Family.yml:
--------------------------------------------------------------------------------
1 | index: "Family"
2 | dataset:
3 | coord_scale: 40.0
4 | type: "t2"
5 | path: "./data/tanks_temples/Family/"
6 | factor: 2
7 | geoms:
8 | points:
9 | init_scale: [0.3, 0.3, 0.3]
10 | init_num: 5000
11 | models:
12 | attn:
13 | embed:
14 | value:
15 | skip_layers: [5]
16 | training:
17 | add_start: 10000
18 | add_stop: 80000
19 | add_num: 500
20 | eval:
21 | dataset:
22 | type: "t2"
23 | path: "./data/tanks_temples/Family/"
24 | factor: 2
25 | img_idx: 0
26 | test:
27 | # load_path: "checkpoints/Family.pth"
28 | datasets:
29 | - name: "testset"
30 | type: "t2"
31 | path: "./data/tanks_temples/Family/"
32 | factor: 2
--------------------------------------------------------------------------------
/configs/nerfsyn/lego.yml:
--------------------------------------------------------------------------------
1 | index: "lego"
2 | dataset:
3 | path: "./data/nerf_synthetic/lego/"
4 | geoms:
5 | background:
6 | constant: 3.0
7 | models:
8 | attn:
9 | embed:
10 | key:
11 | ff_act: "leakyrelu"
12 | query:
13 | ff_act: "leakyrelu"
14 | value:
15 | ff_act: "leakyrelu"
16 | skip_layers: [5]
17 | training:
18 | prune_thresh_list: [0.0, 0.2]
19 | prune_steps_list: [40000]
20 | add_start: 20000
21 | lr:
22 | points:
23 | base_lr: 3.0e-3
24 | eval:
25 | dataset:
26 | path: "./data/nerf_synthetic/lego/"
27 | test:
28 | # load_path: "checkpoints/lego.pth"
29 | datasets:
30 | - name: "testset"
31 | path: "./data/nerf_synthetic/lego/"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2023.11.17
2 | charset-normalizer==3.3.2
3 | contourpy==1.2.0
4 | cycler==0.12.1
5 | fonttools==4.47.2
6 | idna==3.6
7 | imageio==2.33.1
8 | imageio-ffmpeg==0.4.9
9 | importlib-resources==6.1.1
10 | kiwisolver==1.4.5
11 | lazy-loader==0.3
12 | lpips==0.1.4
13 | matplotlib==3.7.2
14 | networkx==3.2.1
15 | numpy==1.25.2
16 | packaging==23.2
17 | Pillow==10.1.0
18 | pyparsing==3.0.9
19 | python-dateutil==2.8.2
20 | PyWavelets==1.3.0
21 | PyYAML==6.0.1
22 | requests==2.31.0
23 | scikit-image==0.19.2
24 | scipy==1.11.2
25 | six==1.16.0
26 | tifffile==2024.1.30
27 | torch==1.13.1+cu117
28 | torchaudio==0.13.1+cu117
29 | torchvision==0.14.1+cu117
30 | tqdm==4.66.1
31 | typing-extensions==4.9.0
32 | urllib3==2.1.0
33 | zipp==3.17.0
34 |
--------------------------------------------------------------------------------
/configs/t2/Caterpillar.yml:
--------------------------------------------------------------------------------
1 | index: "Caterpillar"
2 | use_amp: false
3 | dataset:
4 | coord_scale: 30.0
5 | type: "t2"
6 | path: "./data/tanks_temples/Caterpillar/"
7 | factor: 2
8 | patches:
9 | height: 180
10 | width: 180
11 | geoms:
12 | points:
13 | init_scale: [1.0, 1.0, 1.0]
14 | init_num: 5000
15 | background:
16 | constant: 4.0
17 | models:
18 | attn:
19 | embed:
20 | k_L: [4, 4, 4]
21 | q_L: [4]
22 | v_L: [4, 4]
23 | training:
24 | add_start: 10000
25 | add_stop: 80000
26 | add_num: 500
27 | lr:
28 | points:
29 | base_lr: 6.0e-3
30 | eval:
31 | dataset:
32 | type: "t2"
33 | path: "./data/tanks_temples/Caterpillar/"
34 | factor: 2
35 | img_idx: 0
36 | test:
37 | # load_path: "checkpoints/Caterpillar.pth"
38 | datasets:
39 | - name: "testset"
40 | type: "t2"
41 | path: "./data/tanks_temples/Caterpillar/"
42 | factor: 2
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import RINDataset
2 | from torch.utils.data import DataLoader
3 |
4 |
5 | def get_traindataset(args):
6 | return RINDataset(args, mode='train')
7 |
8 |
9 | def get_trainloader(dataset, args):
10 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle)
11 |
12 |
13 | def get_testdataset(args):
14 | return RINDataset(args, mode='test')
15 |
16 |
17 | def get_testloader(dataset, args):
18 | return DataLoader(dataset, batch_size=1, shuffle=False)
19 |
20 |
21 | def get_dataset(args, mode):
22 | if mode == 'train':
23 | return get_traindataset(args)
24 | elif mode == 'test':
25 | return get_testdataset(args)
26 | else:
27 | raise ValueError("Unknown mode: {}".format(mode))
28 |
29 |
30 | def get_loader(dataset, args, mode):
31 | if mode == 'train':
32 | return get_trainloader(dataset, args)
33 | elif mode == 'test':
34 | return get_testloader(dataset, args)
35 | else:
36 | raise ValueError("Unknown mode: {}".format(mode))
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/t2/Ignatius.yml:
--------------------------------------------------------------------------------
1 | index: "Ignatius"
2 | dataset:
3 | coord_scale: 50.0
4 | type: "t2"
5 | path: "./data/tanks_temples/Ignatius/"
6 | factor: 2
7 | patches:
8 | height: 180
9 | width: 180
10 | geoms:
11 | points:
12 | init_scale: [0.4, 0.4, 0.4]
13 | num: 5000
14 | models:
15 | attn:
16 | embed:
17 | k_L: [4, 4, 4]
18 | q_L: [4]
19 | v_L: [4, 4]
20 | value:
21 | skip_layers: [5]
22 | training:
23 | add_start: 10000
24 | add_stop: 80000
25 | add_num: 500
26 | lr:
27 | attn:
28 | type: "cosine"
29 | warmup: 5000
30 | points:
31 | base_lr: 9.0e-3
32 | points_influ_scores:
33 | type: "cosine"
34 | warmup: 5000
35 | feats:
36 | type: "cosine"
37 | warmup: 5000
38 | generator:
39 | type: "cosine"
40 | warmup: 5000
41 | eval:
42 | dataset:
43 | type: "t2"
44 | path: "./data/tanks_temples/Ignatius/"
45 | factor: 2
46 | img_idx: 0
47 | test:
48 | # load_path: "checkpoints/Ignatius.pth"
49 | datasets:
50 | - name: "testset"
51 | type: "t2"
52 | path: "./data/tanks_temples/Ignatius/"
53 | factor: 2
--------------------------------------------------------------------------------
/configs/t2/Caterpillar_exposure_control.yml:
--------------------------------------------------------------------------------
1 | index: "Caterpillar_exposure_control3"
2 | load_path: "./checkpoints_new/Caterpillar.pth"
3 | use_amp: false
4 | dataset:
5 | coord_scale: 30.0
6 | type: "t2"
7 | path: "./data/tanks_temples/Caterpillar/"
8 | factor: 2
9 | geoms:
10 | background:
11 | constant: 4.0
12 | exposure_control:
13 | use: true
14 | models:
15 | attn:
16 | embed:
17 | k_L: [4, 4, 4]
18 | q_L: [4]
19 | v_L: [4, 4]
20 | renderer:
21 | generator:
22 | small_unet:
23 | affine_layer: -1
24 | training:
25 | steps: 100000
26 | lr:
27 | lr_factor: 0.2
28 | attn:
29 | type: "none"
30 | warmup: 0
31 | points:
32 | base_lr: 0.0
33 | points_influ_scores:
34 | type: "none"
35 | warmup: 0
36 | feats:
37 | type: "none"
38 | warmup: 0
39 | generator:
40 | type: "none"
41 | warmup: 0
42 | eval:
43 | dataset:
44 | type: "t2"
45 | path: "./data/tanks_temples/Caterpillar/"
46 | factor: 2
47 | img_idx: 0
48 | test:
49 | load_path: "checkpoints/Caterpillar_exposure_control.pth"
50 | datasets:
51 | - name: "testset"
52 | type: "t2"
53 | path: "./data/tanks_temples/Caterpillar/"
54 | factor: 2
--------------------------------------------------------------------------------
/dataset/load_nerfsyn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import imageio
4 | from PIL import Image
5 | import json
6 |
7 |
8 | def load_blender_data(basedir, split='train', factor=1, read_offline=True):
9 | with open(os.path.join(basedir, f'transforms_{split}.json'), 'r') as fp:
10 | meta = json.load(fp)
11 |
12 | poses = []
13 | images = []
14 | image_paths = []
15 |
16 | for i, frame in enumerate(meta['frames']):
17 | img_path = os.path.abspath(os.path.join(basedir, frame['file_path'] + '.png'))
18 | poses.append(np.array(frame['transform_matrix']))
19 | image_paths.append(img_path)
20 |
21 | if read_offline:
22 | img = imageio.imread(img_path)
23 | H, W = img.shape[:2]
24 | if factor > 1:
25 | img = Image.fromarray(img).resize((W//factor, H//factor))
26 | images.append((np.array(img) / 255.).astype(np.float32))
27 | elif i == 0:
28 | img = imageio.imread(img_path)
29 | H, W = img.shape[:2]
30 | if factor > 1:
31 | img = Image.fromarray(img).resize((W//factor, H//factor))
32 | images.append((np.array(img) / 255.).astype(np.float32))
33 |
34 | poses = np.array(poses).astype(np.float32)
35 | images = np.array(images).astype(np.float32)
36 |
37 | H, W = images[0].shape[:2]
38 | camera_angle_x = float(meta['camera_angle_x'])
39 | focal = .5 * W / np.tan(.5 * camera_angle_x)
40 |
41 | return images, poses, [H, W, focal], image_paths
42 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import lpips
4 | from .model import PAPR
5 | from .lpips import LPNet
6 |
7 |
8 | class BasicLoss(nn.Module):
9 | def __init__(self, losses_and_weights):
10 | super(BasicLoss, self).__init__()
11 | self.losses_and_weights = losses_and_weights
12 |
13 | def forward(self, pred, target):
14 | loss = 0
15 | for name_and_weight, loss_func in self.losses_and_weights.items():
16 | name, weight = name_and_weight.split('/')
17 | cur_loss = loss_func(pred, target)
18 | loss += float(weight) * cur_loss
19 | # print(name, weight, cur_loss, loss)
20 | return loss
21 |
22 |
23 | def get_model(args, device='cuda'):
24 | return PAPR(args, device=device)
25 |
26 |
27 | def get_loss(args, bias=1.0):
28 | losses = nn.ModuleDict()
29 | for loss_name, weight in args.items():
30 | if weight > 0:
31 | if loss_name == "mse":
32 | losses[loss_name + "/" +
33 | str(format(weight, '.0e'))] = nn.MSELoss()
34 | print("Using MSE loss, loss weight: ", weight)
35 | elif loss_name == "l1":
36 | losses[loss_name + "/" +
37 | str(format(weight, '.0e'))] = nn.L1Loss()
38 | print("Using L1 loss, loss weight: ", weight)
39 | elif loss_name == "lpips":
40 | lpips = LPNet()
41 | lpips.eval()
42 | losses[loss_name + "/" + str(format(weight, '.0e'))] = lpips
43 | print("Using LPIPS loss, loss weight: ", weight)
44 | elif loss_name == "lpips_alex":
45 | lpips = lpips.LPIPS()
46 | lpips.eval()
47 | losses[loss_name + "/" + str(format(weight, '.0e'))] = lpips
48 | print("Using LPIPS AlexNet loss, loss weight: ", weight)
49 | else:
50 | raise NotImplementedError(
51 | 'loss [{:s}] is not supported'.format(loss_name))
52 | return BasicLoss(losses)
53 |
--------------------------------------------------------------------------------
/models/renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .unet import SmallUNet
3 | from .mlp import MLP
4 |
5 |
6 | class MLPGenerator(torch.nn.Module):
7 | def __init__(self, inp_dim=2, num_layers=3, num_channels=128, out_dim=2, act_type="leakyrelu", last_act_type="none",
8 | use_wn=True, a=1., b=1., trainable=False, skip_layers=[], bias=True, half_layers=[], residual_layers=[],
9 | residual_dims=[]):
10 | super(MLPGenerator, self).__init__()
11 | self.mlp = MLP(inp_dim=inp_dim, num_layers=num_layers, num_channels=num_channels, out_dim=out_dim,
12 | act_type=act_type, last_act_type=last_act_type, use_wn=use_wn, a=a, b=b, trainable=trainable,
13 | skip_layers=skip_layers, bias=bias, half_layers=half_layers, residual_layers=residual_layers,
14 | residual_dims=residual_dims)
15 |
16 | def forward(self, x, residuals=[], gamma=None, beta=None): # (N, C, H, W)
17 | return self.mlp(x.permute(0, 2, 3, 1), residuals).permute(0, 3, 1, 2)
18 |
19 |
20 |
21 | def get_generator(args, in_c, out_c, use_amp=False, amp_dtype=torch.float16):
22 | if args.type == "small-unet":
23 | opt = args.small_unet
24 | return SmallUNet(in_c, out_c, bilinear=opt.bilinear, single=opt.single, norm=opt.norm, last_act=opt.last_act,
25 | use_amp=use_amp, amp_dtype=amp_dtype, affine_layer=opt.affine_layer)
26 | elif args.type == "mlp":
27 | opt = args.mlp
28 | return MLPGenerator(inp_dim=in_c, num_layers=opt.num_layers, num_channels=opt.num_channels, out_dim=out_c,
29 | act_type=opt.act_type, last_act_type=opt.last_act_type, use_wn=opt.use_wn, a=opt.act_a, b=opt.act_b,
30 | trainable=opt.act_trainable, skip_layers=opt.skip_layers, bias=opt.bias, half_layers=opt.half_layers,
31 | residual_layers=opt.residual_layers, residual_dims=opt.residual_dims)
32 | else:
33 | raise NotImplementedError(
34 | 'generator type [{:d}] is not supported'.format(args.type))
35 |
--------------------------------------------------------------------------------
/dataset/load_t2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import imageio
4 | from PIL import Image
5 |
6 | blender2opencv = np.array(
7 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
8 |
9 |
10 | def get_instrinsic(filepath):
11 | try:
12 | intrinsic = np.loadtxt(filepath).astype(np.float32)[:3, :3]
13 | return intrinsic
14 | except ValueError:
15 | pass
16 |
17 | # Get camera intrinsics
18 | with open(filepath, 'r') as file:
19 | f, cx, cy, _ = map(float, file.readline().split())
20 | fy = fx = f
21 |
22 | # Build the intrinsic matrices
23 | intrinsic = np.array([[fx, 0., cx],
24 | [0., fy, cy],
25 | [0., 0, 1]])
26 | return intrinsic
27 |
28 |
29 | def load_t2_data(basedir, factor=1, split="train", read_offline=True, tgtH=1280, tgtW=2176):
30 | colordir = os.path.join(basedir, "rgb")
31 | posedir = os.path.join(basedir, "pose")
32 | train_image_paths = [f for f in os.listdir(colordir) if os.path.isfile(
33 | os.path.join(colordir, f)) and f.startswith("0")]
34 | test_image_paths = [f for f in os.listdir(colordir) if os.path.isfile(
35 | os.path.join(colordir, f)) and f.startswith("1")]
36 |
37 | if split == "train":
38 | image_paths = train_image_paths
39 | elif split == "test":
40 | image_paths = test_image_paths
41 | else:
42 | raise ValueError("Unknown split: {}".format(split))
43 |
44 | image_paths = sorted(image_paths, key=lambda x: int(
45 | x.split(".")[0].split("_")[-1]))
46 |
47 | images = []
48 | poses = []
49 | out_image_paths = []
50 |
51 | intrinsic = get_instrinsic(os.path.join(basedir, "intrinsics.txt"))
52 | fx, _, cx = intrinsic[0]
53 | _, fy, cy = intrinsic[1]
54 |
55 | for i, img_path in enumerate(image_paths):
56 | image_path = os.path.abspath(os.path.join(colordir, img_path))
57 | out_image_paths.append(image_path)
58 |
59 | if read_offline:
60 | image = imageio.imread(image_path)
61 | H, W = image.shape[:2]
62 | if factor != 1:
63 | image = Image.fromarray(image).resize(
64 | (tgtW // factor, tgtH // factor))
65 | images.append((np.array(image) / 255.).astype(np.float32))
66 | elif i == 0:
67 | image = imageio.imread(image_path)
68 | H, W = image.shape[:2]
69 | if factor != 1:
70 | image = Image.fromarray(image).resize(
71 | (tgtW // factor, tgtH // factor))
72 | images.append((np.array(image) / 255.).astype(np.float32))
73 |
74 | pose_path = os.path.join(posedir, img_path.replace(".png", ".txt"))
75 | pose = np.loadtxt(pose_path).astype(np.float32)
76 | pose = pose @ blender2opencv
77 | poses.append(pose)
78 |
79 | images = np.stack(images, 0)
80 | poses = np.stack(poses, 0)
81 |
82 | realH, realW = images.shape[1:3]
83 | fx = fx * (realW / W)
84 | fy = fy * (realH / H)
85 |
86 | return images, poses, [realH, realW, fx, fy], out_image_paths
87 |
88 |
89 | if __name__ == '__main__':
90 | images, poses, [realH, realW, fx, fy], out_image_paths = load_t2_data('../data/tanks_temples/Family/', factor=1, split="train", read_offline=False, tgtH=1280, tgtW=2176)
91 | # print(out_image_paths)
92 | for i, path in enumerate(out_image_paths):
93 | # print(path)
94 | if '0_0063_00000100' in path:
95 | print(i, path)
96 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.utils import weight_norm
4 | from torch import autocast
5 | from .utils import activation_func
6 |
7 |
8 | def get_mapping_mlp(args, use_amp=False, amp_dtype=torch.float16):
9 | return MappingMLP(args.mapping_mlp, inp_dim=args.shading_code_dim, out_dim=args.mapping_mlp.out_dim, use_amp=use_amp, amp_dtype=amp_dtype)
10 |
11 |
12 | class MLP(nn.Module):
13 | def __init__(self, inp_dim=2, num_layers=3, num_channels=128, out_dim=2, act_type="leakyrelu", last_act_type="none",
14 | use_wn=True, a=1., b=1., trainable=False, skip_layers=[], bias=True, half_layers=[], residual_layers=[],
15 | residual_dims=[]):
16 | super(MLP, self).__init__()
17 | self.skip_layers = skip_layers
18 | self.residual_layers = residual_layers
19 | self.residual_dims = residual_dims
20 | assert len(residual_dims) == len(residual_layers)
21 | wn = weight_norm if use_wn else lambda x, **kwargs: x
22 | layers = [nn.Identity()]
23 | for i in range(num_layers):
24 | cur_inp = inp_dim if i == 0 else num_channels
25 | cur_out = out_dim if i == num_layers - 1 else num_channels
26 | if (i+1) in half_layers:
27 | cur_out = cur_out // 2
28 | if i in half_layers:
29 | cur_inp = cur_inp // 2
30 | if i in self.skip_layers:
31 | cur_inp += inp_dim
32 | if i in self.residual_layers:
33 | cur_inp += self.residual_dims[residual_layers.index(i)]
34 | layers.append(
35 | wn(nn.Linear(cur_inp, cur_out, bias=bias), name='weight'))
36 | layers.append(activation_func(act_type=act_type,
37 | num_channels=cur_out, a=a, b=b, trainable=trainable))
38 | layers[-1] = activation_func(act_type=last_act_type,
39 | num_channels=out_dim, a=a, b=b, trainable=trainable)
40 | assert len(layers) == 2 * num_layers + 1
41 | self.model = nn.ModuleList(layers)
42 |
43 | for p in self.model.parameters():
44 | if p.dim() > 1:
45 | nn.init.xavier_uniform_(p)
46 |
47 | def forward(self, x, residuals=[]):
48 | skip_layers = [i*2+1 for i in self.skip_layers]
49 | residual_layers = [i*2+1 for i in self.residual_layers]
50 | assert len(residuals) == len(self.residual_layers)
51 | # print(skip_layers)
52 | inp = x
53 | for i, layer in enumerate(self.model):
54 | if i in skip_layers:
55 | x = torch.cat([x, inp], dim=-1)
56 | if i in residual_layers:
57 | x = torch.cat([x, residuals[residual_layers.index(i)]], dim=-1)
58 | x = layer(x)
59 | return x
60 |
61 |
62 | class MappingMLP(nn.Module):
63 | def __init__(self, args, inp_dim=2, out_dim=2, use_amp=False, amp_dtype=torch.float16):
64 | super(MappingMLP, self).__init__()
65 | self.args = args
66 | self.inp_dim = inp_dim
67 | self.out_dim = out_dim
68 | self.use_amp = use_amp
69 | self.amp_dtype = amp_dtype
70 | self.model = MLP(inp_dim=inp_dim, num_layers=args.num_layers, num_channels=args.dim, out_dim=out_dim,
71 | act_type=args.act, last_act_type=args.last_act, use_wn=args.use_wn)
72 | print("Mapping MLP:\n", self.model)
73 |
74 | def forward(self, x):
75 |
76 | with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp):
77 | out = self.model(x)
78 | return out
79 |
--------------------------------------------------------------------------------
/models/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models as tv
4 | from collections import namedtuple
5 | import os
6 |
7 |
8 | class vgg16(nn.Module):
9 | def __init__(self, requires_grad=False, pretrained=True):
10 | super(vgg16, self).__init__()
11 | vgg_pretrained_features = tv.vgg16(weights=tv.VGG16_Weights.IMAGENET1K_V1 if pretrained else None).features
12 | self.slice1 = nn.Sequential()
13 | self.slice2 = nn.Sequential()
14 | self.slice3 = nn.Sequential()
15 | self.slice4 = nn.Sequential()
16 | self.slice5 = nn.Sequential()
17 | self.N_slices = 5
18 | for x in range(4):
19 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
20 | for x in range(4, 9):
21 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
22 | for x in range(9, 16):
23 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(16, 23):
25 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
26 | for x in range(23, 30):
27 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
28 | if not requires_grad:
29 | for param in self.parameters():
30 | param.requires_grad = False
31 |
32 | def forward(self, x):
33 | h = self.slice1(x)
34 | h_relu1_2 = h
35 | h = self.slice2(h)
36 | h_relu2_2 = h
37 | h = self.slice3(h)
38 | h_relu3_3 = h
39 | h = self.slice4(h)
40 | h_relu4_3 = h
41 | h = self.slice5(h)
42 | h_relu5_3 = h
43 | vgg_outputs = namedtuple(
44 | "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
45 | out = vgg_outputs(h_relu1_2, h_relu2_2,
46 | h_relu3_3, h_relu4_3, h_relu5_3)
47 |
48 | return out
49 |
50 |
51 | class ScalingLayer(nn.Module):
52 | # For rescaling the input to vgg16
53 | def __init__(self):
54 | super(ScalingLayer, self).__init__()
55 | self.register_buffer('shift', torch.Tensor(
56 | [-.030, -.088, -.188])[None, :, None, None])
57 | self.register_buffer('scale', torch.Tensor(
58 | [.458, .448, .450])[None, :, None, None])
59 |
60 | def forward(self, inp):
61 | return (inp - self.shift) / self.scale
62 |
63 |
64 | def normalize_tensor(in_feat, eps=1e-10):
65 | norm_factor = torch.sqrt(
66 | torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps)
67 | return in_feat / (norm_factor + eps)
68 |
69 |
70 | def spatial_average(in_tens, keepdim=True):
71 | return in_tens.mean([2, 3], keepdim=keepdim)
72 |
73 |
74 | class NetLinLayer(nn.Module):
75 | ''' A single linear layer used as placeholder for LPIPS learnt weights '''
76 |
77 | def __init__(self):
78 | super(NetLinLayer, self).__init__()
79 | self.weight = None
80 |
81 | def forward(self, inp):
82 | out = torch.sum(self.weight * inp, 1, keepdim=True)
83 | return out
84 |
85 |
86 | class LPNet(nn.Module):
87 | def __init__(self):
88 | super(LPNet, self).__init__()
89 |
90 | self.scaling_layer = ScalingLayer()
91 | self.net = vgg16(pretrained=True, requires_grad=False)
92 | self.L = 5
93 | self.lins = [NetLinLayer() for _ in range(self.L)]
94 | self.lins = nn.ModuleList(self.lins)
95 | model_path = os.path.abspath(
96 | os.path.join('.', 'vgg.pth'))
97 | print('Loading model from: %s' % model_path)
98 | weights = torch.load(model_path, map_location='cpu')
99 | for ll in range(self.L):
100 | self.lins[ll].weight = nn.Parameter(
101 | weights["lin%d.model.1.weight" % ll])
102 |
103 | def forward(self, in0, in1):
104 | in0 = in0.permute(0, 3, 1, 2)
105 | in1 = in1.permute(0, 3, 1, 2)
106 | in0 = 2 * in0 - 1
107 | in0_input = self.scaling_layer(in0)
108 | in1 = 2 * in1 - 1
109 | in1_input = self.scaling_layer(in1)
110 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
111 | feats0, feats1, diffs = {}, {}, {}
112 |
113 | for kk in range(self.L):
114 | feats0[kk], feats1[kk] = normalize_tensor(
115 | outs0[kk]), normalize_tensor(outs1[kk])
116 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
117 |
118 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
119 | for kk in range(self.L)]
120 |
121 | val = res[0]
122 | for ll in range(1, self.L):
123 | val += res[ll]
124 |
125 | return val.squeeze().mean()
126 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import imageio
5 | from PIL import Image
6 | from copy import deepcopy
7 | from .utils import load_meta_data, get_rays, extract_patches
8 |
9 |
10 | class RINDataset(Dataset):
11 | """ Ray Image Normal Dataset """
12 |
13 | def __init__(self, args, mode='train'):
14 | self.args = args
15 | images, c2w, H, W, focal_x, focal_y, image_paths = load_meta_data(args, mode=mode)
16 | num_imgs = len(image_paths)
17 |
18 | self.num_imgs = num_imgs
19 | coord_scale = args.coord_scale
20 | if coord_scale != 1:
21 | scaling_matrix = torch.tensor([[coord_scale, 0, 0, 0],
22 | [0, coord_scale, 0, 0],
23 | [0, 0, coord_scale, 0],
24 | [0, 0, 0, 1]], dtype=torch.float32)
25 | c2w = torch.matmul(scaling_matrix, c2w)
26 | print("c2w: ", c2w.shape)
27 |
28 | self.H = H
29 | self.W = W
30 | self.focal_x = focal_x
31 | self.focal_y = focal_y
32 | self.c2w = c2w # (N, 4, 4)
33 | self.image_paths = image_paths
34 | self.images = images # (N, H, W, C) or None
35 |
36 | if self.args.read_offline:
37 | rays_o, rays_d = get_rays(H, W, focal_x, focal_y, c2w)
38 | self.rayd = rays_d # (N, H, W, 3)
39 | self.rayo = rays_o # (N, 3)
40 |
41 | if self.args.extract_patch == True and self.args.extract_online == False and self.args.read_offline == True:
42 | img_patches, rayd_patches, rayo_patches, num_patches = extract_patches(images, rays_o, rays_d, args)
43 | # (N, n_patches, patch_height, patch_width, C) or None
44 | self.img_patches = img_patches
45 | # (N, n_patches, patch_height, patch_width, 3)
46 | self.rayd_patches = rayd_patches
47 | self.rayo_patches = rayo_patches # (N, n_patches, 3)
48 | self.num_patches = num_patches
49 |
50 | def _read_image_from_path(self, image_idx):
51 | image_path = self.image_paths[image_idx]
52 | image = imageio.imread(image_path)
53 | image = Image.fromarray(image).resize((self.W, self.H))
54 | image = (np.array(image) / 255.).astype(np.float32)
55 |
56 | if self.args.white_bg and image.shape[-1] == 4:
57 | image = image[..., :3] * image[..., -1:] + (1. - image[..., -1:])
58 | elif not self.args.white_bg:
59 | image = image[..., :3]
60 | mask = image.sum(-1) == 3.0
61 | image[mask] = 0.
62 |
63 | image = torch.from_numpy(image).float()
64 |
65 | rayo, rayd = get_rays(self.H, self.W, self.focal_x, self.focal_y, self.c2w[image_idx:image_idx+1])
66 |
67 | return image, rayo, rayd
68 |
69 | def __len__(self):
70 | if self.args.extract_patch == True and self.args.extract_online == False and self.args.read_offline == True:
71 | return self.num_imgs * self.num_patches
72 | else:
73 | return self.num_imgs
74 |
75 | def __getitem__(self, idx):
76 | if self.args.extract_patch == True and self.args.extract_online == False and self.args.read_offline == True:
77 | img_idx = idx // self.num_patches
78 | patch_idx = idx % self.num_patches
79 | return img_idx, patch_idx, \
80 | self.img_patches[img_idx, patch_idx] if self.img_patches is not None else 0, \
81 | self.rayd_patches[img_idx, patch_idx], \
82 | self.rayo_patches[img_idx, patch_idx]
83 |
84 | elif self.args.extract_patch == True and self.args.extract_online == True:
85 | img_idx = idx
86 | args = self.args
87 | args.patches.max_patches = 1
88 | if self.args.read_offline:
89 | img_patches, rayd_patches, rayo_patches, _ = extract_patches(self.images[img_idx:img_idx+1],
90 | self.rayo[img_idx:img_idx+1],
91 | self.rayd[img_idx:img_idx+1],
92 | args)
93 | else:
94 | image, rayo, rayd = self._read_image_from_path(img_idx)
95 | img_patches, rayd_patches, rayo_patches, _ = extract_patches(image[None, ...], rayo, rayd, args)
96 |
97 | return img_idx, 0, \
98 | img_patches[0, 0] if img_patches is not None else 0, \
99 | rayd_patches[0, 0], \
100 | rayo_patches[0, 0]
101 | else:
102 | if self.args.read_offline:
103 | return idx, 0, self.images[idx] if self.images is not None else 0, \
104 | self.rayd[idx], self.rayo[idx]
105 | else:
106 | image, rayo, rayd = self._read_image_from_path(idx)
107 | return idx, 0, image, rayd.squeeze(0), rayo.squeeze(0)
108 |
109 | def get_full_img(self, img_idx):
110 | if self.args.read_offline:
111 | return self.images[img_idx].unsqueeze(0) if self.images is not None else None, \
112 | self.rayd[img_idx].unsqueeze(0), self.rayo[img_idx].unsqueeze(0)
113 | else:
114 | image, rayo, rayd = self._read_image_from_path(img_idx)
115 | return image[None, ...], rayd, rayo
116 |
117 | def get_c2w(self, img_idx):
118 | return self.c2w[img_idx]
119 |
120 | def get_new_rays(self, c2w):
121 | return get_rays(self.H, self.W, self.focal_x, self.focal_y, c2w)
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PAPR: Proximity Attention Point Rendering (NeurIPS 2023 Spotlight 🤩)
2 | [Yanshu Zhang*](https://zvict.github.io/), [Shichong Peng*](https://sites.google.com/view/niopeng/home), [Alireza Moazeni](https://amoazeni75.github.io/), [Ke Li](https://www.sfu.ca/~keli/) (* denotes equal contribution)
3 |
4 | 
5 |
6 | [Project Sites](https://zvict.github.io/papr)
7 | | [Paper](https://arxiv.org/abs/2307.11086) |
8 | Primary contact: [Yanshu Zhang](https://zvict.github.io/)
9 |
10 | Proximity Attention Point Rendering (PAPR) is a new method for joint novel view synthesis and 3D reconstruction. It simultaneously learns from scratch an accurate point cloud representation of the scene surface, and an attention-based neural network that renders the point cloud from novel views.
11 |
12 |
13 |
14 | [](https://youtu.be/1atBGH_pDHY)
15 |
16 | ## BibTeX
17 | PAPR: Proximity Attention Point Rendering.
18 | ```
19 | @inproceedings{zhang2023papr,
20 | title={PAPR: Proximity Attention Point Rendering},
21 | author={Yanshu Zhang and Shichong Peng and Seyed Alireza Moazenipourasil and Ke Li},
22 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
23 | year={2023}
24 | }
25 | ```
26 |
27 | ## Installation
28 | ```
29 | git clone git@github.com:zvict/papr.git # or 'git clone https://github.com/zvict/papr'
30 | cd papr
31 | conda env create -f papr.yml
32 | conda activate papr
33 | ```
34 | Or use virtual environment with `python=3.9`
35 | ```
36 | python -m venv path/to/
37 | source path/to//bin/activate
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ## Data Preparation
42 |
43 | Expected dataset structure in the source path location:
44 | ```
45 | papr
46 | ├── data
47 | │ ├── nerf_synthetic
48 | │ │ ├── chair
49 | │ │ │ ├── train
50 | │ │ │ ├── val
51 | │ │ │ ├── test
52 | │ │ │ ├── transforms_train.json
53 | │ │ │ ├── transforms_val.json
54 | │ │ │ ├── transforms_test.json
55 | │ │ ├── ...
56 | │ ├── tanks_temples
57 | │ │ ├── Barn
58 | │ │ │ ├── pose
59 | │ │ │ ├── rgb
60 | │ │ │ ├── intrinsics.txt
61 | │ │ ├── ...
62 | ```
63 | ### NeRF Synthetic
64 | Download NeRF Synthetic Dataset from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and put it under `data/nerf_synthetic/`
65 |
66 |
67 | ### Tanks & Temples
68 | Download [Tanks&Temples](https://www.tanksandtemples.org/) from [here](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) and put it under:
69 | `data/tanks_temples/`
70 |
71 | ### Use your own data
72 | You can refer to this [issue](https://github.com/zvict/papr/issues/3#issuecomment-1907260683) for the instructions on how to prepare the dataset.
73 |
74 | You need to create a new configuration file for your own dataset, and put it under `configs`. The parameter `dataset.type` in the configuration file specifies the type of the dataset. If your dataset is in the same format as the NeRF Synthetic dataset, you can directly set `dataset.type` to `"synthetic"`. Otherwise, you need to implement your own python script to load the dataset under the `dataset` folder, and add it in the function `load_meta_data` in `dataset/utils.py`.
75 |
76 | Most default parameters in `configs/default.yml` are general and can be used for your own dataset. You can specify the parameters that are specific to your dataset in the configuration file you created, similar to the configuration files for the NeRF Synthetic dataset and the Tanks and Temples dataset.
77 |
78 | ## Overview
79 |
80 | The codebase has two main components: data loading part in `dataset/` and models in `models/`. Class `PAPR` in `models/model.py` defines our main model. All the configurations are in `configs/`, and `configs/demo.yml` is a demo configuration with comments of important arguments.
81 |
82 | We provide a notebook `demo.ipynb` to demonstrate how to train and test the model with the demo configuration file, as well as how to use exposure control to improve the rendering quality of real-world scenes captured with auto-exposure turned on.
83 |
84 | ## Training
85 | ```
86 | python train.py --opt configs/nerfsyn/chair.yml
87 | ```
88 |
89 | ## Finetuning with [cIMLE](https://arxiv.org/abs/2004.03590) (Optional)
90 |
91 | For real-world scenes where exposure can change between views, we can introduce an additional latent code input into our model and finetune the model using a technique called [conditional Implicit Maximum Likelihood Estimation (cIMLE)](https://arxiv.org/abs/2004.03590) to control the exposure level of the rendered image, as described in Section 4.4 and Appendix A.8 in the paper. A pre-trained model is required to finetune with exposure control, by running `train.py` with default configurations. We provide a demo configuration file for the Caterpillar scene from the Tanks and Temples dataset at `configs/t2/Caterpillar_exposure_control.yml`.
92 |
93 | To finetune a pre-trained model with exposure control, run:
94 | ```
95 | python exposure_control_finetune.py --opt configs/t2/Caterpillar_exposure_control.yml
96 | ```
97 |
98 | ## Evaluation
99 | To evaluate your trained model without the finetuning for exposure control, run:
100 | ```
101 | python test.py --opt configs/nerfsyn/chair.yml
102 | ```
103 | Which gives you rendered images and metrics on the test set.
104 |
105 | With a finetuned model, you can render all the test views with a single random exposure level, by runing:
106 | ```
107 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp
108 | ```
109 | To generate images with different random exposure levels for a single view, run:
110 | ```
111 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --random --view 0
112 | ```
113 | Note that during testing, the scale of the latent codes should be increased to generate images with more diverse exposures, for example,
114 | ```
115 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --random --view 0 --scale 8
116 | ```
117 | Once you generate images with different exposure levels, you can interpolate two picked exposure levels by specifiying their index, for example,
118 | ```
119 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --intrp --view 0 --start_index 0 --end_index 1
120 | ```
121 |
122 | ## Pretrained Models
123 |
124 | We provide pretrained models on NeRF Synthetic and Tanks&Temples datasets here (without finetuning): [Google Drive](https://drive.google.com/drive/folders/1HSNlMu6Uup9o5hqi7T0hgDf63yR9W90s?usp=sharing). We also provide a pre-trained model with exposure control on the Caterpillar scene in the Google Drive. To load the pretrained models, please put them under `checkpoints/`, and change the `test.load_path` in the config file.
125 |
126 | ## Acknowledgement
127 | This research was enabled in part by support provided by NSERC, the BC DRI Group and the Digital Research Alliance of Canada.
128 |
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import math
4 | from .load_t2 import load_t2_data
5 | from .load_nerfsyn import load_blender_data
6 |
7 |
8 | def cam_to_world(coords, c2w, vector=True):
9 | """
10 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3]
11 | c2w: [N, 4, 4] or [4, 4]
12 | """
13 | if vector: # Convert to homogeneous coordinates
14 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1)
15 | else:
16 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1)
17 |
18 | if coords.ndim == 5:
19 | assert c2w.ndim == 2
20 | B, H, W, N, _ = coords.shape
21 | transformed_coords = torch.sum(
22 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3]
23 | elif coords.ndim == 4:
24 | assert c2w.ndim == 3
25 | _, H, W, _ = coords.shape
26 | N = c2w.shape[0]
27 | transformed_coords = torch.sum(
28 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4]
29 | elif coords.ndim == 3:
30 | assert c2w.ndim == 2
31 | H, W, _ = coords.shape
32 | transformed_coords = torch.sum(
33 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4]
34 | elif coords.ndim == 2:
35 | assert c2w.ndim == 2
36 | K, _ = coords.shape
37 | transformed_coords = torch.sum(
38 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4]
39 | else:
40 | raise ValueError('Wrong dimension of coords')
41 | return transformed_coords[..., :3]
42 |
43 |
44 | def world_to_cam(coords, c2w, vector=True):
45 | """
46 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3]
47 | c2w: [N, 4, 4] or [4, 4]
48 | """
49 | if vector: # Convert to homogeneous coordinates
50 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1)
51 | else:
52 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1)
53 |
54 | c2w = torch.inverse(c2w)
55 | if coords.ndim == 5:
56 | assert c2w.ndim == 2
57 | B, H, W, N, _ = coords.shape
58 | transformed_coords = torch.sum(
59 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3]
60 | elif coords.ndim == 4:
61 | assert c2w.ndim == 3
62 | _, H, W, _ = coords.shape
63 | N = c2w.shape[0]
64 | transformed_coords = torch.sum(
65 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4]
66 | elif coords.ndim == 3:
67 | assert c2w.ndim == 2
68 | H, W, _ = coords.shape
69 | transformed_coords = torch.sum(
70 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4]
71 | elif coords.ndim == 2:
72 | assert c2w.ndim == 2
73 | K, _ = coords.shape
74 | transformed_coords = torch.sum(
75 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4]
76 | else:
77 | raise ValueError('Wrong dimension of coords')
78 | return transformed_coords[..., :3]
79 |
80 |
81 | def get_rays(H, W, focal_x, focal_y, c2w, fineness=1):
82 | N = c2w.shape[0]
83 | width = torch.linspace(
84 | 0, W / focal_x, steps=int(W / fineness) + 1, dtype=torch.float32)
85 | height = torch.linspace(
86 | 0, H / focal_y, steps=int(H / fineness) + 1, dtype=torch.float32)
87 | y, x = torch.meshgrid(height, width, indexing='ij')
88 | pixel_size_x = width[1] - width[0]
89 | pixel_size_y = height[1] - height[0]
90 | x = (x - W / focal_x / 2 + pixel_size_x / 2)[:-1, :-1]
91 | y = -(y - H / focal_y / 2 + pixel_size_y / 2)[:-1, :-1]
92 | # [H, W, 3], vectors, since the camera is at the origin
93 | dirs_d = torch.stack([x, y, -torch.ones_like(x)], -1)
94 | rays_d = cam_to_world(dirs_d.unsqueeze(0), c2w) # [N, H, W, 3]
95 | rays_o = c2w[:, :3, -1] # [N, 3]
96 | return rays_o, rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
97 |
98 |
99 | def extract_patches(imgs, rays_o, rays_d, args):
100 | patch_opt = args.patches
101 | N, H, W, C = imgs.shape
102 |
103 | num_patches = patch_opt.max_patches
104 | rayd_patches = np.zeros((N, num_patches, patch_opt.height, patch_opt.width, 3), dtype=np.float32)
105 | rayo_patches = np.zeros((N, num_patches, 3), dtype=np.float32)
106 | img_patches = np.zeros((N, num_patches, patch_opt.height, patch_opt.width, C), dtype=np.float32)
107 |
108 | for i in range(N):
109 | for n_patch in range(num_patches):
110 | start_height = np.random.randint(0, H - patch_opt.height)
111 | start_width = np.random.randint(0, W - patch_opt.width)
112 | end_height = start_height + patch_opt.height
113 | end_width = start_width + patch_opt.width
114 | rayd_patches[i, n_patch, :, :] = rays_d[i, start_height:end_height, start_width:end_width]
115 | rayo_patches[i, n_patch, :] = rays_o[i, :]
116 | img_patches[i, n_patch, :, :] = imgs[i, start_height:end_height, start_width:end_width]
117 |
118 | return img_patches, rayd_patches, rayo_patches, num_patches
119 |
120 |
121 | def load_meta_data(args, mode="train"):
122 | """
123 | 0 -----------> W
124 | |
125 | |
126 | |
127 | ⬇
128 | H
129 | [H, W, 4]
130 | """
131 | image_paths = None
132 |
133 | if args.type == "synthetic":
134 | images, poses, hwf, image_paths = load_blender_data(
135 | args.path, split=mode, factor=args.factor, read_offline=args.read_offline)
136 | print('Loaded blender', images.shape, hwf, args.path)
137 |
138 | H, W, focal = hwf
139 | hwf = [H, W, focal, focal]
140 |
141 | if args.white_bg:
142 | images = images[..., :3] * \
143 | images[..., -1:] + (1. - images[..., -1:])
144 | else:
145 | images = images[..., :3]
146 |
147 | elif args.type == "t2":
148 | images, poses, hwf, image_paths = load_t2_data(
149 | args.path, factor=args.factor, split=mode, read_offline=args.read_offline)
150 | print('Loaded t2', images.shape, hwf, args.path,
151 | images.min(), images.max(), images[0, 10, 10, :])
152 |
153 | if args.white_bg and images.shape[-1] == 4:
154 | images = images[..., :3] * \
155 | images[..., -1:] + (1. - images[..., -1:])
156 | elif not args.white_bg:
157 | images = images[..., :3]
158 | mask = images.sum(-1) == 3.0
159 | images[mask] = 0.
160 |
161 | else:
162 | raise ValueError("Unknown dataset type: {}".format(args.type))
163 |
164 | H, W, focal_x, focal_y = hwf
165 |
166 | images = torch.from_numpy(images).float()
167 | poses = torch.from_numpy(poses).float()
168 |
169 | return images, poses, H, W, focal_x, focal_y, image_paths
170 |
171 |
172 | def rgb2norm(img):
173 | norm_vec = np.stack([img[..., 0] * 2.0 / 255.0 - 1.0,
174 | img[..., 1] * 2.0 / 255.0 - 1.0,
175 | img[..., 2] * 2.0 / 255.0 - 1.0,
176 | img[..., 3] / 255.0], axis=-1)
177 | return norm_vec
178 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 0. Environment setup and dataset preparation\n",
8 | "\n",
9 | "Please follow the instructions in the README file to setup the environment and prepare the dataset."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# 1. Configurations\n",
17 | "\n",
18 | "We povide the default values of the parameters in `configs/default.yml`, with comments explaining the meaning of most parameters. For each individual scene, we can override the default values by providing a separate configuration file in the `configs` directory. For example, the configuration file for the Lego scene from the NeRF Synthetic dataset is `configs/nerfsyn/lego.yml`. "
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "# 2. Novel view synthesis using PAPR\n",
26 | "\n",
27 | "## 2.1 Training\n",
28 | "\n",
29 | "To synthesize novel views of a scene, we need to train a PAPR model. We provide the configuration files for each scene in the NeRF Synthetic dataset and the Tanks and Temples dataset under the `configs` directory. \n",
30 | "\n",
31 | "The configuration files specify the training parameters and the dataset paths. Detailed comments are provided in `configs/demo.yml` to help you understand the parameters.\n",
32 | "\n",
33 | "To train a PAPR model, on the lego scene of NeRF Synthetic dataset for example, run the following command:"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "!python train.py --opt configs/nerfsyn/lego.yml"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "You can find the trained model in the `experiments` folder. You can specify the folder name using the `index` parameter in the configuration files. You can also find the training logs and middle plots in the folder during training.\n",
50 | "\n",
51 | "## 2.2 Testing\n",
52 | "\n",
53 | "Once the training is finished, you can synthesize images from the test views by running:"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "!python test.py --opt configs/nerfsyn/lego.yml"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "The generated images and evaluation metrics will be saved under `experiments/{index}/test`."
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "# 3. Improve the performance for in-the-wild data with exposure control\n",
77 | "\n",
78 | "Whenever people take real life photos with auto-exposure turned on, the exposure of the images can vary significantly across different views, for example the Tanks and Temples dataset. To improve the performance of PAPR on in-the-wild data, we propose a simple yet effective method to control the exposure of the rendered images. We introduce an additional latent code input into our model and finetune it using a technique called [conditional Implicit Maximum Likelihood Estimation (cIMLE)](https://arxiv.org/abs/2004.03590). During test time, we can manipulate the exposure of the rendered image by changing the latent code input.\n",
79 | "\n",
80 | "One can refer to Section 4.4 and Appendix A.8 in the paper for more details.\n",
81 | "\n",
82 | "## 3.1 Finetuning\n",
83 | "\n",
84 | "To finetune the model with exposure control, we need a pretrained model generated by following the instructions in [Section 1.1](##-1.1-Training). We provide a exposure control configuration file for the Caterpillar scene from the Tanks and Temples dataset in `configs/t2/Caterpillar_exposure_control.yml`. You need to specify the path to the pretrained model by setting the `load_path` parameter in the configuration file.\n",
85 | "\n",
86 | "To finetune the model with exposure control, run the following command:\n",
87 | "\n"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "!python exposure_control_finetune.py --opt configs/t2/Caterpillar_exposure_control.yml"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "## 3.2 Testing\n",
104 | "\n",
105 | "To render all the test views with a single random exposure level, run:"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "!python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "## 3.3 Generating images with different exposure levels\n",
122 | "\n",
123 | "You can generate images with different exposure levels by changing the latent code input for a given view. We recommend to increase the `shading_code_scale` to generate images with more diverse exposure levels. You can either change the parameter in the configuration file or specify it in the command line. For example, to generate images with different exposure levels for the Caterpillar scene, run:"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "!python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --random --scale 8.0 --num_samples 20 --view 0"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "## 3.4 Generating images with continuous exposure changes"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "After you generate the images with different exposure levels, you can interpolate the latent codes to generate images with intermediate exposure levels. For example, to generate images with intermediate exposure levels for the Caterpillar scene, run:"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "!python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --intrp --scale 8 --num_samples 20 --view 0 --start_index 8 --end_index 11 --num_intrp 10"
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "metadata": {},
161 | "source": [
162 | "# 4. Use PAPR for your own dataset\n",
163 | "\n",
164 | "## 4.1 Prepare the dataset\n",
165 | "\n",
166 | "You can refer to this [issue](https://github.com/zvict/papr/issues/3#issuecomment-1907260683) for the instructions on how to prepare the dataset.\n",
167 | "\n",
168 | "## 4.2 Configuration\n",
169 | "\n",
170 | "You need to create a new configuration file for your own dataset, and put it under `configs`. The parameter `dataset.type` in the configuration file specifies the type of the dataset. If your dataset is in the same format as the NeRF Synthetic dataset, you can directly set `dataset.type` to `\"synthetic\"`. Otherwise, you need to implement your own python script to load the dataset under the `dataset` folder, and add it in the function `load_meta_data` in `dataset/utils.py`.\n",
171 | "\n",
172 | "Most default parameters in `configs/default.yml` are general and can be used for your own dataset. You can specify the parameters that are specific to your dataset in the configuration file you created, similar to the configuration files for the NeRF Synthetic dataset and the Tanks and Temples dataset."
173 | ]
174 | }
175 | ],
176 | "metadata": {
177 | "kernelspec": {
178 | "display_name": "base",
179 | "language": "python",
180 | "name": "python3"
181 | },
182 | "language_info": {
183 | "codemirror_mode": {
184 | "name": "ipython",
185 | "version": 3
186 | },
187 | "file_extension": ".py",
188 | "mimetype": "text/x-python",
189 | "name": "python",
190 | "nbconvert_exporter": "python",
191 | "pygments_lexer": "ipython3",
192 | "version": "3.12.2"
193 | }
194 | },
195 | "nbformat": 4,
196 | "nbformat_minor": 2
197 | }
198 |
--------------------------------------------------------------------------------
/configs/default.yml:
--------------------------------------------------------------------------------
1 | index: "lego" # this is the name of the experiment
2 | load_path: "" # path to load the model from
3 | save_dir: "./experiments" # folder to save the model to
4 | seed: 1 # random seed
5 | eps: 1.0e-6 # epsilon for numerical stability
6 | use_amp: true # use automatic mixed precision
7 | amp_dtype: "float16"
8 | scaler_min_scale: -1.0 # set a lower bound for the scaler for numerical stability
9 | max_num_pts: 30000 # upper bound of total number of points, set to -1 to disable
10 | dataset:
11 | mode: "train" # train or test set
12 | coord_scale: 10.0 # scale the global coordinate by this factor, larger scale helps with geometry details in the point cloud, but not always larger the better
13 | type: "synthetic" # synthetic (nerf synthetic) or t2 (tanks and temples)
14 | white_bg: true # use white or black background
15 | path: "./data/nerf_synthetic/lego/"
16 | factor: 1 # downsample the target images by this factor
17 | batch_size: 1
18 | shuffle: true
19 | extract_patch: true # extract patches from the target images
20 | extract_online: true # set to false if you have enough memory to load all the patches into memory before training
21 | read_offline: false # set to true if you have enough memory to load all the images into memory before training
22 | patches:
23 | height: 160 # patch height
24 | width: 160 # patch width
25 | max_patches: 10 # maximum number of patches to extract from each image (if extract_online is false)
26 | geoms:
27 | points:
28 | select_k: 20 # number of top-k nearby points to select from the point cloud for each ray
29 | select_k_type: "d2r" # select k points based on their distances to the rays (d2r)
30 | select_k_sorted: false # sort the indices of selected points by select_k_type
31 | load_path: "" # path to load the point cloud from
32 | init_type: "cube" # initialize the point cloud in a cube or on a sphere ("sphere")
33 | init_scale: [1.2, 1.2, 1.2] # initial scale of the point cloud, normalized by the coord_scale
34 | init_center: [0.0, 0.0, 0.0] # initial center of the point cloud, normalized by the coord_scale
35 | init_num: 3000 # initial number of points in the point cloud
36 | influ_init_val: 0.0 # initial influence score of each point
37 | add_type: "random" # add points by interpolate add_k points randomly (or "mean", ...)
38 | add_k: 3 # number of points to interpolate when adding new points
39 | add_sample_type: "top-knn-std" # determine where to add new points by sparsity ("top-knn-std") or ...
40 | add_sample_k: 10 # number of points to consider when measuring sparsity or ...
41 | background:
42 | learnable: false # learn the background color or not
43 | init_color: [1.0, 1.0, 1.0] # initial background color, normalized by 255
44 | constant: 5.0 # constant background score
45 | point_feats:
46 | dim: 64 # dimension of the point features
47 | use_inv: true # use as a value feature in the attention layer
48 | use_ink: false # use as a key feature in the attention layer
49 | use_inq: false # use as a query feature in the attention layer
50 | exposure_control:
51 | use: false
52 | shading_code_dim: 128 # dimension of the shading codes
53 | shading_code_scale: 1.0 # scale the shading codes by this factor, use 8.0 for random generation
54 | shading_code_num_samples: 20 # number of samples to generate the shading codes
55 | shading_code_resample_iter: 10000 # number of iterations to resample the shading codes
56 | shading_code_resample_size: 200 # number of samples to resample the shading codes
57 | shading_code_resample_select_by: "psnr" # select the best samples by "psnr" or "lpips"
58 | mapping_mlp:
59 | num_layers: 8
60 | dim: 256
61 | act: "relu"
62 | last_act: "relu+1"
63 | use_wn: false
64 | out_dim: 64
65 | models:
66 | use_renderer: true # use the UNet or not, if not, predicted rgb is a fused output of value embedding MLP
67 | last_act: "none" # last activation function to normalize the predicted rgb
68 | normalize_topk_attn: true # normalize the top-k attention weights after softmax
69 | attn:
70 | k_type: 1
71 | q_type: 1
72 | v_type: 1
73 | d_model: 256
74 | score_act: "relu"
75 | embed:
76 | embed_type: 1 # 1: positional encoding with input itself, 2: positional encoding without input itself
77 | k_L: [6, 6, 6] # order of the positional encoding for each feature in k
78 | q_L: [6] # order of the positional encoding for each feature in q
79 | v_L: [6, 6] # order of the positional encoding for each feature in v
80 | pe_factor: 2.0
81 | pe_mult_factor: 1.0
82 | key:
83 | d_ff: 256 # dimension of the hidden layer in the embedding MLP
84 | d_ff_out: 256 # dimension of the output of the embedding MLP
85 | n_ff_layer: 5 # number of layers in the embedding MLP
86 | ff_act: "relu"
87 | ff_act_a: 1.0
88 | ff_act_b: 1.0
89 | ff_act_trainable: false
90 | ff_last_act: "none" # last activation function in the embedding MLP
91 | norm: "layernorm"
92 | dropout_ff: 0.0
93 | use_wn: false
94 | residual_ff: false
95 | skip_layers: []
96 | half_layers: []
97 | residual_layers: []
98 | residual_dims: []
99 | query:
100 | d_ff: 256
101 | d_ff_out: 256
102 | n_ff_layer: 5
103 | ff_act: "relu"
104 | ff_act_a: 1.0
105 | ff_act_b: 1.0
106 | ff_act_trainable: false
107 | ff_last_act: "none"
108 | norm: "layernorm"
109 | dropout_ff: 0.0
110 | use_wn: false
111 | residual_ff: false
112 | skip_layers: []
113 | half_layers: []
114 | residual_layers: []
115 | residual_dims: []
116 | value:
117 | d_ff: 256
118 | d_ff_out: 32
119 | n_ff_layer: 8
120 | ff_act: "relu"
121 | ff_act_a: 1.0
122 | ff_act_b: 1.0
123 | ff_act_trainable: false
124 | ff_last_act: "none"
125 | norm: "none"
126 | dropout_ff: 0.0
127 | use_wn: false
128 | residual_ff: false
129 | skip_layers: []
130 | half_layers: []
131 | residual_layers: []
132 | residual_dims: []
133 | renderer: # the UNet
134 | generator:
135 | type: "small-unet"
136 | small_unet:
137 | bilinear: false
138 | norm: "none"
139 | single: true
140 | last_act: "none"
141 | affine_layer: -1
142 | training:
143 | steps: 250000 # number of training steps
144 | prune_steps: 500 # interval to prune the point cloud
145 | prune_start: 10000 # start pruning after this step
146 | prune_stop: 150000 # stop pruning after this step
147 | prune_thresh: 0.0 # influence score threshold to prune the point cloud
148 | prune_thresh_list: [] # change the threshold at steps in prune_steps_list, if empty, then use prune_thresh
149 | prune_steps_list: []
150 | prune_type: "<" # prune points with influence scores smaller than the threshold ("<") or larger than the threshold (">")
151 | add_steps: 1000 # interval to add new points
152 | add_start: 20000 # start adding new points after this step
153 | add_stop: 70000 # stop adding new points after this step
154 | add_num: 1000 # number of points to add each time
155 | add_num_list: []
156 | add_steps_list: []
157 | exclude_keys: [] # not loading these parameters when loading the model
158 | fix_keys: [ # fix these parameters when training the model
159 | # "points",
160 | # "pc_norms",
161 | # # "norm_mlp",
162 | # "attn",
163 | # "pc_feats",
164 | # "attn_mlp",
165 | # # "renderer",
166 | # # "bkg_feats",
167 | # "bkg_points",
168 | # "points_influ_scores"
169 | ]
170 | losses: # loss weights
171 | mse: 1.0
172 | lpips: 1.0e-2
173 | lpips_alex: 0.0
174 | lr:
175 | lr_factor: 1.0 # learning rate factor, multiply all the learning rates by this factor
176 | mapping_mlp:
177 | type: "none"
178 | base_lr: 1.0e-6
179 | factor: 1
180 | warmup: 0
181 | weight_decay: 0
182 | attn:
183 | type: "cosine-hlfperiod"
184 | base_lr: 3.0e-4
185 | factor: 1
186 | warmup: 10000
187 | weight_decay: 0
188 | points:
189 | type: "cosine"
190 | base_lr: 2.0e-3
191 | factor: 1
192 | warmup: 0
193 | weight_decay: 0
194 | bkg_feats:
195 | type: "none"
196 | base_lr: 0.0
197 | factor: 1
198 | warmup: 10000
199 | weight_decay: 0
200 | points_influ_scores:
201 | type: "cosine-hlfperiod"
202 | base_lr: 1.0e-3
203 | factor: 1
204 | warmup: 10000
205 | weight_decay: 0
206 | feats:
207 | type: "cosine-hlfperiod"
208 | base_lr: 1.0e-3
209 | factor: 1
210 | warmup: 10000
211 | weight_decay: 0
212 | generator:
213 | type: "cosine-hlfperiod"
214 | base_lr: 1.0e-4
215 | factor: 1
216 | warmup: 10000
217 | weight_decay: 0
218 | eval:
219 | dataset:
220 | name: "testset"
221 | mode: "test"
222 | extract_patch: false
223 | type: "synthetic"
224 | white_bg: true
225 | path: "./data/nerf_synthetic/lego/"
226 | factor: 1
227 | num_workers: 0
228 | num_slices: -1
229 | step: 5000 # evaluate the model every this number of steps
230 | img_idx: 50 # index of the image to evaluate
231 | max_height: 100 # maximum size for each loop when rendering a full image, if the image is too large, it will be rendered in multiple loops due to memory constraints
232 | max_width: 100
233 | save_fig: true # save the log images during training
234 | test:
235 | load_path: "" # path to load the model from for testing
236 | save_fig: true # save the log images during testing
237 | save_video: false # save the video during testing
238 | max_height: 100
239 | max_width: 100
240 | datasets:
241 | - name: "testset"
242 | mode: "test"
243 | extract_patch: false
244 | type: "synthetic"
245 | white_bg: true
246 | path: "./data/nerf_synthetic/lego/"
247 | factor: 1
248 | num_workers: 0
249 | num_slices: -1
250 | plots: # videos that visualize different components of the model
251 | pcrgb: true
252 | featattn: false
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | # credit: https://github.com/princeton-vl/SNP
2 | """ Parts of the U-Net model """
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import autocast
8 | from .utils import activation_func
9 |
10 |
11 | class SingleConv(nn.Module):
12 | """(convolution => [BN] => ReLU) * 2"""
13 |
14 | def __init__(self, in_channels, out_channels, mid_channels=None, norm='none'):
15 | super().__init__()
16 | if not mid_channels:
17 | mid_channels = out_channels
18 | if norm == 'instance':
19 | self.double_conv = nn.Sequential(
20 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
21 | nn.InstanceNorm2d(mid_channels),
22 | nn.ReLU(inplace=True),
23 | )
24 | elif norm == 'batch':
25 | self.double_conv = nn.Sequential(
26 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
27 | nn.BatchNorm2d(mid_channels),
28 | nn.ReLU(inplace=True),
29 | )
30 | elif norm == 'none':
31 | self.double_conv = nn.Sequential(
32 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
33 | nn.ReLU(inplace=True),
34 | )
35 | else:
36 | raise NotImplementedError
37 |
38 | def forward(self, x):
39 | return self.double_conv(x)
40 |
41 |
42 | class DoubleConv(nn.Module):
43 | """(convolution => [BN] => ReLU) * 2"""
44 |
45 | def __init__(self, in_channels, out_channels, mid_channels=None, norm='none'):
46 | super().__init__()
47 | if not mid_channels:
48 | mid_channels = out_channels
49 | if norm == 'instance':
50 | self.double_conv = nn.Sequential(
51 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
52 | nn.InstanceNorm2d(mid_channels),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(mid_channels, out_channels,
55 | kernel_size=3, padding=1),
56 | nn.InstanceNorm2d(out_channels),
57 | nn.ReLU(inplace=True)
58 | )
59 | elif norm == 'batch':
60 | self.double_conv = nn.Sequential(
61 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
62 | nn.BatchNorm2d(mid_channels),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(mid_channels, out_channels,
65 | kernel_size=3, padding=1),
66 | nn.BatchNorm2d(out_channels),
67 | nn.ReLU(inplace=True)
68 | )
69 | elif norm == 'none':
70 | self.double_conv = nn.Sequential(
71 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
72 | nn.ReLU(inplace=True),
73 | nn.Conv2d(mid_channels, out_channels,
74 | kernel_size=3, padding=1),
75 | nn.ReLU(inplace=True)
76 | )
77 |
78 | def forward(self, x):
79 | return self.double_conv(x)
80 |
81 |
82 | class Down(nn.Module):
83 | """Downscaling with maxpool then double conv"""
84 |
85 | def __init__(self, in_channels, out_channels, single=False, norm='none'):
86 | super().__init__()
87 | # print('down norm:', norm)
88 |
89 | if single:
90 | self.maxpool_conv = nn.Sequential(
91 | nn.MaxPool2d(2),
92 | SingleConv(in_channels, out_channels, norm=norm)
93 | )
94 | else:
95 | self.maxpool_conv = nn.Sequential(
96 | nn.MaxPool2d(2),
97 | DoubleConv(in_channels, out_channels, norm=norm)
98 | )
99 |
100 | def forward(self, x):
101 | return self.maxpool_conv(x)
102 |
103 |
104 | class Up(nn.Module):
105 | """Upscaling then double conv"""
106 |
107 | def __init__(self, in_channels, out_channels, bilinear=True, single=False, norm='none'):
108 | super().__init__()
109 |
110 | # if bilinear, use the normal convolutions to reduce the number of channels
111 | if bilinear:
112 | self.up = nn.Upsample(
113 | scale_factor=2, mode='bilinear', align_corners=True)
114 | if single:
115 | self.conv = SingleConv(
116 | in_channels, out_channels, in_channels // 2, norm=norm)
117 | else:
118 | self.conv = DoubleConv(
119 | in_channels, out_channels, in_channels // 2, norm=norm)
120 | else:
121 | self.up = nn.ConvTranspose2d(
122 | in_channels, in_channels // 2, kernel_size=2, stride=2)
123 | if single:
124 | self.conv = SingleConv(in_channels, out_channels, norm=norm)
125 | else:
126 | self.conv = DoubleConv(in_channels, out_channels, norm=norm)
127 |
128 | def forward(self, x1, x2):
129 | x1 = self.up(x1)
130 | # input is CHW
131 | diffY = x2.size()[2] - x1.size()[2]
132 | diffX = x2.size()[3] - x1.size()[3]
133 |
134 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
135 | diffY // 2, diffY - diffY // 2])
136 | # if you have padding issues, see
137 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
138 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
139 | x = torch.cat([x2, x1], dim=1)
140 | return self.conv(x)
141 |
142 |
143 | class UpSample(nn.Module):
144 | """Upscaling then double conv"""
145 |
146 | def __init__(self, in_channels, out_channels, bilinear=True, single=False, norm='none'):
147 | super().__init__()
148 | # print('up norm:', norm)
149 |
150 | # if bilinear, use the normal convolutions to reduce the number of channels
151 | if bilinear:
152 | self.up = nn.Upsample(
153 | scale_factor=2, mode='bilinear', align_corners=True)
154 | if single:
155 | self.conv = SingleConv(
156 | in_channels, out_channels, in_channels // 2, norm=norm)
157 | else:
158 | self.conv = DoubleConv(
159 | in_channels, out_channels, in_channels // 2, norm=norm)
160 | else:
161 | self.up = nn.ConvTranspose2d(
162 | in_channels, in_channels, kernel_size=2, stride=2)
163 | if single:
164 | self.conv = SingleConv(in_channels, out_channels, norm=norm)
165 | else:
166 | self.conv = DoubleConv(in_channels, out_channels, norm=norm)
167 |
168 | def forward(self, x1):
169 | x1 = self.up(x1)
170 | return self.conv(x1)
171 |
172 |
173 | class OutConv(nn.Module):
174 | def __init__(self, in_channels, out_channels):
175 | super(OutConv, self).__init__()
176 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
177 |
178 | def forward(self, x):
179 | return self.conv(x)
180 |
181 |
182 | class SmallUNet(nn.Module):
183 | def __init__(self, n_channels, n_classes, bilinear=False, single=True, norm='none', render_scale=1, last_act='none',
184 | use_amp=False, amp_dtype=torch.float16, affine_layer=-1):
185 | super(SmallUNet, self).__init__()
186 | self.n_channels = n_channels
187 | self.n_classes = n_classes
188 | self.bilinear = bilinear
189 | self.use_amp = use_amp
190 | self.amp_dtype = amp_dtype
191 | self.affine_layer = affine_layer
192 |
193 | assert (render_scale == 1 or render_scale == 2)
194 | self.render_scale = render_scale
195 |
196 | self.inc = SingleConv(n_channels, 128, norm=norm)
197 | self.down1 = Down(128, 256, single=single, norm=norm)
198 | self.down2 = Down(256, 512, single=single, norm=norm)
199 | self.up1 = Up(512, 256, bilinear, single=single, norm=norm)
200 | self.up2 = Up(256, 128, bilinear, single=single, norm=norm)
201 |
202 | if render_scale == 2:
203 | self.up3 = UpSample(128, 128, bilinear, single=False, norm=norm)
204 |
205 | self.outc = OutConv(128, n_classes)
206 | self.last_act = activation_func(last_act)
207 |
208 | def forward(self, x, log=False, gamma=None, beta=None):
209 | if self.affine_layer >= 0:
210 | assert gamma is not None and beta is not None
211 |
212 | with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp):
213 | if self.affine_layer == 0:
214 | B, C, H, W = x.shape
215 | assert gamma.shape == (C,) and beta.shape == (C,)
216 | # print(gamma.mean(), beta.mean())
217 | x = x * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1)
218 |
219 | x1 = self.inc(x)
220 | if self.affine_layer == 1:
221 | B, C, H, W = x1.shape
222 | assert gamma.shape == (C,) and beta.shape == (C,)
223 | x1 = x1 * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1)
224 |
225 | x2 = self.down1(x1)
226 | if self.affine_layer == 2:
227 | B, C, H, W = x2.shape
228 | assert gamma.shape == (C,) and beta.shape == (C,)
229 | x2 = x2 * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1)
230 |
231 | x3 = self.down2(x2)
232 | if self.affine_layer == 3:
233 | B, C, H, W = x3.shape
234 | assert gamma.shape == (C,) and beta.shape == (C,)
235 | x3 = x3 * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1)
236 |
237 | x = self.up1(x3, x2)
238 | if self.affine_layer == 4:
239 | B, C, H, W = x.shape
240 | assert gamma.shape == (C,) and beta.shape == (C,)
241 | x = x * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1)
242 |
243 | x = self.up2(x, x1)
244 | if self.affine_layer == 5:
245 | B, C, H, W = x.shape
246 | assert gamma.shape == (C,) and beta.shape == (C,)
247 | x = x * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1)
248 |
249 | if self.render_scale == 2:
250 | x = self.up3(x)
251 | logits = self.outc(x)
252 |
253 | logits = self.last_act(logits)
254 |
255 | if log:
256 | print(x1.dtype, x2.dtype, x3.dtype, x.dtype, logits.dtype)
257 |
258 | return logits
259 |
--------------------------------------------------------------------------------
/models/attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import autocast
4 | import math
5 | from .mlp import MLP
6 | from .utils import PoseEnc, activation_func
7 |
8 |
9 | def get_proximity_attention_layer(args, v_extra_dim=0, k_extra_dim=0, q_extra_dim=0, eps=1e-6, use_amp=False, amp_dtype=torch.float16):
10 | k_dim_map = {
11 | 1: [3, 3, 3],
12 | }
13 | k_dim = k_dim_map[args.k_type]
14 |
15 | q_dim_map = {
16 | 1: [3],
17 | }
18 | q_dim = q_dim_map[args.q_type]
19 |
20 | v_dim_map = {
21 | 1: [3, 3],
22 | }
23 | v_dim = v_dim_map[args.v_type]
24 |
25 | return ProximityAttention(d_k=k_dim, d_q=q_dim, d_v=v_dim, d_model=args.d_model, embed_args=args.embed,
26 | score_act=args.score_act, d_ko=k_extra_dim, d_qo=q_extra_dim,
27 | d_vo=v_extra_dim, eps=eps, use_amp=use_amp, amp_dtype=amp_dtype)
28 |
29 |
30 | class LayerNorm(nn.Module):
31 | "Construct a layernorm module"
32 |
33 | def __init__(self, features, eps=1e-6):
34 | super(LayerNorm, self).__init__()
35 | self.a_2 = nn.Parameter(torch.ones(features))
36 | self.b_2 = nn.Parameter(torch.zeros(features))
37 | self.eps = eps
38 |
39 | def forward(self, x):
40 | mean = x.mean(-1, keepdim=True)
41 | std = x.std(-1, keepdim=True)
42 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
43 |
44 |
45 | def attention(query, key, kernel_type):
46 | """
47 | Compute Attention Scores
48 | query: [batch_size, 1, query_len, d_kq] or [batch_size, query_len, d_kq]
49 | key: [batch_size, 1, seq_len, d_kq] or [batch_size, seq_len, d_kq]
50 | """
51 | d_kq = query.size(-1)
52 |
53 | if kernel_type == "scaled-dot":
54 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_kq)
55 | elif kernel_type == "-scaled-dot":
56 | scores = -torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_kq)
57 | elif kernel_type == "dot":
58 | scores = torch.matmul(query, key.transpose(-2, -1))
59 | elif kernel_type == "-dot":
60 | scores = -torch.matmul(query, key.transpose(-2, -1))
61 | elif kernel_type == "l1-dist":
62 | scores = torch.norm(query.unsqueeze(-2) -
63 | key.unsqueeze(-3), p=1, dim=-1)
64 | elif kernel_type == "-l1-dist":
65 | scores = -torch.norm(query.unsqueeze(-2) -
66 | key.unsqueeze(-3), p=1, dim=-1)
67 | elif kernel_type == "l2-dist":
68 | scores = torch.norm(query.unsqueeze(-2) -
69 | key.unsqueeze(-3), p=2, dim=-1)
70 | elif kernel_type == "-l2-dist":
71 | scores = -torch.norm(query.unsqueeze(-2) -
72 | key.unsqueeze(-3), p=2, dim=-1)
73 | elif kernel_type == "scaled-l2-dist":
74 | scores = torch.norm(query.unsqueeze(-2) -
75 | key.unsqueeze(-3), p=2, dim=-1) / math.sqrt(d_kq)
76 | elif kernel_type == "-scaled-l2-dist":
77 | scores = -torch.norm(query.unsqueeze(-2) -
78 | key.unsqueeze(-3), p=2, dim=-1) / math.sqrt(d_kq)
79 | elif kernel_type == "cosine":
80 | scores = torch.matmul(query, key.transpose(-2, -1)) / (
81 | torch.norm(query, dim=-1, keepdim=True)
82 | * torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1)
83 | )
84 | else:
85 | raise ValueError("Unknown kernel type: {}".format(kernel_type))
86 |
87 | return scores
88 |
89 |
90 | class FeedForward(nn.Module):
91 | "Implements FFN module."
92 |
93 | def __init__(self, d_input, d_output, d_ff, n_layer=2, act="relu", last_act="none", dropout=0.1, norm="layernorm",
94 | residual=False, act_a=1.0, act_b=1.0, act_trainable=False, use_wn=False, eps=1e-6, skip_layers=[],
95 | half_layers=[]):
96 | super(FeedForward, self).__init__()
97 | self.eps = eps
98 | self.d_input = d_input
99 | self.d_output = d_output
100 | if norm == "layernorm":
101 | self.innorm = LayerNorm(d_input, eps)
102 | self.outnorm = LayerNorm(d_output, eps)
103 | elif norm == "none":
104 | self.innorm = nn.Identity()
105 | self.outnorm = nn.Identity()
106 | else:
107 | raise ValueError("Invalid attention norm type")
108 | self.dropout = nn.Dropout(dropout)
109 | self.mlp = MLP(d_input, n_layer, d_ff, d_output, act_type=act, last_act_type=last_act, use_wn=use_wn,
110 | a=act_a, b=act_b, trainable=act_trainable, skip_layers=skip_layers, half_layers=half_layers)
111 | self.residual = residual
112 |
113 | def forward(self, x):
114 | if self.residual and x.shape[-1] == self.d_output:
115 | return self.outnorm(x + self.dropout(self.mlp(self.innorm(x))))
116 | else:
117 | return self.outnorm(self.dropout(self.mlp(self.innorm(x))))
118 |
119 |
120 | class Embeddings(nn.Module):
121 | def __init__(self, d_k, d_q, d_v, d_model, args, d_ko=0, d_qo=0, d_vo=0, eps=1e-6):
122 | super(Embeddings, self).__init__()
123 | self.d_k = d_k
124 | self.d_q = d_q
125 | self.d_v = d_v
126 | self.args = args
127 | self.embed_type = args.embed_type
128 | self.d_model = d_model
129 | self.d_ko = d_ko
130 | self.d_qo = d_qo
131 | self.d_vo = d_vo
132 | self.eps = eps
133 |
134 | self.posenc = PoseEnc(args.pe_factor, args.pe_mult_factor)
135 |
136 | if self.embed_type == 1:
137 | # Positional Encoding with itself
138 | d_k = sum([d + d * 2 * args.k_L[i] for i, d in enumerate(d_k)]) + d_ko
139 | d_q = sum([d + d * 2 * args.q_L[i] for i, d in enumerate(d_q)]) + d_qo
140 | d_v = sum([d + d * 2 * args.v_L[i] for i, d in enumerate(d_v)]) + d_vo
141 |
142 | elif self.embed_type == 2:
143 | # Positional Encoding without itself
144 | d_k = sum([d * 2 * args.k_L[i] for i, d in enumerate(d_k)]) + d_ko
145 | d_q = sum([d * 2 * args.q_L[i] for i, d in enumerate(d_q)]) + d_qo
146 | d_v = sum([d * 2 * args.v_L[i] for i, d in enumerate(d_v)]) + d_vo
147 |
148 | else:
149 | raise ValueError(
150 | 'Unknown embedding type: {}'.format(self.embed_type))
151 |
152 | self.embed_k = FeedForward(d_k, args.key.d_ff_out, args.key.d_ff, args.key.n_ff_layer, args.key.ff_act,
153 | args.key.ff_last_act, args.key.dropout_ff, args.key.norm, args.key.residual_ff,
154 | args.key.ff_act_a, args.key.ff_act_b, args.key.ff_act_trainable, args.key.use_wn, eps,
155 | args.key.skip_layers, args.key.half_layers)
156 | self.embed_q = FeedForward(d_q, args.query.d_ff_out, args.query.d_ff, args.query.n_ff_layer, args.query.ff_act,
157 | args.query.ff_last_act, args.query.dropout_ff, args.query.norm, args.query.residual_ff,
158 | args.query.ff_act_a, args.query.ff_act_b, args.query.ff_act_trainable, args.query.use_wn, eps,
159 | args.query.skip_layers, args.query.half_layers)
160 | self.embed_v = FeedForward(d_v, args.value.d_ff_out, args.value.d_ff, args.value.n_ff_layer, args.value.ff_act,
161 | args.value.ff_last_act, args.value.dropout_ff, args.value.norm, args.value.residual_ff,
162 | args.value.ff_act_a, args.value.ff_act_b, args.value.ff_act_trainable, args.value.use_wn, eps,
163 | args.value.skip_layers, args.value.half_layers)
164 |
165 | def forward(self, k_features, q_features, v_features, k_other=None, q_other=None, v_other=None):
166 | """
167 | k_features: [(B, H, W, N, Dk_i)]
168 | q_features: [(B, H, W, 1, Dq_i)]
169 | v_features: [(B, H, W, N, Dv_i)]
170 | """
171 |
172 | if self.embed_type == 1:
173 | pe_k_features = [self.posenc(f, self.args.k_L[i]) for i, f in enumerate(k_features)]
174 | pe_q_features = [self.posenc(f, self.args.q_L[i]) for i, f in enumerate(q_features)]
175 | pe_v_features = [self.posenc(f, self.args.v_L[i]) for i, f in enumerate(v_features)]
176 |
177 | elif self.embed_type == 2:
178 | pe_k_features = [self.posenc(f, self.args.k_L[i], without_self=True) for i, f in enumerate(k_features)]
179 | pe_q_features = [self.posenc(f, self.args.q_L[i], without_self=True) for i, f in enumerate(q_features)]
180 | pe_v_features = [self.posenc(f, self.args.v_L[i], without_self=True) for i, f in enumerate(v_features)]
181 |
182 | else:
183 | raise ValueError('Unknown embedding type: {}'.format(self.embed_type))
184 |
185 | if self.d_ko > 0: pe_k_features = pe_k_features + k_other
186 | if self.d_qo > 0: pe_q_features = pe_q_features + q_other
187 | if self.d_vo > 0: pe_v_features = pe_v_features + v_other
188 |
189 | k = torch.cat(pe_k_features, dim=-1).flatten(0, 2)
190 | q = torch.cat(pe_q_features, dim=-1).flatten(0, 2)
191 | v = torch.cat(pe_v_features, dim=-1).flatten(0, 2)
192 |
193 | k = self.embed_k(k)
194 | q = self.embed_q(q)
195 | v = self.embed_v(v)
196 |
197 | return k, q, v
198 |
199 |
200 | class AttentionLayer(nn.Module):
201 | def __init__(self, embed_args, d_model, score_act_type):
202 | super(AttentionLayer, self).__init__()
203 | self.d_model = d_model
204 | self.w_k = nn.Linear(embed_args.key.d_ff_out, d_model)
205 | self.w_q = nn.Linear(embed_args.query.d_ff_out, d_model)
206 |
207 | nn.init.xavier_uniform_(self.w_k.weight)
208 | nn.init.xavier_uniform_(self.w_q.weight)
209 |
210 | self.score_act = activation_func(score_act_type)
211 |
212 | def forward(self, key, query, value):
213 | nbatches, nseqv, _ = value.shape
214 | _, nseqk, _ = key.shape
215 | assert nseqv == nseqk
216 |
217 | key = self.w_k(key)
218 | query = self.w_q(query)
219 |
220 | key = key.view(nbatches, -1, 1, self.d_model).transpose(1, 2)
221 | query = query.view(nbatches, -1, 1, self.d_model).transpose(1, 2)
222 |
223 | # [nbatches, nhead, nseq, nseq]
224 | scores = attention(query, key, "scaled-dot")
225 | scores = self.score_act(scores)
226 | return scores
227 |
228 |
229 | class ProximityAttention(nn.Module):
230 | def __init__(self, d_k, d_q, d_v, d_model, embed_args, score_act="none",
231 | d_ko=0, d_qo=0, d_vo=0, eps=1e-6, use_amp=False, amp_dtype=torch.float16):
232 | super(ProximityAttention, self).__init__()
233 | self.eps = eps
234 | self.use_amp = use_amp
235 | self.amp_dtype = amp_dtype
236 |
237 | self.embed = Embeddings(d_k, d_q, d_v, d_model, embed_args, d_ko, d_qo, d_vo, eps)
238 |
239 | self.attention_layer = AttentionLayer(embed_args, d_model, score_act)
240 |
241 | def forward(self, k_features, q_features, v_features, k_other=None, q_other=None, v_other=None, step=-1):
242 | """
243 | k_features: [(H, W, N, Dk_i)]
244 | q_features: [(H, W, 1, Dq_i)] or [(H, W, N, Dq_i)]
245 | v_features: [(H, W, N, Dv_i)]
246 | """
247 |
248 | with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp):
249 | k, q, v = self.embed(k_features, q_features, v_features, k_other, q_other, v_other)
250 | scores = self.attention_layer(k, q, v)
251 |
252 | return k, q, v, scores
253 |
--------------------------------------------------------------------------------
/exposure_control_finetune.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import os
6 | import shutil
7 | import zipfile
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import random
11 | import copy
12 | import bisect
13 | import time
14 | import sys
15 | import io
16 | import imageio
17 | from tqdm import tqdm
18 | from PIL import Image
19 | from utils import *
20 | from dataset import get_dataset, get_loader
21 | from models import get_model, get_loss
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description="PAPR")
26 | parser.add_argument('--opt', type=str, default="", help='Option file path')
27 | parser.add_argument('--resume', type=int, default=0, help='Resume training')
28 | return parser.parse_args()
29 |
30 |
31 | def eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, train_out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs):
32 | step = steps[-1]
33 | train_img_idx, _, train_patch, _, _ = batch
34 | train_img, train_rayd, train_rayo = dataset.get_full_img(train_img_idx[0])
35 | img, rayd, rayo = eval_dataset.get_full_img(args.eval.img_idx)
36 | c2w = dataset.get_c2w(args.eval.img_idx)
37 |
38 | print("Before resample shading codes, eval shading codes mean: ", model.eval_shading_codes[args.eval.img_idx].mean())
39 | resample_shading_codes(model.eval_shading_codes, args, model, eval_dataset, args.eval.img_idx, loss_fn, step, full_img=True)
40 | print("After resample shading codes, eval shading codes mean: ", model.eval_shading_codes[args.eval.img_idx].mean())
41 |
42 | N, H, W, _ = rayd.shape
43 | num_pts, _ = model.points.shape
44 |
45 | rayo = rayo.to(device)
46 | rayd = rayd.to(device)
47 | img = img.to(device)
48 | c2w = c2w.to(device)
49 |
50 | topk = min([num_pts, model.select_k])
51 |
52 | selected_points = torch.zeros(1, H, W, topk, 3)
53 |
54 | bkg_seq_len_attn = 0
55 | feat_dim = args.models.attn.embed.value.d_ff_out
56 | if model.bkg_feats is not None:
57 | bkg_seq_len_attn = model.bkg_feats.shape[0]
58 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(device)
59 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(device)
60 |
61 | with torch.no_grad():
62 | cur_shading_code = model.eval_shading_codes[args.eval.img_idx]
63 | cur_affine = model.mapping_mlp(cur_shading_code)
64 | cur_affine_dim = cur_affine.shape[-1]
65 | cur_gamma, cur_beta = cur_affine[:cur_affine_dim // 2], cur_affine[cur_affine_dim // 2:]
66 |
67 | for height_start in range(0, H, args.eval.max_height):
68 | for width_start in range(0, W, args.eval.max_width):
69 | height_end = min(height_start + args.eval.max_height, H)
70 | width_end = min(width_start + args.eval.max_width, W)
71 |
72 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \
73 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=step)
74 |
75 | selected_points[:, height_start:height_end, width_start:width_end, :, :] = model.selected_points
76 |
77 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2), gamma=cur_gamma, beta=cur_beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3)
78 |
79 | if model.bkg_feats is not None:
80 | bkg_attn = attn[..., topk:, :]
81 | if args.models.normalize_topk_attn:
82 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
83 | else:
84 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
85 | rgb = rgb.squeeze(-2)
86 | else:
87 | rgb = foreground_rgb.squeeze(-2)
88 |
89 | rgb = model.last_act(rgb)
90 | rgb = torch.clamp(rgb, 0, 1)
91 |
92 | eval_loss = loss_fn(rgb, img)
93 | eval_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.)
94 |
95 | model.clear_grad()
96 |
97 | eval_losses.append(eval_loss.item())
98 | eval_psnrs.append(eval_psnr.item())
99 |
100 | print("Eval step:", step, "train_loss:", train_losses[-1], "eval_loss:", eval_losses[-1], "eval_psnr:", eval_psnrs[-1])
101 |
102 | log_dir = os.path.join(args.save_dir, args.index)
103 | os.makedirs(log_dir, exist_ok=True)
104 | if args.eval.save_fig:
105 | os.makedirs(os.path.join(log_dir, "train_main_plots"), exist_ok=True)
106 | os.makedirs(os.path.join(log_dir, "train_pcd_plots"), exist_ok=True)
107 |
108 | coord_scale = args.dataset.coord_scale
109 | pt_plot_scale = 1.0 * coord_scale
110 | if "Barn" in args.dataset.path:
111 | pt_plot_scale *= 1.8
112 | if "Family" in args.dataset.path:
113 | pt_plot_scale *= 0.5
114 |
115 | # calculate depth, weighted sum the distances from top K points to image plane
116 | od = -rayo
117 | D = torch.sum(od * rayo)
118 | dists = torch.abs(torch.sum(selected_points.to(od.device) * od, -1) - D) / torch.norm(od)
119 | if model.bkg_feats is not None:
120 | dists = torch.cat([dists, torch.ones(N, H, W, model.bkg_feats.shape[0]).to(dists.device) * 0], dim=-1)
121 | cur_depth = (torch.sum(attn.squeeze(-1).to(od.device) * dists, dim=-1)).detach().cpu()
122 |
123 | train_tgt_rgb = train_img.squeeze().cpu().numpy().astype(np.float32)
124 | train_tgt_patch = train_patch[0].cpu().numpy().astype(np.float32)
125 | train_pred_patch = train_out[0]
126 | test_tgt_rgb = img.squeeze().cpu().numpy().astype(np.float32)
127 | test_pred_rgb = rgb.squeeze().detach().cpu().numpy().astype(np.float32)
128 | points_np = model.points.detach().cpu().numpy()
129 | depth = cur_depth.squeeze().numpy().astype(np.float32)
130 | points_influ_scores_np = None
131 | if model.points_influ_scores is not None:
132 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy()
133 |
134 | # main plot
135 | main_plot = get_training_main_plot(args.index, steps, train_tgt_rgb, train_tgt_patch, train_pred_patch, test_tgt_rgb, test_pred_rgb, train_losses,
136 | eval_losses, points_np, pt_plot_scale, depth, pt_lrs, attn_lrs, eval_psnrs, points_influ_scores_np)
137 | save_name = os.path.join(log_dir, "train_main_plots", "%s_iter_%d.png" % (args.index, step))
138 | main_plot.save(save_name)
139 |
140 | # point cloud plot
141 | ro = train_rayo.squeeze().detach().cpu().numpy()
142 | rd = train_rayd.squeeze().detach().cpu().numpy()
143 |
144 | pcd_plot = get_training_pcd_plot(args.index, steps[-1], ro, rd, points_np, args.dataset.coord_scale, pt_plot_scale, points_influ_scores_np)
145 | save_name = os.path.join(log_dir, "train_pcd_plots", "%s_iter_%d.png" % (args.index, step))
146 | pcd_plot.save(save_name)
147 |
148 | model.save(step, log_dir)
149 | if step % 50000 == 0:
150 | torch.save(model.state_dict(), os.path.join(log_dir, "model_%d.pth" % step))
151 |
152 | torch.save(torch.tensor(train_losses), os.path.join(log_dir, "train_losses.pth"))
153 | torch.save(torch.tensor(eval_losses), os.path.join(log_dir, "eval_losses.pth"))
154 | torch.save(torch.tensor(eval_psnrs), os.path.join(log_dir, "eval_psnrs.pth"))
155 |
156 | return 0
157 |
158 |
159 | def train_step(step, model, device, dataset, batch, loss_fn, args):
160 | img_idx, _, tgt, rayd, rayo = batch
161 | c2w = dataset.get_c2w(img_idx[0])
162 |
163 | rayo = rayo.to(device)
164 | rayd = rayd.to(device)
165 | tgt = tgt.to(device)
166 | c2w = c2w.to(device)
167 |
168 | shading_code = model.train_shading_codes[img_idx[0]]
169 |
170 | model.clear_grad()
171 | out = model(rayo, rayd, c2w, step, shading_code=shading_code)
172 | out = model.last_act(out)
173 | loss = loss_fn(out, tgt)
174 | model.scaler.scale(loss).backward()
175 | model.step(step)
176 | if args.scaler_min_scale > 0 and model.scaler.get_scale() < args.scaler_min_scale:
177 | model.scaler.update(args.scaler_min_scale)
178 | else:
179 | model.scaler.update()
180 |
181 | return loss.item(), out.detach().cpu().numpy()
182 |
183 |
184 | def train_and_eval(start_step, model, device, dataset, eval_dataset, sample_dataset, losses, args):
185 | trainloader = get_loader(dataset, args.dataset, mode="train")
186 |
187 | loss_fn = get_loss(args.training.losses)
188 | loss_fn = loss_fn.to(device)
189 |
190 | log_dir = os.path.join(args.save_dir, args.index)
191 | os.makedirs(os.path.join(log_dir, "test"), exist_ok=True)
192 | log_dir = os.path.join(log_dir, "test")
193 |
194 | steps = []
195 | train_losses, eval_losses, eval_psnrs = losses
196 | pt_lrs = []
197 | attn_lrs = []
198 |
199 | avg_train_loss = 0.
200 | step = start_step
201 | eval_step_cnt = start_step
202 | pc_frames = []
203 |
204 | model.train_shading_codes = nn.Parameter(torch.randn(len(dataset), args.exposure_control.shading_code_dim, device=device) * args.exposure_control.shading_code_scale, requires_grad=False)
205 | model.eval_shading_codes = nn.Parameter(torch.randn(len(eval_dataset), args.exposure_control.shading_code_dim, device=device) * args.exposure_control.shading_code_scale, requires_grad=False)
206 | print("!!!!! train_shading_codes:", model.train_shading_codes.shape, model.train_shading_codes.min(), model.train_shading_codes.max())
207 | print("!!!!! eval_shading_codes:", model.eval_shading_codes.shape, model.eval_shading_codes.min(), model.eval_shading_codes.max())
208 |
209 | print("Start step:", start_step, "Total steps:", args.training.steps)
210 | start_time = time.time()
211 | while step < args.training.steps:
212 | for _, batch in enumerate(trainloader):
213 | if step % args.exposure_control.shading_code_resample_iter == 0: # Resample shading codes
214 | print("Resampling shading codes")
215 | print("Before resampling:", model.train_shading_codes.shape, model.train_shading_codes.min(), model.train_shading_codes.max())
216 | for img_idx in tqdm(range(len(sample_dataset))):
217 | resample_shading_codes(model.train_shading_codes, args, model, sample_dataset, img_idx, loss_fn, step)
218 | print("After resampling:", model.train_shading_codes.shape, model.train_shading_codes.min(), model.train_shading_codes.max())
219 |
220 | loss, out = train_step(step, model, device, dataset, batch, loss_fn, args)
221 | avg_train_loss += loss
222 | step += 1
223 | eval_step_cnt += 1
224 |
225 | if step % 200 == 0:
226 | time_used = time.time() - start_time
227 | print("Train step:", step, "loss:", loss, "attn_lr:", model.attn_lr, "pts_lr:", model.pts_lr, "scale:", model.scaler.get_scale(), f"time: {time_used:.2f}s")
228 | print(model.mapping_mlp.model.model[7].weight[0, :5])
229 | start_time = time.time()
230 |
231 | if (step % args.eval.step == 0) or (step % 500 == 0 and step < 10000):
232 | train_losses.append(avg_train_loss / eval_step_cnt)
233 | pt_lrs.append(model.pts_lr)
234 | attn_lrs.append(model.attn_lr)
235 | steps.append(step)
236 | eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs)
237 | avg_train_loss = 0.
238 | eval_step_cnt = 0
239 |
240 | if ((step - 1) % 200 == 0) and args.eval.save_fig:
241 | coord_scale = args.dataset.coord_scale
242 | pt_plot_scale = 0.8 * coord_scale
243 | if "Barn" in args.dataset.path:
244 | pt_plot_scale *= 1.5
245 | if "Family" in args.dataset.path:
246 | pt_plot_scale *= 0.5
247 |
248 | pc_dir = os.path.join(log_dir, "point_clouds")
249 | os.makedirs(pc_dir, exist_ok=True)
250 |
251 | points_np = model.points.detach().cpu().numpy()
252 | points_influ_scores_np = None
253 | if model.points_influ_scores is not None:
254 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy()
255 | pcd_plot = get_training_pcd_single_plot(step, points_np, pt_plot_scale, points_influ_scores_np)
256 | pc_frames.append(pcd_plot)
257 |
258 | if step == 1:
259 | pcd_plot.save(os.path.join(pc_dir, "init_pcd.png"))
260 |
261 | if step >= args.training.steps:
262 | break
263 |
264 | if args.eval.save_fig and pc_frames != []:
265 | f = os.path.join(log_dir, f"{args.index}-pc.mp4")
266 | imageio.mimwrite(f, pc_frames, fps=30, quality=10)
267 |
268 | print("Training finished!")
269 |
270 |
271 | def main(args, eval_args, sample_args, resume):
272 | log_dir = os.path.join(args.save_dir, args.index)
273 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
274 |
275 | model = get_model(args, device)
276 | dataset = get_dataset(args.dataset, mode="train")
277 | sample_dataset = get_dataset(sample_args.dataset, mode="train")
278 | eval_dataset = get_dataset(eval_args.dataset, mode="test")
279 | model = model.to(device)
280 |
281 | # if torch.__version__ >= "2.0":
282 | # model = torch.compile(model)
283 |
284 | start_step = 0
285 | losses = [[], [], []]
286 | if resume > 0:
287 | start_step = model.load(log_dir)
288 |
289 | train_losses = torch.load(os.path.join(log_dir, "train_losses.pth")).tolist()
290 | eval_losses = torch.load(os.path.join(log_dir, "eval_losses.pth")).tolist()
291 | eval_psnrs = torch.load(os.path.join(log_dir, "eval_psnrs.pth")).tolist()
292 | losses = [train_losses, eval_losses, eval_psnrs]
293 |
294 | print("!!!!! Resume from step %s" % start_step)
295 | elif args.load_path:
296 | try:
297 | model_state_dict = torch.load(args.load_path)
298 | for step, state_dict in model_state_dict.items():
299 | resume_step = int(step)
300 | model.load_my_state_dict(state_dict)
301 | except:
302 | model_state_dict = torch.load(os.path.join(args.save_dir, args.load_path, "model.pth"))
303 | for step, state_dict in model_state_dict.items():
304 | resume_step = step
305 | model.load_my_state_dict(state_dict)
306 | print("!!!!! Loaded model from %s at step %s" % (args.load_path, resume_step))
307 |
308 | train_and_eval(start_step, model, device, dataset, eval_dataset, sample_dataset, losses, args)
309 | print(torch.cuda.memory_summary())
310 |
311 |
312 | if __name__ == '__main__':
313 |
314 | with open("configs/default.yml", 'r') as f:
315 | default_config = yaml.safe_load(f)
316 |
317 | args = parse_args()
318 | with open(args.opt, 'r') as f:
319 | config = yaml.safe_load(f)
320 |
321 | train_config = copy.deepcopy(default_config)
322 | update_dict(train_config, config)
323 |
324 | sample_config = copy.deepcopy(train_config)
325 | sample_config['dataset']['patches']['height'] = train_config['exposure_control']['shading_code_resample_size']
326 | sample_config['dataset']['patches']['width'] = train_config['exposure_control']['shading_code_resample_size']
327 | sample_config = DictAsMember(sample_config)
328 |
329 | eval_config = copy.deepcopy(train_config)
330 | eval_config['dataset'].update(eval_config['eval']['dataset'])
331 | eval_config = DictAsMember(eval_config)
332 | train_config = DictAsMember(train_config)
333 |
334 | assert train_config.models.use_renderer, "Currently only support using renderer for exposure control"
335 |
336 | log_dir = os.path.join(train_config.save_dir, train_config.index)
337 | os.makedirs(log_dir, exist_ok=True)
338 |
339 | sys.stdout = Logger(os.path.join(log_dir, 'train.log'), sys.stdout)
340 | sys.stderr = Logger(os.path.join(log_dir, 'train_error.log'), sys.stderr)
341 |
342 | shutil.copyfile(__file__, os.path.join(log_dir, os.path.basename(__file__)))
343 | shutil.copyfile(args.opt, os.path.join(log_dir, os.path.basename(args.opt)))
344 |
345 | find_all_python_files_and_zip(".", os.path.join(log_dir, "code.zip"))
346 |
347 | setup_seed(train_config.seed)
348 |
349 | main(train_config, eval_config, sample_config, args.resume)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import os
5 | import shutil
6 | import zipfile
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import random
10 | import copy
11 | import bisect
12 | import time
13 | import sys
14 | import io
15 | import imageio
16 | from PIL import Image
17 | from utils import *
18 | from dataset import get_dataset, get_loader
19 | from models import get_model, get_loss
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description="PAPR")
24 | parser.add_argument('--opt', type=str, default="", help='Option file path')
25 | parser.add_argument('--resume', type=int, default=0, help='Resume training')
26 | return parser.parse_args()
27 |
28 |
29 | def eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, train_out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs):
30 | step = steps[-1]
31 | train_img_idx, _, train_patch, _, _ = batch
32 | # For visualization, use the first image in the batch
33 | first_idx = train_img_idx[0].item() if isinstance(train_img_idx, torch.Tensor) else train_img_idx[0]
34 | train_img, train_rayd, train_rayo = dataset.get_full_img(first_idx)
35 | img, rayd, rayo = eval_dataset.get_full_img(args.eval.img_idx)
36 | c2w = eval_dataset.get_c2w(args.eval.img_idx)
37 |
38 | N, H, W, _ = rayd.shape
39 | num_pts, _ = model.points.shape
40 |
41 | rayo = rayo.to(device)
42 | rayd = rayd.to(device)
43 | img = img.to(device)
44 | c2w = c2w.to(device)
45 |
46 | topk = min([num_pts, model.select_k])
47 |
48 | selected_points = torch.zeros(1, H, W, topk, 3)
49 |
50 | bkg_seq_len_attn = 0
51 | attn_opt = args.models.attn
52 | feat_dim = attn_opt.embed.value.d_ff_out
53 | if model.bkg_feats is not None:
54 | bkg_seq_len_attn = model.bkg_feats.shape[0]
55 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(device)
56 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(device)
57 |
58 | with torch.no_grad():
59 | for height_start in range(0, H, args.eval.max_height):
60 | for width_start in range(0, W, args.eval.max_width):
61 | height_end = min(height_start + args.eval.max_height, H)
62 | width_end = min(width_start + args.eval.max_width, W)
63 |
64 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \
65 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=step)
66 |
67 | selected_points[:, height_start:height_end, width_start:width_end, :, :] = model.selected_points
68 |
69 | if args.models.use_renderer:
70 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2)).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3)
71 | else:
72 | foreground_rgb = feature_map
73 |
74 | if model.bkg_feats is not None:
75 | bkg_attn = attn[..., topk:, :]
76 | if args.models.normalize_topk_attn:
77 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
78 | else:
79 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
80 | rgb = rgb.squeeze(-2)
81 | else:
82 | rgb = foreground_rgb.squeeze(-2)
83 |
84 | rgb = model.last_act(rgb)
85 | rgb = torch.clamp(rgb, 0, 1)
86 |
87 | eval_loss = loss_fn(rgb, img)
88 | eval_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.)
89 |
90 | model.clear_grad()
91 |
92 | eval_losses.append(eval_loss.item())
93 | eval_psnrs.append(eval_psnr.item())
94 |
95 | print("Eval step:", step, "train_loss:", train_losses[-1], "eval_loss:", eval_losses[-1], "eval_psnr:", eval_psnrs[-1])
96 |
97 | log_dir = os.path.join(args.save_dir, args.index)
98 | os.makedirs(log_dir, exist_ok=True)
99 | if args.eval.save_fig:
100 | os.makedirs(os.path.join(log_dir, "train_main_plots"), exist_ok=True)
101 | os.makedirs(os.path.join(log_dir, "train_pcd_plots"), exist_ok=True)
102 |
103 | coord_scale = args.dataset.coord_scale
104 | pt_plot_scale = 1.0 * coord_scale
105 | if "Barn" in args.dataset.path:
106 | pt_plot_scale *= 1.8
107 | if "Family" in args.dataset.path:
108 | pt_plot_scale *= 0.5
109 |
110 | # calculate depth, weighted sum the distances from top K points to image plane
111 | od = -rayo
112 | D = torch.sum(od * rayo)
113 | dists = torch.abs(torch.sum(selected_points.to(od.device) * od, -1) - D) / torch.norm(od)
114 | if model.bkg_feats is not None:
115 | dists = torch.cat([dists, torch.ones(N, H, W, model.bkg_feats.shape[0]).to(dists.device) * 0], dim=-1)
116 | cur_depth = (torch.sum(attn.squeeze(-1).to(od.device) * dists, dim=-1)).detach().cpu()
117 |
118 | train_tgt_rgb = train_img.squeeze().cpu().numpy().astype(np.float32)
119 | # For visualization, use the first patch/output in the batch
120 | train_tgt_patch = train_patch[0].cpu().numpy().astype(np.float32)
121 | train_pred_patch = train_out[0]
122 | test_tgt_rgb = img.squeeze().cpu().numpy().astype(np.float32)
123 | test_pred_rgb = rgb.squeeze().detach().cpu().numpy().astype(np.float32)
124 | points_np = model.points.detach().cpu().numpy()
125 | depth = cur_depth.squeeze().numpy().astype(np.float32)
126 | points_influ_scores_np = None
127 | if model.points_influ_scores is not None:
128 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy()
129 |
130 | # main plot
131 | main_plot = get_training_main_plot(args.index, steps, train_tgt_rgb, train_tgt_patch, train_pred_patch, test_tgt_rgb, test_pred_rgb, train_losses,
132 | eval_losses, points_np, pt_plot_scale, depth, pt_lrs, attn_lrs, eval_psnrs, points_influ_scores_np)
133 | save_name = os.path.join(log_dir, "train_main_plots", "%s_iter_%d.png" % (args.index, step))
134 | main_plot.save(save_name)
135 |
136 | # point cloud plot
137 | ro = train_rayo.squeeze().detach().cpu().numpy()
138 | rd = train_rayd.squeeze().detach().cpu().numpy()
139 |
140 | pcd_plot = get_training_pcd_plot(args.index, steps[-1], ro, rd, points_np, args.dataset.coord_scale, pt_plot_scale, points_influ_scores_np)
141 | save_name = os.path.join(log_dir, "train_pcd_plots", "%s_iter_%d.png" % (args.index, step))
142 | pcd_plot.save(save_name)
143 |
144 | model.save(step, log_dir)
145 | if step % 50000 == 0:
146 | torch.save(model.state_dict(), os.path.join(log_dir, "model_%d.pth" % step))
147 |
148 | torch.save(torch.tensor(train_losses), os.path.join(log_dir, "train_losses.pth"))
149 | torch.save(torch.tensor(eval_losses), os.path.join(log_dir, "eval_losses.pth"))
150 | torch.save(torch.tensor(eval_psnrs), os.path.join(log_dir, "eval_psnrs.pth"))
151 |
152 | return 0
153 |
154 |
155 | def train_step(step, model, device, dataset, batch, loss_fn, args):
156 | img_idx, _, tgt, rayd, rayo = batch
157 | # Get c2w for all images in the batch
158 | if isinstance(img_idx, torch.Tensor):
159 | c2w = torch.stack([dataset.get_c2w(idx.item()) for idx in img_idx])
160 | else:
161 | c2w = torch.stack([dataset.get_c2w(idx) for idx in img_idx])
162 |
163 | rayo = rayo.to(device)
164 | rayd = rayd.to(device)
165 | tgt = tgt.to(device)
166 | c2w = c2w.to(device)
167 |
168 | model.clear_grad()
169 | out = model(rayo, rayd, c2w, step)
170 | out = model.last_act(out)
171 | loss = loss_fn(out, tgt)
172 | model.scaler.scale(loss).backward()
173 | model.step(step)
174 | if args.scaler_min_scale > 0 and model.scaler.get_scale() < args.scaler_min_scale:
175 | model.scaler.update(args.scaler_min_scale)
176 | else:
177 | model.scaler.update()
178 |
179 | return loss.item(), out.detach().cpu().numpy()
180 |
181 |
182 | def train_and_eval(start_step, model, device, dataset, eval_dataset, losses, args):
183 | trainloader = get_loader(dataset, args.dataset, mode="train")
184 |
185 | loss_fn = get_loss(args.training.losses)
186 | loss_fn = loss_fn.to(device)
187 |
188 | log_dir = os.path.join(args.save_dir, args.index)
189 | os.makedirs(os.path.join(log_dir, "test"), exist_ok=True)
190 | log_dir = os.path.join(log_dir, "test")
191 |
192 | steps = []
193 | train_losses, eval_losses, eval_psnrs = losses
194 | pt_lrs = []
195 | attn_lrs = []
196 |
197 | avg_train_loss = 0.
198 | step = start_step
199 | eval_step_cnt = start_step
200 | pruned = False
201 | pc_frames = []
202 |
203 | print("Start step:", start_step, "Total steps:", args.training.steps)
204 | start_time = time.time()
205 | while step < args.training.steps:
206 | for _, batch in enumerate(trainloader):
207 | if (args.training.prune_steps > 0) and (step < args.training.prune_stop) and (step >= args.training.prune_start):
208 | if len(args.training.prune_steps_list) > 0 and step % args.training.prune_steps == 0:
209 | cur_prune_thresh = args.training.prune_thresh_list[bisect.bisect_left(args.training.prune_steps_list, step)]
210 | model.clear_optimizer()
211 | model.clear_scheduler()
212 | num_pruned = model.prune_points(cur_prune_thresh)
213 | model.init_optimizers(step)
214 | pruned = True
215 | print("Step %d: Pruned %d points, prune threshold %f" % (step, num_pruned, cur_prune_thresh))
216 |
217 | elif step % args.training.prune_steps == 0:
218 | model.clear_optimizer()
219 | model.clear_scheduler()
220 | num_pruned = model.prune_points(args.training.prune_thresh)
221 | model.init_optimizers(step)
222 | pruned = True
223 | print("Step %d: Pruned %d points" % (step, num_pruned))
224 |
225 | if pruned and len(args.training.add_steps_list) > 0:
226 | if step in args.training.add_steps_list:
227 | cur_add_num = args.training.add_num_list[args.training.add_steps_list.index(step)]
228 | if 'max_num_pts' in args and args.max_num_pts > 0:
229 | cur_add_num = min(cur_add_num, args.max_num_pts - model.points.shape[0])
230 |
231 | if cur_add_num > 0:
232 | model.clear_optimizer()
233 | model.clear_scheduler()
234 | num_added = model.add_points(cur_add_num)
235 | model.init_optimizers(step)
236 | model.added_points = True
237 | print("Step %d: Added %d points" % (step, num_added))
238 |
239 | elif pruned and (args.training.add_steps > 0) and (step % args.training.add_steps == 0) and (step < args.training.add_stop) and (step >= args.training.add_start):
240 | cur_add_num = args.training.add_num
241 | if 'max_num_pts' in args and args.max_num_pts > 0:
242 | cur_add_num = min(cur_add_num, args.max_num_pts - model.points.shape[0])
243 |
244 | if cur_add_num > 0:
245 | model.clear_optimizer()
246 | model.clear_scheduler()
247 | num_added = model.add_points(args.training.add_num)
248 | model.init_optimizers(step)
249 | model.added_points = True
250 | print("Step %d: Added %d points" % (step, num_added))
251 |
252 | loss, out = train_step(step, model, device, dataset, batch, loss_fn, args)
253 | avg_train_loss += loss
254 | step += 1
255 | eval_step_cnt += 1
256 |
257 | if step % 200 == 0:
258 | time_used = time.time() - start_time
259 | print("Train step:", step, "loss:", loss, "attn_lr:", model.attn_lr, "pts_lr:", model.pts_lr, "scale:", model.scaler.get_scale(), f"time: {time_used:.2f}s")
260 | start_time = time.time()
261 |
262 | if (step % args.eval.step == 0) or (step % 500 == 0 and step < 10000):
263 | train_losses.append(avg_train_loss / eval_step_cnt)
264 | pt_lrs.append(model.pts_lr)
265 | attn_lrs.append(model.attn_lr)
266 | steps.append(step)
267 | eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs)
268 | avg_train_loss = 0.
269 | eval_step_cnt = 0
270 |
271 | if ((step - 1) % 200 == 0) and args.eval.save_fig:
272 | coord_scale = args.dataset.coord_scale
273 | pt_plot_scale = 0.8 * coord_scale
274 | if "Barn" in args.dataset.path:
275 | pt_plot_scale *= 1.5
276 | if "Family" in args.dataset.path:
277 | pt_plot_scale *= 0.5
278 |
279 | pc_dir = os.path.join(log_dir, "point_clouds")
280 | os.makedirs(pc_dir, exist_ok=True)
281 |
282 | points_np = model.points.detach().cpu().numpy()
283 | points_influ_scores_np = None
284 | if model.points_influ_scores is not None:
285 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy()
286 | pcd_plot = get_training_pcd_single_plot(step, points_np, pt_plot_scale, points_influ_scores_np)
287 | pc_frames.append(pcd_plot)
288 |
289 | if step == 1:
290 | pcd_plot.save(os.path.join(pc_dir, "init_pcd.png"))
291 |
292 | if step >= args.training.steps:
293 | break
294 |
295 | if args.eval.save_fig and pc_frames != []:
296 | f = os.path.join(log_dir, f"{args.index}-pc.mp4")
297 | imageio.mimwrite(f, pc_frames, fps=30, quality=10)
298 |
299 | print("Training finished!")
300 |
301 |
302 | def main(args, eval_args, resume):
303 | log_dir = os.path.join(args.save_dir, args.index)
304 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
305 |
306 | model = get_model(args, device)
307 | dataset = get_dataset(args.dataset, mode="train")
308 | eval_dataset = get_dataset(eval_args.dataset, mode="test")
309 | model = model.to(device)
310 |
311 | # if torch.__version__ >= "2.0":
312 | # model = torch.compile(model)
313 |
314 | start_step = 0
315 | losses = [[], [], []]
316 | if resume > 0:
317 | start_step = model.load(log_dir)
318 |
319 | train_losses = torch.load(os.path.join(log_dir, "train_losses.pth")).tolist()
320 | eval_losses = torch.load(os.path.join(log_dir, "eval_losses.pth")).tolist()
321 | eval_psnrs = torch.load(os.path.join(log_dir, "eval_psnrs.pth")).tolist()
322 | losses = [train_losses, eval_losses, eval_psnrs]
323 |
324 | print("!!!!! Resume from step %s" % start_step)
325 | elif args.load_path:
326 | try:
327 | resume_step = model.load(os.path.join(args.save_dir, args.load_path))
328 | except:
329 | model_state_dict = torch.load(os.path.join(args.save_dir, args.load_path, "model.pth"))
330 | for step, state_dict in model_state_dict.items():
331 | resume_step = step
332 | model.load_my_state_dict(state_dict)
333 | print("!!!!! Loaded model from %s at step %s" % (args.load_path, resume_step))
334 |
335 | train_and_eval(start_step, model, device, dataset, eval_dataset, losses, args)
336 | print(torch.cuda.memory_summary())
337 |
338 |
339 | if __name__ == '__main__':
340 |
341 | with open("configs/default.yml", 'r') as f:
342 | default_config = yaml.safe_load(f)
343 |
344 | args = parse_args()
345 | with open(args.opt, 'r') as f:
346 | config = yaml.safe_load(f)
347 |
348 | train_config = copy.deepcopy(default_config)
349 | update_dict(train_config, config)
350 |
351 | eval_config = copy.deepcopy(train_config)
352 | eval_config['dataset'].update(eval_config['eval']['dataset'])
353 | eval_config = DictAsMember(eval_config)
354 | train_config = DictAsMember(train_config)
355 |
356 | log_dir = os.path.join(train_config.save_dir, train_config.index)
357 | os.makedirs(log_dir, exist_ok=True)
358 |
359 | sys.stdout = Logger(os.path.join(log_dir, 'train.log'), sys.stdout)
360 | sys.stderr = Logger(os.path.join(log_dir, 'train_error.log'), sys.stderr)
361 |
362 | shutil.copyfile(__file__, os.path.join(log_dir, os.path.basename(__file__)))
363 | shutil.copyfile(args.opt, os.path.join(log_dir, os.path.basename(args.opt)))
364 |
365 | find_all_python_files_and_zip(".", os.path.join(log_dir, "code.zip"))
366 |
367 | setup_seed(train_config.seed)
368 |
369 | main(train_config, eval_config, args.resume)
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim.lr_scheduler as lr_scheduler
4 | import scipy
5 | from scipy.spatial import KDTree
6 | import numpy as np
7 |
8 |
9 | def add_points_knn(coords, influ_scores, add_num, k, comb_type="mean", sample_type="random", sample_k=10, point_features=None):
10 | """
11 | Add points to the point cloud by kNN
12 | """
13 | pc = KDTree(coords)
14 | N = coords.shape[0]
15 |
16 | # Step 1: Determine where to add points
17 | if N <= add_num and "random" in comb_type:
18 | inds = np.random.choice(N, add_num, replace=True)
19 | query_coords = coords[inds, :]
20 | elif N <= add_num:
21 | query_coords = coords
22 | inds = list(range(N))
23 | else:
24 | if sample_type == "random":
25 | inds = np.random.choice(N, add_num, replace=False)
26 | query_coords = coords[inds, :]
27 | elif sample_type == "top-knn-std":
28 | assert k >= 2
29 | nns_dists, nns_inds = pc.query(coords, k=sample_k)
30 | inds = np.argsort(nns_dists.std(axis=-1))[-add_num:]
31 | query_coords = coords[inds, :]
32 | elif sample_type == "top-knn-mean":
33 | assert k >= 2
34 | nns_dists, nns_inds = pc.query(coords, k=sample_k)
35 | inds = np.argsort(nns_dists.mean(axis=-1))[-add_num:]
36 | query_coords = coords[inds, :]
37 | elif sample_type == "top-knn-max":
38 | assert k >= 2
39 | nns_dists, nns_inds = pc.query(coords, k=sample_k)
40 | inds = np.argsort(nns_dists.max(axis=-1))[-add_num:]
41 | query_coords = coords[inds, :]
42 | elif sample_type == "top-knn-min":
43 | assert k >= 2
44 | nns_dists, nns_inds = pc.query(coords, k=sample_k)
45 | inds = np.argsort(nns_dists.min(axis=-1))[-add_num:]
46 | query_coords = coords[inds, :]
47 | elif sample_type == "influ-scores-max":
48 | inds = np.argsort(influ_scores.squeeze())[-add_num:]
49 | query_coords = coords[inds, :]
50 | elif sample_type == "influ-scores-min":
51 | inds = np.argsort(influ_scores.squeeze())[:add_num]
52 | query_coords = coords[inds, :]
53 | else:
54 | raise NotImplementedError
55 |
56 | # Step 2: Add points by kNN
57 | new_features = None
58 | if comb_type == "duplicate":
59 | noise = np.random.randn(3).astype(np.float32)
60 | noise = noise / np.linalg.norm(noise)
61 | noise *= k
62 | new_coords = (query_coords + noise)
63 | new_influ_scores = influ_scores[inds, :]
64 | if point_features is not None:
65 | new_features = point_features[inds, :]
66 | else:
67 | nns_dists, nns_inds = pc.query(query_coords, k=k+1)
68 | nns_dists = nns_dists.astype(np.float32)
69 | nns_dists = nns_dists[:, 1:]
70 | nns_inds = nns_inds[:, 1:]
71 | if comb_type == "mean":
72 | new_coords = coords[nns_inds, :].mean(
73 | axis=-2) # (Nq, k, 3) -> (Nq, 3)
74 | new_influ_scores = influ_scores[nns_inds, :].mean(axis=-2)
75 | if point_features is not None:
76 | new_features = point_features[nns_inds, :].mean(axis=-2)
77 | elif comb_type == "random":
78 | rnd_w = np.random.uniform(
79 | 0, 1, (query_coords.shape[0], k)).astype(np.float32)
80 | rnd_w /= rnd_w.sum(axis=-1, keepdims=True)
81 | new_coords = (coords[nns_inds, :] *
82 | rnd_w.reshape(-1, k, 1)).sum(axis=-2)
83 | new_influ_scores = (
84 | influ_scores[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2)
85 | if point_features is not None:
86 | new_features = (
87 | point_features[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2)
88 | elif comb_type == "random-softmax":
89 | rnd_w = np.random.randn(
90 | query_coords.shape[0], k).astype(np.float32)
91 | rnd_w = scipy.special.softmax(rnd_w, axis=-1)
92 | new_coords = (coords[nns_inds, :] *
93 | rnd_w.reshape(-1, k, 1)).sum(axis=-2)
94 | new_influ_scores = (
95 | influ_scores[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2)
96 | if point_features is not None:
97 | new_features = (
98 | point_features[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2)
99 | elif comb_type == "weighted":
100 | new_coords = (coords[nns_inds, :] * (1 / (nns_dists + 1e-6)).reshape(-1, k, 1)).sum(
101 | axis=-2) / (1 / (nns_dists + 1e-6)).sum(axis=-1, keepdims=True)
102 | new_influ_scores = (influ_scores[nns_inds, :] * (1 / (nns_dists + 1e-6)).reshape(-1, k, 1)).sum(
103 | axis=-2) / (1 / (nns_dists + 1e-6)).sum(axis=-1, keepdims=True)
104 | if point_features is not None:
105 | new_features = (point_features[nns_inds, :] * (1 / (nns_dists + 1e-6)).reshape(-1, k, 1)).sum(
106 | axis=-2) / (1 / (nns_dists + 1e-6)).sum(axis=-1, keepdims=True)
107 | else:
108 | raise NotImplementedError
109 | return new_coords, len(new_coords), new_influ_scores, new_features
110 |
111 |
112 | def cam_to_world(coords, c2w, vector=True):
113 | """
114 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3]
115 | c2w: [N, 4, 4] or [4, 4]
116 | """
117 | if vector: # Convert to homogeneous coordinates
118 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1)
119 | else:
120 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1)
121 |
122 | if coords.ndim == 5:
123 | assert c2w.ndim == 2
124 | B, H, W, N, _ = coords.shape
125 | transformed_coords = torch.sum(
126 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3]
127 | elif coords.ndim == 4:
128 | assert c2w.ndim == 3
129 | N, H, W, _ = coords.shape
130 | transformed_coords = torch.sum(
131 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4]
132 | elif coords.ndim == 3:
133 | assert c2w.ndim == 2
134 | H, W, _ = coords.shape
135 | transformed_coords = torch.sum(
136 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4]
137 | elif coords.ndim == 2:
138 | assert c2w.ndim == 2
139 | K, _ = coords.shape
140 | transformed_coords = torch.sum(
141 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4]
142 | else:
143 | raise ValueError('Wrong dimension of coords')
144 | return transformed_coords[..., :3]
145 |
146 |
147 | def world_to_cam(coords, c2w, vector=True):
148 | """
149 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3]
150 | c2w: [N, 4, 4] or [4, 4]
151 | """
152 | if vector: # Convert to homogeneous coordinates
153 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1)
154 | else:
155 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1)
156 |
157 | c2w = torch.inverse(c2w)
158 | if coords.ndim == 5:
159 | assert c2w.ndim == 2
160 | B, H, W, N, _ = coords.shape
161 | transformed_coords = torch.sum(
162 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3]
163 | elif coords.ndim == 4:
164 | assert c2w.ndim == 3
165 | N, H, W, _ = coords.shape
166 | transformed_coords = torch.sum(
167 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4]
168 | elif coords.ndim == 3:
169 | assert c2w.ndim == 2
170 | H, W, _ = coords.shape
171 | transformed_coords = torch.sum(
172 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4]
173 | elif coords.ndim == 2:
174 | assert c2w.ndim == 2
175 | K, _ = coords.shape
176 | transformed_coords = torch.sum(
177 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4]
178 | else:
179 | raise ValueError('Wrong dimension of coords')
180 | return transformed_coords[..., :3]
181 |
182 |
183 | def activation_func(act_type='leakyrelu', neg_slope=0.2, inplace=True, num_channels=128, a=1., b=1., trainable=False):
184 | act_type = act_type.lower()
185 | if act_type == 'none':
186 | layer = nn.Identity()
187 | elif act_type == 'leakyrelu':
188 | layer = nn.LeakyReLU(neg_slope, inplace)
189 | elif act_type == 'prelu':
190 | layer = nn.PReLU(num_channels)
191 | elif act_type == 'relu':
192 | layer = nn.ReLU(inplace)
193 | elif act_type == '+1':
194 | layer = PlusOneActivation()
195 | elif act_type == 'relu+1':
196 | layer = nn.Sequential(nn.ReLU(inplace), PlusOneActivation())
197 | elif act_type == 'tanh':
198 | layer = nn.Tanh()
199 | elif act_type == 'shifted_tanh':
200 | layer = ShiftedTanh()
201 | elif act_type == 'sigmoid':
202 | layer = nn.Sigmoid()
203 | elif act_type == 'gelu':
204 | layer = nn.GELU()
205 | elif act_type == 'gaussian':
206 | layer = GaussianActivation(a, trainable)
207 | elif act_type == 'quadratic':
208 | layer = QuadraticActivation(a, trainable)
209 | elif act_type == 'multi-quadratic':
210 | layer = MultiQuadraticActivation(a, trainable)
211 | elif act_type == 'laplacian':
212 | layer = LaplacianActivation(a, trainable)
213 | elif act_type == 'super-gaussian':
214 | layer = SuperGaussianActivation(a, b, trainable)
215 | elif act_type == 'expsin':
216 | layer = ExpSinActivation(a, trainable)
217 | elif act_type == 'clamp':
218 | layer = Clamp(0, 1)
219 | elif 'sine' in act_type:
220 | layer = Sine(factor=a)
221 | elif 'softplus' in act_type:
222 | a, b, c = [float(i) for i in act_type.split('_')[1:]]
223 | print(
224 | 'Softplus activation: a={:.2f}, b={:.2f}, c={:.2f}'.format(a, b, c))
225 | layer = SoftplusActivation(a, b, c)
226 | else:
227 | raise NotImplementedError(
228 | 'activation layer [{:s}] is not found'.format(act_type))
229 | return layer
230 |
231 |
232 | def posenc(x, L_embed, factor=2.0, without_self=False, mult_factor=1.0):
233 | if without_self:
234 | rets = []
235 | else:
236 | rets = [x]
237 | for i in range(L_embed):
238 | for fn in [torch.sin, torch.cos]:
239 | rets.append(fn(factor**i * x * mult_factor))
240 | # return torch.cat(rets, 1)
241 | # To make sure the dimensions of the same meaning are together
242 | return torch.flatten(torch.stack(rets, -1), start_dim=-2, end_dim=-1)
243 |
244 |
245 | class PoseEnc(nn.Module):
246 | def __init__(self, factor=2.0, mult_factor=1.0):
247 | super(PoseEnc, self).__init__()
248 | self.factor = factor
249 | self.mult_factor = mult_factor
250 |
251 | def forward(self, x, L_embed, without_self=False):
252 | return posenc(x, L_embed, self.factor, without_self, self.mult_factor)
253 |
254 |
255 | def normalize_vector(x, eps=0.):
256 | # assert(x.shape[-1] == 3)
257 | return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
258 |
259 |
260 | def create_learning_rate_fn(optimizer, max_steps, args, debug=False):
261 | """Create learning rate schedule."""
262 | if args.type == "none":
263 | return None
264 |
265 | if args.warmup > 0:
266 | warmup_start_factor = 1e-16
267 | else:
268 | warmup_start_factor = 1.0
269 |
270 | warmup_fn = lr_scheduler.LinearLR(optimizer,
271 | start_factor=warmup_start_factor,
272 | end_factor=1.0,
273 | total_iters=args.warmup,
274 | verbose=debug)
275 |
276 | if args.type == "linear":
277 | decay_fn = lr_scheduler.LinearLR(optimizer,
278 | start_factor=1.0,
279 | end_factor=0.,
280 | total_iters=max_steps - args.warmup,
281 | verbose=debug)
282 | schedulers = [warmup_fn, decay_fn]
283 | milestones = [args.warmup]
284 |
285 | elif args.type == "cosine":
286 | cosine_steps = max(max_steps - args.warmup, 1)
287 | decay_fn = lr_scheduler.CosineAnnealingLR(optimizer,
288 | T_max=cosine_steps,
289 | verbose=debug)
290 | schedulers = [warmup_fn, decay_fn]
291 | milestones = [args.warmup]
292 |
293 | elif args.type == "cosine-hlfperiod":
294 | cosine_steps = max(max_steps - args.warmup, 1) * 2
295 | decay_fn = lr_scheduler.CosineAnnealingLR(optimizer,
296 | T_max=cosine_steps,
297 | verbose=debug)
298 | schedulers = [warmup_fn, decay_fn]
299 | milestones = [args.warmup]
300 |
301 | elif args.type == "exp":
302 | decay_fn = lr_scheduler.ExponentialLR(optimizer,
303 | gamma=args.gamma,
304 | verbose=debug)
305 | schedulers = [warmup_fn, decay_fn]
306 | milestones = [args.warmup]
307 |
308 | elif args.type == "stop":
309 | decay_fn = lr_scheduler.StepLR(
310 | optimizer, step_size=1, gamma=0.0, verbose=debug)
311 | schedulers = [warmup_fn, decay_fn]
312 | milestones = [args.warmup]
313 |
314 | else:
315 | raise NotImplementedError
316 |
317 | schedule_fn = lr_scheduler.SequentialLR(optimizer,
318 | schedulers=schedulers,
319 | milestones=milestones,
320 | verbose=debug)
321 |
322 | return schedule_fn
323 |
324 |
325 | class Sine(nn.Module):
326 | def __init__(self, factor=30):
327 | super().__init__()
328 | self.factor = factor
329 |
330 | def forward(self, x):
331 | return torch.sin(x * self.factor)
332 |
333 |
334 | class Clamp(nn.Module):
335 | def __init__(self, min_val, max_val):
336 | super().__init__()
337 | self.min_val = min_val
338 | self.max_val = max_val
339 |
340 | def forward(self, x):
341 | return torch.clamp(x, self.min_val, self.max_val)
342 |
343 |
344 | class ShiftedTanh(nn.Module):
345 | def __init__(self):
346 | super().__init__()
347 |
348 | def forward(self, x):
349 | return (torch.tanh(x) + 1) / 2
350 |
351 |
352 | class SoftplusActivation(nn.Module):
353 | def __init__(self, c1=1, c2=1, c3=0):
354 | super().__init__()
355 | self.c1 = c1
356 | self.c2 = c2
357 | self.c3 = c3
358 |
359 | def forward(self, x):
360 | return self.c1 * nn.functional.softplus(self.c2 * x + self.c3)
361 |
362 |
363 | class GaussianActivation(nn.Module):
364 | def __init__(self, a=1., trainable=True):
365 | super().__init__()
366 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable))
367 |
368 | def forward(self, x):
369 | return torch.exp(-x**2/(2*self.a**2))
370 |
371 |
372 | class QuadraticActivation(nn.Module):
373 | def __init__(self, a=1., trainable=True):
374 | super().__init__()
375 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable))
376 |
377 | def forward(self, x):
378 | return 1/(1+(self.a*x)**2)
379 |
380 |
381 | class MultiQuadraticActivation(nn.Module):
382 | def __init__(self, a=1., trainable=True):
383 | super().__init__()
384 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable))
385 |
386 | def forward(self, x):
387 | return 1/(1+(self.a*x)**2)**0.5
388 |
389 |
390 | class LaplacianActivation(nn.Module):
391 | def __init__(self, a=1., trainable=True):
392 | super().__init__()
393 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable))
394 |
395 | def forward(self, x):
396 | return torch.exp(-torch.abs(x)/self.a)
397 |
398 |
399 | class SuperGaussianActivation(nn.Module):
400 | def __init__(self, a=1., b=1., trainable=True):
401 | super().__init__()
402 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable))
403 | self.register_parameter('b', nn.Parameter(b*torch.ones(1), trainable))
404 |
405 | def forward(self, x):
406 | return torch.exp(-x**2/(2*self.a**2))**self.b
407 |
408 |
409 | class ExpSinActivation(nn.Module):
410 | def __init__(self, a=1., trainable=True):
411 | super().__init__()
412 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable))
413 |
414 | def forward(self, x):
415 | return torch.exp(-torch.sin(self.a*x))
416 |
417 |
418 | class PlusOneActivation(nn.Module):
419 | def __init__(self):
420 | super().__init__()
421 |
422 | def forward(self, x):
423 | return x + 1
424 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from datetime import datetime
3 | import matplotlib.pyplot as plt
4 | import io
5 | from PIL import Image
6 | import numpy as np
7 | import os
8 | import zipfile
9 | import torch
10 | import random
11 | import copy
12 |
13 |
14 | class DictAsMember(dict):
15 | def __getattr__(self, name):
16 | value = self[name]
17 | if isinstance(value, dict):
18 | value = DictAsMember(value)
19 | return value
20 |
21 |
22 | def update_dict(original, param):
23 | for key in param.keys():
24 | if type(param[key]) == dict:
25 | update_dict(original[key], param[key])
26 | elif type(param[key]) == list and key == "datasets":
27 | for i in range(len(param[key])):
28 | name = param[key][i]['name']
29 | for j in range(len(original[key])):
30 | if original[key][j]['name'] == name:
31 | for k in param[key][i].keys():
32 | original[key][j][k] = param[key][i][k]
33 | break
34 | else:
35 | new_param = copy.deepcopy(original[key][0])
36 | update_dict(new_param, param[key][i])
37 | original[key].append(new_param)
38 | else:
39 | original[key] = param[key]
40 |
41 |
42 | def setup_seed(seed):
43 | torch.manual_seed(seed)
44 | torch.cuda.manual_seed_all(seed)
45 | np.random.seed(seed)
46 | random.seed(seed)
47 |
48 |
49 | def find_all_python_files_and_zip(src_dir, dst_path):
50 | # find all python files in src_dir
51 | python_files = []
52 | for root, dirs, files in os.walk(src_dir):
53 | if 'experiment' in root:
54 | continue
55 | for cur_file in files:
56 | if cur_file.endswith('.py'):
57 | python_files.append(os.path.join(root, cur_file))
58 |
59 | # zip all python files
60 | with zipfile.ZipFile(dst_path, 'w') as zip_file:
61 | for cur_file in python_files:
62 | zip_file.write(cur_file, os.path.relpath(cur_file, src_dir))
63 |
64 |
65 | class Logger(object):
66 | def __init__(self, filename='default.log', stream=sys.stdout):
67 | self.terminal = stream
68 | self.log = open(filename, 'a')
69 | ct = datetime.now()
70 | self.log.write('*'*50 + '\n' + str(ct) + '\n' + '*'*50 + '\n')
71 |
72 | def write(self, message):
73 | self.terminal.write(message)
74 | self.log.write(message)
75 |
76 | def flush(self):
77 | pass
78 |
79 |
80 | def get_colors(weights):
81 | N = weights.shape[0]
82 | weights = (weights - weights.min()) / (weights.max() - weights.min())
83 | colors = np.full((N, 3), [1., 0., 0.])
84 | colors[:, 0] *= weights[:N]
85 | colors[:, 2] = (1 - weights[:N])
86 | return colors
87 |
88 |
89 | def get_training_main_plot(index, steps, train_tgt_rgb, train_tgt_patch, train_pred_patch, test_tgt_tgb, test_pred_rgb, train_losses,
90 | eval_losses, points_np, pt_plot_scale, depth_np, pt_lrs, attn_lrs, eval_psnrs, points_conf_scores_np=None):
91 | step = steps[-1]
92 | fig = plt.figure(figsize=(20, 10))
93 |
94 | ax = fig.add_subplot(2, 5, 1)
95 | ax.imshow(train_tgt_rgb)
96 | ax.set_title(f'Iteration: {step} train norm')
97 |
98 | ax = fig.add_subplot(2, 5, 2)
99 | ax.imshow(train_tgt_patch)
100 | ax.set_title(f'Iteration: {step} train norm patch')
101 |
102 | ax = fig.add_subplot(2, 5, 3)
103 | ax.imshow(train_pred_patch)
104 | ax.set_title(f'Iteration: {step} train output')
105 |
106 | ax = fig.add_subplot(2, 5, 4)
107 | ax.plot(steps, train_losses, label='train')
108 | ax.plot(steps, eval_losses, label='eval')
109 | ax.legend()
110 | ax.set_title('losses')
111 |
112 | ax = fig.add_subplot(2, 5, 5, projection='3d')
113 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
114 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
115 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
116 | ax.set_xlabel('x')
117 | ax.set_ylabel('y')
118 | ax.set_zlabel('z')
119 | cur_color = "grey"
120 | if points_conf_scores_np is not None:
121 | cur_color = get_colors(points_conf_scores_np)
122 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color)
123 | ax.set_title('Point Cloud')
124 |
125 | ax = fig.add_subplot(2, 5, 6)
126 | cd = ax.imshow(depth_np)
127 | fig.colorbar(cd, ax=ax)
128 | ax.set_title(f'depth map')
129 |
130 | ax = fig.add_subplot(2, 5, 7)
131 | ax.imshow(test_tgt_tgb)
132 | ax.set_title(f'Iteration: {step} eval norm')
133 |
134 | ax = fig.add_subplot(2, 5, 8)
135 | ax.imshow(test_pred_rgb)
136 | ax.set_title(f'Iteration: {step} eval predict')
137 |
138 | ax = fig.add_subplot(2, 5, 9)
139 | ax.plot(steps, np.log10(pt_lrs), label="pt lr")
140 | ax.plot(steps, np.log10(attn_lrs), label="attn lr")
141 | ax.legend()
142 | ax.set_title('learning rates log10')
143 |
144 | ax = fig.add_subplot(2, 5, 10)
145 | ax.plot(steps, eval_psnrs)
146 | ax.set_title('eval psnr')
147 |
148 | fig.suptitle("Main Plot\n%s\niter %d\nnum pts: %d" % (index, step, points_np.shape[0]))
149 |
150 | canvas = fig.canvas
151 | buffer = io.BytesIO()
152 | canvas.print_png(buffer)
153 | data = buffer.getvalue()
154 | buffer.write(data)
155 | img = Image.open(buffer)
156 | plt.close()
157 |
158 | return img
159 |
160 |
161 | def get_training_pcd_plot(index, step, ro, rd, points_np, coord_scale, pt_plot_scale, points_conf_scores_np=None):
162 | num_plots = 6 if points_conf_scores_np is not None else 4
163 | fig = plt.figure(figsize=(5 * num_plots, 6))
164 |
165 | H, W, _ = rd.shape
166 |
167 | ax = fig.add_subplot(1, num_plots, 1, projection='3d')
168 | ax.view_init(elev=0., azim=90)
169 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
170 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
171 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
172 | ax.set_xlabel('x')
173 | ax.set_ylabel('y')
174 | ax.set_zlabel('z')
175 | cur_color = "orange"
176 | if points_conf_scores_np is not None:
177 | cur_color = get_colors(points_conf_scores_np)
178 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale)
179 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10)
180 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue")
181 | ax.set_title('Point Cloud View 1')
182 |
183 | ax = fig.add_subplot(1, num_plots, 2, projection='3d')
184 | ax.view_init(elev=0., azim=180)
185 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
186 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
187 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
188 | ax.set_xlabel('x')
189 | ax.set_ylabel('y')
190 | ax.set_zlabel('z')
191 | cur_color = "orange"
192 | if points_conf_scores_np is not None:
193 | cur_color = get_colors(points_conf_scores_np)
194 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale)
195 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10)
196 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue")
197 | ax.set_title('Point Cloud View 2')
198 |
199 | ax = fig.add_subplot(1, num_plots, 3, projection='3d')
200 | ax.view_init(elev=0., azim=270)
201 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
202 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
203 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
204 | ax.set_xlabel('x')
205 | ax.set_ylabel('y')
206 | ax.set_zlabel('z')
207 | cur_color = "orange"
208 | if points_conf_scores_np is not None:
209 | cur_color = get_colors(points_conf_scores_np)
210 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale)
211 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10)
212 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue")
213 | ax.set_title('Point Cloud View 3')
214 |
215 | ax = fig.add_subplot(1, num_plots, 4, projection='3d')
216 | ax.view_init(elev=89.9, azim=90)
217 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
218 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
219 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
220 | ax.set_xlabel('x')
221 | ax.set_ylabel('y')
222 | ax.set_zlabel('z')
223 | cur_color = "orange"
224 | if points_conf_scores_np is not None:
225 | cur_color = get_colors(points_conf_scores_np)
226 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale)
227 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10)
228 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue")
229 | ax.set_title('Point Cloud View 1 Up')
230 |
231 | if points_conf_scores_np is not None:
232 | ax = fig.add_subplot(1, num_plots, 5)
233 | ax.scatter(range(len(points_conf_scores_np)), points_conf_scores_np)
234 | ax.set_title('Confidence Scores scatter plot')
235 |
236 | ax = fig.add_subplot(1, num_plots, 6)
237 | bins = np.linspace(-1, 1, 100).tolist()
238 | ax.hist(points_conf_scores_np, bins=bins)
239 | ax.set_title('Confidence Scores histogram')
240 |
241 | fig.suptitle("Point Clouds\n%s\niter %d" % (index, step))
242 |
243 | canvas = fig.canvas
244 | buffer = io.BytesIO()
245 | canvas.print_png(buffer)
246 | data = buffer.getvalue()
247 | buffer.write(data)
248 | img = Image.open(buffer)
249 | plt.close()
250 |
251 | return img
252 |
253 |
254 | def get_training_pcd_single_plot(step, points_np, pt_plot_scale, points_conf_scores_np=None):
255 | fig = plt.figure(figsize=(5, 5))
256 | fig.tight_layout()
257 |
258 | ax = fig.add_subplot(1, 1, 1, projection='3d')
259 | ax.view_init(elev=20., azim=90 + (step / 500) * (720. / 500))
260 | ax.set_xlim3d(-pt_plot_scale * 1.5, pt_plot_scale * 1.5)
261 | ax.set_ylim3d(-pt_plot_scale * 1.5, pt_plot_scale * 1.5)
262 | ax.set_zlim3d(-pt_plot_scale * 1.5, pt_plot_scale * 1.5)
263 | ax.set_axis_off()
264 | cur_color = "orange"
265 | if points_conf_scores_np is not None:
266 | cur_color = get_colors(points_conf_scores_np)
267 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color)
268 |
269 | fig.suptitle("iter %d\n#points: %d" % (step, points_np.shape[0]))
270 | fig.tight_layout()
271 |
272 | canvas = fig.canvas
273 | buffer = io.BytesIO()
274 | canvas.print_png(buffer)
275 | data = buffer.getvalue()
276 | buffer.write(data)
277 | img = Image.open(buffer)
278 | plt.close()
279 |
280 | return img
281 |
282 |
283 | def get_test_pcrgb(frame, th, azmin, test_psnr, points_np, rgb_pred_np, rgb_gt_np, depth_np, pt_plot_scale, points_conf_scores_np=None):
284 | fig = plt.figure(figsize=(30, 10))
285 |
286 | ax = fig.add_subplot(1, 5, 1, projection='3d')
287 | ax.axis('off')
288 | ax.view_init(elev=20., azim=90 - th)
289 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
290 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
291 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
292 | ax.set_xlabel('x')
293 | ax.set_ylabel('y')
294 | ax.set_zlabel('z')
295 | cur_color = "grey"
296 | if points_conf_scores_np is not None:
297 | cur_color = get_colors(points_conf_scores_np)
298 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.5)
299 |
300 | ax = fig.add_subplot(1, 5, 2, projection='3d')
301 | # ax.axis('off')
302 | ax.view_init(elev=0., azim=azmin)
303 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
304 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
305 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
306 | ax.set_xlabel('x')
307 | ax.set_ylabel('y')
308 | ax.set_zlabel('z')
309 | switch_yz = np.array([[1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=np.float32)
310 | cur_points_np = points_np @ switch_yz
311 | cur_points_np[:, 2] = -cur_points_np[:, 2]
312 | cur_color = "orange"
313 | if points_conf_scores_np is not None:
314 | cur_color = get_colors(points_conf_scores_np)
315 | ax.scatter3D(cur_points_np[:, 0], cur_points_np[:, 1], cur_points_np[:, 2], c=cur_color, s=0.5)
316 |
317 | ax = fig.add_subplot(1, 5, 3)
318 | ax.axis('off')
319 | ax.imshow(rgb_pred_np)
320 |
321 | ax = fig.add_subplot(1, 5, 4)
322 | ax.axis('off')
323 | ax.imshow(rgb_gt_np)
324 |
325 | ax = fig.add_subplot(1, 5, 5)
326 | ax.axis('off')
327 | ax.imshow(depth_np)
328 |
329 | fig.suptitle("Point Cloud and RGB\nframe %d, PSNR %.3f, num points %d" % (frame, test_psnr, points_np.shape[0]))
330 |
331 | canvas = fig.canvas
332 | buffer = io.BytesIO()
333 | canvas.print_png(buffer)
334 | data = buffer.getvalue()
335 | buffer.write(data)
336 | img = Image.open(buffer)
337 | plt.close()
338 |
339 | return img
340 |
341 |
342 | def get_test_featmap_attn(frame, th, points_np, rgb_pred_np, rgb_gt_np, pt_plot_scale, featmap_np, attn_np, points_conf_scores_np=None):
343 | fig = plt.figure(figsize=(20, 15))
344 |
345 | ax = fig.add_subplot(3, 5, 1, projection='3d')
346 | ax.axis('off')
347 | ax.view_init(elev=20., azim=90 - th)
348 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale)
349 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale)
350 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale)
351 | ax.set_xlabel('x')
352 | ax.set_ylabel('y')
353 | ax.set_zlabel('z')
354 | if points_conf_scores_np is not None:
355 | cur_color = get_colors(points_conf_scores_np)
356 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color)
357 |
358 | ax = fig.add_subplot(3, 5, 2)
359 | ax.axis('off')
360 | ax.imshow(rgb_gt_np)
361 |
362 | ax = fig.add_subplot(3, 5, 3)
363 | ax.axis('off')
364 | ax.imshow(rgb_pred_np)
365 |
366 | C = featmap_np.shape[-1]
367 | for i in range(min(5, C)):
368 | ax = fig.add_subplot(3, 5, 6 + i)
369 | ax.axis('off')
370 | cur_dim = C * i // 5
371 | cur_min = featmap_np[..., cur_dim].min()
372 | cur_max = featmap_np[..., cur_dim].max()
373 | ax.imshow(featmap_np[..., cur_dim])
374 | ax.set_title(f'featmap dim {cur_dim}\nmin: %.3f, max: %.3f' % (cur_min, cur_max))
375 |
376 | K = attn_np.shape[-1]
377 | for i in range(min(K-1, 4)):
378 | ax = fig.add_subplot(3, 5, 11 + i)
379 | ax.axis('off')
380 | cur_dim = K * i // 4
381 | cur_min = attn_np[..., cur_dim].min()
382 | cur_max = attn_np[..., cur_dim].max()
383 | ax.imshow(attn_np[..., cur_dim])
384 | ax.set_title(f'attn dim {cur_dim}\nmin: %.3f, max: %.3f' % (cur_min, cur_max))
385 |
386 | ax = fig.add_subplot(3, 5, 15)
387 | ax.axis('off')
388 | cur_min = attn_np[..., -1].min()
389 | cur_max = attn_np[..., -1].max()
390 | ax.imshow(attn_np[..., -1])
391 | ax.set_title(f'attn dim -1\nmin: %.3f, max: %.3f' % (cur_min, cur_max))
392 |
393 | fig.suptitle("feature map and attention\nframe %d\n" % (frame))
394 |
395 | canvas = fig.canvas
396 | buffer = io.BytesIO()
397 | canvas.print_png(buffer)
398 | data = buffer.getvalue()
399 | buffer.write(data)
400 | img = Image.open(buffer)
401 | plt.close()
402 |
403 | return img
404 |
405 |
406 | def resample_shading_codes(shading_codes, args, model, dataset, img_id, loss_fn, step, full_img=False):
407 | if full_img == True:
408 | img, rayd, rayo = dataset.get_full_img(img_id)
409 | c2w = dataset.get_c2w(img_id)
410 | else:
411 | _, _, img, rayd, rayo = dataset[img_id]
412 | c2w = dataset.get_c2w(img_id)
413 | img = torch.from_numpy(img).unsqueeze(0)
414 | rayd = torch.from_numpy(rayd).unsqueeze(0)
415 | rayo = torch.from_numpy(rayo).unsqueeze(0)
416 |
417 | sampled_shading_codes = torch.randn(args.exposure_control.shading_code_num_samples,
418 | args.exposure_control.shading_code_dim, device=model.device) \
419 | * args.exposure_control.shading_code_scale
420 |
421 | N, H, W, _ = rayd.shape
422 | num_pts, _ = model.points.shape
423 |
424 | rayo = rayo.to(model.device)
425 | rayd = rayd.to(model.device)
426 | img = img.to(model.device)
427 | c2w = c2w.to(model.device)
428 |
429 | topk = min([num_pts, model.select_k])
430 |
431 | bkg_seq_len_attn = 0
432 | feat_dim = args.models.attn.embed.value.d_ff_out
433 | if model.bkg_feats is not None:
434 | bkg_seq_len_attn = model.bkg_feats.shape[0]
435 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(model.device)
436 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(model.device)
437 |
438 | best_idx = 0
439 | best_loss = 1e10
440 | best_loss_idx = 0
441 | best_psnr = 0
442 | best_psnr_idx = 0
443 |
444 | with torch.no_grad():
445 | for height_start in range(0, H, args.eval.max_height):
446 | for width_start in range(0, W, args.eval.max_width):
447 | height_end = min(height_start + args.eval.max_height, H)
448 | width_end = min(width_start + args.eval.max_width, W)
449 |
450 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \
451 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=step)
452 |
453 | for i in range(args.exposure_control.shading_code_num_samples):
454 | torch.cuda.empty_cache()
455 | cur_shading_code = sampled_shading_codes[i]
456 | cur_affine = model.mapping_mlp(cur_shading_code)
457 | cur_affine_dim = cur_affine.shape[-1]
458 | cur_gamma, cur_beta = cur_affine[:cur_affine_dim // 2], cur_affine[cur_affine_dim // 2:]
459 | # print(cur_shading_code.min().item(), cur_shading_code.max().item(), cur_gamma.min().item(), cur_gamma.max().item(), cur_beta.min().item(), cur_beta.max().item())
460 |
461 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2), gamma=cur_gamma, beta=cur_beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3)
462 |
463 | if model.bkg_feats is not None:
464 | bkg_attn = attn[..., topk:, :]
465 | if args.models.normalize_topk_attn:
466 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
467 | else:
468 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
469 | rgb = rgb.squeeze(-2)
470 | else:
471 | rgb = foreground_rgb.squeeze(-2)
472 |
473 | rgb = model.last_act(rgb)
474 | # rgb = torch.clamp(rgb, 0, 1)
475 |
476 | eval_loss = loss_fn(rgb, img)
477 | eval_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.)
478 |
479 | if eval_loss < best_loss:
480 | best_loss = eval_loss
481 | best_loss_idx = i
482 |
483 | if eval_psnr > best_psnr:
484 | best_psnr = eval_psnr
485 | best_psnr_idx = i
486 |
487 | model.clear_grad()
488 |
489 | # print("Best loss:", best_loss, "Best loss idx:", best_loss_idx, "Best psnr:", best_psnr, "Best psnr idx:", best_psnr_idx)
490 | best_idx = best_loss_idx if args.exposure_control.shading_code_resample_select_by == "loss" else best_psnr_idx
491 | shading_codes[img_id] = sampled_shading_codes[best_idx]
492 |
493 | del rayo, rayd, img, c2w, attn
494 | del eval_loss
495 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import os
6 | import io
7 | import shutil
8 | from PIL import Image
9 | import imageio
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import random
13 | import sys
14 | from utils import *
15 | from dataset import get_dataset, get_loader
16 | from models import get_model, get_loss
17 | import lpips
18 | try:
19 | from skimage.measure import compare_ssim
20 | except:
21 | from skimage.metrics import structural_similarity
22 |
23 | def compare_ssim(gt, img, win_size, channel_axis=2):
24 | return structural_similarity(gt, img, win_size=win_size, channel_axis=channel_axis, data_range=1.0)
25 |
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser(description="PAPR")
29 | parser.add_argument('--opt', type=str, default="", help='Option file path')
30 | parser.add_argument('--resume', type=int, default=250000, help='Resume step')
31 | parser.add_argument('--exp', action='store_true', help='[Exposure control] To test with exposure control enabled')
32 | parser.add_argument('--intrp', action='store_true', help='[Exposure control] Interpolation')
33 | parser.add_argument('--random', action='store_true', help='[Exposure control] Random exposure control')
34 | parser.add_argument('--resample', action='store_true', help='[Exposure control] Resample shading codes')
35 | parser.add_argument('--seed', type=int, default=1, help='[Exposure control] Random seed')
36 | parser.add_argument('--view', type=int, default=0, help='[Exposure control] Test frame index')
37 | parser.add_argument('--scale', type=float, default=1.0, help='[Exposure control] Shading code scale')
38 | parser.add_argument('--num_samples', type=int, default=20, help='[Exposure control] Number of samples for random exposure control')
39 | parser.add_argument('--start_index', type=int, default=0, help='[Exposure control] Interpolation start index')
40 | parser.add_argument('--end_index', type=int, default=1, help='[Exposure control] Interpolation end index')
41 | parser.add_argument('--num_intrp', type=int, default=10, help='[Exposure control] Number of interpolations')
42 | return parser.parse_args()
43 |
44 |
45 | def test_step(frame, i, num_frames, model, device, dataset, batch, loss_fn, lpips_loss_fn_alex, lpips_loss_fn_vgg, args, config,
46 | test_losses, test_psnrs, test_ssims, test_lpips_alexs, test_lpips_vggs, resume_step, cur_shading_code=None, suffix=""):
47 | idx, _, img, rayd, rayo = batch
48 | c2w = dataset.get_c2w(idx.squeeze())
49 |
50 | N, H, W, _ = rayd.shape
51 | num_pts, _ = model.points.shape
52 |
53 | rayo = rayo.to(device)
54 | rayd = rayd.to(device)
55 | img = img.to(device)
56 | c2w = c2w.to(device)
57 |
58 | topk = min([num_pts, model.select_k])
59 | selected_points = torch.zeros(1, H, W, topk, 3)
60 |
61 | bkg_seq_len_attn = 0
62 | feat_dim = args.models.attn.embed.value.d_ff_out
63 | if model.bkg_feats is not None:
64 | bkg_seq_len_attn = model.bkg_feats.shape[0]
65 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(device)
66 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(device)
67 |
68 | with torch.no_grad():
69 | cur_gamma, cur_beta, code_mean = None, None, 0
70 | if cur_shading_code is not None:
71 | code_mean = cur_shading_code.mean().item()
72 | cur_affine = model.mapping_mlp(cur_shading_code)
73 | cur_affine_dim = cur_affine.shape[-1]
74 | cur_gamma, cur_beta = cur_affine[:cur_affine_dim // 2], cur_affine[cur_affine_dim // 2:]
75 |
76 | for height_start in range(0, H, args.test.max_height):
77 | for width_start in range(0, W, args.test.max_width):
78 | height_end = min(height_start + args.test.max_height, H)
79 | width_end = min(width_start + args.test.max_width, W)
80 |
81 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \
82 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=resume_step)
83 |
84 | selected_points[:, height_start:height_end, width_start:width_end, :, :] = model.selected_points
85 |
86 | if args.models.use_renderer:
87 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2), gamma=cur_gamma, beta=cur_beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3)
88 | else:
89 | foreground_rgb = feature_map
90 |
91 | if model.bkg_feats is not None:
92 | bkg_attn = attn[..., topk:, :]
93 | bkg_mask = (model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn).squeeze()
94 | if args.models.normalize_topk_attn:
95 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
96 | else:
97 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
98 | rgb = rgb.squeeze(-2)
99 | else:
100 | rgb = foreground_rgb.squeeze(-2)
101 | bkg_mask = torch.zeros(N, H, W, 1).to(device)
102 |
103 | rgb = model.last_act(rgb)
104 | rgb = torch.clamp(rgb, 0, 1)
105 |
106 | test_loss = loss_fn(rgb, img)
107 | test_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.)
108 | test_ssim = compare_ssim(rgb.squeeze().detach().cpu().numpy(), img.squeeze().detach().cpu().numpy(), 11, channel_axis=2)
109 | test_lpips_alex = lpips_loss_fn_alex(rgb.permute(0, 3, 1, 2), img.permute(0, 3, 1, 2)).squeeze().item()
110 | test_lpips_vgg = lpips_loss_fn_vgg(rgb.permute(0, 3, 1, 2), img.permute(0, 3, 1, 2)).squeeze().item()
111 |
112 | test_losses.append(test_loss.item())
113 | test_psnrs.append(test_psnr)
114 | test_ssims.append(test_ssim)
115 | test_lpips_alexs.append(test_lpips_alex)
116 | test_lpips_vggs.append(test_lpips_vgg)
117 |
118 | print(f"Test frame: {frame}, code mean: {code_mean}, test_loss: {test_losses[-1]:.4f}, test_psnr: {test_psnrs[-1]:.4f}, test_ssim: {test_ssims[-1]:.4f}, test_lpips_alex: {test_lpips_alexs[-1]:.4f}, test_lpips_vgg: {test_lpips_vggs[-1]:.4f}")
119 |
120 | od = -rayo
121 | D = torch.sum(od * rayo)
122 | dists = torch.abs(torch.sum(selected_points.to(od.device) * od, -1) - D) / torch.norm(od)
123 | if model.bkg_feats is not None:
124 | dists = torch.cat([dists, torch.ones(N, H, W, model.bkg_feats.shape[0]).to(dists.device) * 0], dim=-1)
125 | cur_depth = (torch.sum(attn.squeeze(-1).to(od.device) * dists, dim=-1)).detach().cpu().squeeze().numpy().astype(np.float32)
126 | depth_np = cur_depth.copy()
127 |
128 | if args.test.save_fig:
129 | # To save the rendered images, depth maps, foreground rgb, and background mask
130 | dir_name = "images"
131 | if cur_shading_code is not None:
132 | dir_name = f'exposure_control_{suffix}_scale{config.scale}' if suffix in ['intrp', 'random'] else f'exposure_control_{suffix}'
133 | log_dir = os.path.join(args.save_dir, args.index, 'test', dir_name)
134 | os.makedirs(log_dir, exist_ok=True)
135 | cur_depth /= args.dataset.coord_scale
136 | cur_depth *= (65536 / 10)
137 | cur_depth = cur_depth.astype(np.uint16)
138 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-predrgb-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), (rgb.squeeze().detach().cpu().numpy() * 255).astype(np.uint8))
139 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-depth-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), cur_depth)
140 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-fgrgb-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), (foreground_rgb.squeeze().clamp(0, 1).detach().cpu().numpy() * 255).astype(np.uint8))
141 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-bkgmask-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), (bkg_mask.detach().cpu().numpy() * 255).astype(np.uint8))
142 |
143 | plots = {}
144 |
145 | if args.test.save_video:
146 | # To save the rendered videos
147 | coord_scale = args.dataset.coord_scale
148 | if "Barn" in args.dataset.path:
149 | coord_scale *= 1.5
150 | if "Family" in args.dataset.path:
151 | coord_scale *= 0.5
152 | pt_plot_scale = 1.0 * coord_scale
153 |
154 | plot_opt = args.test.plots
155 | th = -frame * (360. / num_frames)
156 | azims = np.linspace(180, -180, num_frames)
157 | azmin = azims[frame]
158 |
159 | points_np = model.points.detach().cpu().numpy()
160 | rgb_pred_np = rgb.squeeze().detach().cpu().numpy().astype(np.float32)
161 | rgb_gt_np = img.squeeze().detach().cpu().numpy().astype(np.float32)
162 | points_influ_scores_np = None
163 | if model.points_influ_scores is not None:
164 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy()
165 |
166 | if plot_opt.pcrgb:
167 | pcrgb_plot = get_test_pcrgb(frame, th, azmin, test_psnr, points_np,
168 | rgb_pred_np, rgb_gt_np, depth_np, pt_plot_scale, points_influ_scores_np)
169 | plots["pcrgb"] = pcrgb_plot
170 |
171 | if plot_opt.featattn: # Note that these plots are not necessarily meaningful since each ray has different top K points
172 | featmap_np = feature_map[0].squeeze().detach().cpu().numpy().astype(np.float32)
173 | attn_np = attn[0].squeeze().detach().cpu().numpy().astype(np.float32)
174 | featattn_plot = get_test_featmap_attn(frame, th, points_np, rgb_pred_np, rgb_gt_np,
175 | pt_plot_scale, featmap_np, attn_np, points_influ_scores_np)
176 | plots["featattn"] = featattn_plot
177 |
178 | return plots
179 |
180 |
181 | def test(model, device, dataset, save_name, args, config, resume_step, shading_codes=None):
182 | testloader = get_loader(dataset, args.dataset, mode="test")
183 | print("testloader:", testloader)
184 |
185 | loss_fn = get_loss(args.training.losses)
186 | loss_fn = loss_fn.to(device)
187 |
188 | lpips_loss_fn_alex = lpips.LPIPS(net='alex', version='0.1')
189 | lpips_loss_fn_alex = lpips_loss_fn_alex.to(device)
190 | lpips_loss_fn_vgg = lpips.LPIPS(net='vgg', version='0.1')
191 | lpips_loss_fn_vgg = lpips_loss_fn_vgg.to(device)
192 |
193 | test_losses = []
194 | test_psnrs = []
195 | test_ssims = []
196 | test_lpips_alexs = []
197 | test_lpips_vggs = []
198 |
199 | frames = {}
200 |
201 | if config.exp: # test with exposure control, the model needs to be finetuned with exposure control first
202 | if config.random:
203 | suffix = "random"
204 | for frame, batch in enumerate(testloader):
205 | if frame != config.view:
206 | continue
207 |
208 | for i in range(config.num_samples):
209 | print("test seed:", config.seed, "i:", i)
210 | shading_codes = torch.randn(1, args.exposure_control.shading_code_dim, device=device) * config.scale
211 | plots = test_step(frame, i, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex,
212 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs,
213 | test_lpips_vggs, resume_step, shading_codes, suffix)
214 |
215 | elif config.intrp:
216 | suffix = "intrp"
217 | latent_codes = []
218 | ids = [config.start_index, config.end_index]
219 | for i in range(config.num_samples):
220 | print("test seed:", config.seed, "i:", i)
221 | shading_codes = torch.randn(1, args.exposure_control.shading_code_dim, device=device) * config.scale
222 |
223 | if i in ids:
224 | latent_codes.append(shading_codes)
225 |
226 | interpolated_codes = []
227 | for j in range(config.num_intrp):
228 | interpolated_codes.append(latent_codes[0] + (latent_codes[1] - latent_codes[0]) * (j + 1) / config.num_intrp)
229 |
230 | frames = {}
231 | for frame, batch in enumerate(testloader):
232 | if frame != config.view:
233 | continue
234 |
235 | for i in range(config.num_intrp):
236 | shading_codes = interpolated_codes[i]
237 | plots = test_step(frame, i, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex,
238 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs,
239 | test_lpips_vggs, resume_step, shading_codes, suffix)
240 |
241 | else:
242 | suffix = "test"
243 | shading_code = torch.randn(args.exposure_control.shading_code_dim, device=device) * config.scale
244 | for frame, batch in enumerate(testloader):
245 | # shading_code = shading_codes[frame]
246 | plots = test_step(frame, 0, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex,
247 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs,
248 | test_lpips_vggs, resume_step, shading_code, suffix)
249 |
250 | if plots:
251 | for key, value in plots.items():
252 | if key not in frames:
253 | frames[key] = []
254 | frames[key].append(value)
255 |
256 | else: # test without exposure control
257 | for frame, batch in enumerate(testloader):
258 | plots = test_step(frame, 0, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex,
259 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs,
260 | test_lpips_vggs, resume_step)
261 |
262 | if plots:
263 | for key, value in plots.items():
264 | if key not in frames:
265 | frames[key] = []
266 | frames[key].append(value)
267 |
268 | test_loss = np.mean(test_losses)
269 | test_psnr = np.mean(test_psnrs)
270 | test_ssim = np.mean(test_ssims)
271 | test_lpips_alex = np.mean(test_lpips_alexs)
272 | test_lpips_vgg = np.mean(test_lpips_vggs)
273 |
274 | if frames:
275 | for key, value in frames.items():
276 | name = f"{args.index}-PSNR{test_psnr:.3f}-SSIM{test_ssim:.4f}-LPIPSA{test_lpips_alex:.4f}-LPIPSV{test_lpips_vgg:.4f}-{key}-{save_name}-step{resume_step}.mp4"
277 | # In case the name is too long
278 | name = name[-255:] if len(name) > 255 else name
279 | log_dir = os.path.join(args.save_dir, args.index, 'test', 'videos')
280 | os.makedirs(log_dir, exist_ok=True)
281 | f = os.path.join(log_dir, name)
282 | imageio.mimwrite(f, value, fps=30, quality=10)
283 |
284 | print(f"Avg test loss: {test_loss:.4f}, test PSNR: {test_psnr:.4f}, test SSIM: {test_ssim:.4f}, test LPIPS Alex: {test_lpips_alex:.4f}, test LPIPS VGG: {test_lpips_vgg:.4f}")
285 |
286 |
287 | def main(config, args, save_name, mode, resume_step=0):
288 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
289 | model = get_model(args, device)
290 | dataset = get_dataset(args.dataset, mode=mode)
291 |
292 | if args.test.load_path:
293 | try:
294 | model_state_dict = torch.load(args.test.load_path)
295 | for step, state_dict in model_state_dict.items():
296 | resume_step = int(step)
297 | if config.exp:
298 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False)
299 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False)
300 | model.load_my_state_dict(state_dict)
301 | except:
302 | model_state_dict = torch.load(os.path.join(args.save_dir, args.test.load_path, "model.pth"))
303 | for step, state_dict in model_state_dict.items():
304 | resume_step = step
305 | if config.exp:
306 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False)
307 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False)
308 | model.load_my_state_dict(state_dict)
309 | print("!!!!! Loaded model from %s at step %s" % (args.test.load_path, resume_step))
310 | else:
311 | try:
312 | model_state_dict = torch.load(os.path.join(args.save_dir, args.index, "model.pth"))
313 | for step, state_dict in model_state_dict.items():
314 | resume_step = int(step)
315 | if config.exp:
316 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False)
317 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False)
318 | model.load_my_state_dict(state_dict)
319 | except:
320 | state_dict = torch.load(os.path.join(args.save_dir, args.index, f"model_{resume_step}.pth"))
321 | if config.exp:
322 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False)
323 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False)
324 | model.load_my_state_dict(state_dict)
325 | print("!!!!! Loaded model from %s at step %s" % (os.path.join(args.save_dir, args.index), resume_step))
326 |
327 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
328 | model = model.to(device)
329 |
330 | shading_codes = None
331 | if config.exp:
332 | if mode == 'train':
333 | shading_codes = model.train_shading_codes
334 | print("Using train shading_codes:", shading_codes.shape, shading_codes.min(), shading_codes.max())
335 | elif mode == 'test':
336 | shading_codes = model.eval_shading_codes
337 | print("Using eval shading_codes:", shading_codes.shape, shading_codes.min(), shading_codes.max())
338 | else:
339 | raise NotImplementedError
340 |
341 | test(model, device, dataset, save_name, args, config, resume_step, shading_codes)
342 |
343 |
344 | if __name__ == '__main__':
345 |
346 | with open("configs/default.yml", 'r') as f:
347 | default_config = yaml.safe_load(f)
348 |
349 | args = parse_args()
350 | if args.intrp or args.random: assert args.exp, "You need to trun on the exposure control (--exp) for expsoure interpolation or generating images with random exposure levels."
351 | assert not args.intrp or not args.random, "Cannot do exposure interpolation and random exposure generation at the same time."
352 | with open(args.opt, 'r') as f:
353 | config = yaml.safe_load(f)
354 |
355 | test_config = copy.deepcopy(default_config)
356 | update_dict(test_config, config)
357 |
358 | resume_step = args.resume
359 |
360 | log_dir = os.path.join(test_config["save_dir"], test_config['index'])
361 | os.makedirs(log_dir, exist_ok=True)
362 |
363 | sys.stdout = Logger(os.path.join(log_dir, 'test.log'), sys.stdout)
364 | sys.stderr = Logger(os.path.join(log_dir, 'test_error.log'), sys.stderr)
365 |
366 | shutil.copyfile(__file__, os.path.join(log_dir, os.path.basename(__file__)))
367 | shutil.copyfile(args.opt, os.path.join(log_dir, os.path.basename(args.opt)))
368 |
369 | setup_seed(test_config['seed'])
370 |
371 | for i, dataset in enumerate(test_config['test']['datasets']):
372 | name = dataset['name']
373 | mode = dataset['mode']
374 | print(name, dataset)
375 | test_config['dataset'].update(dataset)
376 | test_config = DictAsMember(test_config)
377 |
378 | if args.exp:
379 | assert test_config.models.use_renderer, "Currently only support using renderer for exposure control"
380 |
381 | main(args, test_config, name, mode, resume_step)
382 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import os
6 | import numpy as np
7 | from .utils import normalize_vector, create_learning_rate_fn, add_points_knn, activation_func
8 | from .mlp import get_mapping_mlp
9 | from .attn import get_proximity_attention_layer
10 | from .renderer import get_generator
11 |
12 |
13 | def count_parameters(model):
14 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
15 |
16 |
17 | class PAPR(nn.Module):
18 | def __init__(self, args, device='cuda'):
19 | super(PAPR, self).__init__()
20 | self.args = args
21 | self.eps = args.eps
22 | self.device = device
23 |
24 | self.use_amp = args.use_amp
25 | self.amp_dtype = torch.float16 if args.amp_dtype == 'float16' else torch.bfloat16
26 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
27 |
28 | point_opt = args.geoms.points
29 | pc_feat_opt = args.geoms.point_feats
30 | bkg_feat_opt = args.geoms.background
31 | exposure_opt = args.exposure_control
32 | self.exposure_opt = exposure_opt
33 |
34 | self.register_buffer('select_k', torch.tensor(
35 | point_opt.select_k, device=device, dtype=torch.int32))
36 |
37 | self.coord_scale = args.dataset.coord_scale
38 |
39 | if point_opt.load_path:
40 | if point_opt.load_path.endswith('.pth') or point_opt.load_path.endswith('.pt'):
41 | points = torch.load(point_opt.load_path, map_location='cpu')
42 | points = np.asarray(points).astype(np.float32)
43 | np.random.shuffle(points)
44 | points = points[:args.max_num_pts, :]
45 | points = torch.from_numpy(points).float()
46 | print("Loaded points from {}, shape: {}, dtype {}".format(point_opt.load_path, points.shape, points.dtype))
47 | print("Loaded points scale: ", points[:, 0].min(), points[:, 0].max(), points[:, 1].min(), points[:, 1].max(), points[:, 2].min(), points[:, 2].max())
48 | else:
49 | # Initialize point positions
50 | pt_init_center = [i * self.coord_scale for i in point_opt.init_center]
51 | pt_init_scale = [i * self.coord_scale for i in point_opt.init_scale]
52 | if point_opt.init_type == 'sphere': # initial points on a sphere
53 | points = self._sphere_pc(pt_init_center, point_opt.init_num, pt_init_scale)
54 | elif point_opt.init_type == 'cube': # initial points in a cube
55 | points = self._cube_normal_pc(pt_init_center, point_opt.init_num, pt_init_scale)
56 | else:
57 | raise NotImplementedError("Point init type [{:s}] is not found".format(point_opt.init_type))
58 | print("Initialized points scale: ", points[:, 0].min(), points[:, 0].max(), points[:, 1].min(), points[:, 1].max(), points[:, 2].min(), points[:, 2].max())
59 | self.points = torch.nn.Parameter(points, requires_grad=True)
60 |
61 | # Initialize point influence scores
62 | self.points_influ_scores = torch.nn.Parameter(torch.ones(
63 | points.shape[0], 1, device=device) * point_opt.influ_init_val, requires_grad=True)
64 |
65 | # Initialize mapping MLP, only if fine-tuning with IMLE for the exposure control
66 | self.mapping_mlp = None
67 | if exposure_opt.use:
68 | self.mapping_mlp = get_mapping_mlp(exposure_opt, use_amp=self.use_amp, amp_dtype=self.amp_dtype)
69 |
70 | # Initialize UNet
71 | if args.models.use_renderer:
72 | attn_opt = args.models.attn
73 | feat_dim = attn_opt.embed.value.d_ff_out
74 | self.renderer = get_generator(args.models.renderer.generator, in_c=feat_dim,
75 | out_c=3, use_amp=self.use_amp, amp_dtype=self.amp_dtype)
76 | print("Number of parameters of renderer: ", count_parameters(self.renderer))
77 | else:
78 | assert args.models.attn.embed.value.d_ff_out == 3, \
79 | "Value embedding MLP should have output dim 3 if not using renderer"
80 |
81 | # Initialize background score and features
82 | self.bkg_feats = nn.Parameter(torch.FloatTensor(bkg_feat_opt.init_color)[None, :], requires_grad=bkg_feat_opt.learnable)
83 | self.bkg_score = torch.tensor(bkg_feat_opt.constant, device=device, dtype=torch.float32).reshape(1)
84 |
85 | # Initialize point features
86 | self.use_pc_feats = pc_feat_opt.use_ink or pc_feat_opt.use_inq or pc_feat_opt.use_inv
87 | if self.use_pc_feats:
88 | self.pc_feats = nn.Parameter(torch.randn(points.shape[0], pc_feat_opt.dim), requires_grad=True)
89 | print("Point features: ", self.pc_feats.shape, self.pc_feats.min(), self.pc_feats.max(), self.pc_feats.mean(), self.pc_feats.std())
90 |
91 | v_extra_dim = 0
92 | k_extra_dim = 0
93 | q_extra_dim = 0
94 | if pc_feat_opt.use_inv:
95 | v_extra_dim = self.pc_feats.shape[-1]
96 | print("Using v_extra_dim: ", v_extra_dim)
97 | if pc_feat_opt.use_ink:
98 | k_extra_dim = self.pc_feats.shape[-1]
99 | print("Using k_extra_dim: ", k_extra_dim)
100 | if pc_feat_opt.use_inq:
101 | q_extra_dim = self.pc_feats.shape[-1]
102 | print("Using q_extra_dim: ", q_extra_dim)
103 |
104 | self.last_act = activation_func(args.models.last_act)
105 |
106 | # Initialize proximity attention layer
107 | self.proximity_attn = get_proximity_attention_layer(args.models.attn,
108 | v_extra_dim=v_extra_dim,
109 | k_extra_dim=k_extra_dim,
110 | q_extra_dim=q_extra_dim,
111 | eps=self.eps,
112 | use_amp=self.use_amp,
113 | amp_dtype=self.amp_dtype)
114 |
115 | self.init_optimizers(total_steps=0)
116 |
117 | def init_optimizers(self, total_steps):
118 | lr_opt = self.args.training.lr
119 | print("LR factor: ", lr_opt.lr_factor)
120 | optimizer_points = torch.optim.Adam([self.points], lr=lr_opt.points.base_lr * lr_opt.lr_factor)
121 | optimizer_attn = torch.optim.Adam(self.proximity_attn.parameters(), lr=lr_opt.attn.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.attn.weight_decay)
122 | optimizer_points_influ_scores = torch.optim.Adam([self.points_influ_scores], lr=lr_opt.points_influ_scores.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.points_influ_scores.weight_decay)
123 |
124 | debug = False
125 | lr_scheduler_points = create_learning_rate_fn(optimizer_points, self.args.training.steps, lr_opt.points, debug=debug)
126 | lr_scheduler_attn = create_learning_rate_fn(optimizer_attn, self.args.training.steps, lr_opt.attn, debug=debug)
127 | lr_scheduler_points_influ_scores = create_learning_rate_fn(optimizer_points_influ_scores, self.args.training.steps, lr_opt.points_influ_scores, debug=debug)
128 |
129 | self.optimizers = {
130 | "points": optimizer_points,
131 | "attn": optimizer_attn,
132 | "points_influ_scores": optimizer_points_influ_scores,
133 | }
134 |
135 | self.schedulers = {
136 | "points": lr_scheduler_points,
137 | "attn": lr_scheduler_attn,
138 | "points_influ_scores": lr_scheduler_points_influ_scores,
139 | }
140 |
141 | if self.use_pc_feats:
142 | optimizer_pc_feats = torch.optim.Adam([self.pc_feats], lr=lr_opt.feats.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.feats.weight_decay)
143 | lr_scheduler_pc_feats = create_learning_rate_fn(optimizer_pc_feats, self.args.training.steps, lr_opt.feats, debug=debug)
144 |
145 | self.optimizers["pc_feats"] = optimizer_pc_feats
146 | self.schedulers["pc_feats"] = lr_scheduler_pc_feats
147 |
148 | if self.mapping_mlp is not None:
149 | optimizer_mapping_mlp = torch.optim.Adam(self.mapping_mlp.parameters(), lr=lr_opt.mapping_mlp.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.mapping_mlp.weight_decay)
150 | lr_scheduler_mapping_mlp = create_learning_rate_fn(optimizer_mapping_mlp, self.args.training.steps, lr_opt.mapping_mlp, debug=debug)
151 |
152 | self.optimizers["mapping_mlp"] = optimizer_mapping_mlp
153 | self.schedulers["mapping_mlp"] = lr_scheduler_mapping_mlp
154 |
155 | if self.args.models.use_renderer:
156 | optimizer_renderer = torch.optim.Adam(self.renderer.parameters(), lr=lr_opt.generator.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.generator.weight_decay)
157 | lr_scheduler_renderer = create_learning_rate_fn(optimizer_renderer, self.args.training.steps, lr_opt.generator, debug=debug)
158 |
159 | self.optimizers["renderer"] = optimizer_renderer
160 | self.schedulers["renderer"] = lr_scheduler_renderer
161 |
162 | if self.bkg_feats is not None and self.args.geoms.background.learnable:
163 | optimizer_bkg_feats = torch.optim.Adam([self.bkg_feats], lr=lr_opt.bkg_feats.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.bkg_feats.weight_decay)
164 | lr_scheduler_bkg_feats = create_learning_rate_fn(optimizer_bkg_feats, self.args.training.steps, lr_opt.bkg_feats, debug=debug)
165 |
166 | self.optimizers["bkg_feats"] = optimizer_bkg_feats
167 | self.schedulers["bkg_feats"] = lr_scheduler_bkg_feats
168 |
169 | for name in self.args.training.fix_keys:
170 | if name in self.optimizers:
171 | print("Fixing {}".format(name))
172 | self.optimizers.pop(name)
173 | self.schedulers.pop(name)
174 |
175 | if total_steps > 0:
176 | for _, scheduler in self.schedulers.items():
177 | if scheduler is not None:
178 | for _ in range(total_steps):
179 | scheduler.step()
180 |
181 | def clear_optimizer(self):
182 | self.optimizers.clear()
183 | del self.optimizers
184 |
185 | def clear_scheduler(self):
186 | self.schedulers.clear()
187 | del self.schedulers
188 |
189 | def clear_grad(self):
190 | for _, optimizer in self.optimizers.items():
191 | if optimizer is not None:
192 | optimizer.zero_grad()
193 |
194 | def _sphere_pc(self, center, num_pts, scale):
195 | xs, ys, zs = [], [], []
196 | phi = math.pi * (3. - math.sqrt(5.))
197 | for i in range(num_pts):
198 | y = 1 - (i / float(num_pts - 1)) * 2
199 | radius = math.sqrt(1 - y * y)
200 | theta = phi * i
201 | x = math.cos(theta) * radius
202 | z = math.sin(theta) * radius
203 | xs.append(x * scale[0] + center[0])
204 | ys.append(y * scale[1] + center[1])
205 | zs.append(z * scale[2] + center[2])
206 | points = np.stack([np.array(xs), np.array(ys), np.array(zs)], axis=-1)
207 | return torch.from_numpy(points).float()
208 |
209 | def _semi_sphere_pc(self, center, num_pts, scale, flatten="-z", flatten_coord=0.0):
210 | xs, ys, zs = [], [], []
211 | phi = math.pi * (3. - math.sqrt(5.))
212 | for i in range(num_pts):
213 | y = 1 - (i / float(num_pts - 1)) * 2
214 | radius = math.sqrt(1 - y * y)
215 | theta = phi * i
216 | x = math.cos(theta) * radius
217 | z = math.sin(theta) * radius
218 | xs.append(x * scale[0] + center[0])
219 | ys.append(y * scale[1] + center[1])
220 | zs.append(z * scale[2] + center[2])
221 | points = np.stack([np.array(xs), np.array(ys), np.array(zs)], axis=-1)
222 | points = torch.from_numpy(points).float()
223 | if flatten == "-z":
224 | points[:, 2] = torch.clamp(points[:, 2], min=flatten_coord)
225 | elif flatten == "+z":
226 | points[:, 2] = torch.clamp(points[:, 2], max=flatten_coord)
227 | elif flatten == "-y":
228 | points[:, 1] = torch.clamp(points[:, 1], min=flatten_coord)
229 | elif flatten == "+y":
230 | points[:, 1] = torch.clamp(points[:, 1], max=flatten_coord)
231 | elif flatten == "-x":
232 | points[:, 0] = torch.clamp(points[:, 0], min=flatten_coord)
233 | elif flatten == "+x":
234 | points[:, 0] = torch.clamp(points[:, 0], max=flatten_coord)
235 | else:
236 | raise ValueError("Invalid flatten type")
237 | return points
238 |
239 | def _cube_pc(self, center, num_pts, scale):
240 | xs = np.random.uniform(-scale[0], scale[0], num_pts) + center[0]
241 | ys = np.random.uniform(-scale[1], scale[1], num_pts) + center[1]
242 | zs = np.random.uniform(-scale[2], scale[2], num_pts) + center[2]
243 | points = np.stack([np.array(xs), np.array(ys), np.array(zs)], axis=-1)
244 | return torch.from_numpy(points).float()
245 |
246 | def _cube_normal_pc(self, center, num_pts, scale):
247 | axis_num_pts = int(num_pts ** (1.0 / 3.0))
248 | xs = np.linspace(-scale[0], scale[0], axis_num_pts) + center[0]
249 | ys = np.linspace(-scale[1], scale[1], axis_num_pts) + center[1]
250 | zs = np.linspace(-scale[2], scale[2], axis_num_pts) + center[2]
251 | points = np.array([[i, j, k] for i in xs for j in ys for k in zs])
252 | rest_num_pts = num_pts - points.shape[0]
253 | if rest_num_pts > 0:
254 | rest_points = self._cube_pc(center, rest_num_pts, scale)
255 | points = np.concatenate([points, rest_points], axis=0)
256 | return torch.from_numpy(points).float()
257 |
258 | def _calculate_global_distances(self, rays_o, rays_d, points):
259 | """
260 | Select the top-k points with the smallest distance to the rays from all points
261 |
262 | Args:
263 | rays_o: (N, 3)
264 | rays_d: (N, H, W, 3)
265 | points: (num_pts, 3)
266 | Returns:
267 | select_k_ind: (N, H, W, select_k)
268 | """
269 | N, H, W, _ = rays_d.shape
270 | num_pts, _ = points.shape
271 |
272 | rays_d = rays_d.unsqueeze(-2) # (N, H, W, 1, 3)
273 | rays_o = rays_o.reshape(N, 1, 1, 1, 3)
274 | points = points.reshape(1, 1, 1, num_pts, 3)
275 |
276 | v = points - rays_o # (N, 1, 1, num_pts, 3)
277 | proj = rays_d * (torch.sum(v * rays_d, dim=-1) / (torch.sum(rays_d * rays_d, dim=-1) + self.eps)).unsqueeze(-1)
278 | D = v - proj # (N, H, W, num_pts, 3)
279 | feature = torch.norm(D, dim=-1)
280 |
281 | _, select_k_ind = feature.topk(self.select_k, dim=-1, largest=False, sorted=False) # (N, H, W, select_k)
282 |
283 | return select_k_ind
284 |
285 | def _calculate_distances(self, rays_o, rays_d, points, c2w):
286 | """
287 | Calculate the distances from top-k points to rays TODO: redundant with _calculate_global_distances
288 |
289 | Args:
290 | rays_o: (N, 3)
291 | rays_d: (N, H, W, 3)
292 | points: (N, H, W, select_k, 3)
293 | c2w: (N, 4, 4)
294 | Returns:
295 | proj_dists: (N, H, W, select_k, 1)
296 | dists_to_rays: (N, H, W, select_k, 1)
297 | proj: (N, H, W, select_k, 3) # the vector s in Figure 2
298 | D: (N, H, W, select_k, 3) # the vector t in Figure 2
299 | """
300 | N, H, W, _ = rays_d.shape
301 |
302 | rays = normalize_vector(rays_d, eps=self.eps).unsqueeze(-2) # (N, H, W, 1, 3)
303 | v = points - rays_o.reshape(N, 1, 1, 1, 3) # (N, 1, 1, num_pts, 3)
304 | proj = rays * (torch.sum(v * rays, dim=-1) / (torch.sum(rays * rays, dim=-1) + self.eps)).unsqueeze(-1)
305 | D = v - proj # (N, H, W, num_pts, 3)
306 |
307 | dists_to_rays = torch.norm(D, dim=-1, keepdim=True)
308 | proj_dists = torch.norm(proj, dim=-1, keepdim=True)
309 |
310 | return proj_dists, dists_to_rays, proj, D
311 |
312 | def _get_points(self, rays_o, rays_d, c2w, step=-1):
313 | """
314 | Select the top-k points with the smallest distance to the rays
315 |
316 | Args:
317 | rays_o: (N, 3)
318 | rays_d: (N, H, W, 3)
319 | c2w: (N, 4, 4)
320 | Returns:
321 | selected_points: (N, H, W, select_k, 3)
322 | select_k_ind: (N, H, W, select_k)
323 | """
324 | points = self.points
325 | N, H, W, _ = rays_d.shape
326 | if self.select_k >= points.shape[0] or self.select_k < 0:
327 | select_k_ind = torch.arange(points.shape[0], device=points.device).expand(N, H, W, -1)
328 | else:
329 | select_k_ind = self._calculate_global_distances(rays_o, rays_d, points) # (N, H, W, num_pts)
330 | selected_points = points[select_k_ind, :] # (N, H, W, select_k, 3)
331 | self.selected_points = selected_points
332 |
333 | return selected_points, select_k_ind
334 |
335 | def prune_points(self, thresh):
336 | if self.points_influ_scores is not None:
337 | if self.args.training.prune_type == '<':
338 | mask = (self.points_influ_scores[:, 0] > thresh)
339 | elif self.args.training.prune_type == '>':
340 | mask = (self.points_influ_scores[:, 0] < thresh)
341 | print(
342 | "@@@@@@@@@ pruned {}/{}".format(torch.sum(mask == 0), mask.shape[0]))
343 |
344 | cur_requires_grad = self.points.requires_grad
345 | self.points = nn.Parameter(self.points[mask, :], requires_grad=cur_requires_grad)
346 | print("@@@@@@@@@ New points: ", self.points.shape)
347 |
348 | cur_requires_grad = self.points_influ_scores.requires_grad
349 | self.points_influ_scores = nn.Parameter(self.points_influ_scores[mask, :], requires_grad=cur_requires_grad)
350 | print("@@@@@@@@@ New points_influ_scores: ", self.points_influ_scores.shape)
351 |
352 | if self.use_pc_feats:
353 | cur_requires_grad = self.pc_feats.requires_grad
354 | self.pc_feats = nn.Parameter(self.pc_feats[mask, :], requires_grad=cur_requires_grad)
355 | print("@@@@@@@@@ New pc_feats: ", self.pc_feats.shape)
356 |
357 | return torch.sum(mask == 0)
358 | return 0
359 |
360 | def add_points(self, add_num):
361 | points = self.points.detach().cpu()
362 | point_features = None
363 | cur_num_points = points.shape[0]
364 |
365 | if 'max_points' in self.args and self.args.max_points > 0 and (cur_num_points + add_num) >= self.args.max_points:
366 | add_num = self.args.max_points - cur_num_points
367 | if add_num <= 0:
368 | return 0
369 |
370 | if self.use_pc_feats:
371 | point_features = self.pc_feats.detach().cpu()
372 |
373 | new_points, num_new_points, new_influ_scores, new_point_features = add_points_knn(points, self.points_influ_scores.detach().cpu(), add_num=add_num,
374 | k=self.args.geoms.points.add_k, comb_type=self.args.geoms.points.add_type,
375 | sample_k=self.args.geoms.points.add_sample_k, sample_type=self.args.geoms.points.add_sample_type,
376 | point_features=point_features)
377 | print("@@@@@@@@@ added {} points".format(num_new_points))
378 |
379 | if num_new_points > 0:
380 | cur_requires_grad = self.points.requires_grad
381 | self.points = nn.Parameter(torch.cat([points, new_points], dim=0).to(self.points.device), requires_grad=cur_requires_grad)
382 | print("@@@@@@@@@ New points: ", self.points.shape)
383 |
384 | if self.points_influ_scores is not None:
385 | cur_requires_grad = self.points_influ_scores.requires_grad
386 | self.points_influ_scores = nn.Parameter(torch.cat([self.points_influ_scores, new_influ_scores.to(self.points_influ_scores.device)], dim=0), requires_grad=cur_requires_grad)
387 | print("@@@@@@@@@ New points_influ_scores: ", self.points_influ_scores.shape)
388 |
389 | if self.use_pc_feats:
390 | cur_requires_grad = self.pc_feats.requires_grad
391 | self.pc_feats = nn.Parameter(torch.cat([self.pc_feats, new_point_features.to(self.pc_feats.device)], dim=0), requires_grad=cur_requires_grad)
392 | print("@@@@@@@@@ New pc_feats: ", self.pc_feats.shape)
393 |
394 | return num_new_points
395 |
396 | def _get_kqv(self, rays_o, rays_d, points, c2w, select_k_ind, step=-1):
397 | """
398 | Get the key, query, value for the proximity attention layer(s)
399 | """
400 | _, _, vec_p2o, vec_p2r = self._calculate_distances(rays_o, rays_d, points, c2w)
401 |
402 | k_type = self.args.models.attn.k_type
403 | k_L = self.args.models.attn.embed.k_L
404 | if k_type == 1:
405 | key = [points.detach(), vec_p2o, vec_p2r]
406 | else:
407 | raise ValueError('Invalid key type')
408 | assert len(key) == (len(k_L))
409 |
410 | q_type = self.args.models.attn.q_type
411 | q_L = self.args.models.attn.embed.q_L
412 | if q_type == 1:
413 | query = [rays_d.unsqueeze(-2)]
414 | else:
415 | raise ValueError('Invalid query type')
416 | assert len(query) == (len(q_L))
417 |
418 | v_type = self.args.models.attn.v_type
419 | v_L = self.args.models.attn.embed.v_L
420 | if v_type == 1:
421 | value = [vec_p2o, vec_p2r]
422 | else:
423 | raise ValueError('Invalid value type')
424 | assert len(value) == (len(v_L))
425 |
426 | # Add extra features that won't be passed through positional encoding
427 | k_extra = None
428 | q_extra = None
429 | v_extra = None
430 | if self.args.geoms.point_feats.use_ink:
431 | k_extra = [self.pc_feats[select_k_ind, :]]
432 | if self.args.geoms.point_feats.use_inq:
433 | q_extra = [self.pc_feats[select_k_ind, :]]
434 | if self.args.geoms.point_feats.use_inv:
435 | v_extra = [self.pc_feats[select_k_ind, :]]
436 |
437 | return key, query, value, k_extra, q_extra, v_extra
438 |
439 | def step(self, step=-1):
440 | for _, optimizer in self.optimizers.items():
441 | if optimizer is not None:
442 | self.scaler.step(optimizer)
443 |
444 | for _, scheduler in self.schedulers.items():
445 | if scheduler is not None:
446 | scheduler.step()
447 |
448 | self.attn_lr = 0
449 | if 'attn' in self.optimizers:
450 | if self.schedulers['attn'] is not None:
451 | self.attn_lr = self.schedulers['attn'].get_last_lr()[0]
452 | else:
453 | self.attn_lr = self.optimizers['attn'].param_groups[0]['lr']
454 |
455 | self.pts_lr = 0
456 | if 'points' in self.optimizers:
457 | if self.schedulers['points'] is not None:
458 | self.pts_lr = self.schedulers['points'].get_last_lr()[0]
459 | else:
460 | self.pts_lr = self.optimizers['points'].param_groups[0]['lr']
461 |
462 | def evaluate(self, rays_o, rays_d, c2w, step=-1, shading_code=None):
463 | points, select_k_ind = self._get_points(rays_o, rays_d, c2w, step)
464 | self.select_k_ind = select_k_ind
465 | key, query, value, k_extra, q_extra, v_extra = self._get_kqv(rays_o, rays_d, points, c2w, select_k_ind, step)
466 | N, H, W, _ = rays_d.shape
467 | num_pts = points.shape[-2]
468 |
469 | cur_points_influ_score = self.points_influ_scores[select_k_ind] if self.points_influ_scores is not None else None
470 |
471 | _, _, embedv, scores = self.proximity_attn(key, query, value, k_extra, q_extra, v_extra, step=step)
472 |
473 | embedv = embedv.reshape(N, H, W, -1, embedv.shape[-1])
474 | scores = scores.reshape(N, H, W, -1, 1)
475 |
476 | if cur_points_influ_score is not None:
477 | scores = scores * cur_points_influ_score
478 | if self.bkg_feats is not None:
479 | bkg_seq_len = self.bkg_feats.shape[0]
480 | scores = torch.cat([scores, self.bkg_score.expand(N, H, W, bkg_seq_len, -1)], dim=-2)
481 | attn = F.softmax(scores, dim=3) # (N, H, W, num_pts+bkg_seq_len, 1)
482 | topk_attn = attn[..., :num_pts, :]
483 | if self.args.models.normalize_topk_attn:
484 | topk_attn = topk_attn / torch.sum(topk_attn, dim=3, keepdim=True)
485 | fused_features = torch.sum(embedv * topk_attn, dim=3, keepdim=True) # (N, H, W, 1, C)
486 | else:
487 | attn = F.softmax(scores, dim=3)
488 | if self.args.models.normalize_topk_attn:
489 | attn = attn / torch.sum(attn, dim=3, keepdim=True)
490 | fused_features = torch.sum(embedv * attn, dim=3, keepdim=True) # (N, H, W, 1, C)
491 |
492 | return fused_features, attn
493 |
494 | def forward(self, rays_o, rays_d, c2w, step=-1, shading_code=None):
495 | gamma, beta = None, None
496 | if shading_code is not None and self.mapping_mlp is not None:
497 | affine = self.mapping_mlp(shading_code)
498 | affine_dim = affine.shape[-1]
499 | gamma, beta = affine[:affine_dim//2], affine[affine_dim//2:]
500 |
501 | if step % 200 == 0:
502 | print(shading_code.min().item(), shading_code.max().item(), gamma.min().item(), gamma.max().item(), beta.min().item(), beta.max().item())
503 |
504 | points, select_k_ind = self._get_points(rays_o, rays_d, c2w, step)
505 | key, query, value, k_extra, q_extra, v_extra = self._get_kqv(rays_o, rays_d, points, c2w, select_k_ind, step)
506 | N, H, W, _ = rays_d.shape
507 | num_pts = points.shape[-2]
508 |
509 | cur_points_influ_scores = self.points_influ_scores[select_k_ind] if self.points_influ_scores is not None else None
510 |
511 | _, _, embedv, scores = self.proximity_attn(key, query, value, k_extra, q_extra, v_extra, step=step)
512 |
513 | if step >= 0 and step % 200 == 0:
514 | print(' embedv:', step, embedv.shape, embedv.min().item(),
515 | embedv.max().item(), embedv.mean().item(), embedv.std().item())
516 | print(' scores:', step, scores.shape, scores.min().item(),
517 | scores.max().item(), scores.mean().item(), scores.std().item())
518 |
519 | embedv = embedv.reshape(N, H, W, -1, embedv.shape[-1])
520 | scores = scores.reshape(N, H, W, -1, 1)
521 |
522 | if cur_points_influ_scores is not None:
523 | # Multiply the influence scores to the attention scores
524 | scores = scores * cur_points_influ_scores
525 |
526 | if self.bkg_feats is not None:
527 | bkg_seq_len = self.bkg_feats.shape[0]
528 | scores = torch.cat([scores, self.bkg_score.expand(N, H, W, bkg_seq_len, -1)], dim=-2)
529 | attn = F.softmax(scores, dim=3) # (N, H, W, num_pts+bkg_seq_len, 1)
530 | topk_attn = attn[..., :num_pts, :]
531 | bkg_attn = attn[..., num_pts:, :]
532 | if self.args.models.normalize_topk_attn:
533 | topk_attn = topk_attn / torch.sum(topk_attn, dim=3, keepdim=True)
534 | fused_features = torch.sum(embedv * topk_attn, dim=3) # (N, H, W, C)
535 |
536 | if self.args.models.use_renderer:
537 | foreground = self.renderer(fused_features.permute(0, 3, 1, 2), gamma=gamma, beta=beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3)
538 | else:
539 | foreground = fused_features.unsqueeze(-2)
540 |
541 | if self.args.models.normalize_topk_attn:
542 | rgb = foreground * (1 - bkg_attn) + self.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
543 | else:
544 | rgb = foreground + self.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn
545 | rgb = rgb.squeeze(-2)
546 | else:
547 | attn = F.softmax(scores, dim=3)
548 | fused_features = torch.sum(embedv * attn, dim=3) # (N, H, W, C)
549 | if self.args.models.use_renderer:
550 | rgb = self.renderer(fused_features.permute(0, 3, 1, 2), gamma=gamma, beta=beta).permute(0, 2, 3, 1) # (N, H, W, 3)
551 | else:
552 | rgb = fused_features
553 |
554 | if step >= 0 and step % 1000 == 0:
555 | print(' feat map:', step, fused_features.shape, fused_features.min().item(),
556 | fused_features.max().item(), fused_features.mean().item(), fused_features.std().item())
557 | print(' predict rgb:', step, rgb.shape, rgb.min().item(),
558 | rgb.max().item(), rgb.mean().item(), rgb.std().item())
559 |
560 | return rgb
561 |
562 | def save(self, step, save_dir):
563 | torch.save({str(step): self.state_dict()},
564 | os.path.join(save_dir, 'model.pth'))
565 |
566 | optimizers_state_dict = {}
567 | for name, optimizer in self.optimizers.items():
568 | if optimizer is not None:
569 | optimizers_state_dict[name] = optimizer.state_dict()
570 | else:
571 | optimizers_state_dict[name] = None
572 | torch.save(optimizers_state_dict, os.path.join(
573 | save_dir, 'optimizers.pth'))
574 |
575 | schedulers_state_dict = {}
576 | for name, scheduler in self.schedulers.items():
577 | if scheduler is not None:
578 | schedulers_state_dict[name] = scheduler.state_dict()
579 | else:
580 | schedulers_state_dict[name] = None
581 | torch.save(schedulers_state_dict, os.path.join(
582 | save_dir, 'schedulers.pth'))
583 |
584 | scaler_state_dict = self.scaler.state_dict()
585 | torch.save(scaler_state_dict, os.path.join(
586 | save_dir, 'scaler.pth'))
587 |
588 | def load(self, load_dir, load_optimizer=False):
589 | if load_optimizer == True:
590 | optimizers_state_dict = torch.load(
591 | os.path.join(load_dir, 'optimizers.pth'))
592 | for name, optimizer in self.optimizers.items():
593 | if optimizer is not None:
594 | optimizer.load_state_dict(optimizers_state_dict[name])
595 | else:
596 | assert optimizers_state_dict[name] is None
597 |
598 | schedulers_state_dict = torch.load(
599 | os.path.join(load_dir, 'schedulers.pth'))
600 | for name, scheduler in self.schedulers.items():
601 | if scheduler is not None:
602 | scheduler.load_state_dict(schedulers_state_dict[name])
603 | else:
604 | assert schedulers_state_dict[name] is None
605 |
606 | if os.path.exists(os.path.join(load_dir, 'scaler.pth')):
607 | scaler_state_dict = torch.load(
608 | os.path.join(load_dir, 'scaler.pth'))
609 | self.scaler.load_state_dict(scaler_state_dict)
610 |
611 | model_state_dict = torch.load(os.path.join(load_dir, 'model.pth'))
612 | for step, state_dict in model_state_dict.items():
613 | # self.load_state_dict(state_dict)
614 | self.load_my_state_dict(state_dict)
615 | return int(step)
616 |
617 | def load_my_state_dict(self, state_dict, exclude_keys=[]):
618 | own_state = self.state_dict()
619 | for name, param in state_dict.items():
620 | print(name, param.shape)
621 | for exclude_key in exclude_keys:
622 | if exclude_key in name:
623 | print("exclude", name)
624 | break
625 | else:
626 | if name not in ['points', 'points_influ_scores', 'pc_feats']:
627 | if isinstance(param, nn.Parameter):
628 | # backwards compatibility for serialized parameters
629 | param = param.data
630 | try:
631 | own_state[name].copy_(param)
632 | except:
633 | print("Can't load", name)
634 |
635 | self.points = nn.Parameter(
636 | state_dict['points'].data, requires_grad=self.points.requires_grad)
637 | if self.points_influ_scores is not None:
638 | self.points_influ_scores = nn.Parameter(
639 | state_dict['points_influ_scores'].data, requires_grad=self.points_influ_scores.requires_grad)
640 | self.pc_feats = nn.Parameter(state_dict['pc_feats'].data, requires_grad=self.pc_feats.requires_grad)
641 | print("load pc_feats", self.pc_feats.shape, self.pc_feats.min(), self.pc_feats.max())
642 |
--------------------------------------------------------------------------------