├── vgg.pth ├── images ├── SFU_AI.png ├── APEX_lab.png └── papr_video_cover.png ├── .gitignore ├── configs ├── nerfsyn │ ├── mic.yml │ ├── ficus.yml │ ├── chair.yml │ ├── materials.yml │ ├── hotdog.yml │ ├── ship.yml │ ├── drums.yml │ └── lego.yml ├── t2 │ ├── Truck.yml │ ├── Barn.yml │ ├── Family.yml │ ├── Caterpillar.yml │ ├── Ignatius.yml │ └── Caterpillar_exposure_control.yml └── default.yml ├── papr.yml ├── requirements.txt ├── dataset ├── __init__.py ├── load_nerfsyn.py ├── load_t2.py ├── dataset.py └── utils.py ├── LICENSE ├── models ├── __init__.py ├── renderer.py ├── mlp.py ├── lpips.py ├── unet.py ├── attn.py ├── utils.py └── model.py ├── README.md ├── demo.ipynb ├── exposure_control_finetune.py ├── train.py ├── utils.py └── test.py /vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zvict/papr/HEAD/vgg.pth -------------------------------------------------------------------------------- /images/SFU_AI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zvict/papr/HEAD/images/SFU_AI.png -------------------------------------------------------------------------------- /images/APEX_lab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zvict/papr/HEAD/images/APEX_lab.png -------------------------------------------------------------------------------- /images/papr_video_cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zvict/papr/HEAD/images/papr_video_cover.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | data 3 | .vscode/ 4 | data/ 5 | experiments/ 6 | checkpoints/ 7 | *.pth 8 | *.pt 9 | *.out 10 | *.mp4 11 | *.tar.gz 12 | !vgg.pth -------------------------------------------------------------------------------- /configs/nerfsyn/mic.yml: -------------------------------------------------------------------------------- 1 | index: "mic" 2 | dataset: 3 | path: "./data/nerf_synthetic/mic/" 4 | training: 5 | add_start: 10000 6 | add_stop: 80000 7 | add_num: 500 8 | eval: 9 | dataset: 10 | path: "./data/nerf_synthetic/mic/" 11 | test: 12 | # load_path: "checkpoints/mic.pth" 13 | datasets: 14 | - name: "testset" 15 | path: "./data/nerf_synthetic/mic/" -------------------------------------------------------------------------------- /configs/nerfsyn/ficus.yml: -------------------------------------------------------------------------------- 1 | index: "ficus" 2 | dataset: 3 | path: "./data/nerf_synthetic/ficus/" 4 | training: 5 | add_start: 10000 6 | add_stop: 90000 7 | add_num: 500 8 | eval: 9 | dataset: 10 | path: "./data/nerf_synthetic/ficus/" 11 | test: 12 | # load_path: "checkpoints/ficus.pth" 13 | datasets: 14 | - name: "testset" 15 | path: "./data/nerf_synthetic/ficus/" -------------------------------------------------------------------------------- /configs/nerfsyn/chair.yml: -------------------------------------------------------------------------------- 1 | index: "chair" 2 | dataset: 3 | path: "./data/nerf_synthetic/chair/" 4 | geoms: 5 | points: 6 | init_num: 10000 7 | training: 8 | add_start: 10000 9 | add_stop: 50000 10 | eval: 11 | dataset: 12 | path: "./data/nerf_synthetic/chair/" 13 | test: 14 | # load_path: "checkpoints/chair.pth" 15 | datasets: 16 | - name: "testset" 17 | path: "./data/nerf_synthetic/chair/" -------------------------------------------------------------------------------- /configs/nerfsyn/materials.yml: -------------------------------------------------------------------------------- 1 | index: "materials" 2 | dataset: 3 | path: "./data/nerf_synthetic/materials/" 4 | geoms: 5 | point_feats: 6 | dim: 128 7 | training: 8 | add_start: 10000 9 | add_stop: 80000 10 | add_num: 500 11 | eval: 12 | dataset: 13 | path: "./data/nerf_synthetic/materials/" 14 | test: 15 | # load_path: "checkpoints/materials.pth" 16 | datasets: 17 | - name: "testset" 18 | path: "./data/nerf_synthetic/materials/" -------------------------------------------------------------------------------- /configs/nerfsyn/hotdog.yml: -------------------------------------------------------------------------------- 1 | index: "hotdog" 2 | dataset: 3 | path: "./data/nerf_synthetic/hotdog/" 4 | geoms: 5 | points: 6 | select_k: 30 7 | init_num: 10000 8 | training: 9 | add_start: 10000 10 | add_stop: 80000 11 | add_num: 500 12 | eval: 13 | dataset: 14 | path: "./data/nerf_synthetic/hotdog/" 15 | test: 16 | # load_path: "checkpoints/hotdog.pth" 17 | datasets: 18 | - name: "testset" 19 | path: "./data/nerf_synthetic/hotdog/" -------------------------------------------------------------------------------- /papr.yml: -------------------------------------------------------------------------------- 1 | name: papr 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.9 9 | - pip=23.3.1 10 | - pytorch-cuda=11.8 11 | - pytorch==2.1.2 12 | - torchaudio==2.1.2 13 | - torchvision==0.16.2 14 | - matplotlib 15 | - imageio 16 | - scipy 17 | - pip: 18 | - fsspec==2024.3.1 19 | - lpips==0.1.4 20 | - tqdm==4.66.2 21 | - imageio-ffmpeg==0.4.9 22 | - scikit-image==0.22.0 23 | -------------------------------------------------------------------------------- /configs/nerfsyn/ship.yml: -------------------------------------------------------------------------------- 1 | index: "ship" 2 | dataset: 3 | path: "./data/nerf_synthetic/ship/" 4 | geoms: 5 | points: 6 | init_num: 10000 7 | training: 8 | add_start: 10000 9 | add_stop: 80000 10 | add_num: 500 11 | lr: 12 | points: 13 | base_lr: 3.0e-3 14 | eval: 15 | dataset: 16 | path: "./data/nerf_synthetic/ship/" 17 | test: 18 | # load_path: "checkpoints/ship.pth" 19 | datasets: 20 | - name: "testset" 21 | path: "./data/nerf_synthetic/ship/" -------------------------------------------------------------------------------- /configs/nerfsyn/drums.yml: -------------------------------------------------------------------------------- 1 | index: "drums" 2 | dataset: 3 | path: "./data/nerf_synthetic/drums/" 4 | models: 5 | attn: 6 | embed: 7 | value: 8 | skip_layers: [5] 9 | training: 10 | add_start: 10000 11 | add_stop: 40000 12 | lr: 13 | generator: 14 | base_lr: 2.0e-4 15 | eval: 16 | dataset: 17 | path: "./data/nerf_synthetic/drums/" 18 | test: 19 | # load_path: "checkpoints/drums.pth" 20 | datasets: 21 | - name: "testset" 22 | path: "./data/nerf_synthetic/drums/" -------------------------------------------------------------------------------- /configs/t2/Truck.yml: -------------------------------------------------------------------------------- 1 | index: "Truck" 2 | dataset: 3 | coord_scale: 40.0 4 | type: "t2" 5 | path: "./data/tanks_temples/Truck/" 6 | factor: 2 7 | geoms: 8 | points: 9 | init_scale: [1.0, 1.0, 1.0] 10 | num: 5000 11 | constant: 4.0 12 | training: 13 | add_start: 20000 14 | add_stop: 60000 15 | lr: 16 | points: 17 | base_lr: 8.0e-3 18 | eval: 19 | dataset: 20 | type: "t2" 21 | path: "./data/tanks_temples/Truck/" 22 | factor: 2 23 | img_idx: 0 24 | test: 25 | # load_path: "checkpoints/Truck.pth"\ 26 | datasets: 27 | - name: "testset" 28 | type: "t2" 29 | path: "./data/tanks_temples/Truck/" 30 | factor: 2 -------------------------------------------------------------------------------- /configs/t2/Barn.yml: -------------------------------------------------------------------------------- 1 | index: "Barn" 2 | dataset: 3 | coord_scale: 30.0 4 | type: "t2" 5 | path: "./data/tanks_temples/Barn/" 6 | factor: 2 7 | patches: 8 | height: 180 9 | width: 180 10 | geoms: 11 | points: 12 | init_scale: [1.8, 1.8, 1.8] 13 | init_num: 5000 14 | training: 15 | add_start: 10000 16 | add_stop: 35000 17 | lr: 18 | points: 19 | base_lr: 1.0e-2 20 | eval: 21 | dataset: 22 | type: "t2" 23 | path: "./data/tanks_temples/Barn/" 24 | factor: 2 25 | img_idx: 0 26 | test: 27 | # load_path: "checkpoints/Barn.pth" 28 | datasets: 29 | - name: "testset" 30 | type: "t2" 31 | path: "./data/tanks_temples/Barn/" 32 | factor: 2 -------------------------------------------------------------------------------- /configs/t2/Family.yml: -------------------------------------------------------------------------------- 1 | index: "Family" 2 | dataset: 3 | coord_scale: 40.0 4 | type: "t2" 5 | path: "./data/tanks_temples/Family/" 6 | factor: 2 7 | geoms: 8 | points: 9 | init_scale: [0.3, 0.3, 0.3] 10 | init_num: 5000 11 | models: 12 | attn: 13 | embed: 14 | value: 15 | skip_layers: [5] 16 | training: 17 | add_start: 10000 18 | add_stop: 80000 19 | add_num: 500 20 | eval: 21 | dataset: 22 | type: "t2" 23 | path: "./data/tanks_temples/Family/" 24 | factor: 2 25 | img_idx: 0 26 | test: 27 | # load_path: "checkpoints/Family.pth" 28 | datasets: 29 | - name: "testset" 30 | type: "t2" 31 | path: "./data/tanks_temples/Family/" 32 | factor: 2 -------------------------------------------------------------------------------- /configs/nerfsyn/lego.yml: -------------------------------------------------------------------------------- 1 | index: "lego" 2 | dataset: 3 | path: "./data/nerf_synthetic/lego/" 4 | geoms: 5 | background: 6 | constant: 3.0 7 | models: 8 | attn: 9 | embed: 10 | key: 11 | ff_act: "leakyrelu" 12 | query: 13 | ff_act: "leakyrelu" 14 | value: 15 | ff_act: "leakyrelu" 16 | skip_layers: [5] 17 | training: 18 | prune_thresh_list: [0.0, 0.2] 19 | prune_steps_list: [40000] 20 | add_start: 20000 21 | lr: 22 | points: 23 | base_lr: 3.0e-3 24 | eval: 25 | dataset: 26 | path: "./data/nerf_synthetic/lego/" 27 | test: 28 | # load_path: "checkpoints/lego.pth" 29 | datasets: 30 | - name: "testset" 31 | path: "./data/nerf_synthetic/lego/" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2023.11.17 2 | charset-normalizer==3.3.2 3 | contourpy==1.2.0 4 | cycler==0.12.1 5 | fonttools==4.47.2 6 | idna==3.6 7 | imageio==2.33.1 8 | imageio-ffmpeg==0.4.9 9 | importlib-resources==6.1.1 10 | kiwisolver==1.4.5 11 | lazy-loader==0.3 12 | lpips==0.1.4 13 | matplotlib==3.7.2 14 | networkx==3.2.1 15 | numpy==1.25.2 16 | packaging==23.2 17 | Pillow==10.1.0 18 | pyparsing==3.0.9 19 | python-dateutil==2.8.2 20 | PyWavelets==1.3.0 21 | PyYAML==6.0.1 22 | requests==2.31.0 23 | scikit-image==0.19.2 24 | scipy==1.11.2 25 | six==1.16.0 26 | tifffile==2024.1.30 27 | torch==1.13.1+cu117 28 | torchaudio==0.13.1+cu117 29 | torchvision==0.14.1+cu117 30 | tqdm==4.66.1 31 | typing-extensions==4.9.0 32 | urllib3==2.1.0 33 | zipp==3.17.0 34 | -------------------------------------------------------------------------------- /configs/t2/Caterpillar.yml: -------------------------------------------------------------------------------- 1 | index: "Caterpillar" 2 | use_amp: false 3 | dataset: 4 | coord_scale: 30.0 5 | type: "t2" 6 | path: "./data/tanks_temples/Caterpillar/" 7 | factor: 2 8 | patches: 9 | height: 180 10 | width: 180 11 | geoms: 12 | points: 13 | init_scale: [1.0, 1.0, 1.0] 14 | init_num: 5000 15 | background: 16 | constant: 4.0 17 | models: 18 | attn: 19 | embed: 20 | k_L: [4, 4, 4] 21 | q_L: [4] 22 | v_L: [4, 4] 23 | training: 24 | add_start: 10000 25 | add_stop: 80000 26 | add_num: 500 27 | lr: 28 | points: 29 | base_lr: 6.0e-3 30 | eval: 31 | dataset: 32 | type: "t2" 33 | path: "./data/tanks_temples/Caterpillar/" 34 | factor: 2 35 | img_idx: 0 36 | test: 37 | # load_path: "checkpoints/Caterpillar.pth" 38 | datasets: 39 | - name: "testset" 40 | type: "t2" 41 | path: "./data/tanks_temples/Caterpillar/" 42 | factor: 2 -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import RINDataset 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | def get_traindataset(args): 6 | return RINDataset(args, mode='train') 7 | 8 | 9 | def get_trainloader(dataset, args): 10 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle) 11 | 12 | 13 | def get_testdataset(args): 14 | return RINDataset(args, mode='test') 15 | 16 | 17 | def get_testloader(dataset, args): 18 | return DataLoader(dataset, batch_size=1, shuffle=False) 19 | 20 | 21 | def get_dataset(args, mode): 22 | if mode == 'train': 23 | return get_traindataset(args) 24 | elif mode == 'test': 25 | return get_testdataset(args) 26 | else: 27 | raise ValueError("Unknown mode: {}".format(mode)) 28 | 29 | 30 | def get_loader(dataset, args, mode): 31 | if mode == 'train': 32 | return get_trainloader(dataset, args) 33 | elif mode == 'test': 34 | return get_testloader(dataset, args) 35 | else: 36 | raise ValueError("Unknown mode: {}".format(mode)) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/t2/Ignatius.yml: -------------------------------------------------------------------------------- 1 | index: "Ignatius" 2 | dataset: 3 | coord_scale: 50.0 4 | type: "t2" 5 | path: "./data/tanks_temples/Ignatius/" 6 | factor: 2 7 | patches: 8 | height: 180 9 | width: 180 10 | geoms: 11 | points: 12 | init_scale: [0.4, 0.4, 0.4] 13 | num: 5000 14 | models: 15 | attn: 16 | embed: 17 | k_L: [4, 4, 4] 18 | q_L: [4] 19 | v_L: [4, 4] 20 | value: 21 | skip_layers: [5] 22 | training: 23 | add_start: 10000 24 | add_stop: 80000 25 | add_num: 500 26 | lr: 27 | attn: 28 | type: "cosine" 29 | warmup: 5000 30 | points: 31 | base_lr: 9.0e-3 32 | points_influ_scores: 33 | type: "cosine" 34 | warmup: 5000 35 | feats: 36 | type: "cosine" 37 | warmup: 5000 38 | generator: 39 | type: "cosine" 40 | warmup: 5000 41 | eval: 42 | dataset: 43 | type: "t2" 44 | path: "./data/tanks_temples/Ignatius/" 45 | factor: 2 46 | img_idx: 0 47 | test: 48 | # load_path: "checkpoints/Ignatius.pth" 49 | datasets: 50 | - name: "testset" 51 | type: "t2" 52 | path: "./data/tanks_temples/Ignatius/" 53 | factor: 2 -------------------------------------------------------------------------------- /configs/t2/Caterpillar_exposure_control.yml: -------------------------------------------------------------------------------- 1 | index: "Caterpillar_exposure_control3" 2 | load_path: "./checkpoints_new/Caterpillar.pth" 3 | use_amp: false 4 | dataset: 5 | coord_scale: 30.0 6 | type: "t2" 7 | path: "./data/tanks_temples/Caterpillar/" 8 | factor: 2 9 | geoms: 10 | background: 11 | constant: 4.0 12 | exposure_control: 13 | use: true 14 | models: 15 | attn: 16 | embed: 17 | k_L: [4, 4, 4] 18 | q_L: [4] 19 | v_L: [4, 4] 20 | renderer: 21 | generator: 22 | small_unet: 23 | affine_layer: -1 24 | training: 25 | steps: 100000 26 | lr: 27 | lr_factor: 0.2 28 | attn: 29 | type: "none" 30 | warmup: 0 31 | points: 32 | base_lr: 0.0 33 | points_influ_scores: 34 | type: "none" 35 | warmup: 0 36 | feats: 37 | type: "none" 38 | warmup: 0 39 | generator: 40 | type: "none" 41 | warmup: 0 42 | eval: 43 | dataset: 44 | type: "t2" 45 | path: "./data/tanks_temples/Caterpillar/" 46 | factor: 2 47 | img_idx: 0 48 | test: 49 | load_path: "checkpoints/Caterpillar_exposure_control.pth" 50 | datasets: 51 | - name: "testset" 52 | type: "t2" 53 | path: "./data/tanks_temples/Caterpillar/" 54 | factor: 2 -------------------------------------------------------------------------------- /dataset/load_nerfsyn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import imageio 4 | from PIL import Image 5 | import json 6 | 7 | 8 | def load_blender_data(basedir, split='train', factor=1, read_offline=True): 9 | with open(os.path.join(basedir, f'transforms_{split}.json'), 'r') as fp: 10 | meta = json.load(fp) 11 | 12 | poses = [] 13 | images = [] 14 | image_paths = [] 15 | 16 | for i, frame in enumerate(meta['frames']): 17 | img_path = os.path.abspath(os.path.join(basedir, frame['file_path'] + '.png')) 18 | poses.append(np.array(frame['transform_matrix'])) 19 | image_paths.append(img_path) 20 | 21 | if read_offline: 22 | img = imageio.imread(img_path) 23 | H, W = img.shape[:2] 24 | if factor > 1: 25 | img = Image.fromarray(img).resize((W//factor, H//factor)) 26 | images.append((np.array(img) / 255.).astype(np.float32)) 27 | elif i == 0: 28 | img = imageio.imread(img_path) 29 | H, W = img.shape[:2] 30 | if factor > 1: 31 | img = Image.fromarray(img).resize((W//factor, H//factor)) 32 | images.append((np.array(img) / 255.).astype(np.float32)) 33 | 34 | poses = np.array(poses).astype(np.float32) 35 | images = np.array(images).astype(np.float32) 36 | 37 | H, W = images[0].shape[:2] 38 | camera_angle_x = float(meta['camera_angle_x']) 39 | focal = .5 * W / np.tan(.5 * camera_angle_x) 40 | 41 | return images, poses, [H, W, focal], image_paths 42 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import lpips 4 | from .model import PAPR 5 | from .lpips import LPNet 6 | 7 | 8 | class BasicLoss(nn.Module): 9 | def __init__(self, losses_and_weights): 10 | super(BasicLoss, self).__init__() 11 | self.losses_and_weights = losses_and_weights 12 | 13 | def forward(self, pred, target): 14 | loss = 0 15 | for name_and_weight, loss_func in self.losses_and_weights.items(): 16 | name, weight = name_and_weight.split('/') 17 | cur_loss = loss_func(pred, target) 18 | loss += float(weight) * cur_loss 19 | # print(name, weight, cur_loss, loss) 20 | return loss 21 | 22 | 23 | def get_model(args, device='cuda'): 24 | return PAPR(args, device=device) 25 | 26 | 27 | def get_loss(args, bias=1.0): 28 | losses = nn.ModuleDict() 29 | for loss_name, weight in args.items(): 30 | if weight > 0: 31 | if loss_name == "mse": 32 | losses[loss_name + "/" + 33 | str(format(weight, '.0e'))] = nn.MSELoss() 34 | print("Using MSE loss, loss weight: ", weight) 35 | elif loss_name == "l1": 36 | losses[loss_name + "/" + 37 | str(format(weight, '.0e'))] = nn.L1Loss() 38 | print("Using L1 loss, loss weight: ", weight) 39 | elif loss_name == "lpips": 40 | lpips = LPNet() 41 | lpips.eval() 42 | losses[loss_name + "/" + str(format(weight, '.0e'))] = lpips 43 | print("Using LPIPS loss, loss weight: ", weight) 44 | elif loss_name == "lpips_alex": 45 | lpips = lpips.LPIPS() 46 | lpips.eval() 47 | losses[loss_name + "/" + str(format(weight, '.0e'))] = lpips 48 | print("Using LPIPS AlexNet loss, loss weight: ", weight) 49 | else: 50 | raise NotImplementedError( 51 | 'loss [{:s}] is not supported'.format(loss_name)) 52 | return BasicLoss(losses) 53 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .unet import SmallUNet 3 | from .mlp import MLP 4 | 5 | 6 | class MLPGenerator(torch.nn.Module): 7 | def __init__(self, inp_dim=2, num_layers=3, num_channels=128, out_dim=2, act_type="leakyrelu", last_act_type="none", 8 | use_wn=True, a=1., b=1., trainable=False, skip_layers=[], bias=True, half_layers=[], residual_layers=[], 9 | residual_dims=[]): 10 | super(MLPGenerator, self).__init__() 11 | self.mlp = MLP(inp_dim=inp_dim, num_layers=num_layers, num_channels=num_channels, out_dim=out_dim, 12 | act_type=act_type, last_act_type=last_act_type, use_wn=use_wn, a=a, b=b, trainable=trainable, 13 | skip_layers=skip_layers, bias=bias, half_layers=half_layers, residual_layers=residual_layers, 14 | residual_dims=residual_dims) 15 | 16 | def forward(self, x, residuals=[], gamma=None, beta=None): # (N, C, H, W) 17 | return self.mlp(x.permute(0, 2, 3, 1), residuals).permute(0, 3, 1, 2) 18 | 19 | 20 | 21 | def get_generator(args, in_c, out_c, use_amp=False, amp_dtype=torch.float16): 22 | if args.type == "small-unet": 23 | opt = args.small_unet 24 | return SmallUNet(in_c, out_c, bilinear=opt.bilinear, single=opt.single, norm=opt.norm, last_act=opt.last_act, 25 | use_amp=use_amp, amp_dtype=amp_dtype, affine_layer=opt.affine_layer) 26 | elif args.type == "mlp": 27 | opt = args.mlp 28 | return MLPGenerator(inp_dim=in_c, num_layers=opt.num_layers, num_channels=opt.num_channels, out_dim=out_c, 29 | act_type=opt.act_type, last_act_type=opt.last_act_type, use_wn=opt.use_wn, a=opt.act_a, b=opt.act_b, 30 | trainable=opt.act_trainable, skip_layers=opt.skip_layers, bias=opt.bias, half_layers=opt.half_layers, 31 | residual_layers=opt.residual_layers, residual_dims=opt.residual_dims) 32 | else: 33 | raise NotImplementedError( 34 | 'generator type [{:d}] is not supported'.format(args.type)) 35 | -------------------------------------------------------------------------------- /dataset/load_t2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import imageio 4 | from PIL import Image 5 | 6 | blender2opencv = np.array( 7 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 8 | 9 | 10 | def get_instrinsic(filepath): 11 | try: 12 | intrinsic = np.loadtxt(filepath).astype(np.float32)[:3, :3] 13 | return intrinsic 14 | except ValueError: 15 | pass 16 | 17 | # Get camera intrinsics 18 | with open(filepath, 'r') as file: 19 | f, cx, cy, _ = map(float, file.readline().split()) 20 | fy = fx = f 21 | 22 | # Build the intrinsic matrices 23 | intrinsic = np.array([[fx, 0., cx], 24 | [0., fy, cy], 25 | [0., 0, 1]]) 26 | return intrinsic 27 | 28 | 29 | def load_t2_data(basedir, factor=1, split="train", read_offline=True, tgtH=1280, tgtW=2176): 30 | colordir = os.path.join(basedir, "rgb") 31 | posedir = os.path.join(basedir, "pose") 32 | train_image_paths = [f for f in os.listdir(colordir) if os.path.isfile( 33 | os.path.join(colordir, f)) and f.startswith("0")] 34 | test_image_paths = [f for f in os.listdir(colordir) if os.path.isfile( 35 | os.path.join(colordir, f)) and f.startswith("1")] 36 | 37 | if split == "train": 38 | image_paths = train_image_paths 39 | elif split == "test": 40 | image_paths = test_image_paths 41 | else: 42 | raise ValueError("Unknown split: {}".format(split)) 43 | 44 | image_paths = sorted(image_paths, key=lambda x: int( 45 | x.split(".")[0].split("_")[-1])) 46 | 47 | images = [] 48 | poses = [] 49 | out_image_paths = [] 50 | 51 | intrinsic = get_instrinsic(os.path.join(basedir, "intrinsics.txt")) 52 | fx, _, cx = intrinsic[0] 53 | _, fy, cy = intrinsic[1] 54 | 55 | for i, img_path in enumerate(image_paths): 56 | image_path = os.path.abspath(os.path.join(colordir, img_path)) 57 | out_image_paths.append(image_path) 58 | 59 | if read_offline: 60 | image = imageio.imread(image_path) 61 | H, W = image.shape[:2] 62 | if factor != 1: 63 | image = Image.fromarray(image).resize( 64 | (tgtW // factor, tgtH // factor)) 65 | images.append((np.array(image) / 255.).astype(np.float32)) 66 | elif i == 0: 67 | image = imageio.imread(image_path) 68 | H, W = image.shape[:2] 69 | if factor != 1: 70 | image = Image.fromarray(image).resize( 71 | (tgtW // factor, tgtH // factor)) 72 | images.append((np.array(image) / 255.).astype(np.float32)) 73 | 74 | pose_path = os.path.join(posedir, img_path.replace(".png", ".txt")) 75 | pose = np.loadtxt(pose_path).astype(np.float32) 76 | pose = pose @ blender2opencv 77 | poses.append(pose) 78 | 79 | images = np.stack(images, 0) 80 | poses = np.stack(poses, 0) 81 | 82 | realH, realW = images.shape[1:3] 83 | fx = fx * (realW / W) 84 | fy = fy * (realH / H) 85 | 86 | return images, poses, [realH, realW, fx, fy], out_image_paths 87 | 88 | 89 | if __name__ == '__main__': 90 | images, poses, [realH, realW, fx, fy], out_image_paths = load_t2_data('../data/tanks_temples/Family/', factor=1, split="train", read_offline=False, tgtH=1280, tgtW=2176) 91 | # print(out_image_paths) 92 | for i, path in enumerate(out_image_paths): 93 | # print(path) 94 | if '0_0063_00000100' in path: 95 | print(i, path) 96 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import weight_norm 4 | from torch import autocast 5 | from .utils import activation_func 6 | 7 | 8 | def get_mapping_mlp(args, use_amp=False, amp_dtype=torch.float16): 9 | return MappingMLP(args.mapping_mlp, inp_dim=args.shading_code_dim, out_dim=args.mapping_mlp.out_dim, use_amp=use_amp, amp_dtype=amp_dtype) 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self, inp_dim=2, num_layers=3, num_channels=128, out_dim=2, act_type="leakyrelu", last_act_type="none", 14 | use_wn=True, a=1., b=1., trainable=False, skip_layers=[], bias=True, half_layers=[], residual_layers=[], 15 | residual_dims=[]): 16 | super(MLP, self).__init__() 17 | self.skip_layers = skip_layers 18 | self.residual_layers = residual_layers 19 | self.residual_dims = residual_dims 20 | assert len(residual_dims) == len(residual_layers) 21 | wn = weight_norm if use_wn else lambda x, **kwargs: x 22 | layers = [nn.Identity()] 23 | for i in range(num_layers): 24 | cur_inp = inp_dim if i == 0 else num_channels 25 | cur_out = out_dim if i == num_layers - 1 else num_channels 26 | if (i+1) in half_layers: 27 | cur_out = cur_out // 2 28 | if i in half_layers: 29 | cur_inp = cur_inp // 2 30 | if i in self.skip_layers: 31 | cur_inp += inp_dim 32 | if i in self.residual_layers: 33 | cur_inp += self.residual_dims[residual_layers.index(i)] 34 | layers.append( 35 | wn(nn.Linear(cur_inp, cur_out, bias=bias), name='weight')) 36 | layers.append(activation_func(act_type=act_type, 37 | num_channels=cur_out, a=a, b=b, trainable=trainable)) 38 | layers[-1] = activation_func(act_type=last_act_type, 39 | num_channels=out_dim, a=a, b=b, trainable=trainable) 40 | assert len(layers) == 2 * num_layers + 1 41 | self.model = nn.ModuleList(layers) 42 | 43 | for p in self.model.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, x, residuals=[]): 48 | skip_layers = [i*2+1 for i in self.skip_layers] 49 | residual_layers = [i*2+1 for i in self.residual_layers] 50 | assert len(residuals) == len(self.residual_layers) 51 | # print(skip_layers) 52 | inp = x 53 | for i, layer in enumerate(self.model): 54 | if i in skip_layers: 55 | x = torch.cat([x, inp], dim=-1) 56 | if i in residual_layers: 57 | x = torch.cat([x, residuals[residual_layers.index(i)]], dim=-1) 58 | x = layer(x) 59 | return x 60 | 61 | 62 | class MappingMLP(nn.Module): 63 | def __init__(self, args, inp_dim=2, out_dim=2, use_amp=False, amp_dtype=torch.float16): 64 | super(MappingMLP, self).__init__() 65 | self.args = args 66 | self.inp_dim = inp_dim 67 | self.out_dim = out_dim 68 | self.use_amp = use_amp 69 | self.amp_dtype = amp_dtype 70 | self.model = MLP(inp_dim=inp_dim, num_layers=args.num_layers, num_channels=args.dim, out_dim=out_dim, 71 | act_type=args.act, last_act_type=args.last_act, use_wn=args.use_wn) 72 | print("Mapping MLP:\n", self.model) 73 | 74 | def forward(self, x): 75 | 76 | with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 77 | out = self.model(x) 78 | return out 79 | -------------------------------------------------------------------------------- /models/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models as tv 4 | from collections import namedtuple 5 | import os 6 | 7 | 8 | class vgg16(nn.Module): 9 | def __init__(self, requires_grad=False, pretrained=True): 10 | super(vgg16, self).__init__() 11 | vgg_pretrained_features = tv.vgg16(weights=tv.VGG16_Weights.IMAGENET1K_V1 if pretrained else None).features 12 | self.slice1 = nn.Sequential() 13 | self.slice2 = nn.Sequential() 14 | self.slice3 = nn.Sequential() 15 | self.slice4 = nn.Sequential() 16 | self.slice5 = nn.Sequential() 17 | self.N_slices = 5 18 | for x in range(4): 19 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(4, 9): 21 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(9, 16): 23 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(16, 23): 25 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(23, 30): 27 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 28 | if not requires_grad: 29 | for param in self.parameters(): 30 | param.requires_grad = False 31 | 32 | def forward(self, x): 33 | h = self.slice1(x) 34 | h_relu1_2 = h 35 | h = self.slice2(h) 36 | h_relu2_2 = h 37 | h = self.slice3(h) 38 | h_relu3_3 = h 39 | h = self.slice4(h) 40 | h_relu4_3 = h 41 | h = self.slice5(h) 42 | h_relu5_3 = h 43 | vgg_outputs = namedtuple( 44 | "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 45 | out = vgg_outputs(h_relu1_2, h_relu2_2, 46 | h_relu3_3, h_relu4_3, h_relu5_3) 47 | 48 | return out 49 | 50 | 51 | class ScalingLayer(nn.Module): 52 | # For rescaling the input to vgg16 53 | def __init__(self): 54 | super(ScalingLayer, self).__init__() 55 | self.register_buffer('shift', torch.Tensor( 56 | [-.030, -.088, -.188])[None, :, None, None]) 57 | self.register_buffer('scale', torch.Tensor( 58 | [.458, .448, .450])[None, :, None, None]) 59 | 60 | def forward(self, inp): 61 | return (inp - self.shift) / self.scale 62 | 63 | 64 | def normalize_tensor(in_feat, eps=1e-10): 65 | norm_factor = torch.sqrt( 66 | torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps) 67 | return in_feat / (norm_factor + eps) 68 | 69 | 70 | def spatial_average(in_tens, keepdim=True): 71 | return in_tens.mean([2, 3], keepdim=keepdim) 72 | 73 | 74 | class NetLinLayer(nn.Module): 75 | ''' A single linear layer used as placeholder for LPIPS learnt weights ''' 76 | 77 | def __init__(self): 78 | super(NetLinLayer, self).__init__() 79 | self.weight = None 80 | 81 | def forward(self, inp): 82 | out = torch.sum(self.weight * inp, 1, keepdim=True) 83 | return out 84 | 85 | 86 | class LPNet(nn.Module): 87 | def __init__(self): 88 | super(LPNet, self).__init__() 89 | 90 | self.scaling_layer = ScalingLayer() 91 | self.net = vgg16(pretrained=True, requires_grad=False) 92 | self.L = 5 93 | self.lins = [NetLinLayer() for _ in range(self.L)] 94 | self.lins = nn.ModuleList(self.lins) 95 | model_path = os.path.abspath( 96 | os.path.join('.', 'vgg.pth')) 97 | print('Loading model from: %s' % model_path) 98 | weights = torch.load(model_path, map_location='cpu') 99 | for ll in range(self.L): 100 | self.lins[ll].weight = nn.Parameter( 101 | weights["lin%d.model.1.weight" % ll]) 102 | 103 | def forward(self, in0, in1): 104 | in0 = in0.permute(0, 3, 1, 2) 105 | in1 = in1.permute(0, 3, 1, 2) 106 | in0 = 2 * in0 - 1 107 | in0_input = self.scaling_layer(in0) 108 | in1 = 2 * in1 - 1 109 | in1_input = self.scaling_layer(in1) 110 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 111 | feats0, feats1, diffs = {}, {}, {} 112 | 113 | for kk in range(self.L): 114 | feats0[kk], feats1[kk] = normalize_tensor( 115 | outs0[kk]), normalize_tensor(outs1[kk]) 116 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 117 | 118 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 119 | for kk in range(self.L)] 120 | 121 | val = res[0] 122 | for ll in range(1, self.L): 123 | val += res[ll] 124 | 125 | return val.squeeze().mean() 126 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import imageio 5 | from PIL import Image 6 | from copy import deepcopy 7 | from .utils import load_meta_data, get_rays, extract_patches 8 | 9 | 10 | class RINDataset(Dataset): 11 | """ Ray Image Normal Dataset """ 12 | 13 | def __init__(self, args, mode='train'): 14 | self.args = args 15 | images, c2w, H, W, focal_x, focal_y, image_paths = load_meta_data(args, mode=mode) 16 | num_imgs = len(image_paths) 17 | 18 | self.num_imgs = num_imgs 19 | coord_scale = args.coord_scale 20 | if coord_scale != 1: 21 | scaling_matrix = torch.tensor([[coord_scale, 0, 0, 0], 22 | [0, coord_scale, 0, 0], 23 | [0, 0, coord_scale, 0], 24 | [0, 0, 0, 1]], dtype=torch.float32) 25 | c2w = torch.matmul(scaling_matrix, c2w) 26 | print("c2w: ", c2w.shape) 27 | 28 | self.H = H 29 | self.W = W 30 | self.focal_x = focal_x 31 | self.focal_y = focal_y 32 | self.c2w = c2w # (N, 4, 4) 33 | self.image_paths = image_paths 34 | self.images = images # (N, H, W, C) or None 35 | 36 | if self.args.read_offline: 37 | rays_o, rays_d = get_rays(H, W, focal_x, focal_y, c2w) 38 | self.rayd = rays_d # (N, H, W, 3) 39 | self.rayo = rays_o # (N, 3) 40 | 41 | if self.args.extract_patch == True and self.args.extract_online == False and self.args.read_offline == True: 42 | img_patches, rayd_patches, rayo_patches, num_patches = extract_patches(images, rays_o, rays_d, args) 43 | # (N, n_patches, patch_height, patch_width, C) or None 44 | self.img_patches = img_patches 45 | # (N, n_patches, patch_height, patch_width, 3) 46 | self.rayd_patches = rayd_patches 47 | self.rayo_patches = rayo_patches # (N, n_patches, 3) 48 | self.num_patches = num_patches 49 | 50 | def _read_image_from_path(self, image_idx): 51 | image_path = self.image_paths[image_idx] 52 | image = imageio.imread(image_path) 53 | image = Image.fromarray(image).resize((self.W, self.H)) 54 | image = (np.array(image) / 255.).astype(np.float32) 55 | 56 | if self.args.white_bg and image.shape[-1] == 4: 57 | image = image[..., :3] * image[..., -1:] + (1. - image[..., -1:]) 58 | elif not self.args.white_bg: 59 | image = image[..., :3] 60 | mask = image.sum(-1) == 3.0 61 | image[mask] = 0. 62 | 63 | image = torch.from_numpy(image).float() 64 | 65 | rayo, rayd = get_rays(self.H, self.W, self.focal_x, self.focal_y, self.c2w[image_idx:image_idx+1]) 66 | 67 | return image, rayo, rayd 68 | 69 | def __len__(self): 70 | if self.args.extract_patch == True and self.args.extract_online == False and self.args.read_offline == True: 71 | return self.num_imgs * self.num_patches 72 | else: 73 | return self.num_imgs 74 | 75 | def __getitem__(self, idx): 76 | if self.args.extract_patch == True and self.args.extract_online == False and self.args.read_offline == True: 77 | img_idx = idx // self.num_patches 78 | patch_idx = idx % self.num_patches 79 | return img_idx, patch_idx, \ 80 | self.img_patches[img_idx, patch_idx] if self.img_patches is not None else 0, \ 81 | self.rayd_patches[img_idx, patch_idx], \ 82 | self.rayo_patches[img_idx, patch_idx] 83 | 84 | elif self.args.extract_patch == True and self.args.extract_online == True: 85 | img_idx = idx 86 | args = self.args 87 | args.patches.max_patches = 1 88 | if self.args.read_offline: 89 | img_patches, rayd_patches, rayo_patches, _ = extract_patches(self.images[img_idx:img_idx+1], 90 | self.rayo[img_idx:img_idx+1], 91 | self.rayd[img_idx:img_idx+1], 92 | args) 93 | else: 94 | image, rayo, rayd = self._read_image_from_path(img_idx) 95 | img_patches, rayd_patches, rayo_patches, _ = extract_patches(image[None, ...], rayo, rayd, args) 96 | 97 | return img_idx, 0, \ 98 | img_patches[0, 0] if img_patches is not None else 0, \ 99 | rayd_patches[0, 0], \ 100 | rayo_patches[0, 0] 101 | else: 102 | if self.args.read_offline: 103 | return idx, 0, self.images[idx] if self.images is not None else 0, \ 104 | self.rayd[idx], self.rayo[idx] 105 | else: 106 | image, rayo, rayd = self._read_image_from_path(idx) 107 | return idx, 0, image, rayd.squeeze(0), rayo.squeeze(0) 108 | 109 | def get_full_img(self, img_idx): 110 | if self.args.read_offline: 111 | return self.images[img_idx].unsqueeze(0) if self.images is not None else None, \ 112 | self.rayd[img_idx].unsqueeze(0), self.rayo[img_idx].unsqueeze(0) 113 | else: 114 | image, rayo, rayd = self._read_image_from_path(img_idx) 115 | return image[None, ...], rayd, rayo 116 | 117 | def get_c2w(self, img_idx): 118 | return self.c2w[img_idx] 119 | 120 | def get_new_rays(self, c2w): 121 | return get_rays(self.H, self.W, self.focal_x, self.focal_y, c2w) 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PAPR: Proximity Attention Point Rendering (NeurIPS 2023 Spotlight 🤩) 2 | [Yanshu Zhang*](https://zvict.github.io/), [Shichong Peng*](https://sites.google.com/view/niopeng/home), [Alireza Moazeni](https://amoazeni75.github.io/), [Ke Li](https://www.sfu.ca/~keli/) (* denotes equal contribution)
3 | 4 | 5 | 6 | [Project Sites](https://zvict.github.io/papr) 7 | | [Paper](https://arxiv.org/abs/2307.11086) | 8 | Primary contact: [Yanshu Zhang](https://zvict.github.io/) 9 | 10 | Proximity Attention Point Rendering (PAPR) is a new method for joint novel view synthesis and 3D reconstruction. It simultaneously learns from scratch an accurate point cloud representation of the scene surface, and an attention-based neural network that renders the point cloud from novel views. 11 | 12 | 13 | 14 | [![NeurIPS 2023 Presentation](https://github.com/zvict/papr/blob/main/images/papr_video_cover.png)](https://youtu.be/1atBGH_pDHY) 15 | 16 | ## BibTeX 17 | PAPR: Proximity Attention Point Rendering.     18 | ``` 19 | @inproceedings{zhang2023papr, 20 | title={PAPR: Proximity Attention Point Rendering}, 21 | author={Yanshu Zhang and Shichong Peng and Seyed Alireza Moazenipourasil and Ke Li}, 22 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 23 | year={2023} 24 | } 25 | ``` 26 | 27 | ## Installation 28 | ``` 29 | git clone git@github.com:zvict/papr.git # or 'git clone https://github.com/zvict/papr' 30 | cd papr 31 | conda env create -f papr.yml 32 | conda activate papr 33 | ``` 34 | Or use virtual environment with `python=3.9` 35 | ``` 36 | python -m venv path/to/ 37 | source path/to//bin/activate 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ## Data Preparation 42 | 43 | Expected dataset structure in the source path location: 44 | ``` 45 | papr 46 | ├── data 47 | │ ├── nerf_synthetic 48 | │ │ ├── chair 49 | │ │ │ ├── train 50 | │ │ │ ├── val 51 | │ │ │ ├── test 52 | │ │ │ ├── transforms_train.json 53 | │ │ │ ├── transforms_val.json 54 | │ │ │ ├── transforms_test.json 55 | │ │ ├── ... 56 | │ ├── tanks_temples 57 | │ │ ├── Barn 58 | │ │ │ ├── pose 59 | │ │ │ ├── rgb 60 | │ │ │ ├── intrinsics.txt 61 | │ │ ├── ... 62 | ``` 63 | ### NeRF Synthetic 64 | Download NeRF Synthetic Dataset from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and put it under `data/nerf_synthetic/` 65 | 66 | 67 | ### Tanks & Temples 68 | Download [Tanks&Temples](https://www.tanksandtemples.org/) from [here](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) and put it under: 69 | `data/tanks_temples/` 70 | 71 | ### Use your own data 72 | You can refer to this [issue](https://github.com/zvict/papr/issues/3#issuecomment-1907260683) for the instructions on how to prepare the dataset. 73 | 74 | You need to create a new configuration file for your own dataset, and put it under `configs`. The parameter `dataset.type` in the configuration file specifies the type of the dataset. If your dataset is in the same format as the NeRF Synthetic dataset, you can directly set `dataset.type` to `"synthetic"`. Otherwise, you need to implement your own python script to load the dataset under the `dataset` folder, and add it in the function `load_meta_data` in `dataset/utils.py`. 75 | 76 | Most default parameters in `configs/default.yml` are general and can be used for your own dataset. You can specify the parameters that are specific to your dataset in the configuration file you created, similar to the configuration files for the NeRF Synthetic dataset and the Tanks and Temples dataset. 77 | 78 | ## Overview 79 | 80 | The codebase has two main components: data loading part in `dataset/` and models in `models/`. Class `PAPR` in `models/model.py` defines our main model. All the configurations are in `configs/`, and `configs/demo.yml` is a demo configuration with comments of important arguments. 81 | 82 | We provide a notebook `demo.ipynb` to demonstrate how to train and test the model with the demo configuration file, as well as how to use exposure control to improve the rendering quality of real-world scenes captured with auto-exposure turned on. 83 | 84 | ## Training 85 | ``` 86 | python train.py --opt configs/nerfsyn/chair.yml 87 | ``` 88 | 89 | ## Finetuning with [cIMLE](https://arxiv.org/abs/2004.03590) (Optional) 90 | 91 | For real-world scenes where exposure can change between views, we can introduce an additional latent code input into our model and finetune the model using a technique called [conditional Implicit Maximum Likelihood Estimation (cIMLE)](https://arxiv.org/abs/2004.03590) to control the exposure level of the rendered image, as described in Section 4.4 and Appendix A.8 in the paper. A pre-trained model is required to finetune with exposure control, by running `train.py` with default configurations. We provide a demo configuration file for the Caterpillar scene from the Tanks and Temples dataset at `configs/t2/Caterpillar_exposure_control.yml`. 92 | 93 | To finetune a pre-trained model with exposure control, run: 94 | ``` 95 | python exposure_control_finetune.py --opt configs/t2/Caterpillar_exposure_control.yml 96 | ``` 97 | 98 | ## Evaluation 99 | To evaluate your trained model without the finetuning for exposure control, run: 100 | ``` 101 | python test.py --opt configs/nerfsyn/chair.yml 102 | ``` 103 | Which gives you rendered images and metrics on the test set. 104 | 105 | With a finetuned model, you can render all the test views with a single random exposure level, by runing: 106 | ``` 107 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp 108 | ``` 109 | To generate images with different random exposure levels for a single view, run: 110 | ``` 111 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --random --view 0 112 | ``` 113 | Note that during testing, the scale of the latent codes should be increased to generate images with more diverse exposures, for example, 114 | ``` 115 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --random --view 0 --scale 8 116 | ``` 117 | Once you generate images with different exposure levels, you can interpolate two picked exposure levels by specifiying their index, for example, 118 | ``` 119 | python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --intrp --view 0 --start_index 0 --end_index 1 120 | ``` 121 | 122 | ## Pretrained Models 123 | 124 | We provide pretrained models on NeRF Synthetic and Tanks&Temples datasets here (without finetuning): [Google Drive](https://drive.google.com/drive/folders/1HSNlMu6Uup9o5hqi7T0hgDf63yR9W90s?usp=sharing). We also provide a pre-trained model with exposure control on the Caterpillar scene in the Google Drive. To load the pretrained models, please put them under `checkpoints/`, and change the `test.load_path` in the config file. 125 | 126 | ## Acknowledgement 127 | This research was enabled in part by support provided by NSERC, the BC DRI Group and the Digital Research Alliance of Canada. 128 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | from .load_t2 import load_t2_data 5 | from .load_nerfsyn import load_blender_data 6 | 7 | 8 | def cam_to_world(coords, c2w, vector=True): 9 | """ 10 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3] 11 | c2w: [N, 4, 4] or [4, 4] 12 | """ 13 | if vector: # Convert to homogeneous coordinates 14 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1) 15 | else: 16 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1) 17 | 18 | if coords.ndim == 5: 19 | assert c2w.ndim == 2 20 | B, H, W, N, _ = coords.shape 21 | transformed_coords = torch.sum( 22 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3] 23 | elif coords.ndim == 4: 24 | assert c2w.ndim == 3 25 | _, H, W, _ = coords.shape 26 | N = c2w.shape[0] 27 | transformed_coords = torch.sum( 28 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4] 29 | elif coords.ndim == 3: 30 | assert c2w.ndim == 2 31 | H, W, _ = coords.shape 32 | transformed_coords = torch.sum( 33 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4] 34 | elif coords.ndim == 2: 35 | assert c2w.ndim == 2 36 | K, _ = coords.shape 37 | transformed_coords = torch.sum( 38 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4] 39 | else: 40 | raise ValueError('Wrong dimension of coords') 41 | return transformed_coords[..., :3] 42 | 43 | 44 | def world_to_cam(coords, c2w, vector=True): 45 | """ 46 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3] 47 | c2w: [N, 4, 4] or [4, 4] 48 | """ 49 | if vector: # Convert to homogeneous coordinates 50 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1) 51 | else: 52 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1) 53 | 54 | c2w = torch.inverse(c2w) 55 | if coords.ndim == 5: 56 | assert c2w.ndim == 2 57 | B, H, W, N, _ = coords.shape 58 | transformed_coords = torch.sum( 59 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3] 60 | elif coords.ndim == 4: 61 | assert c2w.ndim == 3 62 | _, H, W, _ = coords.shape 63 | N = c2w.shape[0] 64 | transformed_coords = torch.sum( 65 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4] 66 | elif coords.ndim == 3: 67 | assert c2w.ndim == 2 68 | H, W, _ = coords.shape 69 | transformed_coords = torch.sum( 70 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4] 71 | elif coords.ndim == 2: 72 | assert c2w.ndim == 2 73 | K, _ = coords.shape 74 | transformed_coords = torch.sum( 75 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4] 76 | else: 77 | raise ValueError('Wrong dimension of coords') 78 | return transformed_coords[..., :3] 79 | 80 | 81 | def get_rays(H, W, focal_x, focal_y, c2w, fineness=1): 82 | N = c2w.shape[0] 83 | width = torch.linspace( 84 | 0, W / focal_x, steps=int(W / fineness) + 1, dtype=torch.float32) 85 | height = torch.linspace( 86 | 0, H / focal_y, steps=int(H / fineness) + 1, dtype=torch.float32) 87 | y, x = torch.meshgrid(height, width, indexing='ij') 88 | pixel_size_x = width[1] - width[0] 89 | pixel_size_y = height[1] - height[0] 90 | x = (x - W / focal_x / 2 + pixel_size_x / 2)[:-1, :-1] 91 | y = -(y - H / focal_y / 2 + pixel_size_y / 2)[:-1, :-1] 92 | # [H, W, 3], vectors, since the camera is at the origin 93 | dirs_d = torch.stack([x, y, -torch.ones_like(x)], -1) 94 | rays_d = cam_to_world(dirs_d.unsqueeze(0), c2w) # [N, H, W, 3] 95 | rays_o = c2w[:, :3, -1] # [N, 3] 96 | return rays_o, rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 97 | 98 | 99 | def extract_patches(imgs, rays_o, rays_d, args): 100 | patch_opt = args.patches 101 | N, H, W, C = imgs.shape 102 | 103 | num_patches = patch_opt.max_patches 104 | rayd_patches = np.zeros((N, num_patches, patch_opt.height, patch_opt.width, 3), dtype=np.float32) 105 | rayo_patches = np.zeros((N, num_patches, 3), dtype=np.float32) 106 | img_patches = np.zeros((N, num_patches, patch_opt.height, patch_opt.width, C), dtype=np.float32) 107 | 108 | for i in range(N): 109 | for n_patch in range(num_patches): 110 | start_height = np.random.randint(0, H - patch_opt.height) 111 | start_width = np.random.randint(0, W - patch_opt.width) 112 | end_height = start_height + patch_opt.height 113 | end_width = start_width + patch_opt.width 114 | rayd_patches[i, n_patch, :, :] = rays_d[i, start_height:end_height, start_width:end_width] 115 | rayo_patches[i, n_patch, :] = rays_o[i, :] 116 | img_patches[i, n_patch, :, :] = imgs[i, start_height:end_height, start_width:end_width] 117 | 118 | return img_patches, rayd_patches, rayo_patches, num_patches 119 | 120 | 121 | def load_meta_data(args, mode="train"): 122 | """ 123 | 0 -----------> W 124 | | 125 | | 126 | | 127 | ⬇ 128 | H 129 | [H, W, 4] 130 | """ 131 | image_paths = None 132 | 133 | if args.type == "synthetic": 134 | images, poses, hwf, image_paths = load_blender_data( 135 | args.path, split=mode, factor=args.factor, read_offline=args.read_offline) 136 | print('Loaded blender', images.shape, hwf, args.path) 137 | 138 | H, W, focal = hwf 139 | hwf = [H, W, focal, focal] 140 | 141 | if args.white_bg: 142 | images = images[..., :3] * \ 143 | images[..., -1:] + (1. - images[..., -1:]) 144 | else: 145 | images = images[..., :3] 146 | 147 | elif args.type == "t2": 148 | images, poses, hwf, image_paths = load_t2_data( 149 | args.path, factor=args.factor, split=mode, read_offline=args.read_offline) 150 | print('Loaded t2', images.shape, hwf, args.path, 151 | images.min(), images.max(), images[0, 10, 10, :]) 152 | 153 | if args.white_bg and images.shape[-1] == 4: 154 | images = images[..., :3] * \ 155 | images[..., -1:] + (1. - images[..., -1:]) 156 | elif not args.white_bg: 157 | images = images[..., :3] 158 | mask = images.sum(-1) == 3.0 159 | images[mask] = 0. 160 | 161 | else: 162 | raise ValueError("Unknown dataset type: {}".format(args.type)) 163 | 164 | H, W, focal_x, focal_y = hwf 165 | 166 | images = torch.from_numpy(images).float() 167 | poses = torch.from_numpy(poses).float() 168 | 169 | return images, poses, H, W, focal_x, focal_y, image_paths 170 | 171 | 172 | def rgb2norm(img): 173 | norm_vec = np.stack([img[..., 0] * 2.0 / 255.0 - 1.0, 174 | img[..., 1] * 2.0 / 255.0 - 1.0, 175 | img[..., 2] * 2.0 / 255.0 - 1.0, 176 | img[..., 3] / 255.0], axis=-1) 177 | return norm_vec 178 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 0. Environment setup and dataset preparation\n", 8 | "\n", 9 | "Please follow the instructions in the README file to setup the environment and prepare the dataset." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# 1. Configurations\n", 17 | "\n", 18 | "We povide the default values of the parameters in `configs/default.yml`, with comments explaining the meaning of most parameters. For each individual scene, we can override the default values by providing a separate configuration file in the `configs` directory. For example, the configuration file for the Lego scene from the NeRF Synthetic dataset is `configs/nerfsyn/lego.yml`. " 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# 2. Novel view synthesis using PAPR\n", 26 | "\n", 27 | "## 2.1 Training\n", 28 | "\n", 29 | "To synthesize novel views of a scene, we need to train a PAPR model. We provide the configuration files for each scene in the NeRF Synthetic dataset and the Tanks and Temples dataset under the `configs` directory. \n", 30 | "\n", 31 | "The configuration files specify the training parameters and the dataset paths. Detailed comments are provided in `configs/demo.yml` to help you understand the parameters.\n", 32 | "\n", 33 | "To train a PAPR model, on the lego scene of NeRF Synthetic dataset for example, run the following command:" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "!python train.py --opt configs/nerfsyn/lego.yml" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "You can find the trained model in the `experiments` folder. You can specify the folder name using the `index` parameter in the configuration files. You can also find the training logs and middle plots in the folder during training.\n", 50 | "\n", 51 | "## 2.2 Testing\n", 52 | "\n", 53 | "Once the training is finished, you can synthesize images from the test views by running:" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "!python test.py --opt configs/nerfsyn/lego.yml" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "The generated images and evaluation metrics will be saved under `experiments/{index}/test`." 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# 3. Improve the performance for in-the-wild data with exposure control\n", 77 | "\n", 78 | "Whenever people take real life photos with auto-exposure turned on, the exposure of the images can vary significantly across different views, for example the Tanks and Temples dataset. To improve the performance of PAPR on in-the-wild data, we propose a simple yet effective method to control the exposure of the rendered images. We introduce an additional latent code input into our model and finetune it using a technique called [conditional Implicit Maximum Likelihood Estimation (cIMLE)](https://arxiv.org/abs/2004.03590). During test time, we can manipulate the exposure of the rendered image by changing the latent code input.\n", 79 | "\n", 80 | "One can refer to Section 4.4 and Appendix A.8 in the paper for more details.\n", 81 | "\n", 82 | "## 3.1 Finetuning\n", 83 | "\n", 84 | "To finetune the model with exposure control, we need a pretrained model generated by following the instructions in [Section 1.1](##-1.1-Training). We provide a exposure control configuration file for the Caterpillar scene from the Tanks and Temples dataset in `configs/t2/Caterpillar_exposure_control.yml`. You need to specify the path to the pretrained model by setting the `load_path` parameter in the configuration file.\n", 85 | "\n", 86 | "To finetune the model with exposure control, run the following command:\n", 87 | "\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "!python exposure_control_finetune.py --opt configs/t2/Caterpillar_exposure_control.yml" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## 3.2 Testing\n", 104 | "\n", 105 | "To render all the test views with a single random exposure level, run:" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "!python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## 3.3 Generating images with different exposure levels\n", 122 | "\n", 123 | "You can generate images with different exposure levels by changing the latent code input for a given view. We recommend to increase the `shading_code_scale` to generate images with more diverse exposure levels. You can either change the parameter in the configuration file or specify it in the command line. For example, to generate images with different exposure levels for the Caterpillar scene, run:" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "!python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --random --scale 8.0 --num_samples 20 --view 0" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## 3.4 Generating images with continuous exposure changes" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "After you generate the images with different exposure levels, you can interpolate the latent codes to generate images with intermediate exposure levels. For example, to generate images with intermediate exposure levels for the Caterpillar scene, run:" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "!python test.py --opt configs/t2/Caterpillar_exposure_control.yml --exp --intrp --scale 8 --num_samples 20 --view 0 --start_index 8 --end_index 11 --num_intrp 10" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "# 4. Use PAPR for your own dataset\n", 163 | "\n", 164 | "## 4.1 Prepare the dataset\n", 165 | "\n", 166 | "You can refer to this [issue](https://github.com/zvict/papr/issues/3#issuecomment-1907260683) for the instructions on how to prepare the dataset.\n", 167 | "\n", 168 | "## 4.2 Configuration\n", 169 | "\n", 170 | "You need to create a new configuration file for your own dataset, and put it under `configs`. The parameter `dataset.type` in the configuration file specifies the type of the dataset. If your dataset is in the same format as the NeRF Synthetic dataset, you can directly set `dataset.type` to `\"synthetic\"`. Otherwise, you need to implement your own python script to load the dataset under the `dataset` folder, and add it in the function `load_meta_data` in `dataset/utils.py`.\n", 171 | "\n", 172 | "Most default parameters in `configs/default.yml` are general and can be used for your own dataset. You can specify the parameters that are specific to your dataset in the configuration file you created, similar to the configuration files for the NeRF Synthetic dataset and the Tanks and Temples dataset." 173 | ] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "base", 179 | "language": "python", 180 | "name": "python3" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 3 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython3", 192 | "version": "3.12.2" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 2 197 | } 198 | -------------------------------------------------------------------------------- /configs/default.yml: -------------------------------------------------------------------------------- 1 | index: "lego" # this is the name of the experiment 2 | load_path: "" # path to load the model from 3 | save_dir: "./experiments" # folder to save the model to 4 | seed: 1 # random seed 5 | eps: 1.0e-6 # epsilon for numerical stability 6 | use_amp: true # use automatic mixed precision 7 | amp_dtype: "float16" 8 | scaler_min_scale: -1.0 # set a lower bound for the scaler for numerical stability 9 | max_num_pts: 30000 # upper bound of total number of points, set to -1 to disable 10 | dataset: 11 | mode: "train" # train or test set 12 | coord_scale: 10.0 # scale the global coordinate by this factor, larger scale helps with geometry details in the point cloud, but not always larger the better 13 | type: "synthetic" # synthetic (nerf synthetic) or t2 (tanks and temples) 14 | white_bg: true # use white or black background 15 | path: "./data/nerf_synthetic/lego/" 16 | factor: 1 # downsample the target images by this factor 17 | batch_size: 1 18 | shuffle: true 19 | extract_patch: true # extract patches from the target images 20 | extract_online: true # set to false if you have enough memory to load all the patches into memory before training 21 | read_offline: false # set to true if you have enough memory to load all the images into memory before training 22 | patches: 23 | height: 160 # patch height 24 | width: 160 # patch width 25 | max_patches: 10 # maximum number of patches to extract from each image (if extract_online is false) 26 | geoms: 27 | points: 28 | select_k: 20 # number of top-k nearby points to select from the point cloud for each ray 29 | select_k_type: "d2r" # select k points based on their distances to the rays (d2r) 30 | select_k_sorted: false # sort the indices of selected points by select_k_type 31 | load_path: "" # path to load the point cloud from 32 | init_type: "cube" # initialize the point cloud in a cube or on a sphere ("sphere") 33 | init_scale: [1.2, 1.2, 1.2] # initial scale of the point cloud, normalized by the coord_scale 34 | init_center: [0.0, 0.0, 0.0] # initial center of the point cloud, normalized by the coord_scale 35 | init_num: 3000 # initial number of points in the point cloud 36 | influ_init_val: 0.0 # initial influence score of each point 37 | add_type: "random" # add points by interpolate add_k points randomly (or "mean", ...) 38 | add_k: 3 # number of points to interpolate when adding new points 39 | add_sample_type: "top-knn-std" # determine where to add new points by sparsity ("top-knn-std") or ... 40 | add_sample_k: 10 # number of points to consider when measuring sparsity or ... 41 | background: 42 | learnable: false # learn the background color or not 43 | init_color: [1.0, 1.0, 1.0] # initial background color, normalized by 255 44 | constant: 5.0 # constant background score 45 | point_feats: 46 | dim: 64 # dimension of the point features 47 | use_inv: true # use as a value feature in the attention layer 48 | use_ink: false # use as a key feature in the attention layer 49 | use_inq: false # use as a query feature in the attention layer 50 | exposure_control: 51 | use: false 52 | shading_code_dim: 128 # dimension of the shading codes 53 | shading_code_scale: 1.0 # scale the shading codes by this factor, use 8.0 for random generation 54 | shading_code_num_samples: 20 # number of samples to generate the shading codes 55 | shading_code_resample_iter: 10000 # number of iterations to resample the shading codes 56 | shading_code_resample_size: 200 # number of samples to resample the shading codes 57 | shading_code_resample_select_by: "psnr" # select the best samples by "psnr" or "lpips" 58 | mapping_mlp: 59 | num_layers: 8 60 | dim: 256 61 | act: "relu" 62 | last_act: "relu+1" 63 | use_wn: false 64 | out_dim: 64 65 | models: 66 | use_renderer: true # use the UNet or not, if not, predicted rgb is a fused output of value embedding MLP 67 | last_act: "none" # last activation function to normalize the predicted rgb 68 | normalize_topk_attn: true # normalize the top-k attention weights after softmax 69 | attn: 70 | k_type: 1 71 | q_type: 1 72 | v_type: 1 73 | d_model: 256 74 | score_act: "relu" 75 | embed: 76 | embed_type: 1 # 1: positional encoding with input itself, 2: positional encoding without input itself 77 | k_L: [6, 6, 6] # order of the positional encoding for each feature in k 78 | q_L: [6] # order of the positional encoding for each feature in q 79 | v_L: [6, 6] # order of the positional encoding for each feature in v 80 | pe_factor: 2.0 81 | pe_mult_factor: 1.0 82 | key: 83 | d_ff: 256 # dimension of the hidden layer in the embedding MLP 84 | d_ff_out: 256 # dimension of the output of the embedding MLP 85 | n_ff_layer: 5 # number of layers in the embedding MLP 86 | ff_act: "relu" 87 | ff_act_a: 1.0 88 | ff_act_b: 1.0 89 | ff_act_trainable: false 90 | ff_last_act: "none" # last activation function in the embedding MLP 91 | norm: "layernorm" 92 | dropout_ff: 0.0 93 | use_wn: false 94 | residual_ff: false 95 | skip_layers: [] 96 | half_layers: [] 97 | residual_layers: [] 98 | residual_dims: [] 99 | query: 100 | d_ff: 256 101 | d_ff_out: 256 102 | n_ff_layer: 5 103 | ff_act: "relu" 104 | ff_act_a: 1.0 105 | ff_act_b: 1.0 106 | ff_act_trainable: false 107 | ff_last_act: "none" 108 | norm: "layernorm" 109 | dropout_ff: 0.0 110 | use_wn: false 111 | residual_ff: false 112 | skip_layers: [] 113 | half_layers: [] 114 | residual_layers: [] 115 | residual_dims: [] 116 | value: 117 | d_ff: 256 118 | d_ff_out: 32 119 | n_ff_layer: 8 120 | ff_act: "relu" 121 | ff_act_a: 1.0 122 | ff_act_b: 1.0 123 | ff_act_trainable: false 124 | ff_last_act: "none" 125 | norm: "none" 126 | dropout_ff: 0.0 127 | use_wn: false 128 | residual_ff: false 129 | skip_layers: [] 130 | half_layers: [] 131 | residual_layers: [] 132 | residual_dims: [] 133 | renderer: # the UNet 134 | generator: 135 | type: "small-unet" 136 | small_unet: 137 | bilinear: false 138 | norm: "none" 139 | single: true 140 | last_act: "none" 141 | affine_layer: -1 142 | training: 143 | steps: 250000 # number of training steps 144 | prune_steps: 500 # interval to prune the point cloud 145 | prune_start: 10000 # start pruning after this step 146 | prune_stop: 150000 # stop pruning after this step 147 | prune_thresh: 0.0 # influence score threshold to prune the point cloud 148 | prune_thresh_list: [] # change the threshold at steps in prune_steps_list, if empty, then use prune_thresh 149 | prune_steps_list: [] 150 | prune_type: "<" # prune points with influence scores smaller than the threshold ("<") or larger than the threshold (">") 151 | add_steps: 1000 # interval to add new points 152 | add_start: 20000 # start adding new points after this step 153 | add_stop: 70000 # stop adding new points after this step 154 | add_num: 1000 # number of points to add each time 155 | add_num_list: [] 156 | add_steps_list: [] 157 | exclude_keys: [] # not loading these parameters when loading the model 158 | fix_keys: [ # fix these parameters when training the model 159 | # "points", 160 | # "pc_norms", 161 | # # "norm_mlp", 162 | # "attn", 163 | # "pc_feats", 164 | # "attn_mlp", 165 | # # "renderer", 166 | # # "bkg_feats", 167 | # "bkg_points", 168 | # "points_influ_scores" 169 | ] 170 | losses: # loss weights 171 | mse: 1.0 172 | lpips: 1.0e-2 173 | lpips_alex: 0.0 174 | lr: 175 | lr_factor: 1.0 # learning rate factor, multiply all the learning rates by this factor 176 | mapping_mlp: 177 | type: "none" 178 | base_lr: 1.0e-6 179 | factor: 1 180 | warmup: 0 181 | weight_decay: 0 182 | attn: 183 | type: "cosine-hlfperiod" 184 | base_lr: 3.0e-4 185 | factor: 1 186 | warmup: 10000 187 | weight_decay: 0 188 | points: 189 | type: "cosine" 190 | base_lr: 2.0e-3 191 | factor: 1 192 | warmup: 0 193 | weight_decay: 0 194 | bkg_feats: 195 | type: "none" 196 | base_lr: 0.0 197 | factor: 1 198 | warmup: 10000 199 | weight_decay: 0 200 | points_influ_scores: 201 | type: "cosine-hlfperiod" 202 | base_lr: 1.0e-3 203 | factor: 1 204 | warmup: 10000 205 | weight_decay: 0 206 | feats: 207 | type: "cosine-hlfperiod" 208 | base_lr: 1.0e-3 209 | factor: 1 210 | warmup: 10000 211 | weight_decay: 0 212 | generator: 213 | type: "cosine-hlfperiod" 214 | base_lr: 1.0e-4 215 | factor: 1 216 | warmup: 10000 217 | weight_decay: 0 218 | eval: 219 | dataset: 220 | name: "testset" 221 | mode: "test" 222 | extract_patch: false 223 | type: "synthetic" 224 | white_bg: true 225 | path: "./data/nerf_synthetic/lego/" 226 | factor: 1 227 | num_workers: 0 228 | num_slices: -1 229 | step: 5000 # evaluate the model every this number of steps 230 | img_idx: 50 # index of the image to evaluate 231 | max_height: 100 # maximum size for each loop when rendering a full image, if the image is too large, it will be rendered in multiple loops due to memory constraints 232 | max_width: 100 233 | save_fig: true # save the log images during training 234 | test: 235 | load_path: "" # path to load the model from for testing 236 | save_fig: true # save the log images during testing 237 | save_video: false # save the video during testing 238 | max_height: 100 239 | max_width: 100 240 | datasets: 241 | - name: "testset" 242 | mode: "test" 243 | extract_patch: false 244 | type: "synthetic" 245 | white_bg: true 246 | path: "./data/nerf_synthetic/lego/" 247 | factor: 1 248 | num_workers: 0 249 | num_slices: -1 250 | plots: # videos that visualize different components of the model 251 | pcrgb: true 252 | featattn: false -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | # credit: https://github.com/princeton-vl/SNP 2 | """ Parts of the U-Net model """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autocast 8 | from .utils import activation_func 9 | 10 | 11 | class SingleConv(nn.Module): 12 | """(convolution => [BN] => ReLU) * 2""" 13 | 14 | def __init__(self, in_channels, out_channels, mid_channels=None, norm='none'): 15 | super().__init__() 16 | if not mid_channels: 17 | mid_channels = out_channels 18 | if norm == 'instance': 19 | self.double_conv = nn.Sequential( 20 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 21 | nn.InstanceNorm2d(mid_channels), 22 | nn.ReLU(inplace=True), 23 | ) 24 | elif norm == 'batch': 25 | self.double_conv = nn.Sequential( 26 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 27 | nn.BatchNorm2d(mid_channels), 28 | nn.ReLU(inplace=True), 29 | ) 30 | elif norm == 'none': 31 | self.double_conv = nn.Sequential( 32 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 33 | nn.ReLU(inplace=True), 34 | ) 35 | else: 36 | raise NotImplementedError 37 | 38 | def forward(self, x): 39 | return self.double_conv(x) 40 | 41 | 42 | class DoubleConv(nn.Module): 43 | """(convolution => [BN] => ReLU) * 2""" 44 | 45 | def __init__(self, in_channels, out_channels, mid_channels=None, norm='none'): 46 | super().__init__() 47 | if not mid_channels: 48 | mid_channels = out_channels 49 | if norm == 'instance': 50 | self.double_conv = nn.Sequential( 51 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 52 | nn.InstanceNorm2d(mid_channels), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(mid_channels, out_channels, 55 | kernel_size=3, padding=1), 56 | nn.InstanceNorm2d(out_channels), 57 | nn.ReLU(inplace=True) 58 | ) 59 | elif norm == 'batch': 60 | self.double_conv = nn.Sequential( 61 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 62 | nn.BatchNorm2d(mid_channels), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(mid_channels, out_channels, 65 | kernel_size=3, padding=1), 66 | nn.BatchNorm2d(out_channels), 67 | nn.ReLU(inplace=True) 68 | ) 69 | elif norm == 'none': 70 | self.double_conv = nn.Sequential( 71 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(mid_channels, out_channels, 74 | kernel_size=3, padding=1), 75 | nn.ReLU(inplace=True) 76 | ) 77 | 78 | def forward(self, x): 79 | return self.double_conv(x) 80 | 81 | 82 | class Down(nn.Module): 83 | """Downscaling with maxpool then double conv""" 84 | 85 | def __init__(self, in_channels, out_channels, single=False, norm='none'): 86 | super().__init__() 87 | # print('down norm:', norm) 88 | 89 | if single: 90 | self.maxpool_conv = nn.Sequential( 91 | nn.MaxPool2d(2), 92 | SingleConv(in_channels, out_channels, norm=norm) 93 | ) 94 | else: 95 | self.maxpool_conv = nn.Sequential( 96 | nn.MaxPool2d(2), 97 | DoubleConv(in_channels, out_channels, norm=norm) 98 | ) 99 | 100 | def forward(self, x): 101 | return self.maxpool_conv(x) 102 | 103 | 104 | class Up(nn.Module): 105 | """Upscaling then double conv""" 106 | 107 | def __init__(self, in_channels, out_channels, bilinear=True, single=False, norm='none'): 108 | super().__init__() 109 | 110 | # if bilinear, use the normal convolutions to reduce the number of channels 111 | if bilinear: 112 | self.up = nn.Upsample( 113 | scale_factor=2, mode='bilinear', align_corners=True) 114 | if single: 115 | self.conv = SingleConv( 116 | in_channels, out_channels, in_channels // 2, norm=norm) 117 | else: 118 | self.conv = DoubleConv( 119 | in_channels, out_channels, in_channels // 2, norm=norm) 120 | else: 121 | self.up = nn.ConvTranspose2d( 122 | in_channels, in_channels // 2, kernel_size=2, stride=2) 123 | if single: 124 | self.conv = SingleConv(in_channels, out_channels, norm=norm) 125 | else: 126 | self.conv = DoubleConv(in_channels, out_channels, norm=norm) 127 | 128 | def forward(self, x1, x2): 129 | x1 = self.up(x1) 130 | # input is CHW 131 | diffY = x2.size()[2] - x1.size()[2] 132 | diffX = x2.size()[3] - x1.size()[3] 133 | 134 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 135 | diffY // 2, diffY - diffY // 2]) 136 | # if you have padding issues, see 137 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 138 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 139 | x = torch.cat([x2, x1], dim=1) 140 | return self.conv(x) 141 | 142 | 143 | class UpSample(nn.Module): 144 | """Upscaling then double conv""" 145 | 146 | def __init__(self, in_channels, out_channels, bilinear=True, single=False, norm='none'): 147 | super().__init__() 148 | # print('up norm:', norm) 149 | 150 | # if bilinear, use the normal convolutions to reduce the number of channels 151 | if bilinear: 152 | self.up = nn.Upsample( 153 | scale_factor=2, mode='bilinear', align_corners=True) 154 | if single: 155 | self.conv = SingleConv( 156 | in_channels, out_channels, in_channels // 2, norm=norm) 157 | else: 158 | self.conv = DoubleConv( 159 | in_channels, out_channels, in_channels // 2, norm=norm) 160 | else: 161 | self.up = nn.ConvTranspose2d( 162 | in_channels, in_channels, kernel_size=2, stride=2) 163 | if single: 164 | self.conv = SingleConv(in_channels, out_channels, norm=norm) 165 | else: 166 | self.conv = DoubleConv(in_channels, out_channels, norm=norm) 167 | 168 | def forward(self, x1): 169 | x1 = self.up(x1) 170 | return self.conv(x1) 171 | 172 | 173 | class OutConv(nn.Module): 174 | def __init__(self, in_channels, out_channels): 175 | super(OutConv, self).__init__() 176 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 177 | 178 | def forward(self, x): 179 | return self.conv(x) 180 | 181 | 182 | class SmallUNet(nn.Module): 183 | def __init__(self, n_channels, n_classes, bilinear=False, single=True, norm='none', render_scale=1, last_act='none', 184 | use_amp=False, amp_dtype=torch.float16, affine_layer=-1): 185 | super(SmallUNet, self).__init__() 186 | self.n_channels = n_channels 187 | self.n_classes = n_classes 188 | self.bilinear = bilinear 189 | self.use_amp = use_amp 190 | self.amp_dtype = amp_dtype 191 | self.affine_layer = affine_layer 192 | 193 | assert (render_scale == 1 or render_scale == 2) 194 | self.render_scale = render_scale 195 | 196 | self.inc = SingleConv(n_channels, 128, norm=norm) 197 | self.down1 = Down(128, 256, single=single, norm=norm) 198 | self.down2 = Down(256, 512, single=single, norm=norm) 199 | self.up1 = Up(512, 256, bilinear, single=single, norm=norm) 200 | self.up2 = Up(256, 128, bilinear, single=single, norm=norm) 201 | 202 | if render_scale == 2: 203 | self.up3 = UpSample(128, 128, bilinear, single=False, norm=norm) 204 | 205 | self.outc = OutConv(128, n_classes) 206 | self.last_act = activation_func(last_act) 207 | 208 | def forward(self, x, log=False, gamma=None, beta=None): 209 | if self.affine_layer >= 0: 210 | assert gamma is not None and beta is not None 211 | 212 | with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 213 | if self.affine_layer == 0: 214 | B, C, H, W = x.shape 215 | assert gamma.shape == (C,) and beta.shape == (C,) 216 | # print(gamma.mean(), beta.mean()) 217 | x = x * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1) 218 | 219 | x1 = self.inc(x) 220 | if self.affine_layer == 1: 221 | B, C, H, W = x1.shape 222 | assert gamma.shape == (C,) and beta.shape == (C,) 223 | x1 = x1 * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1) 224 | 225 | x2 = self.down1(x1) 226 | if self.affine_layer == 2: 227 | B, C, H, W = x2.shape 228 | assert gamma.shape == (C,) and beta.shape == (C,) 229 | x2 = x2 * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1) 230 | 231 | x3 = self.down2(x2) 232 | if self.affine_layer == 3: 233 | B, C, H, W = x3.shape 234 | assert gamma.shape == (C,) and beta.shape == (C,) 235 | x3 = x3 * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1) 236 | 237 | x = self.up1(x3, x2) 238 | if self.affine_layer == 4: 239 | B, C, H, W = x.shape 240 | assert gamma.shape == (C,) and beta.shape == (C,) 241 | x = x * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1) 242 | 243 | x = self.up2(x, x1) 244 | if self.affine_layer == 5: 245 | B, C, H, W = x.shape 246 | assert gamma.shape == (C,) and beta.shape == (C,) 247 | x = x * gamma.reshape(1, C, 1, 1) + beta.reshape(1, C, 1, 1) 248 | 249 | if self.render_scale == 2: 250 | x = self.up3(x) 251 | logits = self.outc(x) 252 | 253 | logits = self.last_act(logits) 254 | 255 | if log: 256 | print(x1.dtype, x2.dtype, x3.dtype, x.dtype, logits.dtype) 257 | 258 | return logits 259 | -------------------------------------------------------------------------------- /models/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import autocast 4 | import math 5 | from .mlp import MLP 6 | from .utils import PoseEnc, activation_func 7 | 8 | 9 | def get_proximity_attention_layer(args, v_extra_dim=0, k_extra_dim=0, q_extra_dim=0, eps=1e-6, use_amp=False, amp_dtype=torch.float16): 10 | k_dim_map = { 11 | 1: [3, 3, 3], 12 | } 13 | k_dim = k_dim_map[args.k_type] 14 | 15 | q_dim_map = { 16 | 1: [3], 17 | } 18 | q_dim = q_dim_map[args.q_type] 19 | 20 | v_dim_map = { 21 | 1: [3, 3], 22 | } 23 | v_dim = v_dim_map[args.v_type] 24 | 25 | return ProximityAttention(d_k=k_dim, d_q=q_dim, d_v=v_dim, d_model=args.d_model, embed_args=args.embed, 26 | score_act=args.score_act, d_ko=k_extra_dim, d_qo=q_extra_dim, 27 | d_vo=v_extra_dim, eps=eps, use_amp=use_amp, amp_dtype=amp_dtype) 28 | 29 | 30 | class LayerNorm(nn.Module): 31 | "Construct a layernorm module" 32 | 33 | def __init__(self, features, eps=1e-6): 34 | super(LayerNorm, self).__init__() 35 | self.a_2 = nn.Parameter(torch.ones(features)) 36 | self.b_2 = nn.Parameter(torch.zeros(features)) 37 | self.eps = eps 38 | 39 | def forward(self, x): 40 | mean = x.mean(-1, keepdim=True) 41 | std = x.std(-1, keepdim=True) 42 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 43 | 44 | 45 | def attention(query, key, kernel_type): 46 | """ 47 | Compute Attention Scores 48 | query: [batch_size, 1, query_len, d_kq] or [batch_size, query_len, d_kq] 49 | key: [batch_size, 1, seq_len, d_kq] or [batch_size, seq_len, d_kq] 50 | """ 51 | d_kq = query.size(-1) 52 | 53 | if kernel_type == "scaled-dot": 54 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_kq) 55 | elif kernel_type == "-scaled-dot": 56 | scores = -torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_kq) 57 | elif kernel_type == "dot": 58 | scores = torch.matmul(query, key.transpose(-2, -1)) 59 | elif kernel_type == "-dot": 60 | scores = -torch.matmul(query, key.transpose(-2, -1)) 61 | elif kernel_type == "l1-dist": 62 | scores = torch.norm(query.unsqueeze(-2) - 63 | key.unsqueeze(-3), p=1, dim=-1) 64 | elif kernel_type == "-l1-dist": 65 | scores = -torch.norm(query.unsqueeze(-2) - 66 | key.unsqueeze(-3), p=1, dim=-1) 67 | elif kernel_type == "l2-dist": 68 | scores = torch.norm(query.unsqueeze(-2) - 69 | key.unsqueeze(-3), p=2, dim=-1) 70 | elif kernel_type == "-l2-dist": 71 | scores = -torch.norm(query.unsqueeze(-2) - 72 | key.unsqueeze(-3), p=2, dim=-1) 73 | elif kernel_type == "scaled-l2-dist": 74 | scores = torch.norm(query.unsqueeze(-2) - 75 | key.unsqueeze(-3), p=2, dim=-1) / math.sqrt(d_kq) 76 | elif kernel_type == "-scaled-l2-dist": 77 | scores = -torch.norm(query.unsqueeze(-2) - 78 | key.unsqueeze(-3), p=2, dim=-1) / math.sqrt(d_kq) 79 | elif kernel_type == "cosine": 80 | scores = torch.matmul(query, key.transpose(-2, -1)) / ( 81 | torch.norm(query, dim=-1, keepdim=True) 82 | * torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) 83 | ) 84 | else: 85 | raise ValueError("Unknown kernel type: {}".format(kernel_type)) 86 | 87 | return scores 88 | 89 | 90 | class FeedForward(nn.Module): 91 | "Implements FFN module." 92 | 93 | def __init__(self, d_input, d_output, d_ff, n_layer=2, act="relu", last_act="none", dropout=0.1, norm="layernorm", 94 | residual=False, act_a=1.0, act_b=1.0, act_trainable=False, use_wn=False, eps=1e-6, skip_layers=[], 95 | half_layers=[]): 96 | super(FeedForward, self).__init__() 97 | self.eps = eps 98 | self.d_input = d_input 99 | self.d_output = d_output 100 | if norm == "layernorm": 101 | self.innorm = LayerNorm(d_input, eps) 102 | self.outnorm = LayerNorm(d_output, eps) 103 | elif norm == "none": 104 | self.innorm = nn.Identity() 105 | self.outnorm = nn.Identity() 106 | else: 107 | raise ValueError("Invalid attention norm type") 108 | self.dropout = nn.Dropout(dropout) 109 | self.mlp = MLP(d_input, n_layer, d_ff, d_output, act_type=act, last_act_type=last_act, use_wn=use_wn, 110 | a=act_a, b=act_b, trainable=act_trainable, skip_layers=skip_layers, half_layers=half_layers) 111 | self.residual = residual 112 | 113 | def forward(self, x): 114 | if self.residual and x.shape[-1] == self.d_output: 115 | return self.outnorm(x + self.dropout(self.mlp(self.innorm(x)))) 116 | else: 117 | return self.outnorm(self.dropout(self.mlp(self.innorm(x)))) 118 | 119 | 120 | class Embeddings(nn.Module): 121 | def __init__(self, d_k, d_q, d_v, d_model, args, d_ko=0, d_qo=0, d_vo=0, eps=1e-6): 122 | super(Embeddings, self).__init__() 123 | self.d_k = d_k 124 | self.d_q = d_q 125 | self.d_v = d_v 126 | self.args = args 127 | self.embed_type = args.embed_type 128 | self.d_model = d_model 129 | self.d_ko = d_ko 130 | self.d_qo = d_qo 131 | self.d_vo = d_vo 132 | self.eps = eps 133 | 134 | self.posenc = PoseEnc(args.pe_factor, args.pe_mult_factor) 135 | 136 | if self.embed_type == 1: 137 | # Positional Encoding with itself 138 | d_k = sum([d + d * 2 * args.k_L[i] for i, d in enumerate(d_k)]) + d_ko 139 | d_q = sum([d + d * 2 * args.q_L[i] for i, d in enumerate(d_q)]) + d_qo 140 | d_v = sum([d + d * 2 * args.v_L[i] for i, d in enumerate(d_v)]) + d_vo 141 | 142 | elif self.embed_type == 2: 143 | # Positional Encoding without itself 144 | d_k = sum([d * 2 * args.k_L[i] for i, d in enumerate(d_k)]) + d_ko 145 | d_q = sum([d * 2 * args.q_L[i] for i, d in enumerate(d_q)]) + d_qo 146 | d_v = sum([d * 2 * args.v_L[i] for i, d in enumerate(d_v)]) + d_vo 147 | 148 | else: 149 | raise ValueError( 150 | 'Unknown embedding type: {}'.format(self.embed_type)) 151 | 152 | self.embed_k = FeedForward(d_k, args.key.d_ff_out, args.key.d_ff, args.key.n_ff_layer, args.key.ff_act, 153 | args.key.ff_last_act, args.key.dropout_ff, args.key.norm, args.key.residual_ff, 154 | args.key.ff_act_a, args.key.ff_act_b, args.key.ff_act_trainable, args.key.use_wn, eps, 155 | args.key.skip_layers, args.key.half_layers) 156 | self.embed_q = FeedForward(d_q, args.query.d_ff_out, args.query.d_ff, args.query.n_ff_layer, args.query.ff_act, 157 | args.query.ff_last_act, args.query.dropout_ff, args.query.norm, args.query.residual_ff, 158 | args.query.ff_act_a, args.query.ff_act_b, args.query.ff_act_trainable, args.query.use_wn, eps, 159 | args.query.skip_layers, args.query.half_layers) 160 | self.embed_v = FeedForward(d_v, args.value.d_ff_out, args.value.d_ff, args.value.n_ff_layer, args.value.ff_act, 161 | args.value.ff_last_act, args.value.dropout_ff, args.value.norm, args.value.residual_ff, 162 | args.value.ff_act_a, args.value.ff_act_b, args.value.ff_act_trainable, args.value.use_wn, eps, 163 | args.value.skip_layers, args.value.half_layers) 164 | 165 | def forward(self, k_features, q_features, v_features, k_other=None, q_other=None, v_other=None): 166 | """ 167 | k_features: [(B, H, W, N, Dk_i)] 168 | q_features: [(B, H, W, 1, Dq_i)] 169 | v_features: [(B, H, W, N, Dv_i)] 170 | """ 171 | 172 | if self.embed_type == 1: 173 | pe_k_features = [self.posenc(f, self.args.k_L[i]) for i, f in enumerate(k_features)] 174 | pe_q_features = [self.posenc(f, self.args.q_L[i]) for i, f in enumerate(q_features)] 175 | pe_v_features = [self.posenc(f, self.args.v_L[i]) for i, f in enumerate(v_features)] 176 | 177 | elif self.embed_type == 2: 178 | pe_k_features = [self.posenc(f, self.args.k_L[i], without_self=True) for i, f in enumerate(k_features)] 179 | pe_q_features = [self.posenc(f, self.args.q_L[i], without_self=True) for i, f in enumerate(q_features)] 180 | pe_v_features = [self.posenc(f, self.args.v_L[i], without_self=True) for i, f in enumerate(v_features)] 181 | 182 | else: 183 | raise ValueError('Unknown embedding type: {}'.format(self.embed_type)) 184 | 185 | if self.d_ko > 0: pe_k_features = pe_k_features + k_other 186 | if self.d_qo > 0: pe_q_features = pe_q_features + q_other 187 | if self.d_vo > 0: pe_v_features = pe_v_features + v_other 188 | 189 | k = torch.cat(pe_k_features, dim=-1).flatten(0, 2) 190 | q = torch.cat(pe_q_features, dim=-1).flatten(0, 2) 191 | v = torch.cat(pe_v_features, dim=-1).flatten(0, 2) 192 | 193 | k = self.embed_k(k) 194 | q = self.embed_q(q) 195 | v = self.embed_v(v) 196 | 197 | return k, q, v 198 | 199 | 200 | class AttentionLayer(nn.Module): 201 | def __init__(self, embed_args, d_model, score_act_type): 202 | super(AttentionLayer, self).__init__() 203 | self.d_model = d_model 204 | self.w_k = nn.Linear(embed_args.key.d_ff_out, d_model) 205 | self.w_q = nn.Linear(embed_args.query.d_ff_out, d_model) 206 | 207 | nn.init.xavier_uniform_(self.w_k.weight) 208 | nn.init.xavier_uniform_(self.w_q.weight) 209 | 210 | self.score_act = activation_func(score_act_type) 211 | 212 | def forward(self, key, query, value): 213 | nbatches, nseqv, _ = value.shape 214 | _, nseqk, _ = key.shape 215 | assert nseqv == nseqk 216 | 217 | key = self.w_k(key) 218 | query = self.w_q(query) 219 | 220 | key = key.view(nbatches, -1, 1, self.d_model).transpose(1, 2) 221 | query = query.view(nbatches, -1, 1, self.d_model).transpose(1, 2) 222 | 223 | # [nbatches, nhead, nseq, nseq] 224 | scores = attention(query, key, "scaled-dot") 225 | scores = self.score_act(scores) 226 | return scores 227 | 228 | 229 | class ProximityAttention(nn.Module): 230 | def __init__(self, d_k, d_q, d_v, d_model, embed_args, score_act="none", 231 | d_ko=0, d_qo=0, d_vo=0, eps=1e-6, use_amp=False, amp_dtype=torch.float16): 232 | super(ProximityAttention, self).__init__() 233 | self.eps = eps 234 | self.use_amp = use_amp 235 | self.amp_dtype = amp_dtype 236 | 237 | self.embed = Embeddings(d_k, d_q, d_v, d_model, embed_args, d_ko, d_qo, d_vo, eps) 238 | 239 | self.attention_layer = AttentionLayer(embed_args, d_model, score_act) 240 | 241 | def forward(self, k_features, q_features, v_features, k_other=None, q_other=None, v_other=None, step=-1): 242 | """ 243 | k_features: [(H, W, N, Dk_i)] 244 | q_features: [(H, W, 1, Dq_i)] or [(H, W, N, Dq_i)] 245 | v_features: [(H, W, N, Dv_i)] 246 | """ 247 | 248 | with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 249 | k, q, v = self.embed(k_features, q_features, v_features, k_other, q_other, v_other) 250 | scores = self.attention_layer(k, q, v) 251 | 252 | return k, q, v, scores 253 | -------------------------------------------------------------------------------- /exposure_control_finetune.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import shutil 7 | import zipfile 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import random 11 | import copy 12 | import bisect 13 | import time 14 | import sys 15 | import io 16 | import imageio 17 | from tqdm import tqdm 18 | from PIL import Image 19 | from utils import * 20 | from dataset import get_dataset, get_loader 21 | from models import get_model, get_loss 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description="PAPR") 26 | parser.add_argument('--opt', type=str, default="", help='Option file path') 27 | parser.add_argument('--resume', type=int, default=0, help='Resume training') 28 | return parser.parse_args() 29 | 30 | 31 | def eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, train_out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs): 32 | step = steps[-1] 33 | train_img_idx, _, train_patch, _, _ = batch 34 | train_img, train_rayd, train_rayo = dataset.get_full_img(train_img_idx[0]) 35 | img, rayd, rayo = eval_dataset.get_full_img(args.eval.img_idx) 36 | c2w = dataset.get_c2w(args.eval.img_idx) 37 | 38 | print("Before resample shading codes, eval shading codes mean: ", model.eval_shading_codes[args.eval.img_idx].mean()) 39 | resample_shading_codes(model.eval_shading_codes, args, model, eval_dataset, args.eval.img_idx, loss_fn, step, full_img=True) 40 | print("After resample shading codes, eval shading codes mean: ", model.eval_shading_codes[args.eval.img_idx].mean()) 41 | 42 | N, H, W, _ = rayd.shape 43 | num_pts, _ = model.points.shape 44 | 45 | rayo = rayo.to(device) 46 | rayd = rayd.to(device) 47 | img = img.to(device) 48 | c2w = c2w.to(device) 49 | 50 | topk = min([num_pts, model.select_k]) 51 | 52 | selected_points = torch.zeros(1, H, W, topk, 3) 53 | 54 | bkg_seq_len_attn = 0 55 | feat_dim = args.models.attn.embed.value.d_ff_out 56 | if model.bkg_feats is not None: 57 | bkg_seq_len_attn = model.bkg_feats.shape[0] 58 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(device) 59 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(device) 60 | 61 | with torch.no_grad(): 62 | cur_shading_code = model.eval_shading_codes[args.eval.img_idx] 63 | cur_affine = model.mapping_mlp(cur_shading_code) 64 | cur_affine_dim = cur_affine.shape[-1] 65 | cur_gamma, cur_beta = cur_affine[:cur_affine_dim // 2], cur_affine[cur_affine_dim // 2:] 66 | 67 | for height_start in range(0, H, args.eval.max_height): 68 | for width_start in range(0, W, args.eval.max_width): 69 | height_end = min(height_start + args.eval.max_height, H) 70 | width_end = min(width_start + args.eval.max_width, W) 71 | 72 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \ 73 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=step) 74 | 75 | selected_points[:, height_start:height_end, width_start:width_end, :, :] = model.selected_points 76 | 77 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2), gamma=cur_gamma, beta=cur_beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3) 78 | 79 | if model.bkg_feats is not None: 80 | bkg_attn = attn[..., topk:, :] 81 | if args.models.normalize_topk_attn: 82 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 83 | else: 84 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 85 | rgb = rgb.squeeze(-2) 86 | else: 87 | rgb = foreground_rgb.squeeze(-2) 88 | 89 | rgb = model.last_act(rgb) 90 | rgb = torch.clamp(rgb, 0, 1) 91 | 92 | eval_loss = loss_fn(rgb, img) 93 | eval_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.) 94 | 95 | model.clear_grad() 96 | 97 | eval_losses.append(eval_loss.item()) 98 | eval_psnrs.append(eval_psnr.item()) 99 | 100 | print("Eval step:", step, "train_loss:", train_losses[-1], "eval_loss:", eval_losses[-1], "eval_psnr:", eval_psnrs[-1]) 101 | 102 | log_dir = os.path.join(args.save_dir, args.index) 103 | os.makedirs(log_dir, exist_ok=True) 104 | if args.eval.save_fig: 105 | os.makedirs(os.path.join(log_dir, "train_main_plots"), exist_ok=True) 106 | os.makedirs(os.path.join(log_dir, "train_pcd_plots"), exist_ok=True) 107 | 108 | coord_scale = args.dataset.coord_scale 109 | pt_plot_scale = 1.0 * coord_scale 110 | if "Barn" in args.dataset.path: 111 | pt_plot_scale *= 1.8 112 | if "Family" in args.dataset.path: 113 | pt_plot_scale *= 0.5 114 | 115 | # calculate depth, weighted sum the distances from top K points to image plane 116 | od = -rayo 117 | D = torch.sum(od * rayo) 118 | dists = torch.abs(torch.sum(selected_points.to(od.device) * od, -1) - D) / torch.norm(od) 119 | if model.bkg_feats is not None: 120 | dists = torch.cat([dists, torch.ones(N, H, W, model.bkg_feats.shape[0]).to(dists.device) * 0], dim=-1) 121 | cur_depth = (torch.sum(attn.squeeze(-1).to(od.device) * dists, dim=-1)).detach().cpu() 122 | 123 | train_tgt_rgb = train_img.squeeze().cpu().numpy().astype(np.float32) 124 | train_tgt_patch = train_patch[0].cpu().numpy().astype(np.float32) 125 | train_pred_patch = train_out[0] 126 | test_tgt_rgb = img.squeeze().cpu().numpy().astype(np.float32) 127 | test_pred_rgb = rgb.squeeze().detach().cpu().numpy().astype(np.float32) 128 | points_np = model.points.detach().cpu().numpy() 129 | depth = cur_depth.squeeze().numpy().astype(np.float32) 130 | points_influ_scores_np = None 131 | if model.points_influ_scores is not None: 132 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy() 133 | 134 | # main plot 135 | main_plot = get_training_main_plot(args.index, steps, train_tgt_rgb, train_tgt_patch, train_pred_patch, test_tgt_rgb, test_pred_rgb, train_losses, 136 | eval_losses, points_np, pt_plot_scale, depth, pt_lrs, attn_lrs, eval_psnrs, points_influ_scores_np) 137 | save_name = os.path.join(log_dir, "train_main_plots", "%s_iter_%d.png" % (args.index, step)) 138 | main_plot.save(save_name) 139 | 140 | # point cloud plot 141 | ro = train_rayo.squeeze().detach().cpu().numpy() 142 | rd = train_rayd.squeeze().detach().cpu().numpy() 143 | 144 | pcd_plot = get_training_pcd_plot(args.index, steps[-1], ro, rd, points_np, args.dataset.coord_scale, pt_plot_scale, points_influ_scores_np) 145 | save_name = os.path.join(log_dir, "train_pcd_plots", "%s_iter_%d.png" % (args.index, step)) 146 | pcd_plot.save(save_name) 147 | 148 | model.save(step, log_dir) 149 | if step % 50000 == 0: 150 | torch.save(model.state_dict(), os.path.join(log_dir, "model_%d.pth" % step)) 151 | 152 | torch.save(torch.tensor(train_losses), os.path.join(log_dir, "train_losses.pth")) 153 | torch.save(torch.tensor(eval_losses), os.path.join(log_dir, "eval_losses.pth")) 154 | torch.save(torch.tensor(eval_psnrs), os.path.join(log_dir, "eval_psnrs.pth")) 155 | 156 | return 0 157 | 158 | 159 | def train_step(step, model, device, dataset, batch, loss_fn, args): 160 | img_idx, _, tgt, rayd, rayo = batch 161 | c2w = dataset.get_c2w(img_idx[0]) 162 | 163 | rayo = rayo.to(device) 164 | rayd = rayd.to(device) 165 | tgt = tgt.to(device) 166 | c2w = c2w.to(device) 167 | 168 | shading_code = model.train_shading_codes[img_idx[0]] 169 | 170 | model.clear_grad() 171 | out = model(rayo, rayd, c2w, step, shading_code=shading_code) 172 | out = model.last_act(out) 173 | loss = loss_fn(out, tgt) 174 | model.scaler.scale(loss).backward() 175 | model.step(step) 176 | if args.scaler_min_scale > 0 and model.scaler.get_scale() < args.scaler_min_scale: 177 | model.scaler.update(args.scaler_min_scale) 178 | else: 179 | model.scaler.update() 180 | 181 | return loss.item(), out.detach().cpu().numpy() 182 | 183 | 184 | def train_and_eval(start_step, model, device, dataset, eval_dataset, sample_dataset, losses, args): 185 | trainloader = get_loader(dataset, args.dataset, mode="train") 186 | 187 | loss_fn = get_loss(args.training.losses) 188 | loss_fn = loss_fn.to(device) 189 | 190 | log_dir = os.path.join(args.save_dir, args.index) 191 | os.makedirs(os.path.join(log_dir, "test"), exist_ok=True) 192 | log_dir = os.path.join(log_dir, "test") 193 | 194 | steps = [] 195 | train_losses, eval_losses, eval_psnrs = losses 196 | pt_lrs = [] 197 | attn_lrs = [] 198 | 199 | avg_train_loss = 0. 200 | step = start_step 201 | eval_step_cnt = start_step 202 | pc_frames = [] 203 | 204 | model.train_shading_codes = nn.Parameter(torch.randn(len(dataset), args.exposure_control.shading_code_dim, device=device) * args.exposure_control.shading_code_scale, requires_grad=False) 205 | model.eval_shading_codes = nn.Parameter(torch.randn(len(eval_dataset), args.exposure_control.shading_code_dim, device=device) * args.exposure_control.shading_code_scale, requires_grad=False) 206 | print("!!!!! train_shading_codes:", model.train_shading_codes.shape, model.train_shading_codes.min(), model.train_shading_codes.max()) 207 | print("!!!!! eval_shading_codes:", model.eval_shading_codes.shape, model.eval_shading_codes.min(), model.eval_shading_codes.max()) 208 | 209 | print("Start step:", start_step, "Total steps:", args.training.steps) 210 | start_time = time.time() 211 | while step < args.training.steps: 212 | for _, batch in enumerate(trainloader): 213 | if step % args.exposure_control.shading_code_resample_iter == 0: # Resample shading codes 214 | print("Resampling shading codes") 215 | print("Before resampling:", model.train_shading_codes.shape, model.train_shading_codes.min(), model.train_shading_codes.max()) 216 | for img_idx in tqdm(range(len(sample_dataset))): 217 | resample_shading_codes(model.train_shading_codes, args, model, sample_dataset, img_idx, loss_fn, step) 218 | print("After resampling:", model.train_shading_codes.shape, model.train_shading_codes.min(), model.train_shading_codes.max()) 219 | 220 | loss, out = train_step(step, model, device, dataset, batch, loss_fn, args) 221 | avg_train_loss += loss 222 | step += 1 223 | eval_step_cnt += 1 224 | 225 | if step % 200 == 0: 226 | time_used = time.time() - start_time 227 | print("Train step:", step, "loss:", loss, "attn_lr:", model.attn_lr, "pts_lr:", model.pts_lr, "scale:", model.scaler.get_scale(), f"time: {time_used:.2f}s") 228 | print(model.mapping_mlp.model.model[7].weight[0, :5]) 229 | start_time = time.time() 230 | 231 | if (step % args.eval.step == 0) or (step % 500 == 0 and step < 10000): 232 | train_losses.append(avg_train_loss / eval_step_cnt) 233 | pt_lrs.append(model.pts_lr) 234 | attn_lrs.append(model.attn_lr) 235 | steps.append(step) 236 | eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs) 237 | avg_train_loss = 0. 238 | eval_step_cnt = 0 239 | 240 | if ((step - 1) % 200 == 0) and args.eval.save_fig: 241 | coord_scale = args.dataset.coord_scale 242 | pt_plot_scale = 0.8 * coord_scale 243 | if "Barn" in args.dataset.path: 244 | pt_plot_scale *= 1.5 245 | if "Family" in args.dataset.path: 246 | pt_plot_scale *= 0.5 247 | 248 | pc_dir = os.path.join(log_dir, "point_clouds") 249 | os.makedirs(pc_dir, exist_ok=True) 250 | 251 | points_np = model.points.detach().cpu().numpy() 252 | points_influ_scores_np = None 253 | if model.points_influ_scores is not None: 254 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy() 255 | pcd_plot = get_training_pcd_single_plot(step, points_np, pt_plot_scale, points_influ_scores_np) 256 | pc_frames.append(pcd_plot) 257 | 258 | if step == 1: 259 | pcd_plot.save(os.path.join(pc_dir, "init_pcd.png")) 260 | 261 | if step >= args.training.steps: 262 | break 263 | 264 | if args.eval.save_fig and pc_frames != []: 265 | f = os.path.join(log_dir, f"{args.index}-pc.mp4") 266 | imageio.mimwrite(f, pc_frames, fps=30, quality=10) 267 | 268 | print("Training finished!") 269 | 270 | 271 | def main(args, eval_args, sample_args, resume): 272 | log_dir = os.path.join(args.save_dir, args.index) 273 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 274 | 275 | model = get_model(args, device) 276 | dataset = get_dataset(args.dataset, mode="train") 277 | sample_dataset = get_dataset(sample_args.dataset, mode="train") 278 | eval_dataset = get_dataset(eval_args.dataset, mode="test") 279 | model = model.to(device) 280 | 281 | # if torch.__version__ >= "2.0": 282 | # model = torch.compile(model) 283 | 284 | start_step = 0 285 | losses = [[], [], []] 286 | if resume > 0: 287 | start_step = model.load(log_dir) 288 | 289 | train_losses = torch.load(os.path.join(log_dir, "train_losses.pth")).tolist() 290 | eval_losses = torch.load(os.path.join(log_dir, "eval_losses.pth")).tolist() 291 | eval_psnrs = torch.load(os.path.join(log_dir, "eval_psnrs.pth")).tolist() 292 | losses = [train_losses, eval_losses, eval_psnrs] 293 | 294 | print("!!!!! Resume from step %s" % start_step) 295 | elif args.load_path: 296 | try: 297 | model_state_dict = torch.load(args.load_path) 298 | for step, state_dict in model_state_dict.items(): 299 | resume_step = int(step) 300 | model.load_my_state_dict(state_dict) 301 | except: 302 | model_state_dict = torch.load(os.path.join(args.save_dir, args.load_path, "model.pth")) 303 | for step, state_dict in model_state_dict.items(): 304 | resume_step = step 305 | model.load_my_state_dict(state_dict) 306 | print("!!!!! Loaded model from %s at step %s" % (args.load_path, resume_step)) 307 | 308 | train_and_eval(start_step, model, device, dataset, eval_dataset, sample_dataset, losses, args) 309 | print(torch.cuda.memory_summary()) 310 | 311 | 312 | if __name__ == '__main__': 313 | 314 | with open("configs/default.yml", 'r') as f: 315 | default_config = yaml.safe_load(f) 316 | 317 | args = parse_args() 318 | with open(args.opt, 'r') as f: 319 | config = yaml.safe_load(f) 320 | 321 | train_config = copy.deepcopy(default_config) 322 | update_dict(train_config, config) 323 | 324 | sample_config = copy.deepcopy(train_config) 325 | sample_config['dataset']['patches']['height'] = train_config['exposure_control']['shading_code_resample_size'] 326 | sample_config['dataset']['patches']['width'] = train_config['exposure_control']['shading_code_resample_size'] 327 | sample_config = DictAsMember(sample_config) 328 | 329 | eval_config = copy.deepcopy(train_config) 330 | eval_config['dataset'].update(eval_config['eval']['dataset']) 331 | eval_config = DictAsMember(eval_config) 332 | train_config = DictAsMember(train_config) 333 | 334 | assert train_config.models.use_renderer, "Currently only support using renderer for exposure control" 335 | 336 | log_dir = os.path.join(train_config.save_dir, train_config.index) 337 | os.makedirs(log_dir, exist_ok=True) 338 | 339 | sys.stdout = Logger(os.path.join(log_dir, 'train.log'), sys.stdout) 340 | sys.stderr = Logger(os.path.join(log_dir, 'train_error.log'), sys.stderr) 341 | 342 | shutil.copyfile(__file__, os.path.join(log_dir, os.path.basename(__file__))) 343 | shutil.copyfile(args.opt, os.path.join(log_dir, os.path.basename(args.opt))) 344 | 345 | find_all_python_files_and_zip(".", os.path.join(log_dir, "code.zip")) 346 | 347 | setup_seed(train_config.seed) 348 | 349 | main(train_config, eval_config, sample_config, args.resume) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import os 5 | import shutil 6 | import zipfile 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import random 10 | import copy 11 | import bisect 12 | import time 13 | import sys 14 | import io 15 | import imageio 16 | from PIL import Image 17 | from utils import * 18 | from dataset import get_dataset, get_loader 19 | from models import get_model, get_loss 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="PAPR") 24 | parser.add_argument('--opt', type=str, default="", help='Option file path') 25 | parser.add_argument('--resume', type=int, default=0, help='Resume training') 26 | return parser.parse_args() 27 | 28 | 29 | def eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, train_out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs): 30 | step = steps[-1] 31 | train_img_idx, _, train_patch, _, _ = batch 32 | # For visualization, use the first image in the batch 33 | first_idx = train_img_idx[0].item() if isinstance(train_img_idx, torch.Tensor) else train_img_idx[0] 34 | train_img, train_rayd, train_rayo = dataset.get_full_img(first_idx) 35 | img, rayd, rayo = eval_dataset.get_full_img(args.eval.img_idx) 36 | c2w = eval_dataset.get_c2w(args.eval.img_idx) 37 | 38 | N, H, W, _ = rayd.shape 39 | num_pts, _ = model.points.shape 40 | 41 | rayo = rayo.to(device) 42 | rayd = rayd.to(device) 43 | img = img.to(device) 44 | c2w = c2w.to(device) 45 | 46 | topk = min([num_pts, model.select_k]) 47 | 48 | selected_points = torch.zeros(1, H, W, topk, 3) 49 | 50 | bkg_seq_len_attn = 0 51 | attn_opt = args.models.attn 52 | feat_dim = attn_opt.embed.value.d_ff_out 53 | if model.bkg_feats is not None: 54 | bkg_seq_len_attn = model.bkg_feats.shape[0] 55 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(device) 56 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(device) 57 | 58 | with torch.no_grad(): 59 | for height_start in range(0, H, args.eval.max_height): 60 | for width_start in range(0, W, args.eval.max_width): 61 | height_end = min(height_start + args.eval.max_height, H) 62 | width_end = min(width_start + args.eval.max_width, W) 63 | 64 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \ 65 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=step) 66 | 67 | selected_points[:, height_start:height_end, width_start:width_end, :, :] = model.selected_points 68 | 69 | if args.models.use_renderer: 70 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2)).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3) 71 | else: 72 | foreground_rgb = feature_map 73 | 74 | if model.bkg_feats is not None: 75 | bkg_attn = attn[..., topk:, :] 76 | if args.models.normalize_topk_attn: 77 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 78 | else: 79 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 80 | rgb = rgb.squeeze(-2) 81 | else: 82 | rgb = foreground_rgb.squeeze(-2) 83 | 84 | rgb = model.last_act(rgb) 85 | rgb = torch.clamp(rgb, 0, 1) 86 | 87 | eval_loss = loss_fn(rgb, img) 88 | eval_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.) 89 | 90 | model.clear_grad() 91 | 92 | eval_losses.append(eval_loss.item()) 93 | eval_psnrs.append(eval_psnr.item()) 94 | 95 | print("Eval step:", step, "train_loss:", train_losses[-1], "eval_loss:", eval_losses[-1], "eval_psnr:", eval_psnrs[-1]) 96 | 97 | log_dir = os.path.join(args.save_dir, args.index) 98 | os.makedirs(log_dir, exist_ok=True) 99 | if args.eval.save_fig: 100 | os.makedirs(os.path.join(log_dir, "train_main_plots"), exist_ok=True) 101 | os.makedirs(os.path.join(log_dir, "train_pcd_plots"), exist_ok=True) 102 | 103 | coord_scale = args.dataset.coord_scale 104 | pt_plot_scale = 1.0 * coord_scale 105 | if "Barn" in args.dataset.path: 106 | pt_plot_scale *= 1.8 107 | if "Family" in args.dataset.path: 108 | pt_plot_scale *= 0.5 109 | 110 | # calculate depth, weighted sum the distances from top K points to image plane 111 | od = -rayo 112 | D = torch.sum(od * rayo) 113 | dists = torch.abs(torch.sum(selected_points.to(od.device) * od, -1) - D) / torch.norm(od) 114 | if model.bkg_feats is not None: 115 | dists = torch.cat([dists, torch.ones(N, H, W, model.bkg_feats.shape[0]).to(dists.device) * 0], dim=-1) 116 | cur_depth = (torch.sum(attn.squeeze(-1).to(od.device) * dists, dim=-1)).detach().cpu() 117 | 118 | train_tgt_rgb = train_img.squeeze().cpu().numpy().astype(np.float32) 119 | # For visualization, use the first patch/output in the batch 120 | train_tgt_patch = train_patch[0].cpu().numpy().astype(np.float32) 121 | train_pred_patch = train_out[0] 122 | test_tgt_rgb = img.squeeze().cpu().numpy().astype(np.float32) 123 | test_pred_rgb = rgb.squeeze().detach().cpu().numpy().astype(np.float32) 124 | points_np = model.points.detach().cpu().numpy() 125 | depth = cur_depth.squeeze().numpy().astype(np.float32) 126 | points_influ_scores_np = None 127 | if model.points_influ_scores is not None: 128 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy() 129 | 130 | # main plot 131 | main_plot = get_training_main_plot(args.index, steps, train_tgt_rgb, train_tgt_patch, train_pred_patch, test_tgt_rgb, test_pred_rgb, train_losses, 132 | eval_losses, points_np, pt_plot_scale, depth, pt_lrs, attn_lrs, eval_psnrs, points_influ_scores_np) 133 | save_name = os.path.join(log_dir, "train_main_plots", "%s_iter_%d.png" % (args.index, step)) 134 | main_plot.save(save_name) 135 | 136 | # point cloud plot 137 | ro = train_rayo.squeeze().detach().cpu().numpy() 138 | rd = train_rayd.squeeze().detach().cpu().numpy() 139 | 140 | pcd_plot = get_training_pcd_plot(args.index, steps[-1], ro, rd, points_np, args.dataset.coord_scale, pt_plot_scale, points_influ_scores_np) 141 | save_name = os.path.join(log_dir, "train_pcd_plots", "%s_iter_%d.png" % (args.index, step)) 142 | pcd_plot.save(save_name) 143 | 144 | model.save(step, log_dir) 145 | if step % 50000 == 0: 146 | torch.save(model.state_dict(), os.path.join(log_dir, "model_%d.pth" % step)) 147 | 148 | torch.save(torch.tensor(train_losses), os.path.join(log_dir, "train_losses.pth")) 149 | torch.save(torch.tensor(eval_losses), os.path.join(log_dir, "eval_losses.pth")) 150 | torch.save(torch.tensor(eval_psnrs), os.path.join(log_dir, "eval_psnrs.pth")) 151 | 152 | return 0 153 | 154 | 155 | def train_step(step, model, device, dataset, batch, loss_fn, args): 156 | img_idx, _, tgt, rayd, rayo = batch 157 | # Get c2w for all images in the batch 158 | if isinstance(img_idx, torch.Tensor): 159 | c2w = torch.stack([dataset.get_c2w(idx.item()) for idx in img_idx]) 160 | else: 161 | c2w = torch.stack([dataset.get_c2w(idx) for idx in img_idx]) 162 | 163 | rayo = rayo.to(device) 164 | rayd = rayd.to(device) 165 | tgt = tgt.to(device) 166 | c2w = c2w.to(device) 167 | 168 | model.clear_grad() 169 | out = model(rayo, rayd, c2w, step) 170 | out = model.last_act(out) 171 | loss = loss_fn(out, tgt) 172 | model.scaler.scale(loss).backward() 173 | model.step(step) 174 | if args.scaler_min_scale > 0 and model.scaler.get_scale() < args.scaler_min_scale: 175 | model.scaler.update(args.scaler_min_scale) 176 | else: 177 | model.scaler.update() 178 | 179 | return loss.item(), out.detach().cpu().numpy() 180 | 181 | 182 | def train_and_eval(start_step, model, device, dataset, eval_dataset, losses, args): 183 | trainloader = get_loader(dataset, args.dataset, mode="train") 184 | 185 | loss_fn = get_loss(args.training.losses) 186 | loss_fn = loss_fn.to(device) 187 | 188 | log_dir = os.path.join(args.save_dir, args.index) 189 | os.makedirs(os.path.join(log_dir, "test"), exist_ok=True) 190 | log_dir = os.path.join(log_dir, "test") 191 | 192 | steps = [] 193 | train_losses, eval_losses, eval_psnrs = losses 194 | pt_lrs = [] 195 | attn_lrs = [] 196 | 197 | avg_train_loss = 0. 198 | step = start_step 199 | eval_step_cnt = start_step 200 | pruned = False 201 | pc_frames = [] 202 | 203 | print("Start step:", start_step, "Total steps:", args.training.steps) 204 | start_time = time.time() 205 | while step < args.training.steps: 206 | for _, batch in enumerate(trainloader): 207 | if (args.training.prune_steps > 0) and (step < args.training.prune_stop) and (step >= args.training.prune_start): 208 | if len(args.training.prune_steps_list) > 0 and step % args.training.prune_steps == 0: 209 | cur_prune_thresh = args.training.prune_thresh_list[bisect.bisect_left(args.training.prune_steps_list, step)] 210 | model.clear_optimizer() 211 | model.clear_scheduler() 212 | num_pruned = model.prune_points(cur_prune_thresh) 213 | model.init_optimizers(step) 214 | pruned = True 215 | print("Step %d: Pruned %d points, prune threshold %f" % (step, num_pruned, cur_prune_thresh)) 216 | 217 | elif step % args.training.prune_steps == 0: 218 | model.clear_optimizer() 219 | model.clear_scheduler() 220 | num_pruned = model.prune_points(args.training.prune_thresh) 221 | model.init_optimizers(step) 222 | pruned = True 223 | print("Step %d: Pruned %d points" % (step, num_pruned)) 224 | 225 | if pruned and len(args.training.add_steps_list) > 0: 226 | if step in args.training.add_steps_list: 227 | cur_add_num = args.training.add_num_list[args.training.add_steps_list.index(step)] 228 | if 'max_num_pts' in args and args.max_num_pts > 0: 229 | cur_add_num = min(cur_add_num, args.max_num_pts - model.points.shape[0]) 230 | 231 | if cur_add_num > 0: 232 | model.clear_optimizer() 233 | model.clear_scheduler() 234 | num_added = model.add_points(cur_add_num) 235 | model.init_optimizers(step) 236 | model.added_points = True 237 | print("Step %d: Added %d points" % (step, num_added)) 238 | 239 | elif pruned and (args.training.add_steps > 0) and (step % args.training.add_steps == 0) and (step < args.training.add_stop) and (step >= args.training.add_start): 240 | cur_add_num = args.training.add_num 241 | if 'max_num_pts' in args and args.max_num_pts > 0: 242 | cur_add_num = min(cur_add_num, args.max_num_pts - model.points.shape[0]) 243 | 244 | if cur_add_num > 0: 245 | model.clear_optimizer() 246 | model.clear_scheduler() 247 | num_added = model.add_points(args.training.add_num) 248 | model.init_optimizers(step) 249 | model.added_points = True 250 | print("Step %d: Added %d points" % (step, num_added)) 251 | 252 | loss, out = train_step(step, model, device, dataset, batch, loss_fn, args) 253 | avg_train_loss += loss 254 | step += 1 255 | eval_step_cnt += 1 256 | 257 | if step % 200 == 0: 258 | time_used = time.time() - start_time 259 | print("Train step:", step, "loss:", loss, "attn_lr:", model.attn_lr, "pts_lr:", model.pts_lr, "scale:", model.scaler.get_scale(), f"time: {time_used:.2f}s") 260 | start_time = time.time() 261 | 262 | if (step % args.eval.step == 0) or (step % 500 == 0 and step < 10000): 263 | train_losses.append(avg_train_loss / eval_step_cnt) 264 | pt_lrs.append(model.pts_lr) 265 | attn_lrs.append(model.attn_lr) 266 | steps.append(step) 267 | eval_step(steps, model, device, dataset, eval_dataset, batch, loss_fn, out, args, train_losses, eval_losses, eval_psnrs, pt_lrs, attn_lrs) 268 | avg_train_loss = 0. 269 | eval_step_cnt = 0 270 | 271 | if ((step - 1) % 200 == 0) and args.eval.save_fig: 272 | coord_scale = args.dataset.coord_scale 273 | pt_plot_scale = 0.8 * coord_scale 274 | if "Barn" in args.dataset.path: 275 | pt_plot_scale *= 1.5 276 | if "Family" in args.dataset.path: 277 | pt_plot_scale *= 0.5 278 | 279 | pc_dir = os.path.join(log_dir, "point_clouds") 280 | os.makedirs(pc_dir, exist_ok=True) 281 | 282 | points_np = model.points.detach().cpu().numpy() 283 | points_influ_scores_np = None 284 | if model.points_influ_scores is not None: 285 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy() 286 | pcd_plot = get_training_pcd_single_plot(step, points_np, pt_plot_scale, points_influ_scores_np) 287 | pc_frames.append(pcd_plot) 288 | 289 | if step == 1: 290 | pcd_plot.save(os.path.join(pc_dir, "init_pcd.png")) 291 | 292 | if step >= args.training.steps: 293 | break 294 | 295 | if args.eval.save_fig and pc_frames != []: 296 | f = os.path.join(log_dir, f"{args.index}-pc.mp4") 297 | imageio.mimwrite(f, pc_frames, fps=30, quality=10) 298 | 299 | print("Training finished!") 300 | 301 | 302 | def main(args, eval_args, resume): 303 | log_dir = os.path.join(args.save_dir, args.index) 304 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 305 | 306 | model = get_model(args, device) 307 | dataset = get_dataset(args.dataset, mode="train") 308 | eval_dataset = get_dataset(eval_args.dataset, mode="test") 309 | model = model.to(device) 310 | 311 | # if torch.__version__ >= "2.0": 312 | # model = torch.compile(model) 313 | 314 | start_step = 0 315 | losses = [[], [], []] 316 | if resume > 0: 317 | start_step = model.load(log_dir) 318 | 319 | train_losses = torch.load(os.path.join(log_dir, "train_losses.pth")).tolist() 320 | eval_losses = torch.load(os.path.join(log_dir, "eval_losses.pth")).tolist() 321 | eval_psnrs = torch.load(os.path.join(log_dir, "eval_psnrs.pth")).tolist() 322 | losses = [train_losses, eval_losses, eval_psnrs] 323 | 324 | print("!!!!! Resume from step %s" % start_step) 325 | elif args.load_path: 326 | try: 327 | resume_step = model.load(os.path.join(args.save_dir, args.load_path)) 328 | except: 329 | model_state_dict = torch.load(os.path.join(args.save_dir, args.load_path, "model.pth")) 330 | for step, state_dict in model_state_dict.items(): 331 | resume_step = step 332 | model.load_my_state_dict(state_dict) 333 | print("!!!!! Loaded model from %s at step %s" % (args.load_path, resume_step)) 334 | 335 | train_and_eval(start_step, model, device, dataset, eval_dataset, losses, args) 336 | print(torch.cuda.memory_summary()) 337 | 338 | 339 | if __name__ == '__main__': 340 | 341 | with open("configs/default.yml", 'r') as f: 342 | default_config = yaml.safe_load(f) 343 | 344 | args = parse_args() 345 | with open(args.opt, 'r') as f: 346 | config = yaml.safe_load(f) 347 | 348 | train_config = copy.deepcopy(default_config) 349 | update_dict(train_config, config) 350 | 351 | eval_config = copy.deepcopy(train_config) 352 | eval_config['dataset'].update(eval_config['eval']['dataset']) 353 | eval_config = DictAsMember(eval_config) 354 | train_config = DictAsMember(train_config) 355 | 356 | log_dir = os.path.join(train_config.save_dir, train_config.index) 357 | os.makedirs(log_dir, exist_ok=True) 358 | 359 | sys.stdout = Logger(os.path.join(log_dir, 'train.log'), sys.stdout) 360 | sys.stderr = Logger(os.path.join(log_dir, 'train_error.log'), sys.stderr) 361 | 362 | shutil.copyfile(__file__, os.path.join(log_dir, os.path.basename(__file__))) 363 | shutil.copyfile(args.opt, os.path.join(log_dir, os.path.basename(args.opt))) 364 | 365 | find_all_python_files_and_zip(".", os.path.join(log_dir, "code.zip")) 366 | 367 | setup_seed(train_config.seed) 368 | 369 | main(train_config, eval_config, args.resume) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim.lr_scheduler as lr_scheduler 4 | import scipy 5 | from scipy.spatial import KDTree 6 | import numpy as np 7 | 8 | 9 | def add_points_knn(coords, influ_scores, add_num, k, comb_type="mean", sample_type="random", sample_k=10, point_features=None): 10 | """ 11 | Add points to the point cloud by kNN 12 | """ 13 | pc = KDTree(coords) 14 | N = coords.shape[0] 15 | 16 | # Step 1: Determine where to add points 17 | if N <= add_num and "random" in comb_type: 18 | inds = np.random.choice(N, add_num, replace=True) 19 | query_coords = coords[inds, :] 20 | elif N <= add_num: 21 | query_coords = coords 22 | inds = list(range(N)) 23 | else: 24 | if sample_type == "random": 25 | inds = np.random.choice(N, add_num, replace=False) 26 | query_coords = coords[inds, :] 27 | elif sample_type == "top-knn-std": 28 | assert k >= 2 29 | nns_dists, nns_inds = pc.query(coords, k=sample_k) 30 | inds = np.argsort(nns_dists.std(axis=-1))[-add_num:] 31 | query_coords = coords[inds, :] 32 | elif sample_type == "top-knn-mean": 33 | assert k >= 2 34 | nns_dists, nns_inds = pc.query(coords, k=sample_k) 35 | inds = np.argsort(nns_dists.mean(axis=-1))[-add_num:] 36 | query_coords = coords[inds, :] 37 | elif sample_type == "top-knn-max": 38 | assert k >= 2 39 | nns_dists, nns_inds = pc.query(coords, k=sample_k) 40 | inds = np.argsort(nns_dists.max(axis=-1))[-add_num:] 41 | query_coords = coords[inds, :] 42 | elif sample_type == "top-knn-min": 43 | assert k >= 2 44 | nns_dists, nns_inds = pc.query(coords, k=sample_k) 45 | inds = np.argsort(nns_dists.min(axis=-1))[-add_num:] 46 | query_coords = coords[inds, :] 47 | elif sample_type == "influ-scores-max": 48 | inds = np.argsort(influ_scores.squeeze())[-add_num:] 49 | query_coords = coords[inds, :] 50 | elif sample_type == "influ-scores-min": 51 | inds = np.argsort(influ_scores.squeeze())[:add_num] 52 | query_coords = coords[inds, :] 53 | else: 54 | raise NotImplementedError 55 | 56 | # Step 2: Add points by kNN 57 | new_features = None 58 | if comb_type == "duplicate": 59 | noise = np.random.randn(3).astype(np.float32) 60 | noise = noise / np.linalg.norm(noise) 61 | noise *= k 62 | new_coords = (query_coords + noise) 63 | new_influ_scores = influ_scores[inds, :] 64 | if point_features is not None: 65 | new_features = point_features[inds, :] 66 | else: 67 | nns_dists, nns_inds = pc.query(query_coords, k=k+1) 68 | nns_dists = nns_dists.astype(np.float32) 69 | nns_dists = nns_dists[:, 1:] 70 | nns_inds = nns_inds[:, 1:] 71 | if comb_type == "mean": 72 | new_coords = coords[nns_inds, :].mean( 73 | axis=-2) # (Nq, k, 3) -> (Nq, 3) 74 | new_influ_scores = influ_scores[nns_inds, :].mean(axis=-2) 75 | if point_features is not None: 76 | new_features = point_features[nns_inds, :].mean(axis=-2) 77 | elif comb_type == "random": 78 | rnd_w = np.random.uniform( 79 | 0, 1, (query_coords.shape[0], k)).astype(np.float32) 80 | rnd_w /= rnd_w.sum(axis=-1, keepdims=True) 81 | new_coords = (coords[nns_inds, :] * 82 | rnd_w.reshape(-1, k, 1)).sum(axis=-2) 83 | new_influ_scores = ( 84 | influ_scores[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2) 85 | if point_features is not None: 86 | new_features = ( 87 | point_features[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2) 88 | elif comb_type == "random-softmax": 89 | rnd_w = np.random.randn( 90 | query_coords.shape[0], k).astype(np.float32) 91 | rnd_w = scipy.special.softmax(rnd_w, axis=-1) 92 | new_coords = (coords[nns_inds, :] * 93 | rnd_w.reshape(-1, k, 1)).sum(axis=-2) 94 | new_influ_scores = ( 95 | influ_scores[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2) 96 | if point_features is not None: 97 | new_features = ( 98 | point_features[nns_inds, :] * rnd_w.reshape(-1, k, 1)).sum(axis=-2) 99 | elif comb_type == "weighted": 100 | new_coords = (coords[nns_inds, :] * (1 / (nns_dists + 1e-6)).reshape(-1, k, 1)).sum( 101 | axis=-2) / (1 / (nns_dists + 1e-6)).sum(axis=-1, keepdims=True) 102 | new_influ_scores = (influ_scores[nns_inds, :] * (1 / (nns_dists + 1e-6)).reshape(-1, k, 1)).sum( 103 | axis=-2) / (1 / (nns_dists + 1e-6)).sum(axis=-1, keepdims=True) 104 | if point_features is not None: 105 | new_features = (point_features[nns_inds, :] * (1 / (nns_dists + 1e-6)).reshape(-1, k, 1)).sum( 106 | axis=-2) / (1 / (nns_dists + 1e-6)).sum(axis=-1, keepdims=True) 107 | else: 108 | raise NotImplementedError 109 | return new_coords, len(new_coords), new_influ_scores, new_features 110 | 111 | 112 | def cam_to_world(coords, c2w, vector=True): 113 | """ 114 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3] 115 | c2w: [N, 4, 4] or [4, 4] 116 | """ 117 | if vector: # Convert to homogeneous coordinates 118 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1) 119 | else: 120 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1) 121 | 122 | if coords.ndim == 5: 123 | assert c2w.ndim == 2 124 | B, H, W, N, _ = coords.shape 125 | transformed_coords = torch.sum( 126 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3] 127 | elif coords.ndim == 4: 128 | assert c2w.ndim == 3 129 | N, H, W, _ = coords.shape 130 | transformed_coords = torch.sum( 131 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4] 132 | elif coords.ndim == 3: 133 | assert c2w.ndim == 2 134 | H, W, _ = coords.shape 135 | transformed_coords = torch.sum( 136 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4] 137 | elif coords.ndim == 2: 138 | assert c2w.ndim == 2 139 | K, _ = coords.shape 140 | transformed_coords = torch.sum( 141 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4] 142 | else: 143 | raise ValueError('Wrong dimension of coords') 144 | return transformed_coords[..., :3] 145 | 146 | 147 | def world_to_cam(coords, c2w, vector=True): 148 | """ 149 | coords: [N, H, W, 3] or [H, W, 3] or [K, 3] 150 | c2w: [N, 4, 4] or [4, 4] 151 | """ 152 | if vector: # Convert to homogeneous coordinates 153 | coords = torch.cat([coords, torch.zeros_like(coords[..., :1])], -1) 154 | else: 155 | coords = torch.cat([coords, torch.ones_like(coords[..., :1])], -1) 156 | 157 | c2w = torch.inverse(c2w) 158 | if coords.ndim == 5: 159 | assert c2w.ndim == 2 160 | B, H, W, N, _ = coords.shape 161 | transformed_coords = torch.sum( 162 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 1, 1, 4, 4), -1) # [B, H, W, N, 3] 163 | elif coords.ndim == 4: 164 | assert c2w.ndim == 3 165 | N, H, W, _ = coords.shape 166 | transformed_coords = torch.sum( 167 | coords.unsqueeze(-2) * c2w.reshape(N, 1, 1, 4, 4), -1) # [N, H, W, 4] 168 | elif coords.ndim == 3: 169 | assert c2w.ndim == 2 170 | H, W, _ = coords.shape 171 | transformed_coords = torch.sum( 172 | coords.unsqueeze(-2) * c2w.reshape(1, 1, 4, 4), -1) # [H, W, 4] 173 | elif coords.ndim == 2: 174 | assert c2w.ndim == 2 175 | K, _ = coords.shape 176 | transformed_coords = torch.sum( 177 | coords.unsqueeze(-2) * c2w.reshape(1, 4, 4), -1) # [K, 4] 178 | else: 179 | raise ValueError('Wrong dimension of coords') 180 | return transformed_coords[..., :3] 181 | 182 | 183 | def activation_func(act_type='leakyrelu', neg_slope=0.2, inplace=True, num_channels=128, a=1., b=1., trainable=False): 184 | act_type = act_type.lower() 185 | if act_type == 'none': 186 | layer = nn.Identity() 187 | elif act_type == 'leakyrelu': 188 | layer = nn.LeakyReLU(neg_slope, inplace) 189 | elif act_type == 'prelu': 190 | layer = nn.PReLU(num_channels) 191 | elif act_type == 'relu': 192 | layer = nn.ReLU(inplace) 193 | elif act_type == '+1': 194 | layer = PlusOneActivation() 195 | elif act_type == 'relu+1': 196 | layer = nn.Sequential(nn.ReLU(inplace), PlusOneActivation()) 197 | elif act_type == 'tanh': 198 | layer = nn.Tanh() 199 | elif act_type == 'shifted_tanh': 200 | layer = ShiftedTanh() 201 | elif act_type == 'sigmoid': 202 | layer = nn.Sigmoid() 203 | elif act_type == 'gelu': 204 | layer = nn.GELU() 205 | elif act_type == 'gaussian': 206 | layer = GaussianActivation(a, trainable) 207 | elif act_type == 'quadratic': 208 | layer = QuadraticActivation(a, trainable) 209 | elif act_type == 'multi-quadratic': 210 | layer = MultiQuadraticActivation(a, trainable) 211 | elif act_type == 'laplacian': 212 | layer = LaplacianActivation(a, trainable) 213 | elif act_type == 'super-gaussian': 214 | layer = SuperGaussianActivation(a, b, trainable) 215 | elif act_type == 'expsin': 216 | layer = ExpSinActivation(a, trainable) 217 | elif act_type == 'clamp': 218 | layer = Clamp(0, 1) 219 | elif 'sine' in act_type: 220 | layer = Sine(factor=a) 221 | elif 'softplus' in act_type: 222 | a, b, c = [float(i) for i in act_type.split('_')[1:]] 223 | print( 224 | 'Softplus activation: a={:.2f}, b={:.2f}, c={:.2f}'.format(a, b, c)) 225 | layer = SoftplusActivation(a, b, c) 226 | else: 227 | raise NotImplementedError( 228 | 'activation layer [{:s}] is not found'.format(act_type)) 229 | return layer 230 | 231 | 232 | def posenc(x, L_embed, factor=2.0, without_self=False, mult_factor=1.0): 233 | if without_self: 234 | rets = [] 235 | else: 236 | rets = [x] 237 | for i in range(L_embed): 238 | for fn in [torch.sin, torch.cos]: 239 | rets.append(fn(factor**i * x * mult_factor)) 240 | # return torch.cat(rets, 1) 241 | # To make sure the dimensions of the same meaning are together 242 | return torch.flatten(torch.stack(rets, -1), start_dim=-2, end_dim=-1) 243 | 244 | 245 | class PoseEnc(nn.Module): 246 | def __init__(self, factor=2.0, mult_factor=1.0): 247 | super(PoseEnc, self).__init__() 248 | self.factor = factor 249 | self.mult_factor = mult_factor 250 | 251 | def forward(self, x, L_embed, without_self=False): 252 | return posenc(x, L_embed, self.factor, without_self, self.mult_factor) 253 | 254 | 255 | def normalize_vector(x, eps=0.): 256 | # assert(x.shape[-1] == 3) 257 | return x / (torch.norm(x, dim=-1, keepdim=True) + eps) 258 | 259 | 260 | def create_learning_rate_fn(optimizer, max_steps, args, debug=False): 261 | """Create learning rate schedule.""" 262 | if args.type == "none": 263 | return None 264 | 265 | if args.warmup > 0: 266 | warmup_start_factor = 1e-16 267 | else: 268 | warmup_start_factor = 1.0 269 | 270 | warmup_fn = lr_scheduler.LinearLR(optimizer, 271 | start_factor=warmup_start_factor, 272 | end_factor=1.0, 273 | total_iters=args.warmup, 274 | verbose=debug) 275 | 276 | if args.type == "linear": 277 | decay_fn = lr_scheduler.LinearLR(optimizer, 278 | start_factor=1.0, 279 | end_factor=0., 280 | total_iters=max_steps - args.warmup, 281 | verbose=debug) 282 | schedulers = [warmup_fn, decay_fn] 283 | milestones = [args.warmup] 284 | 285 | elif args.type == "cosine": 286 | cosine_steps = max(max_steps - args.warmup, 1) 287 | decay_fn = lr_scheduler.CosineAnnealingLR(optimizer, 288 | T_max=cosine_steps, 289 | verbose=debug) 290 | schedulers = [warmup_fn, decay_fn] 291 | milestones = [args.warmup] 292 | 293 | elif args.type == "cosine-hlfperiod": 294 | cosine_steps = max(max_steps - args.warmup, 1) * 2 295 | decay_fn = lr_scheduler.CosineAnnealingLR(optimizer, 296 | T_max=cosine_steps, 297 | verbose=debug) 298 | schedulers = [warmup_fn, decay_fn] 299 | milestones = [args.warmup] 300 | 301 | elif args.type == "exp": 302 | decay_fn = lr_scheduler.ExponentialLR(optimizer, 303 | gamma=args.gamma, 304 | verbose=debug) 305 | schedulers = [warmup_fn, decay_fn] 306 | milestones = [args.warmup] 307 | 308 | elif args.type == "stop": 309 | decay_fn = lr_scheduler.StepLR( 310 | optimizer, step_size=1, gamma=0.0, verbose=debug) 311 | schedulers = [warmup_fn, decay_fn] 312 | milestones = [args.warmup] 313 | 314 | else: 315 | raise NotImplementedError 316 | 317 | schedule_fn = lr_scheduler.SequentialLR(optimizer, 318 | schedulers=schedulers, 319 | milestones=milestones, 320 | verbose=debug) 321 | 322 | return schedule_fn 323 | 324 | 325 | class Sine(nn.Module): 326 | def __init__(self, factor=30): 327 | super().__init__() 328 | self.factor = factor 329 | 330 | def forward(self, x): 331 | return torch.sin(x * self.factor) 332 | 333 | 334 | class Clamp(nn.Module): 335 | def __init__(self, min_val, max_val): 336 | super().__init__() 337 | self.min_val = min_val 338 | self.max_val = max_val 339 | 340 | def forward(self, x): 341 | return torch.clamp(x, self.min_val, self.max_val) 342 | 343 | 344 | class ShiftedTanh(nn.Module): 345 | def __init__(self): 346 | super().__init__() 347 | 348 | def forward(self, x): 349 | return (torch.tanh(x) + 1) / 2 350 | 351 | 352 | class SoftplusActivation(nn.Module): 353 | def __init__(self, c1=1, c2=1, c3=0): 354 | super().__init__() 355 | self.c1 = c1 356 | self.c2 = c2 357 | self.c3 = c3 358 | 359 | def forward(self, x): 360 | return self.c1 * nn.functional.softplus(self.c2 * x + self.c3) 361 | 362 | 363 | class GaussianActivation(nn.Module): 364 | def __init__(self, a=1., trainable=True): 365 | super().__init__() 366 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable)) 367 | 368 | def forward(self, x): 369 | return torch.exp(-x**2/(2*self.a**2)) 370 | 371 | 372 | class QuadraticActivation(nn.Module): 373 | def __init__(self, a=1., trainable=True): 374 | super().__init__() 375 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable)) 376 | 377 | def forward(self, x): 378 | return 1/(1+(self.a*x)**2) 379 | 380 | 381 | class MultiQuadraticActivation(nn.Module): 382 | def __init__(self, a=1., trainable=True): 383 | super().__init__() 384 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable)) 385 | 386 | def forward(self, x): 387 | return 1/(1+(self.a*x)**2)**0.5 388 | 389 | 390 | class LaplacianActivation(nn.Module): 391 | def __init__(self, a=1., trainable=True): 392 | super().__init__() 393 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable)) 394 | 395 | def forward(self, x): 396 | return torch.exp(-torch.abs(x)/self.a) 397 | 398 | 399 | class SuperGaussianActivation(nn.Module): 400 | def __init__(self, a=1., b=1., trainable=True): 401 | super().__init__() 402 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable)) 403 | self.register_parameter('b', nn.Parameter(b*torch.ones(1), trainable)) 404 | 405 | def forward(self, x): 406 | return torch.exp(-x**2/(2*self.a**2))**self.b 407 | 408 | 409 | class ExpSinActivation(nn.Module): 410 | def __init__(self, a=1., trainable=True): 411 | super().__init__() 412 | self.register_parameter('a', nn.Parameter(a*torch.ones(1), trainable)) 413 | 414 | def forward(self, x): 415 | return torch.exp(-torch.sin(self.a*x)) 416 | 417 | 418 | class PlusOneActivation(nn.Module): 419 | def __init__(self): 420 | super().__init__() 421 | 422 | def forward(self, x): 423 | return x + 1 424 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datetime import datetime 3 | import matplotlib.pyplot as plt 4 | import io 5 | from PIL import Image 6 | import numpy as np 7 | import os 8 | import zipfile 9 | import torch 10 | import random 11 | import copy 12 | 13 | 14 | class DictAsMember(dict): 15 | def __getattr__(self, name): 16 | value = self[name] 17 | if isinstance(value, dict): 18 | value = DictAsMember(value) 19 | return value 20 | 21 | 22 | def update_dict(original, param): 23 | for key in param.keys(): 24 | if type(param[key]) == dict: 25 | update_dict(original[key], param[key]) 26 | elif type(param[key]) == list and key == "datasets": 27 | for i in range(len(param[key])): 28 | name = param[key][i]['name'] 29 | for j in range(len(original[key])): 30 | if original[key][j]['name'] == name: 31 | for k in param[key][i].keys(): 32 | original[key][j][k] = param[key][i][k] 33 | break 34 | else: 35 | new_param = copy.deepcopy(original[key][0]) 36 | update_dict(new_param, param[key][i]) 37 | original[key].append(new_param) 38 | else: 39 | original[key] = param[key] 40 | 41 | 42 | def setup_seed(seed): 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed_all(seed) 45 | np.random.seed(seed) 46 | random.seed(seed) 47 | 48 | 49 | def find_all_python_files_and_zip(src_dir, dst_path): 50 | # find all python files in src_dir 51 | python_files = [] 52 | for root, dirs, files in os.walk(src_dir): 53 | if 'experiment' in root: 54 | continue 55 | for cur_file in files: 56 | if cur_file.endswith('.py'): 57 | python_files.append(os.path.join(root, cur_file)) 58 | 59 | # zip all python files 60 | with zipfile.ZipFile(dst_path, 'w') as zip_file: 61 | for cur_file in python_files: 62 | zip_file.write(cur_file, os.path.relpath(cur_file, src_dir)) 63 | 64 | 65 | class Logger(object): 66 | def __init__(self, filename='default.log', stream=sys.stdout): 67 | self.terminal = stream 68 | self.log = open(filename, 'a') 69 | ct = datetime.now() 70 | self.log.write('*'*50 + '\n' + str(ct) + '\n' + '*'*50 + '\n') 71 | 72 | def write(self, message): 73 | self.terminal.write(message) 74 | self.log.write(message) 75 | 76 | def flush(self): 77 | pass 78 | 79 | 80 | def get_colors(weights): 81 | N = weights.shape[0] 82 | weights = (weights - weights.min()) / (weights.max() - weights.min()) 83 | colors = np.full((N, 3), [1., 0., 0.]) 84 | colors[:, 0] *= weights[:N] 85 | colors[:, 2] = (1 - weights[:N]) 86 | return colors 87 | 88 | 89 | def get_training_main_plot(index, steps, train_tgt_rgb, train_tgt_patch, train_pred_patch, test_tgt_tgb, test_pred_rgb, train_losses, 90 | eval_losses, points_np, pt_plot_scale, depth_np, pt_lrs, attn_lrs, eval_psnrs, points_conf_scores_np=None): 91 | step = steps[-1] 92 | fig = plt.figure(figsize=(20, 10)) 93 | 94 | ax = fig.add_subplot(2, 5, 1) 95 | ax.imshow(train_tgt_rgb) 96 | ax.set_title(f'Iteration: {step} train norm') 97 | 98 | ax = fig.add_subplot(2, 5, 2) 99 | ax.imshow(train_tgt_patch) 100 | ax.set_title(f'Iteration: {step} train norm patch') 101 | 102 | ax = fig.add_subplot(2, 5, 3) 103 | ax.imshow(train_pred_patch) 104 | ax.set_title(f'Iteration: {step} train output') 105 | 106 | ax = fig.add_subplot(2, 5, 4) 107 | ax.plot(steps, train_losses, label='train') 108 | ax.plot(steps, eval_losses, label='eval') 109 | ax.legend() 110 | ax.set_title('losses') 111 | 112 | ax = fig.add_subplot(2, 5, 5, projection='3d') 113 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 114 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 115 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 116 | ax.set_xlabel('x') 117 | ax.set_ylabel('y') 118 | ax.set_zlabel('z') 119 | cur_color = "grey" 120 | if points_conf_scores_np is not None: 121 | cur_color = get_colors(points_conf_scores_np) 122 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color) 123 | ax.set_title('Point Cloud') 124 | 125 | ax = fig.add_subplot(2, 5, 6) 126 | cd = ax.imshow(depth_np) 127 | fig.colorbar(cd, ax=ax) 128 | ax.set_title(f'depth map') 129 | 130 | ax = fig.add_subplot(2, 5, 7) 131 | ax.imshow(test_tgt_tgb) 132 | ax.set_title(f'Iteration: {step} eval norm') 133 | 134 | ax = fig.add_subplot(2, 5, 8) 135 | ax.imshow(test_pred_rgb) 136 | ax.set_title(f'Iteration: {step} eval predict') 137 | 138 | ax = fig.add_subplot(2, 5, 9) 139 | ax.plot(steps, np.log10(pt_lrs), label="pt lr") 140 | ax.plot(steps, np.log10(attn_lrs), label="attn lr") 141 | ax.legend() 142 | ax.set_title('learning rates log10') 143 | 144 | ax = fig.add_subplot(2, 5, 10) 145 | ax.plot(steps, eval_psnrs) 146 | ax.set_title('eval psnr') 147 | 148 | fig.suptitle("Main Plot\n%s\niter %d\nnum pts: %d" % (index, step, points_np.shape[0])) 149 | 150 | canvas = fig.canvas 151 | buffer = io.BytesIO() 152 | canvas.print_png(buffer) 153 | data = buffer.getvalue() 154 | buffer.write(data) 155 | img = Image.open(buffer) 156 | plt.close() 157 | 158 | return img 159 | 160 | 161 | def get_training_pcd_plot(index, step, ro, rd, points_np, coord_scale, pt_plot_scale, points_conf_scores_np=None): 162 | num_plots = 6 if points_conf_scores_np is not None else 4 163 | fig = plt.figure(figsize=(5 * num_plots, 6)) 164 | 165 | H, W, _ = rd.shape 166 | 167 | ax = fig.add_subplot(1, num_plots, 1, projection='3d') 168 | ax.view_init(elev=0., azim=90) 169 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 170 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 171 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 172 | ax.set_xlabel('x') 173 | ax.set_ylabel('y') 174 | ax.set_zlabel('z') 175 | cur_color = "orange" 176 | if points_conf_scores_np is not None: 177 | cur_color = get_colors(points_conf_scores_np) 178 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale) 179 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10) 180 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue") 181 | ax.set_title('Point Cloud View 1') 182 | 183 | ax = fig.add_subplot(1, num_plots, 2, projection='3d') 184 | ax.view_init(elev=0., azim=180) 185 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 186 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 187 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 188 | ax.set_xlabel('x') 189 | ax.set_ylabel('y') 190 | ax.set_zlabel('z') 191 | cur_color = "orange" 192 | if points_conf_scores_np is not None: 193 | cur_color = get_colors(points_conf_scores_np) 194 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale) 195 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10) 196 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue") 197 | ax.set_title('Point Cloud View 2') 198 | 199 | ax = fig.add_subplot(1, num_plots, 3, projection='3d') 200 | ax.view_init(elev=0., azim=270) 201 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 202 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 203 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 204 | ax.set_xlabel('x') 205 | ax.set_ylabel('y') 206 | ax.set_zlabel('z') 207 | cur_color = "orange" 208 | if points_conf_scores_np is not None: 209 | cur_color = get_colors(points_conf_scores_np) 210 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale) 211 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10) 212 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue") 213 | ax.set_title('Point Cloud View 3') 214 | 215 | ax = fig.add_subplot(1, num_plots, 4, projection='3d') 216 | ax.view_init(elev=89.9, azim=90) 217 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 218 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 219 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 220 | ax.set_xlabel('x') 221 | ax.set_ylabel('y') 222 | ax.set_zlabel('z') 223 | cur_color = "orange" 224 | if points_conf_scores_np is not None: 225 | cur_color = get_colors(points_conf_scores_np) 226 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.8 * coord_scale) 227 | ax.scatter(ro[0], ro[1], ro[2], c="red", s=10) 228 | ax.quiver(ro[0], ro[1], ro[2], rd[H//2, W//2, 0], rd[H//2, W//2, 1], rd[H//2, W//2, 2], length=2, alpha=1, color="blue") 229 | ax.set_title('Point Cloud View 1 Up') 230 | 231 | if points_conf_scores_np is not None: 232 | ax = fig.add_subplot(1, num_plots, 5) 233 | ax.scatter(range(len(points_conf_scores_np)), points_conf_scores_np) 234 | ax.set_title('Confidence Scores scatter plot') 235 | 236 | ax = fig.add_subplot(1, num_plots, 6) 237 | bins = np.linspace(-1, 1, 100).tolist() 238 | ax.hist(points_conf_scores_np, bins=bins) 239 | ax.set_title('Confidence Scores histogram') 240 | 241 | fig.suptitle("Point Clouds\n%s\niter %d" % (index, step)) 242 | 243 | canvas = fig.canvas 244 | buffer = io.BytesIO() 245 | canvas.print_png(buffer) 246 | data = buffer.getvalue() 247 | buffer.write(data) 248 | img = Image.open(buffer) 249 | plt.close() 250 | 251 | return img 252 | 253 | 254 | def get_training_pcd_single_plot(step, points_np, pt_plot_scale, points_conf_scores_np=None): 255 | fig = plt.figure(figsize=(5, 5)) 256 | fig.tight_layout() 257 | 258 | ax = fig.add_subplot(1, 1, 1, projection='3d') 259 | ax.view_init(elev=20., azim=90 + (step / 500) * (720. / 500)) 260 | ax.set_xlim3d(-pt_plot_scale * 1.5, pt_plot_scale * 1.5) 261 | ax.set_ylim3d(-pt_plot_scale * 1.5, pt_plot_scale * 1.5) 262 | ax.set_zlim3d(-pt_plot_scale * 1.5, pt_plot_scale * 1.5) 263 | ax.set_axis_off() 264 | cur_color = "orange" 265 | if points_conf_scores_np is not None: 266 | cur_color = get_colors(points_conf_scores_np) 267 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color) 268 | 269 | fig.suptitle("iter %d\n#points: %d" % (step, points_np.shape[0])) 270 | fig.tight_layout() 271 | 272 | canvas = fig.canvas 273 | buffer = io.BytesIO() 274 | canvas.print_png(buffer) 275 | data = buffer.getvalue() 276 | buffer.write(data) 277 | img = Image.open(buffer) 278 | plt.close() 279 | 280 | return img 281 | 282 | 283 | def get_test_pcrgb(frame, th, azmin, test_psnr, points_np, rgb_pred_np, rgb_gt_np, depth_np, pt_plot_scale, points_conf_scores_np=None): 284 | fig = plt.figure(figsize=(30, 10)) 285 | 286 | ax = fig.add_subplot(1, 5, 1, projection='3d') 287 | ax.axis('off') 288 | ax.view_init(elev=20., azim=90 - th) 289 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 290 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 291 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 292 | ax.set_xlabel('x') 293 | ax.set_ylabel('y') 294 | ax.set_zlabel('z') 295 | cur_color = "grey" 296 | if points_conf_scores_np is not None: 297 | cur_color = get_colors(points_conf_scores_np) 298 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color, s=0.5) 299 | 300 | ax = fig.add_subplot(1, 5, 2, projection='3d') 301 | # ax.axis('off') 302 | ax.view_init(elev=0., azim=azmin) 303 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 304 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 305 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 306 | ax.set_xlabel('x') 307 | ax.set_ylabel('y') 308 | ax.set_zlabel('z') 309 | switch_yz = np.array([[1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=np.float32) 310 | cur_points_np = points_np @ switch_yz 311 | cur_points_np[:, 2] = -cur_points_np[:, 2] 312 | cur_color = "orange" 313 | if points_conf_scores_np is not None: 314 | cur_color = get_colors(points_conf_scores_np) 315 | ax.scatter3D(cur_points_np[:, 0], cur_points_np[:, 1], cur_points_np[:, 2], c=cur_color, s=0.5) 316 | 317 | ax = fig.add_subplot(1, 5, 3) 318 | ax.axis('off') 319 | ax.imshow(rgb_pred_np) 320 | 321 | ax = fig.add_subplot(1, 5, 4) 322 | ax.axis('off') 323 | ax.imshow(rgb_gt_np) 324 | 325 | ax = fig.add_subplot(1, 5, 5) 326 | ax.axis('off') 327 | ax.imshow(depth_np) 328 | 329 | fig.suptitle("Point Cloud and RGB\nframe %d, PSNR %.3f, num points %d" % (frame, test_psnr, points_np.shape[0])) 330 | 331 | canvas = fig.canvas 332 | buffer = io.BytesIO() 333 | canvas.print_png(buffer) 334 | data = buffer.getvalue() 335 | buffer.write(data) 336 | img = Image.open(buffer) 337 | plt.close() 338 | 339 | return img 340 | 341 | 342 | def get_test_featmap_attn(frame, th, points_np, rgb_pred_np, rgb_gt_np, pt_plot_scale, featmap_np, attn_np, points_conf_scores_np=None): 343 | fig = plt.figure(figsize=(20, 15)) 344 | 345 | ax = fig.add_subplot(3, 5, 1, projection='3d') 346 | ax.axis('off') 347 | ax.view_init(elev=20., azim=90 - th) 348 | ax.set_xlim3d(-pt_plot_scale, pt_plot_scale) 349 | ax.set_ylim3d(-pt_plot_scale, pt_plot_scale) 350 | ax.set_zlim3d(-pt_plot_scale, pt_plot_scale) 351 | ax.set_xlabel('x') 352 | ax.set_ylabel('y') 353 | ax.set_zlabel('z') 354 | if points_conf_scores_np is not None: 355 | cur_color = get_colors(points_conf_scores_np) 356 | ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=cur_color) 357 | 358 | ax = fig.add_subplot(3, 5, 2) 359 | ax.axis('off') 360 | ax.imshow(rgb_gt_np) 361 | 362 | ax = fig.add_subplot(3, 5, 3) 363 | ax.axis('off') 364 | ax.imshow(rgb_pred_np) 365 | 366 | C = featmap_np.shape[-1] 367 | for i in range(min(5, C)): 368 | ax = fig.add_subplot(3, 5, 6 + i) 369 | ax.axis('off') 370 | cur_dim = C * i // 5 371 | cur_min = featmap_np[..., cur_dim].min() 372 | cur_max = featmap_np[..., cur_dim].max() 373 | ax.imshow(featmap_np[..., cur_dim]) 374 | ax.set_title(f'featmap dim {cur_dim}\nmin: %.3f, max: %.3f' % (cur_min, cur_max)) 375 | 376 | K = attn_np.shape[-1] 377 | for i in range(min(K-1, 4)): 378 | ax = fig.add_subplot(3, 5, 11 + i) 379 | ax.axis('off') 380 | cur_dim = K * i // 4 381 | cur_min = attn_np[..., cur_dim].min() 382 | cur_max = attn_np[..., cur_dim].max() 383 | ax.imshow(attn_np[..., cur_dim]) 384 | ax.set_title(f'attn dim {cur_dim}\nmin: %.3f, max: %.3f' % (cur_min, cur_max)) 385 | 386 | ax = fig.add_subplot(3, 5, 15) 387 | ax.axis('off') 388 | cur_min = attn_np[..., -1].min() 389 | cur_max = attn_np[..., -1].max() 390 | ax.imshow(attn_np[..., -1]) 391 | ax.set_title(f'attn dim -1\nmin: %.3f, max: %.3f' % (cur_min, cur_max)) 392 | 393 | fig.suptitle("feature map and attention\nframe %d\n" % (frame)) 394 | 395 | canvas = fig.canvas 396 | buffer = io.BytesIO() 397 | canvas.print_png(buffer) 398 | data = buffer.getvalue() 399 | buffer.write(data) 400 | img = Image.open(buffer) 401 | plt.close() 402 | 403 | return img 404 | 405 | 406 | def resample_shading_codes(shading_codes, args, model, dataset, img_id, loss_fn, step, full_img=False): 407 | if full_img == True: 408 | img, rayd, rayo = dataset.get_full_img(img_id) 409 | c2w = dataset.get_c2w(img_id) 410 | else: 411 | _, _, img, rayd, rayo = dataset[img_id] 412 | c2w = dataset.get_c2w(img_id) 413 | img = torch.from_numpy(img).unsqueeze(0) 414 | rayd = torch.from_numpy(rayd).unsqueeze(0) 415 | rayo = torch.from_numpy(rayo).unsqueeze(0) 416 | 417 | sampled_shading_codes = torch.randn(args.exposure_control.shading_code_num_samples, 418 | args.exposure_control.shading_code_dim, device=model.device) \ 419 | * args.exposure_control.shading_code_scale 420 | 421 | N, H, W, _ = rayd.shape 422 | num_pts, _ = model.points.shape 423 | 424 | rayo = rayo.to(model.device) 425 | rayd = rayd.to(model.device) 426 | img = img.to(model.device) 427 | c2w = c2w.to(model.device) 428 | 429 | topk = min([num_pts, model.select_k]) 430 | 431 | bkg_seq_len_attn = 0 432 | feat_dim = args.models.attn.embed.value.d_ff_out 433 | if model.bkg_feats is not None: 434 | bkg_seq_len_attn = model.bkg_feats.shape[0] 435 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(model.device) 436 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(model.device) 437 | 438 | best_idx = 0 439 | best_loss = 1e10 440 | best_loss_idx = 0 441 | best_psnr = 0 442 | best_psnr_idx = 0 443 | 444 | with torch.no_grad(): 445 | for height_start in range(0, H, args.eval.max_height): 446 | for width_start in range(0, W, args.eval.max_width): 447 | height_end = min(height_start + args.eval.max_height, H) 448 | width_end = min(width_start + args.eval.max_width, W) 449 | 450 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \ 451 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=step) 452 | 453 | for i in range(args.exposure_control.shading_code_num_samples): 454 | torch.cuda.empty_cache() 455 | cur_shading_code = sampled_shading_codes[i] 456 | cur_affine = model.mapping_mlp(cur_shading_code) 457 | cur_affine_dim = cur_affine.shape[-1] 458 | cur_gamma, cur_beta = cur_affine[:cur_affine_dim // 2], cur_affine[cur_affine_dim // 2:] 459 | # print(cur_shading_code.min().item(), cur_shading_code.max().item(), cur_gamma.min().item(), cur_gamma.max().item(), cur_beta.min().item(), cur_beta.max().item()) 460 | 461 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2), gamma=cur_gamma, beta=cur_beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3) 462 | 463 | if model.bkg_feats is not None: 464 | bkg_attn = attn[..., topk:, :] 465 | if args.models.normalize_topk_attn: 466 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 467 | else: 468 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 469 | rgb = rgb.squeeze(-2) 470 | else: 471 | rgb = foreground_rgb.squeeze(-2) 472 | 473 | rgb = model.last_act(rgb) 474 | # rgb = torch.clamp(rgb, 0, 1) 475 | 476 | eval_loss = loss_fn(rgb, img) 477 | eval_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.) 478 | 479 | if eval_loss < best_loss: 480 | best_loss = eval_loss 481 | best_loss_idx = i 482 | 483 | if eval_psnr > best_psnr: 484 | best_psnr = eval_psnr 485 | best_psnr_idx = i 486 | 487 | model.clear_grad() 488 | 489 | # print("Best loss:", best_loss, "Best loss idx:", best_loss_idx, "Best psnr:", best_psnr, "Best psnr idx:", best_psnr_idx) 490 | best_idx = best_loss_idx if args.exposure_control.shading_code_resample_select_by == "loss" else best_psnr_idx 491 | shading_codes[img_id] = sampled_shading_codes[best_idx] 492 | 493 | del rayo, rayd, img, c2w, attn 494 | del eval_loss 495 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import io 7 | import shutil 8 | from PIL import Image 9 | import imageio 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import random 13 | import sys 14 | from utils import * 15 | from dataset import get_dataset, get_loader 16 | from models import get_model, get_loss 17 | import lpips 18 | try: 19 | from skimage.measure import compare_ssim 20 | except: 21 | from skimage.metrics import structural_similarity 22 | 23 | def compare_ssim(gt, img, win_size, channel_axis=2): 24 | return structural_similarity(gt, img, win_size=win_size, channel_axis=channel_axis, data_range=1.0) 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description="PAPR") 29 | parser.add_argument('--opt', type=str, default="", help='Option file path') 30 | parser.add_argument('--resume', type=int, default=250000, help='Resume step') 31 | parser.add_argument('--exp', action='store_true', help='[Exposure control] To test with exposure control enabled') 32 | parser.add_argument('--intrp', action='store_true', help='[Exposure control] Interpolation') 33 | parser.add_argument('--random', action='store_true', help='[Exposure control] Random exposure control') 34 | parser.add_argument('--resample', action='store_true', help='[Exposure control] Resample shading codes') 35 | parser.add_argument('--seed', type=int, default=1, help='[Exposure control] Random seed') 36 | parser.add_argument('--view', type=int, default=0, help='[Exposure control] Test frame index') 37 | parser.add_argument('--scale', type=float, default=1.0, help='[Exposure control] Shading code scale') 38 | parser.add_argument('--num_samples', type=int, default=20, help='[Exposure control] Number of samples for random exposure control') 39 | parser.add_argument('--start_index', type=int, default=0, help='[Exposure control] Interpolation start index') 40 | parser.add_argument('--end_index', type=int, default=1, help='[Exposure control] Interpolation end index') 41 | parser.add_argument('--num_intrp', type=int, default=10, help='[Exposure control] Number of interpolations') 42 | return parser.parse_args() 43 | 44 | 45 | def test_step(frame, i, num_frames, model, device, dataset, batch, loss_fn, lpips_loss_fn_alex, lpips_loss_fn_vgg, args, config, 46 | test_losses, test_psnrs, test_ssims, test_lpips_alexs, test_lpips_vggs, resume_step, cur_shading_code=None, suffix=""): 47 | idx, _, img, rayd, rayo = batch 48 | c2w = dataset.get_c2w(idx.squeeze()) 49 | 50 | N, H, W, _ = rayd.shape 51 | num_pts, _ = model.points.shape 52 | 53 | rayo = rayo.to(device) 54 | rayd = rayd.to(device) 55 | img = img.to(device) 56 | c2w = c2w.to(device) 57 | 58 | topk = min([num_pts, model.select_k]) 59 | selected_points = torch.zeros(1, H, W, topk, 3) 60 | 61 | bkg_seq_len_attn = 0 62 | feat_dim = args.models.attn.embed.value.d_ff_out 63 | if model.bkg_feats is not None: 64 | bkg_seq_len_attn = model.bkg_feats.shape[0] 65 | feature_map = torch.zeros(N, H, W, 1, feat_dim).to(device) 66 | attn = torch.zeros(N, H, W, topk + bkg_seq_len_attn, 1).to(device) 67 | 68 | with torch.no_grad(): 69 | cur_gamma, cur_beta, code_mean = None, None, 0 70 | if cur_shading_code is not None: 71 | code_mean = cur_shading_code.mean().item() 72 | cur_affine = model.mapping_mlp(cur_shading_code) 73 | cur_affine_dim = cur_affine.shape[-1] 74 | cur_gamma, cur_beta = cur_affine[:cur_affine_dim // 2], cur_affine[cur_affine_dim // 2:] 75 | 76 | for height_start in range(0, H, args.test.max_height): 77 | for width_start in range(0, W, args.test.max_width): 78 | height_end = min(height_start + args.test.max_height, H) 79 | width_end = min(width_start + args.test.max_width, W) 80 | 81 | feature_map[:, height_start:height_end, width_start:width_end, :, :], \ 82 | attn[:, height_start:height_end, width_start:width_end, :, :] = model.evaluate(rayo, rayd[:, height_start:height_end, width_start:width_end], c2w, step=resume_step) 83 | 84 | selected_points[:, height_start:height_end, width_start:width_end, :, :] = model.selected_points 85 | 86 | if args.models.use_renderer: 87 | foreground_rgb = model.renderer(feature_map.squeeze(-2).permute(0, 3, 1, 2), gamma=cur_gamma, beta=cur_beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3) 88 | else: 89 | foreground_rgb = feature_map 90 | 91 | if model.bkg_feats is not None: 92 | bkg_attn = attn[..., topk:, :] 93 | bkg_mask = (model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn).squeeze() 94 | if args.models.normalize_topk_attn: 95 | rgb = foreground_rgb * (1 - bkg_attn) + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 96 | else: 97 | rgb = foreground_rgb + model.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 98 | rgb = rgb.squeeze(-2) 99 | else: 100 | rgb = foreground_rgb.squeeze(-2) 101 | bkg_mask = torch.zeros(N, H, W, 1).to(device) 102 | 103 | rgb = model.last_act(rgb) 104 | rgb = torch.clamp(rgb, 0, 1) 105 | 106 | test_loss = loss_fn(rgb, img) 107 | test_psnr = -10. * np.log(((rgb - img)**2).mean().item()) / np.log(10.) 108 | test_ssim = compare_ssim(rgb.squeeze().detach().cpu().numpy(), img.squeeze().detach().cpu().numpy(), 11, channel_axis=2) 109 | test_lpips_alex = lpips_loss_fn_alex(rgb.permute(0, 3, 1, 2), img.permute(0, 3, 1, 2)).squeeze().item() 110 | test_lpips_vgg = lpips_loss_fn_vgg(rgb.permute(0, 3, 1, 2), img.permute(0, 3, 1, 2)).squeeze().item() 111 | 112 | test_losses.append(test_loss.item()) 113 | test_psnrs.append(test_psnr) 114 | test_ssims.append(test_ssim) 115 | test_lpips_alexs.append(test_lpips_alex) 116 | test_lpips_vggs.append(test_lpips_vgg) 117 | 118 | print(f"Test frame: {frame}, code mean: {code_mean}, test_loss: {test_losses[-1]:.4f}, test_psnr: {test_psnrs[-1]:.4f}, test_ssim: {test_ssims[-1]:.4f}, test_lpips_alex: {test_lpips_alexs[-1]:.4f}, test_lpips_vgg: {test_lpips_vggs[-1]:.4f}") 119 | 120 | od = -rayo 121 | D = torch.sum(od * rayo) 122 | dists = torch.abs(torch.sum(selected_points.to(od.device) * od, -1) - D) / torch.norm(od) 123 | if model.bkg_feats is not None: 124 | dists = torch.cat([dists, torch.ones(N, H, W, model.bkg_feats.shape[0]).to(dists.device) * 0], dim=-1) 125 | cur_depth = (torch.sum(attn.squeeze(-1).to(od.device) * dists, dim=-1)).detach().cpu().squeeze().numpy().astype(np.float32) 126 | depth_np = cur_depth.copy() 127 | 128 | if args.test.save_fig: 129 | # To save the rendered images, depth maps, foreground rgb, and background mask 130 | dir_name = "images" 131 | if cur_shading_code is not None: 132 | dir_name = f'exposure_control_{suffix}_scale{config.scale}' if suffix in ['intrp', 'random'] else f'exposure_control_{suffix}' 133 | log_dir = os.path.join(args.save_dir, args.index, 'test', dir_name) 134 | os.makedirs(log_dir, exist_ok=True) 135 | cur_depth /= args.dataset.coord_scale 136 | cur_depth *= (65536 / 10) 137 | cur_depth = cur_depth.astype(np.uint16) 138 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-predrgb-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), (rgb.squeeze().detach().cpu().numpy() * 255).astype(np.uint8)) 139 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-depth-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), cur_depth) 140 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-fgrgb-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), (foreground_rgb.squeeze().clamp(0, 1).detach().cpu().numpy() * 255).astype(np.uint8)) 141 | imageio.imwrite(os.path.join(log_dir, "test-{:04d}-{:02d}-bkgmask-codeMean{:.4f}-PSNR{:.3f}-SSIM{:.4f}-LPIPSA{:.4f}-LPIPSV{:.4f}.png".format(frame, i, code_mean, test_psnr, test_ssim, test_lpips_alex, test_lpips_vgg)), (bkg_mask.detach().cpu().numpy() * 255).astype(np.uint8)) 142 | 143 | plots = {} 144 | 145 | if args.test.save_video: 146 | # To save the rendered videos 147 | coord_scale = args.dataset.coord_scale 148 | if "Barn" in args.dataset.path: 149 | coord_scale *= 1.5 150 | if "Family" in args.dataset.path: 151 | coord_scale *= 0.5 152 | pt_plot_scale = 1.0 * coord_scale 153 | 154 | plot_opt = args.test.plots 155 | th = -frame * (360. / num_frames) 156 | azims = np.linspace(180, -180, num_frames) 157 | azmin = azims[frame] 158 | 159 | points_np = model.points.detach().cpu().numpy() 160 | rgb_pred_np = rgb.squeeze().detach().cpu().numpy().astype(np.float32) 161 | rgb_gt_np = img.squeeze().detach().cpu().numpy().astype(np.float32) 162 | points_influ_scores_np = None 163 | if model.points_influ_scores is not None: 164 | points_influ_scores_np = model.points_influ_scores.squeeze().detach().cpu().numpy() 165 | 166 | if plot_opt.pcrgb: 167 | pcrgb_plot = get_test_pcrgb(frame, th, azmin, test_psnr, points_np, 168 | rgb_pred_np, rgb_gt_np, depth_np, pt_plot_scale, points_influ_scores_np) 169 | plots["pcrgb"] = pcrgb_plot 170 | 171 | if plot_opt.featattn: # Note that these plots are not necessarily meaningful since each ray has different top K points 172 | featmap_np = feature_map[0].squeeze().detach().cpu().numpy().astype(np.float32) 173 | attn_np = attn[0].squeeze().detach().cpu().numpy().astype(np.float32) 174 | featattn_plot = get_test_featmap_attn(frame, th, points_np, rgb_pred_np, rgb_gt_np, 175 | pt_plot_scale, featmap_np, attn_np, points_influ_scores_np) 176 | plots["featattn"] = featattn_plot 177 | 178 | return plots 179 | 180 | 181 | def test(model, device, dataset, save_name, args, config, resume_step, shading_codes=None): 182 | testloader = get_loader(dataset, args.dataset, mode="test") 183 | print("testloader:", testloader) 184 | 185 | loss_fn = get_loss(args.training.losses) 186 | loss_fn = loss_fn.to(device) 187 | 188 | lpips_loss_fn_alex = lpips.LPIPS(net='alex', version='0.1') 189 | lpips_loss_fn_alex = lpips_loss_fn_alex.to(device) 190 | lpips_loss_fn_vgg = lpips.LPIPS(net='vgg', version='0.1') 191 | lpips_loss_fn_vgg = lpips_loss_fn_vgg.to(device) 192 | 193 | test_losses = [] 194 | test_psnrs = [] 195 | test_ssims = [] 196 | test_lpips_alexs = [] 197 | test_lpips_vggs = [] 198 | 199 | frames = {} 200 | 201 | if config.exp: # test with exposure control, the model needs to be finetuned with exposure control first 202 | if config.random: 203 | suffix = "random" 204 | for frame, batch in enumerate(testloader): 205 | if frame != config.view: 206 | continue 207 | 208 | for i in range(config.num_samples): 209 | print("test seed:", config.seed, "i:", i) 210 | shading_codes = torch.randn(1, args.exposure_control.shading_code_dim, device=device) * config.scale 211 | plots = test_step(frame, i, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex, 212 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs, 213 | test_lpips_vggs, resume_step, shading_codes, suffix) 214 | 215 | elif config.intrp: 216 | suffix = "intrp" 217 | latent_codes = [] 218 | ids = [config.start_index, config.end_index] 219 | for i in range(config.num_samples): 220 | print("test seed:", config.seed, "i:", i) 221 | shading_codes = torch.randn(1, args.exposure_control.shading_code_dim, device=device) * config.scale 222 | 223 | if i in ids: 224 | latent_codes.append(shading_codes) 225 | 226 | interpolated_codes = [] 227 | for j in range(config.num_intrp): 228 | interpolated_codes.append(latent_codes[0] + (latent_codes[1] - latent_codes[0]) * (j + 1) / config.num_intrp) 229 | 230 | frames = {} 231 | for frame, batch in enumerate(testloader): 232 | if frame != config.view: 233 | continue 234 | 235 | for i in range(config.num_intrp): 236 | shading_codes = interpolated_codes[i] 237 | plots = test_step(frame, i, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex, 238 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs, 239 | test_lpips_vggs, resume_step, shading_codes, suffix) 240 | 241 | else: 242 | suffix = "test" 243 | shading_code = torch.randn(args.exposure_control.shading_code_dim, device=device) * config.scale 244 | for frame, batch in enumerate(testloader): 245 | # shading_code = shading_codes[frame] 246 | plots = test_step(frame, 0, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex, 247 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs, 248 | test_lpips_vggs, resume_step, shading_code, suffix) 249 | 250 | if plots: 251 | for key, value in plots.items(): 252 | if key not in frames: 253 | frames[key] = [] 254 | frames[key].append(value) 255 | 256 | else: # test without exposure control 257 | for frame, batch in enumerate(testloader): 258 | plots = test_step(frame, 0, len(testloader), model, device, dataset, batch, loss_fn, lpips_loss_fn_alex, 259 | lpips_loss_fn_vgg, args, config, test_losses, test_psnrs, test_ssims, test_lpips_alexs, 260 | test_lpips_vggs, resume_step) 261 | 262 | if plots: 263 | for key, value in plots.items(): 264 | if key not in frames: 265 | frames[key] = [] 266 | frames[key].append(value) 267 | 268 | test_loss = np.mean(test_losses) 269 | test_psnr = np.mean(test_psnrs) 270 | test_ssim = np.mean(test_ssims) 271 | test_lpips_alex = np.mean(test_lpips_alexs) 272 | test_lpips_vgg = np.mean(test_lpips_vggs) 273 | 274 | if frames: 275 | for key, value in frames.items(): 276 | name = f"{args.index}-PSNR{test_psnr:.3f}-SSIM{test_ssim:.4f}-LPIPSA{test_lpips_alex:.4f}-LPIPSV{test_lpips_vgg:.4f}-{key}-{save_name}-step{resume_step}.mp4" 277 | # In case the name is too long 278 | name = name[-255:] if len(name) > 255 else name 279 | log_dir = os.path.join(args.save_dir, args.index, 'test', 'videos') 280 | os.makedirs(log_dir, exist_ok=True) 281 | f = os.path.join(log_dir, name) 282 | imageio.mimwrite(f, value, fps=30, quality=10) 283 | 284 | print(f"Avg test loss: {test_loss:.4f}, test PSNR: {test_psnr:.4f}, test SSIM: {test_ssim:.4f}, test LPIPS Alex: {test_lpips_alex:.4f}, test LPIPS VGG: {test_lpips_vgg:.4f}") 285 | 286 | 287 | def main(config, args, save_name, mode, resume_step=0): 288 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 289 | model = get_model(args, device) 290 | dataset = get_dataset(args.dataset, mode=mode) 291 | 292 | if args.test.load_path: 293 | try: 294 | model_state_dict = torch.load(args.test.load_path) 295 | for step, state_dict in model_state_dict.items(): 296 | resume_step = int(step) 297 | if config.exp: 298 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False) 299 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False) 300 | model.load_my_state_dict(state_dict) 301 | except: 302 | model_state_dict = torch.load(os.path.join(args.save_dir, args.test.load_path, "model.pth")) 303 | for step, state_dict in model_state_dict.items(): 304 | resume_step = step 305 | if config.exp: 306 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False) 307 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False) 308 | model.load_my_state_dict(state_dict) 309 | print("!!!!! Loaded model from %s at step %s" % (args.test.load_path, resume_step)) 310 | else: 311 | try: 312 | model_state_dict = torch.load(os.path.join(args.save_dir, args.index, "model.pth")) 313 | for step, state_dict in model_state_dict.items(): 314 | resume_step = int(step) 315 | if config.exp: 316 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False) 317 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False) 318 | model.load_my_state_dict(state_dict) 319 | except: 320 | state_dict = torch.load(os.path.join(args.save_dir, args.index, f"model_{resume_step}.pth")) 321 | if config.exp: 322 | model.train_shading_codes = nn.Parameter(torch.zeros_like(state_dict['train_shading_codes']), requires_grad=False) 323 | model.eval_shading_codes = nn.Parameter(torch.zeros_like(state_dict['eval_shading_codes']), requires_grad=False) 324 | model.load_my_state_dict(state_dict) 325 | print("!!!!! Loaded model from %s at step %s" % (os.path.join(args.save_dir, args.index), resume_step)) 326 | 327 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 328 | model = model.to(device) 329 | 330 | shading_codes = None 331 | if config.exp: 332 | if mode == 'train': 333 | shading_codes = model.train_shading_codes 334 | print("Using train shading_codes:", shading_codes.shape, shading_codes.min(), shading_codes.max()) 335 | elif mode == 'test': 336 | shading_codes = model.eval_shading_codes 337 | print("Using eval shading_codes:", shading_codes.shape, shading_codes.min(), shading_codes.max()) 338 | else: 339 | raise NotImplementedError 340 | 341 | test(model, device, dataset, save_name, args, config, resume_step, shading_codes) 342 | 343 | 344 | if __name__ == '__main__': 345 | 346 | with open("configs/default.yml", 'r') as f: 347 | default_config = yaml.safe_load(f) 348 | 349 | args = parse_args() 350 | if args.intrp or args.random: assert args.exp, "You need to trun on the exposure control (--exp) for expsoure interpolation or generating images with random exposure levels." 351 | assert not args.intrp or not args.random, "Cannot do exposure interpolation and random exposure generation at the same time." 352 | with open(args.opt, 'r') as f: 353 | config = yaml.safe_load(f) 354 | 355 | test_config = copy.deepcopy(default_config) 356 | update_dict(test_config, config) 357 | 358 | resume_step = args.resume 359 | 360 | log_dir = os.path.join(test_config["save_dir"], test_config['index']) 361 | os.makedirs(log_dir, exist_ok=True) 362 | 363 | sys.stdout = Logger(os.path.join(log_dir, 'test.log'), sys.stdout) 364 | sys.stderr = Logger(os.path.join(log_dir, 'test_error.log'), sys.stderr) 365 | 366 | shutil.copyfile(__file__, os.path.join(log_dir, os.path.basename(__file__))) 367 | shutil.copyfile(args.opt, os.path.join(log_dir, os.path.basename(args.opt))) 368 | 369 | setup_seed(test_config['seed']) 370 | 371 | for i, dataset in enumerate(test_config['test']['datasets']): 372 | name = dataset['name'] 373 | mode = dataset['mode'] 374 | print(name, dataset) 375 | test_config['dataset'].update(dataset) 376 | test_config = DictAsMember(test_config) 377 | 378 | if args.exp: 379 | assert test_config.models.use_renderer, "Currently only support using renderer for exposure control" 380 | 381 | main(args, test_config, name, mode, resume_step) 382 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import os 6 | import numpy as np 7 | from .utils import normalize_vector, create_learning_rate_fn, add_points_knn, activation_func 8 | from .mlp import get_mapping_mlp 9 | from .attn import get_proximity_attention_layer 10 | from .renderer import get_generator 11 | 12 | 13 | def count_parameters(model): 14 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 15 | 16 | 17 | class PAPR(nn.Module): 18 | def __init__(self, args, device='cuda'): 19 | super(PAPR, self).__init__() 20 | self.args = args 21 | self.eps = args.eps 22 | self.device = device 23 | 24 | self.use_amp = args.use_amp 25 | self.amp_dtype = torch.float16 if args.amp_dtype == 'float16' else torch.bfloat16 26 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) 27 | 28 | point_opt = args.geoms.points 29 | pc_feat_opt = args.geoms.point_feats 30 | bkg_feat_opt = args.geoms.background 31 | exposure_opt = args.exposure_control 32 | self.exposure_opt = exposure_opt 33 | 34 | self.register_buffer('select_k', torch.tensor( 35 | point_opt.select_k, device=device, dtype=torch.int32)) 36 | 37 | self.coord_scale = args.dataset.coord_scale 38 | 39 | if point_opt.load_path: 40 | if point_opt.load_path.endswith('.pth') or point_opt.load_path.endswith('.pt'): 41 | points = torch.load(point_opt.load_path, map_location='cpu') 42 | points = np.asarray(points).astype(np.float32) 43 | np.random.shuffle(points) 44 | points = points[:args.max_num_pts, :] 45 | points = torch.from_numpy(points).float() 46 | print("Loaded points from {}, shape: {}, dtype {}".format(point_opt.load_path, points.shape, points.dtype)) 47 | print("Loaded points scale: ", points[:, 0].min(), points[:, 0].max(), points[:, 1].min(), points[:, 1].max(), points[:, 2].min(), points[:, 2].max()) 48 | else: 49 | # Initialize point positions 50 | pt_init_center = [i * self.coord_scale for i in point_opt.init_center] 51 | pt_init_scale = [i * self.coord_scale for i in point_opt.init_scale] 52 | if point_opt.init_type == 'sphere': # initial points on a sphere 53 | points = self._sphere_pc(pt_init_center, point_opt.init_num, pt_init_scale) 54 | elif point_opt.init_type == 'cube': # initial points in a cube 55 | points = self._cube_normal_pc(pt_init_center, point_opt.init_num, pt_init_scale) 56 | else: 57 | raise NotImplementedError("Point init type [{:s}] is not found".format(point_opt.init_type)) 58 | print("Initialized points scale: ", points[:, 0].min(), points[:, 0].max(), points[:, 1].min(), points[:, 1].max(), points[:, 2].min(), points[:, 2].max()) 59 | self.points = torch.nn.Parameter(points, requires_grad=True) 60 | 61 | # Initialize point influence scores 62 | self.points_influ_scores = torch.nn.Parameter(torch.ones( 63 | points.shape[0], 1, device=device) * point_opt.influ_init_val, requires_grad=True) 64 | 65 | # Initialize mapping MLP, only if fine-tuning with IMLE for the exposure control 66 | self.mapping_mlp = None 67 | if exposure_opt.use: 68 | self.mapping_mlp = get_mapping_mlp(exposure_opt, use_amp=self.use_amp, amp_dtype=self.amp_dtype) 69 | 70 | # Initialize UNet 71 | if args.models.use_renderer: 72 | attn_opt = args.models.attn 73 | feat_dim = attn_opt.embed.value.d_ff_out 74 | self.renderer = get_generator(args.models.renderer.generator, in_c=feat_dim, 75 | out_c=3, use_amp=self.use_amp, amp_dtype=self.amp_dtype) 76 | print("Number of parameters of renderer: ", count_parameters(self.renderer)) 77 | else: 78 | assert args.models.attn.embed.value.d_ff_out == 3, \ 79 | "Value embedding MLP should have output dim 3 if not using renderer" 80 | 81 | # Initialize background score and features 82 | self.bkg_feats = nn.Parameter(torch.FloatTensor(bkg_feat_opt.init_color)[None, :], requires_grad=bkg_feat_opt.learnable) 83 | self.bkg_score = torch.tensor(bkg_feat_opt.constant, device=device, dtype=torch.float32).reshape(1) 84 | 85 | # Initialize point features 86 | self.use_pc_feats = pc_feat_opt.use_ink or pc_feat_opt.use_inq or pc_feat_opt.use_inv 87 | if self.use_pc_feats: 88 | self.pc_feats = nn.Parameter(torch.randn(points.shape[0], pc_feat_opt.dim), requires_grad=True) 89 | print("Point features: ", self.pc_feats.shape, self.pc_feats.min(), self.pc_feats.max(), self.pc_feats.mean(), self.pc_feats.std()) 90 | 91 | v_extra_dim = 0 92 | k_extra_dim = 0 93 | q_extra_dim = 0 94 | if pc_feat_opt.use_inv: 95 | v_extra_dim = self.pc_feats.shape[-1] 96 | print("Using v_extra_dim: ", v_extra_dim) 97 | if pc_feat_opt.use_ink: 98 | k_extra_dim = self.pc_feats.shape[-1] 99 | print("Using k_extra_dim: ", k_extra_dim) 100 | if pc_feat_opt.use_inq: 101 | q_extra_dim = self.pc_feats.shape[-1] 102 | print("Using q_extra_dim: ", q_extra_dim) 103 | 104 | self.last_act = activation_func(args.models.last_act) 105 | 106 | # Initialize proximity attention layer 107 | self.proximity_attn = get_proximity_attention_layer(args.models.attn, 108 | v_extra_dim=v_extra_dim, 109 | k_extra_dim=k_extra_dim, 110 | q_extra_dim=q_extra_dim, 111 | eps=self.eps, 112 | use_amp=self.use_amp, 113 | amp_dtype=self.amp_dtype) 114 | 115 | self.init_optimizers(total_steps=0) 116 | 117 | def init_optimizers(self, total_steps): 118 | lr_opt = self.args.training.lr 119 | print("LR factor: ", lr_opt.lr_factor) 120 | optimizer_points = torch.optim.Adam([self.points], lr=lr_opt.points.base_lr * lr_opt.lr_factor) 121 | optimizer_attn = torch.optim.Adam(self.proximity_attn.parameters(), lr=lr_opt.attn.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.attn.weight_decay) 122 | optimizer_points_influ_scores = torch.optim.Adam([self.points_influ_scores], lr=lr_opt.points_influ_scores.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.points_influ_scores.weight_decay) 123 | 124 | debug = False 125 | lr_scheduler_points = create_learning_rate_fn(optimizer_points, self.args.training.steps, lr_opt.points, debug=debug) 126 | lr_scheduler_attn = create_learning_rate_fn(optimizer_attn, self.args.training.steps, lr_opt.attn, debug=debug) 127 | lr_scheduler_points_influ_scores = create_learning_rate_fn(optimizer_points_influ_scores, self.args.training.steps, lr_opt.points_influ_scores, debug=debug) 128 | 129 | self.optimizers = { 130 | "points": optimizer_points, 131 | "attn": optimizer_attn, 132 | "points_influ_scores": optimizer_points_influ_scores, 133 | } 134 | 135 | self.schedulers = { 136 | "points": lr_scheduler_points, 137 | "attn": lr_scheduler_attn, 138 | "points_influ_scores": lr_scheduler_points_influ_scores, 139 | } 140 | 141 | if self.use_pc_feats: 142 | optimizer_pc_feats = torch.optim.Adam([self.pc_feats], lr=lr_opt.feats.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.feats.weight_decay) 143 | lr_scheduler_pc_feats = create_learning_rate_fn(optimizer_pc_feats, self.args.training.steps, lr_opt.feats, debug=debug) 144 | 145 | self.optimizers["pc_feats"] = optimizer_pc_feats 146 | self.schedulers["pc_feats"] = lr_scheduler_pc_feats 147 | 148 | if self.mapping_mlp is not None: 149 | optimizer_mapping_mlp = torch.optim.Adam(self.mapping_mlp.parameters(), lr=lr_opt.mapping_mlp.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.mapping_mlp.weight_decay) 150 | lr_scheduler_mapping_mlp = create_learning_rate_fn(optimizer_mapping_mlp, self.args.training.steps, lr_opt.mapping_mlp, debug=debug) 151 | 152 | self.optimizers["mapping_mlp"] = optimizer_mapping_mlp 153 | self.schedulers["mapping_mlp"] = lr_scheduler_mapping_mlp 154 | 155 | if self.args.models.use_renderer: 156 | optimizer_renderer = torch.optim.Adam(self.renderer.parameters(), lr=lr_opt.generator.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.generator.weight_decay) 157 | lr_scheduler_renderer = create_learning_rate_fn(optimizer_renderer, self.args.training.steps, lr_opt.generator, debug=debug) 158 | 159 | self.optimizers["renderer"] = optimizer_renderer 160 | self.schedulers["renderer"] = lr_scheduler_renderer 161 | 162 | if self.bkg_feats is not None and self.args.geoms.background.learnable: 163 | optimizer_bkg_feats = torch.optim.Adam([self.bkg_feats], lr=lr_opt.bkg_feats.base_lr * lr_opt.lr_factor, weight_decay=lr_opt.bkg_feats.weight_decay) 164 | lr_scheduler_bkg_feats = create_learning_rate_fn(optimizer_bkg_feats, self.args.training.steps, lr_opt.bkg_feats, debug=debug) 165 | 166 | self.optimizers["bkg_feats"] = optimizer_bkg_feats 167 | self.schedulers["bkg_feats"] = lr_scheduler_bkg_feats 168 | 169 | for name in self.args.training.fix_keys: 170 | if name in self.optimizers: 171 | print("Fixing {}".format(name)) 172 | self.optimizers.pop(name) 173 | self.schedulers.pop(name) 174 | 175 | if total_steps > 0: 176 | for _, scheduler in self.schedulers.items(): 177 | if scheduler is not None: 178 | for _ in range(total_steps): 179 | scheduler.step() 180 | 181 | def clear_optimizer(self): 182 | self.optimizers.clear() 183 | del self.optimizers 184 | 185 | def clear_scheduler(self): 186 | self.schedulers.clear() 187 | del self.schedulers 188 | 189 | def clear_grad(self): 190 | for _, optimizer in self.optimizers.items(): 191 | if optimizer is not None: 192 | optimizer.zero_grad() 193 | 194 | def _sphere_pc(self, center, num_pts, scale): 195 | xs, ys, zs = [], [], [] 196 | phi = math.pi * (3. - math.sqrt(5.)) 197 | for i in range(num_pts): 198 | y = 1 - (i / float(num_pts - 1)) * 2 199 | radius = math.sqrt(1 - y * y) 200 | theta = phi * i 201 | x = math.cos(theta) * radius 202 | z = math.sin(theta) * radius 203 | xs.append(x * scale[0] + center[0]) 204 | ys.append(y * scale[1] + center[1]) 205 | zs.append(z * scale[2] + center[2]) 206 | points = np.stack([np.array(xs), np.array(ys), np.array(zs)], axis=-1) 207 | return torch.from_numpy(points).float() 208 | 209 | def _semi_sphere_pc(self, center, num_pts, scale, flatten="-z", flatten_coord=0.0): 210 | xs, ys, zs = [], [], [] 211 | phi = math.pi * (3. - math.sqrt(5.)) 212 | for i in range(num_pts): 213 | y = 1 - (i / float(num_pts - 1)) * 2 214 | radius = math.sqrt(1 - y * y) 215 | theta = phi * i 216 | x = math.cos(theta) * radius 217 | z = math.sin(theta) * radius 218 | xs.append(x * scale[0] + center[0]) 219 | ys.append(y * scale[1] + center[1]) 220 | zs.append(z * scale[2] + center[2]) 221 | points = np.stack([np.array(xs), np.array(ys), np.array(zs)], axis=-1) 222 | points = torch.from_numpy(points).float() 223 | if flatten == "-z": 224 | points[:, 2] = torch.clamp(points[:, 2], min=flatten_coord) 225 | elif flatten == "+z": 226 | points[:, 2] = torch.clamp(points[:, 2], max=flatten_coord) 227 | elif flatten == "-y": 228 | points[:, 1] = torch.clamp(points[:, 1], min=flatten_coord) 229 | elif flatten == "+y": 230 | points[:, 1] = torch.clamp(points[:, 1], max=flatten_coord) 231 | elif flatten == "-x": 232 | points[:, 0] = torch.clamp(points[:, 0], min=flatten_coord) 233 | elif flatten == "+x": 234 | points[:, 0] = torch.clamp(points[:, 0], max=flatten_coord) 235 | else: 236 | raise ValueError("Invalid flatten type") 237 | return points 238 | 239 | def _cube_pc(self, center, num_pts, scale): 240 | xs = np.random.uniform(-scale[0], scale[0], num_pts) + center[0] 241 | ys = np.random.uniform(-scale[1], scale[1], num_pts) + center[1] 242 | zs = np.random.uniform(-scale[2], scale[2], num_pts) + center[2] 243 | points = np.stack([np.array(xs), np.array(ys), np.array(zs)], axis=-1) 244 | return torch.from_numpy(points).float() 245 | 246 | def _cube_normal_pc(self, center, num_pts, scale): 247 | axis_num_pts = int(num_pts ** (1.0 / 3.0)) 248 | xs = np.linspace(-scale[0], scale[0], axis_num_pts) + center[0] 249 | ys = np.linspace(-scale[1], scale[1], axis_num_pts) + center[1] 250 | zs = np.linspace(-scale[2], scale[2], axis_num_pts) + center[2] 251 | points = np.array([[i, j, k] for i in xs for j in ys for k in zs]) 252 | rest_num_pts = num_pts - points.shape[0] 253 | if rest_num_pts > 0: 254 | rest_points = self._cube_pc(center, rest_num_pts, scale) 255 | points = np.concatenate([points, rest_points], axis=0) 256 | return torch.from_numpy(points).float() 257 | 258 | def _calculate_global_distances(self, rays_o, rays_d, points): 259 | """ 260 | Select the top-k points with the smallest distance to the rays from all points 261 | 262 | Args: 263 | rays_o: (N, 3) 264 | rays_d: (N, H, W, 3) 265 | points: (num_pts, 3) 266 | Returns: 267 | select_k_ind: (N, H, W, select_k) 268 | """ 269 | N, H, W, _ = rays_d.shape 270 | num_pts, _ = points.shape 271 | 272 | rays_d = rays_d.unsqueeze(-2) # (N, H, W, 1, 3) 273 | rays_o = rays_o.reshape(N, 1, 1, 1, 3) 274 | points = points.reshape(1, 1, 1, num_pts, 3) 275 | 276 | v = points - rays_o # (N, 1, 1, num_pts, 3) 277 | proj = rays_d * (torch.sum(v * rays_d, dim=-1) / (torch.sum(rays_d * rays_d, dim=-1) + self.eps)).unsqueeze(-1) 278 | D = v - proj # (N, H, W, num_pts, 3) 279 | feature = torch.norm(D, dim=-1) 280 | 281 | _, select_k_ind = feature.topk(self.select_k, dim=-1, largest=False, sorted=False) # (N, H, W, select_k) 282 | 283 | return select_k_ind 284 | 285 | def _calculate_distances(self, rays_o, rays_d, points, c2w): 286 | """ 287 | Calculate the distances from top-k points to rays TODO: redundant with _calculate_global_distances 288 | 289 | Args: 290 | rays_o: (N, 3) 291 | rays_d: (N, H, W, 3) 292 | points: (N, H, W, select_k, 3) 293 | c2w: (N, 4, 4) 294 | Returns: 295 | proj_dists: (N, H, W, select_k, 1) 296 | dists_to_rays: (N, H, W, select_k, 1) 297 | proj: (N, H, W, select_k, 3) # the vector s in Figure 2 298 | D: (N, H, W, select_k, 3) # the vector t in Figure 2 299 | """ 300 | N, H, W, _ = rays_d.shape 301 | 302 | rays = normalize_vector(rays_d, eps=self.eps).unsqueeze(-2) # (N, H, W, 1, 3) 303 | v = points - rays_o.reshape(N, 1, 1, 1, 3) # (N, 1, 1, num_pts, 3) 304 | proj = rays * (torch.sum(v * rays, dim=-1) / (torch.sum(rays * rays, dim=-1) + self.eps)).unsqueeze(-1) 305 | D = v - proj # (N, H, W, num_pts, 3) 306 | 307 | dists_to_rays = torch.norm(D, dim=-1, keepdim=True) 308 | proj_dists = torch.norm(proj, dim=-1, keepdim=True) 309 | 310 | return proj_dists, dists_to_rays, proj, D 311 | 312 | def _get_points(self, rays_o, rays_d, c2w, step=-1): 313 | """ 314 | Select the top-k points with the smallest distance to the rays 315 | 316 | Args: 317 | rays_o: (N, 3) 318 | rays_d: (N, H, W, 3) 319 | c2w: (N, 4, 4) 320 | Returns: 321 | selected_points: (N, H, W, select_k, 3) 322 | select_k_ind: (N, H, W, select_k) 323 | """ 324 | points = self.points 325 | N, H, W, _ = rays_d.shape 326 | if self.select_k >= points.shape[0] or self.select_k < 0: 327 | select_k_ind = torch.arange(points.shape[0], device=points.device).expand(N, H, W, -1) 328 | else: 329 | select_k_ind = self._calculate_global_distances(rays_o, rays_d, points) # (N, H, W, num_pts) 330 | selected_points = points[select_k_ind, :] # (N, H, W, select_k, 3) 331 | self.selected_points = selected_points 332 | 333 | return selected_points, select_k_ind 334 | 335 | def prune_points(self, thresh): 336 | if self.points_influ_scores is not None: 337 | if self.args.training.prune_type == '<': 338 | mask = (self.points_influ_scores[:, 0] > thresh) 339 | elif self.args.training.prune_type == '>': 340 | mask = (self.points_influ_scores[:, 0] < thresh) 341 | print( 342 | "@@@@@@@@@ pruned {}/{}".format(torch.sum(mask == 0), mask.shape[0])) 343 | 344 | cur_requires_grad = self.points.requires_grad 345 | self.points = nn.Parameter(self.points[mask, :], requires_grad=cur_requires_grad) 346 | print("@@@@@@@@@ New points: ", self.points.shape) 347 | 348 | cur_requires_grad = self.points_influ_scores.requires_grad 349 | self.points_influ_scores = nn.Parameter(self.points_influ_scores[mask, :], requires_grad=cur_requires_grad) 350 | print("@@@@@@@@@ New points_influ_scores: ", self.points_influ_scores.shape) 351 | 352 | if self.use_pc_feats: 353 | cur_requires_grad = self.pc_feats.requires_grad 354 | self.pc_feats = nn.Parameter(self.pc_feats[mask, :], requires_grad=cur_requires_grad) 355 | print("@@@@@@@@@ New pc_feats: ", self.pc_feats.shape) 356 | 357 | return torch.sum(mask == 0) 358 | return 0 359 | 360 | def add_points(self, add_num): 361 | points = self.points.detach().cpu() 362 | point_features = None 363 | cur_num_points = points.shape[0] 364 | 365 | if 'max_points' in self.args and self.args.max_points > 0 and (cur_num_points + add_num) >= self.args.max_points: 366 | add_num = self.args.max_points - cur_num_points 367 | if add_num <= 0: 368 | return 0 369 | 370 | if self.use_pc_feats: 371 | point_features = self.pc_feats.detach().cpu() 372 | 373 | new_points, num_new_points, new_influ_scores, new_point_features = add_points_knn(points, self.points_influ_scores.detach().cpu(), add_num=add_num, 374 | k=self.args.geoms.points.add_k, comb_type=self.args.geoms.points.add_type, 375 | sample_k=self.args.geoms.points.add_sample_k, sample_type=self.args.geoms.points.add_sample_type, 376 | point_features=point_features) 377 | print("@@@@@@@@@ added {} points".format(num_new_points)) 378 | 379 | if num_new_points > 0: 380 | cur_requires_grad = self.points.requires_grad 381 | self.points = nn.Parameter(torch.cat([points, new_points], dim=0).to(self.points.device), requires_grad=cur_requires_grad) 382 | print("@@@@@@@@@ New points: ", self.points.shape) 383 | 384 | if self.points_influ_scores is not None: 385 | cur_requires_grad = self.points_influ_scores.requires_grad 386 | self.points_influ_scores = nn.Parameter(torch.cat([self.points_influ_scores, new_influ_scores.to(self.points_influ_scores.device)], dim=0), requires_grad=cur_requires_grad) 387 | print("@@@@@@@@@ New points_influ_scores: ", self.points_influ_scores.shape) 388 | 389 | if self.use_pc_feats: 390 | cur_requires_grad = self.pc_feats.requires_grad 391 | self.pc_feats = nn.Parameter(torch.cat([self.pc_feats, new_point_features.to(self.pc_feats.device)], dim=0), requires_grad=cur_requires_grad) 392 | print("@@@@@@@@@ New pc_feats: ", self.pc_feats.shape) 393 | 394 | return num_new_points 395 | 396 | def _get_kqv(self, rays_o, rays_d, points, c2w, select_k_ind, step=-1): 397 | """ 398 | Get the key, query, value for the proximity attention layer(s) 399 | """ 400 | _, _, vec_p2o, vec_p2r = self._calculate_distances(rays_o, rays_d, points, c2w) 401 | 402 | k_type = self.args.models.attn.k_type 403 | k_L = self.args.models.attn.embed.k_L 404 | if k_type == 1: 405 | key = [points.detach(), vec_p2o, vec_p2r] 406 | else: 407 | raise ValueError('Invalid key type') 408 | assert len(key) == (len(k_L)) 409 | 410 | q_type = self.args.models.attn.q_type 411 | q_L = self.args.models.attn.embed.q_L 412 | if q_type == 1: 413 | query = [rays_d.unsqueeze(-2)] 414 | else: 415 | raise ValueError('Invalid query type') 416 | assert len(query) == (len(q_L)) 417 | 418 | v_type = self.args.models.attn.v_type 419 | v_L = self.args.models.attn.embed.v_L 420 | if v_type == 1: 421 | value = [vec_p2o, vec_p2r] 422 | else: 423 | raise ValueError('Invalid value type') 424 | assert len(value) == (len(v_L)) 425 | 426 | # Add extra features that won't be passed through positional encoding 427 | k_extra = None 428 | q_extra = None 429 | v_extra = None 430 | if self.args.geoms.point_feats.use_ink: 431 | k_extra = [self.pc_feats[select_k_ind, :]] 432 | if self.args.geoms.point_feats.use_inq: 433 | q_extra = [self.pc_feats[select_k_ind, :]] 434 | if self.args.geoms.point_feats.use_inv: 435 | v_extra = [self.pc_feats[select_k_ind, :]] 436 | 437 | return key, query, value, k_extra, q_extra, v_extra 438 | 439 | def step(self, step=-1): 440 | for _, optimizer in self.optimizers.items(): 441 | if optimizer is not None: 442 | self.scaler.step(optimizer) 443 | 444 | for _, scheduler in self.schedulers.items(): 445 | if scheduler is not None: 446 | scheduler.step() 447 | 448 | self.attn_lr = 0 449 | if 'attn' in self.optimizers: 450 | if self.schedulers['attn'] is not None: 451 | self.attn_lr = self.schedulers['attn'].get_last_lr()[0] 452 | else: 453 | self.attn_lr = self.optimizers['attn'].param_groups[0]['lr'] 454 | 455 | self.pts_lr = 0 456 | if 'points' in self.optimizers: 457 | if self.schedulers['points'] is not None: 458 | self.pts_lr = self.schedulers['points'].get_last_lr()[0] 459 | else: 460 | self.pts_lr = self.optimizers['points'].param_groups[0]['lr'] 461 | 462 | def evaluate(self, rays_o, rays_d, c2w, step=-1, shading_code=None): 463 | points, select_k_ind = self._get_points(rays_o, rays_d, c2w, step) 464 | self.select_k_ind = select_k_ind 465 | key, query, value, k_extra, q_extra, v_extra = self._get_kqv(rays_o, rays_d, points, c2w, select_k_ind, step) 466 | N, H, W, _ = rays_d.shape 467 | num_pts = points.shape[-2] 468 | 469 | cur_points_influ_score = self.points_influ_scores[select_k_ind] if self.points_influ_scores is not None else None 470 | 471 | _, _, embedv, scores = self.proximity_attn(key, query, value, k_extra, q_extra, v_extra, step=step) 472 | 473 | embedv = embedv.reshape(N, H, W, -1, embedv.shape[-1]) 474 | scores = scores.reshape(N, H, W, -1, 1) 475 | 476 | if cur_points_influ_score is not None: 477 | scores = scores * cur_points_influ_score 478 | if self.bkg_feats is not None: 479 | bkg_seq_len = self.bkg_feats.shape[0] 480 | scores = torch.cat([scores, self.bkg_score.expand(N, H, W, bkg_seq_len, -1)], dim=-2) 481 | attn = F.softmax(scores, dim=3) # (N, H, W, num_pts+bkg_seq_len, 1) 482 | topk_attn = attn[..., :num_pts, :] 483 | if self.args.models.normalize_topk_attn: 484 | topk_attn = topk_attn / torch.sum(topk_attn, dim=3, keepdim=True) 485 | fused_features = torch.sum(embedv * topk_attn, dim=3, keepdim=True) # (N, H, W, 1, C) 486 | else: 487 | attn = F.softmax(scores, dim=3) 488 | if self.args.models.normalize_topk_attn: 489 | attn = attn / torch.sum(attn, dim=3, keepdim=True) 490 | fused_features = torch.sum(embedv * attn, dim=3, keepdim=True) # (N, H, W, 1, C) 491 | 492 | return fused_features, attn 493 | 494 | def forward(self, rays_o, rays_d, c2w, step=-1, shading_code=None): 495 | gamma, beta = None, None 496 | if shading_code is not None and self.mapping_mlp is not None: 497 | affine = self.mapping_mlp(shading_code) 498 | affine_dim = affine.shape[-1] 499 | gamma, beta = affine[:affine_dim//2], affine[affine_dim//2:] 500 | 501 | if step % 200 == 0: 502 | print(shading_code.min().item(), shading_code.max().item(), gamma.min().item(), gamma.max().item(), beta.min().item(), beta.max().item()) 503 | 504 | points, select_k_ind = self._get_points(rays_o, rays_d, c2w, step) 505 | key, query, value, k_extra, q_extra, v_extra = self._get_kqv(rays_o, rays_d, points, c2w, select_k_ind, step) 506 | N, H, W, _ = rays_d.shape 507 | num_pts = points.shape[-2] 508 | 509 | cur_points_influ_scores = self.points_influ_scores[select_k_ind] if self.points_influ_scores is not None else None 510 | 511 | _, _, embedv, scores = self.proximity_attn(key, query, value, k_extra, q_extra, v_extra, step=step) 512 | 513 | if step >= 0 and step % 200 == 0: 514 | print(' embedv:', step, embedv.shape, embedv.min().item(), 515 | embedv.max().item(), embedv.mean().item(), embedv.std().item()) 516 | print(' scores:', step, scores.shape, scores.min().item(), 517 | scores.max().item(), scores.mean().item(), scores.std().item()) 518 | 519 | embedv = embedv.reshape(N, H, W, -1, embedv.shape[-1]) 520 | scores = scores.reshape(N, H, W, -1, 1) 521 | 522 | if cur_points_influ_scores is not None: 523 | # Multiply the influence scores to the attention scores 524 | scores = scores * cur_points_influ_scores 525 | 526 | if self.bkg_feats is not None: 527 | bkg_seq_len = self.bkg_feats.shape[0] 528 | scores = torch.cat([scores, self.bkg_score.expand(N, H, W, bkg_seq_len, -1)], dim=-2) 529 | attn = F.softmax(scores, dim=3) # (N, H, W, num_pts+bkg_seq_len, 1) 530 | topk_attn = attn[..., :num_pts, :] 531 | bkg_attn = attn[..., num_pts:, :] 532 | if self.args.models.normalize_topk_attn: 533 | topk_attn = topk_attn / torch.sum(topk_attn, dim=3, keepdim=True) 534 | fused_features = torch.sum(embedv * topk_attn, dim=3) # (N, H, W, C) 535 | 536 | if self.args.models.use_renderer: 537 | foreground = self.renderer(fused_features.permute(0, 3, 1, 2), gamma=gamma, beta=beta).permute(0, 2, 3, 1).unsqueeze(-2) # (N, H, W, 1, 3) 538 | else: 539 | foreground = fused_features.unsqueeze(-2) 540 | 541 | if self.args.models.normalize_topk_attn: 542 | rgb = foreground * (1 - bkg_attn) + self.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 543 | else: 544 | rgb = foreground + self.bkg_feats.expand(N, H, W, -1, -1) * bkg_attn 545 | rgb = rgb.squeeze(-2) 546 | else: 547 | attn = F.softmax(scores, dim=3) 548 | fused_features = torch.sum(embedv * attn, dim=3) # (N, H, W, C) 549 | if self.args.models.use_renderer: 550 | rgb = self.renderer(fused_features.permute(0, 3, 1, 2), gamma=gamma, beta=beta).permute(0, 2, 3, 1) # (N, H, W, 3) 551 | else: 552 | rgb = fused_features 553 | 554 | if step >= 0 and step % 1000 == 0: 555 | print(' feat map:', step, fused_features.shape, fused_features.min().item(), 556 | fused_features.max().item(), fused_features.mean().item(), fused_features.std().item()) 557 | print(' predict rgb:', step, rgb.shape, rgb.min().item(), 558 | rgb.max().item(), rgb.mean().item(), rgb.std().item()) 559 | 560 | return rgb 561 | 562 | def save(self, step, save_dir): 563 | torch.save({str(step): self.state_dict()}, 564 | os.path.join(save_dir, 'model.pth')) 565 | 566 | optimizers_state_dict = {} 567 | for name, optimizer in self.optimizers.items(): 568 | if optimizer is not None: 569 | optimizers_state_dict[name] = optimizer.state_dict() 570 | else: 571 | optimizers_state_dict[name] = None 572 | torch.save(optimizers_state_dict, os.path.join( 573 | save_dir, 'optimizers.pth')) 574 | 575 | schedulers_state_dict = {} 576 | for name, scheduler in self.schedulers.items(): 577 | if scheduler is not None: 578 | schedulers_state_dict[name] = scheduler.state_dict() 579 | else: 580 | schedulers_state_dict[name] = None 581 | torch.save(schedulers_state_dict, os.path.join( 582 | save_dir, 'schedulers.pth')) 583 | 584 | scaler_state_dict = self.scaler.state_dict() 585 | torch.save(scaler_state_dict, os.path.join( 586 | save_dir, 'scaler.pth')) 587 | 588 | def load(self, load_dir, load_optimizer=False): 589 | if load_optimizer == True: 590 | optimizers_state_dict = torch.load( 591 | os.path.join(load_dir, 'optimizers.pth')) 592 | for name, optimizer in self.optimizers.items(): 593 | if optimizer is not None: 594 | optimizer.load_state_dict(optimizers_state_dict[name]) 595 | else: 596 | assert optimizers_state_dict[name] is None 597 | 598 | schedulers_state_dict = torch.load( 599 | os.path.join(load_dir, 'schedulers.pth')) 600 | for name, scheduler in self.schedulers.items(): 601 | if scheduler is not None: 602 | scheduler.load_state_dict(schedulers_state_dict[name]) 603 | else: 604 | assert schedulers_state_dict[name] is None 605 | 606 | if os.path.exists(os.path.join(load_dir, 'scaler.pth')): 607 | scaler_state_dict = torch.load( 608 | os.path.join(load_dir, 'scaler.pth')) 609 | self.scaler.load_state_dict(scaler_state_dict) 610 | 611 | model_state_dict = torch.load(os.path.join(load_dir, 'model.pth')) 612 | for step, state_dict in model_state_dict.items(): 613 | # self.load_state_dict(state_dict) 614 | self.load_my_state_dict(state_dict) 615 | return int(step) 616 | 617 | def load_my_state_dict(self, state_dict, exclude_keys=[]): 618 | own_state = self.state_dict() 619 | for name, param in state_dict.items(): 620 | print(name, param.shape) 621 | for exclude_key in exclude_keys: 622 | if exclude_key in name: 623 | print("exclude", name) 624 | break 625 | else: 626 | if name not in ['points', 'points_influ_scores', 'pc_feats']: 627 | if isinstance(param, nn.Parameter): 628 | # backwards compatibility for serialized parameters 629 | param = param.data 630 | try: 631 | own_state[name].copy_(param) 632 | except: 633 | print("Can't load", name) 634 | 635 | self.points = nn.Parameter( 636 | state_dict['points'].data, requires_grad=self.points.requires_grad) 637 | if self.points_influ_scores is not None: 638 | self.points_influ_scores = nn.Parameter( 639 | state_dict['points_influ_scores'].data, requires_grad=self.points_influ_scores.requires_grad) 640 | self.pc_feats = nn.Parameter(state_dict['pc_feats'].data, requires_grad=self.pc_feats.requires_grad) 641 | print("load pc_feats", self.pc_feats.shape, self.pc_feats.min(), self.pc_feats.max()) 642 | --------------------------------------------------------------------------------