├── CONTRIBUTING.md ├── images └── tensorboard.png ├── ibrnet ├── data_loaders │ ├── __init__.py │ ├── flow_utils.py │ ├── create_training_dataset.py │ ├── data_utils.py │ ├── llff_data_utils.py │ └── monocular.py ├── criterion.py ├── projection.py ├── feature_network.py ├── sample_ray.py ├── render_image.py ├── model.py └── mlp_network.py ├── configs_nvidia ├── eval_truck_long.txt ├── eval_jumping_long.txt ├── eval_skating_long.txt ├── eval_balloon1_long.txt ├── eval_balloon2_long.txt ├── eval_umbrella_long.txt ├── eval_dynamicFace_long.txt └── eval_playground_long.txt ├── environment_dynibar.yml ├── configs ├── test_kid-running.txt └── train_kid-running.txt ├── utils.py ├── save_monocular_cameras.py ├── README.md ├── config.py ├── render_source_vv.py ├── LICENSE ├── render_monocular_bt.py └── eval_nvidia.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynibar/HEAD/images/tensorboard.png -------------------------------------------------------------------------------- /ibrnet/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | """Defining a dictionary of dataset class.""" 2 | 3 | from .monocular import MonocularDataset 4 | 5 | dataset_dict = { 6 | 'monocular': MonocularDataset, 7 | } 8 | -------------------------------------------------------------------------------- /configs_nvidia/eval_truck_long.txt: -------------------------------------------------------------------------------- 1 | expname = truck 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/truck 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Truck 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_jumping_long.txt: -------------------------------------------------------------------------------- 1 | expname = jumping 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/jumping 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Jumping 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_skating_long.txt: -------------------------------------------------------------------------------- 1 | expname = skating 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/skating 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Skating 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_balloon1_long.txt: -------------------------------------------------------------------------------- 1 | expname = balloon1 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/balloon1 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Balloon1 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_balloon2_long.txt: -------------------------------------------------------------------------------- 1 | expname = balloon2 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/balloon2 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Balloon2 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_umbrella_long.txt: -------------------------------------------------------------------------------- 1 | expname = umbrella 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/umbrella 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Umbrella 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_dynamicFace_long.txt: -------------------------------------------------------------------------------- 1 | expname = dynamicFace 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/dynamicFace 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = dynamicFace 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /configs_nvidia/eval_playground_long.txt: -------------------------------------------------------------------------------- 1 | expname = playground 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/nvidia_long_release 6 | 7 | coarse_dir = checkpoints/coarse/playground 8 | 9 | distributed = False 10 | 11 | ## dataset 12 | eval_dataset = Nvidia 13 | eval_scenes = Playground 14 | ### TESTING 15 | chunk_size = 8192 16 | 17 | ### RENDERING 18 | N_importance = 64 19 | N_samples = 64 20 | inv_uniform = True 21 | anti_alias_pooling = 1 22 | mask_rgb = 0 23 | 24 | input_dir = True 25 | input_xyz = False 26 | 27 | mask_static = True -------------------------------------------------------------------------------- /environment_dynibar.yml: -------------------------------------------------------------------------------- 1 | name: dynibar 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.8 8 | - pip=20.3 9 | - conda-forge::cudatoolkit=11.3 10 | - pytorch::pytorch=1.10.1 11 | - pytorch::torchvision==0.11.2 12 | - pip: 13 | - configargparse 14 | - scikit-image==0.19.3 15 | - matplotlib 16 | - opencv-python 17 | - torch_efficient_distloss 18 | - imageio==2.22.0 19 | - tensorboard==2.10.0 20 | - scipy==1.9.1 21 | - timm==0.6.7 22 | - kornia==0.6.7 23 | - ninja==1.11.1 24 | - setuptools==59.5.0 25 | -------------------------------------------------------------------------------- /configs/test_kid-running.txt: -------------------------------------------------------------------------------- 1 | # make sure expname is the saved folder name in 'out' directory 2 | expname = kid-running-test_mr-42_w-disp-0.100_w-flow-0.010_anneal_cycle-0.1-0.1-w_mode-0 3 | 4 | rootdir = /home/zhengqili/dynibar 5 | 6 | folder_path = /home/zhengqili/release 7 | 8 | distributed = False 9 | 10 | ## dataset 11 | eval_dataset = dynamic-test 12 | eval_scenes = kid-running 13 | ### TESTING 14 | chunk_size = 8192 15 | 16 | ### RENDERING 17 | N_importance = 64 18 | N_samples = 64 19 | inv_uniform = True 20 | white_bkgd = False 21 | 22 | anti_alias_pooling = 0 23 | mask_rgb = 1 24 | input_dir = True 25 | input_xyz = False 26 | 27 | training_height = 288 28 | 29 | max_range = 40 30 | num_source_views = 7 31 | 32 | render_idx = 30 33 | 34 | mask_src_view = True 35 | num_vv = 3 36 | -------------------------------------------------------------------------------- /configs/train_kid-running.txt: -------------------------------------------------------------------------------- 1 | expname = kid-running-test 2 | 3 | rootdir = /home/zhengqili/dynibar 4 | 5 | folder_path = /home/zhengqili/release 6 | 7 | no_reload = False 8 | render_stride = 1 9 | distributed = False 10 | no_load_opt = True 11 | no_load_scheduler = True 12 | n_iters = 400000 13 | 14 | ## dataset 15 | train_dataset = monocular 16 | train_scenes = kid-running 17 | eval_dataset = monocular 18 | eval_scenes = kid-running 19 | 20 | ### TRAINING 21 | N_rand = 3072 22 | lrate_feature = 8e-4 23 | lrate_mlp = 4e-4 24 | lrate_decay_factor = 0.5 25 | init_decay_epoch = 400 # modify this s.t. num_imgs * num_epoch ~= 30-40K 26 | 27 | ### TESTING 28 | chunk_size = 8192 29 | 30 | ### RENDERING 31 | N_importance = 0 32 | N_samples = 64 33 | inv_uniform = True 34 | white_bkgd = False 35 | 36 | ### CONSOLE AND TENSORBOARD 37 | i_img = 5000 38 | i_print = 5000 39 | i_weights = 10000 40 | 41 | anti_alias_pooling = 0 42 | mask_rgb = 1 43 | input_dir = True 44 | input_xyz = False 45 | 46 | training_height = 288 47 | 48 | w_cycle = 0.1 49 | cycle_factor = 0.1 50 | 51 | w_disp = 1e-1 52 | w_flow = 1e-2 53 | w_distortion = 1e-3 54 | w_reg = 0.05 55 | 56 | w_skew_entropy = 5e-4 57 | lr_multipler = 1.0 58 | 59 | decay_rate = 10 60 | anneal_cycle = True 61 | 62 | erosion_radius = 3 63 | occ_weights_mode = 0 64 | 65 | max_range = 42 66 | num_source_views = 7 67 | 68 | num_vv = 3 69 | mask_src_view = True 70 | -------------------------------------------------------------------------------- /ibrnet/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | import torch 17 | from utils import img2charbonier 18 | 19 | EPSILON = 0.001 20 | 21 | class Criterion(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, outputs, ray_batch, motion_mask=None): 26 | ''' 27 | training criterion 28 | ''' 29 | pred_rgb = outputs['rgb'] 30 | pred_mask = outputs['mask'].float() 31 | gt_rgb = ray_batch['rgb'] 32 | 33 | if motion_mask is not None: 34 | pred_mask = pred_mask * motion_mask.float() 35 | 36 | loss = img2charbonier(pred_rgb, gt_rgb, pred_mask, EPSILON) 37 | 38 | return loss 39 | 40 | 41 | 42 | def compute_temporal_rgb_loss(outputs, ray_batch, motion_mask=None): 43 | pred_rgb = outputs['rgb'] 44 | gt_rgb = ray_batch['rgb'] 45 | 46 | occ_weight_map = outputs['occ_weight_map'] 47 | pred_mask = outputs['mask'].float() 48 | 49 | if motion_mask is not None: 50 | pred_mask = pred_mask * motion_mask 51 | 52 | final_w = pred_mask * occ_weight_map 53 | final_w = final_w.unsqueeze(-1).repeat(1, 3) 54 | 55 | loss = torch.sum(final_w * torch.sqrt((pred_rgb - gt_rgb)**2 + EPSILON**2) ) / (torch.sum(final_w) + 1e-8) 56 | return loss 57 | 58 | def compute_rgb_loss(pred_rgb, ray_batch, pred_mask): 59 | gt_rgb = ray_batch['rgb'] 60 | loss = img2charbonier(pred_rgb, gt_rgb, pred_mask, EPSILON) 61 | 62 | return loss 63 | 64 | # def compute_mask_ssi_depth_loss(pred_depth, gt_depth, mask): 65 | # t_pred = torch.median(pred_depth) 66 | # s_pred = torch.mean(torch.abs(pred_depth - t_pred)) 67 | 68 | # t_gt = torch.median(gt_depth) 69 | # s_gt = torch.mean(torch.abs(gt_depth - t_gt)) 70 | 71 | # pred_depth_n = (pred_depth - t_pred) / s_pred 72 | # gt_depth_n = (gt_depth - t_gt) / s_gt 73 | 74 | # num_pixel = torch.sum(mask) + 1e-8 75 | 76 | # return torch.sum(torch.abs(pred_depth_n - gt_depth_n) * mask)/num_pixel 77 | 78 | 79 | def compute_entropy(x): 80 | return -torch.mean(x * torch.log(x + 1e-8)) 81 | 82 | 83 | def compute_flow_loss(render_flow, gt_flow, gt_mask): 84 | gt_mask_rep = gt_mask.repeat(1, 1, 2) 85 | return torch.sum(torch.abs(render_flow - gt_flow) * gt_mask_rep) / (torch.sum(gt_mask_rep) + 1e-8) 86 | -------------------------------------------------------------------------------- /ibrnet/data_loaders/flow_utils.py: -------------------------------------------------------------------------------- 1 | """Optical flow helper functions.""" 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | def warp_flow(img, flow): 7 | h, w = flow.shape[:2] 8 | flow_new = flow.copy() 9 | flow_new[:,:,0] += np.arange(w) 10 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis] 11 | 12 | res = cv2.remap(img, flow_new, None, 13 | cv2.INTER_LINEAR, 14 | borderMode=cv2.BORDER_CONSTANT) 15 | return res 16 | 17 | def make_color_wheel(): 18 | """ 19 | Generate color wheel according Middlebury color code 20 | :return: Color wheel 21 | """ 22 | RY = 15 23 | YG = 6 24 | GC = 4 25 | CB = 11 26 | BM = 13 27 | MR = 6 28 | 29 | ncols = RY + YG + GC + CB + BM + MR 30 | 31 | colorwheel = np.zeros([ncols, 3]) 32 | 33 | col = 0 34 | 35 | # RY 36 | colorwheel[0:RY, 0] = 255 37 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 38 | col += RY 39 | 40 | # YG 41 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 42 | colorwheel[col:col+YG, 1] = 255 43 | col += YG 44 | 45 | # GC 46 | colorwheel[col:col+GC, 1] = 255 47 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 48 | col += GC 49 | 50 | # CB 51 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 52 | colorwheel[col:col+CB, 2] = 255 53 | col += CB 54 | 55 | # BM 56 | colorwheel[col:col+BM, 2] = 255 57 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 58 | col += + BM 59 | 60 | # MR 61 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 62 | colorwheel[col:col+MR, 0] = 255 63 | 64 | return colorwheel 65 | 66 | 67 | def compute_color(u, v): 68 | """ 69 | compute optical flow color map 70 | :param u: optical flow horizontal map 71 | :param v: optical flow vertical map 72 | :return: optical flow in color code 73 | """ 74 | [h, w] = u.shape 75 | img = np.zeros([h, w, 3]) 76 | nanIdx = np.isnan(u) | np.isnan(v) 77 | u[nanIdx] = 0 78 | v[nanIdx] = 0 79 | 80 | colorwheel = make_color_wheel() 81 | ncols = np.size(colorwheel, 0) 82 | 83 | rad = np.sqrt(u**2+v**2) 84 | 85 | a = np.arctan2(-v, -u) / np.pi 86 | 87 | fk = (a+1) / 2 * (ncols - 1) + 1 88 | 89 | k0 = np.floor(fk).astype(int) 90 | 91 | k1 = k0 + 1 92 | k1[k1 == ncols+1] = 1 93 | f = fk - k0 94 | 95 | for i in range(0, np.size(colorwheel,1)): 96 | tmp = colorwheel[:, i] 97 | col0 = tmp[k0-1] / 255 98 | col1 = tmp[k1-1] / 255 99 | col = (1-f) * col0 + f * col1 100 | 101 | idx = rad <= 1 102 | col[idx] = 1-rad[idx]*(1-col[idx]) 103 | notidx = np.logical_not(idx) 104 | 105 | col[notidx] *= 0.75 106 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 107 | 108 | return img 109 | 110 | 111 | 112 | def flow_to_image(flow, display=False): 113 | """ 114 | Convert flow into middlebury color code image 115 | :param flow: optical flow map 116 | :return: optical flow image in middlebury color 117 | """ 118 | UNKNOWN_FLOW_THRESH = 200 119 | u = flow[:, :, 0] 120 | v = flow[:, :, 1] 121 | 122 | maxu = -999. 123 | maxv = -999. 124 | minu = 999. 125 | minv = 999. 126 | 127 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 128 | u[idxUnknow] = 0 129 | v[idxUnknow] = 0 130 | 131 | maxu = max(maxu, np.max(u)) 132 | minu = min(minu, np.min(u)) 133 | 134 | maxv = max(maxv, np.max(v)) 135 | minv = min(minv, np.min(v)) 136 | 137 | # sqrt_rad = u**2 + v**2 138 | rad = np.sqrt(u**2 + v**2) 139 | 140 | maxrad = max(-1, np.max(rad)) 141 | 142 | if display: 143 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) 144 | 145 | u = u/(maxrad + np.finfo(float).eps) 146 | v = v/(maxrad + np.finfo(float).eps) 147 | 148 | img = compute_color(u, v) 149 | 150 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 151 | img[idx] = 0 152 | 153 | return np.uint8(img) -------------------------------------------------------------------------------- /ibrnet/data_loaders/create_training_dataset.py: -------------------------------------------------------------------------------- 1 | """Class definition of data sampler.""" 2 | 3 | from operator import itemgetter 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torch.utils.data import DistributedSampler 10 | from torch.utils.data import Sampler 11 | from torch.utils.data import WeightedRandomSampler 12 | 13 | from . import dataset_dict 14 | 15 | 16 | class DatasetFromSampler(Dataset): 17 | """Dataset to create indexes from `Sampler`.""" 18 | 19 | def __init__(self, sampler: Sampler): 20 | """Initialisation for DatasetFromSampler.""" 21 | self.sampler = sampler 22 | self.sampler_list = None 23 | 24 | def __getitem__(self, index: int): 25 | """Gets element of the dataset. 26 | 27 | Args: 28 | index: index of the element in the dataset 29 | 30 | Returns: 31 | Single element by index 32 | """ 33 | if self.sampler_list is None: 34 | self.sampler_list = list(self.sampler) 35 | return self.sampler_list[index] 36 | 37 | def __len__(self) -> int: 38 | return len(self.sampler) 39 | 40 | 41 | class DistributedSamplerWrapper(DistributedSampler): 42 | """Wrapper over `Sampler` for distributed training. 43 | 44 | Allows you to use any sampler in distributed mode. It is especially useful in 45 | conjunction with `torch.nn.parallel.DistributedDataParallel`. In such case, 46 | each process can pass a DistributedSamplerWrapper instance as a DataLoader 47 | sampler, and load a subset of subsampled data of the original dataset that is 48 | exclusive to it. .. note:: 49 | 50 | Sampler is assumed to be of constant size. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | sampler, 56 | num_replicas: Optional[int] = None, 57 | rank: Optional[int] = None, 58 | shuffle: bool = True, 59 | ): 60 | super(DistributedSamplerWrapper, self).__init__( 61 | DatasetFromSampler(sampler), 62 | num_replicas=num_replicas, 63 | rank=rank, 64 | shuffle=shuffle, 65 | ) 66 | self.sampler = sampler 67 | 68 | def __iter__(self): 69 | self.dataset = DatasetFromSampler(self.sampler) 70 | indexes_of_indexes = super().__iter__() 71 | subsampler_indexes = self.dataset 72 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) 73 | 74 | 75 | def create_training_dataset(args): 76 | """Creating training dataset. 77 | 78 | Args: 79 | args: input argument 80 | 81 | Returns: 82 | train_dataset: training dataset 83 | train_sampler: training sampler 84 | """ 85 | # parse args.train_dataset, "+" indicates that multiple datasets are used, 86 | # for example "ibrnet_collect+llff+spaces" 87 | # otherwise only one dataset is used 88 | 89 | print('training dataset: {}'.format(args.train_dataset)) 90 | 91 | mode = 'train' 92 | if '+' not in args.train_dataset: 93 | train_dataset = dataset_dict[args.train_dataset]( 94 | args, mode, scenes=args.train_scenes 95 | ) 96 | train_sampler = ( 97 | torch.utils.data.distributed.DistributedSampler(train_dataset) 98 | if args.distributed 99 | else None 100 | ) 101 | else: 102 | train_dataset_names = args.train_dataset.split('+') 103 | weights = args.dataset_weights 104 | assert len(train_dataset_names) == len(weights) 105 | assert np.abs(np.sum(weights) - 1.0) < 1e-6 106 | print('weights:{}'.format(weights)) 107 | train_datasets = [] 108 | train_weights_samples = [] 109 | for training_dataset_name, weight in zip(train_dataset_names, weights): 110 | train_dataset = dataset_dict[training_dataset_name]( 111 | args, 112 | mode, 113 | scenes=args.train_scenes, 114 | ) 115 | train_datasets.append(train_dataset) 116 | num_samples = len(train_dataset) 117 | weight_each_sample = weight / num_samples 118 | train_weights_samples.extend([weight_each_sample] * num_samples) 119 | 120 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 121 | train_weights = torch.from_numpy(np.array(train_weights_samples)) 122 | sampler = WeightedRandomSampler(train_weights, len(train_weights)) 123 | train_sampler = ( 124 | DistributedSamplerWrapper(sampler) if args.distributed else sampler 125 | ) 126 | 127 | return train_dataset, train_sampler 128 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | import cv2 4 | import matplotlib as mpl 5 | from matplotlib import cm 6 | from matplotlib.backends.backend_agg import FigureCanvasAgg 7 | from matplotlib.figure import Figure 8 | import numpy as np 9 | import torch 10 | 11 | HUGE_NUMBER = 1e10 12 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 13 | 14 | img_HWC2CHW = lambda x: x.permute(2, 0, 1) 15 | gray2rgb = lambda x: x.unsqueeze(2).repeat(1, 1, 3) 16 | 17 | 18 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 19 | mse2psnr = lambda x: -10.0 * np.log(x + TINY_NUMBER) / np.log(10.0) 20 | 21 | 22 | def img2mse(x, y, mask=None): 23 | """MSE between two images.""" 24 | if mask is None: 25 | return torch.mean((x - y) * (x - y)) 26 | else: 27 | return torch.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / ( 28 | torch.sum(mask) * x.shape[-1] + TINY_NUMBER 29 | ) 30 | 31 | 32 | def img2charbonier(x, y, mask=None, eps=0.001): 33 | """Charbonier loss between two images.""" 34 | if mask is None: 35 | return torch.mean(torch.sqrt((x - y) ** 2 + eps**2)) 36 | else: 37 | return torch.sum( 38 | torch.sqrt((x - y) ** 2 + eps**2) * mask.unsqueeze(-1) 39 | ) / (torch.sum(mask) * x.shape[-1] + TINY_NUMBER) 40 | 41 | 42 | def img2psnr(x, y, mask=None): 43 | return mse2psnr(img2mse(x, y, mask).item()) 44 | 45 | 46 | def cycle(iterable): 47 | while True: 48 | for x in iterable: 49 | yield x 50 | 51 | 52 | def get_vertical_colorbar( 53 | h, vmin, vmax, cmap_name='jet', label=None, cbar_precision=2 54 | ): 55 | """Get colorbar.""" 56 | fig = Figure(figsize=(2, 8), dpi=100) 57 | fig.subplots_adjust(right=1.5) 58 | canvas = FigureCanvasAgg(fig) 59 | 60 | # Do some plotting. 61 | ax = fig.add_subplot(111) 62 | cmap = cm.get_cmap(cmap_name) 63 | norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 64 | 65 | tick_cnt = 6 66 | tick_loc = np.linspace(vmin, vmax, tick_cnt) 67 | cb1 = mpl.colorbar.ColorbarBase( 68 | ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation='vertical' 69 | ) 70 | 71 | tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc] 72 | if cbar_precision == 0: 73 | tick_label = [x[:-2] for x in tick_label] 74 | 75 | cb1.set_ticklabels(tick_label) 76 | 77 | cb1.ax.tick_params(labelsize=18, rotation=0) 78 | 79 | if label is not None: 80 | cb1.set_label(label) 81 | 82 | fig.tight_layout() 83 | 84 | canvas.draw() 85 | s, (width, height) = canvas.print_to_buffer() 86 | 87 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) 88 | 89 | im = im[:, :, :3].astype(np.float32) / 255.0 90 | if h != im.shape[0]: 91 | w = int(im.shape[1] / im.shape[0] * h) 92 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) 93 | 94 | return im 95 | 96 | 97 | def colorize_np( 98 | x, 99 | cmap_name='jet', 100 | mask=None, 101 | range=None, 102 | append_cbar=False, 103 | cbar_in_image=False, 104 | cbar_precision=2, 105 | ): 106 | """turn a grayscale image into a color image.""" 107 | if range is not None: 108 | vmin, vmax = range 109 | elif mask is not None: 110 | # vmin, vmax = np.percentile(x[mask], (2, 100)) 111 | vmin = np.min(x[mask][np.nonzero(x[mask])]) 112 | vmax = np.max(x[mask]) 113 | # vmin = vmin - np.abs(vmin) * 0.01 114 | x[np.logical_not(mask)] = vmin 115 | # print(vmin, vmax) 116 | else: 117 | vmin, vmax = np.percentile(x, (1, 99)) 118 | vmax += TINY_NUMBER 119 | 120 | x = np.clip(x, vmin, vmax) 121 | x = (x - vmin) / (vmax - vmin) 122 | x = np.clip(x, 0.0, 1.0) 123 | 124 | cmap = cm.get_cmap(cmap_name) 125 | x_new = cmap(x)[:, :, :3] 126 | 127 | if mask is not None: 128 | mask = np.float32(mask[:, :, np.newaxis]) 129 | x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask) 130 | 131 | cbar = get_vertical_colorbar( 132 | h=x.shape[0], 133 | vmin=vmin, 134 | vmax=vmax, 135 | cmap_name=cmap_name, 136 | cbar_precision=cbar_precision, 137 | ) 138 | 139 | if append_cbar: 140 | if cbar_in_image: 141 | x_new[:, -cbar.shape[1] :, :] = cbar 142 | else: 143 | x_new = np.concatenate( 144 | (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1 145 | ) 146 | return x_new 147 | else: 148 | return x_new 149 | 150 | 151 | # tensor 152 | def colorize( 153 | x, 154 | cmap_name='jet', 155 | mask=None, 156 | range=None, 157 | append_cbar=False, 158 | cbar_in_image=False, 159 | ): 160 | """Convert gray scale image such as depth to RGB image.""" 161 | device = x.device 162 | x = x.cpu().numpy() 163 | if mask is not None: 164 | mask = mask.cpu().numpy() > 0.99 165 | kernel = np.ones((3, 3), np.uint8) 166 | mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool) 167 | 168 | x = colorize_np(x, cmap_name, mask, range, append_cbar, cbar_in_image) 169 | x = torch.from_numpy(x).to(device) 170 | return x 171 | -------------------------------------------------------------------------------- /save_monocular_cameras.py: -------------------------------------------------------------------------------- 1 | """Save images, depth, flow and mask data into dynibar input format.""" 2 | 3 | ''' 4 | 15 | 16 | 17 | ''' 18 | 19 | 20 | import argparse 21 | import glob 22 | import os 23 | import cv2 24 | import imageio 25 | import numpy as np 26 | 27 | 28 | SAVE_IMG = True 29 | FINAL_H = 288 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--cvd_dir', type=str, help='depth directory') 34 | parser.add_argument('--data_dir', type=str, help='dataset directory') 35 | # parser.add_argument("--scene_name", type=str, 36 | # help='Scene name') # 'kid-running' 37 | args = parser.parse_args() 38 | 39 | pt_out_list = sorted(glob.glob(os.path.join(args.cvd_dir, '*.npz'))) 40 | data_dir = os.path.join(args.data_dir, 'dense') 41 | 42 | try: 43 | original_img_path = os.path.join(data_dir, 'images', '00000.png') 44 | o_img = imageio.imread(original_img_path) 45 | except: 46 | original_img_path = os.path.join(data_dir, 'images', '00000.jpg') 47 | o_img = imageio.imread(original_img_path) 48 | 49 | o_ar = float(o_img.shape[1]) / float(o_img.shape[0]) 50 | 51 | final_w, final_h = int(round(FINAL_H * o_ar)), int(FINAL_H) 52 | 53 | img_dir = os.path.join(data_dir, 'images_%dx%d' % (final_w, final_h)) 54 | os.makedirs(img_dir, exist_ok=True) 55 | print('img_dir ', img_dir) 56 | disp_dir = os.path.join(data_dir, 'disp') 57 | os.makedirs(disp_dir, exist_ok=True) 58 | 59 | Ks = [] 60 | mono_depths = [] 61 | c2w_mats = [] 62 | imgs = [] 63 | bounds_mats = [] 64 | 65 | for i, pt_out_path in enumerate(pt_out_list): 66 | print(i) 67 | out_name = pt_out_path.split('/')[-1] 68 | pt_data = np.load(pt_out_path) 69 | 70 | img = pt_data['img_1'][0].transpose(1, 2, 0) 71 | pred_depth = pt_data['depth'][0, 0, ...] 72 | pred_disp = 1.0 / pred_depth 73 | K = pt_data['K'][0, 0, 0, ...].transpose() 74 | img = pt_data['img_1'][0].transpose(1, 2, 0) 75 | cam_c2w = pt_data['cam_c2w'][0] 76 | 77 | K[0, :] *= final_w / img.shape[1] 78 | K[1, :] *= final_h / img.shape[0] 79 | 80 | print('K ', K, abs(K[0, 0] - K[1, 1]) / (K[1, 1] + K[0, 0])) 81 | assert ( 82 | abs(K[0, 0] - K[1, 1]) / (K[1, 1] + K[0, 0]) < 0.005 83 | ) # we assume fx ~= fy 84 | 85 | original_img_path = os.path.join( 86 | data_dir, 'images', '%05d.png' % int(out_name[5:9]) 87 | ) 88 | o_img = imageio.imread(original_img_path) 89 | print(o_img.shape, final_w, final_h) 90 | img_resized = cv2.resize( 91 | o_img, (final_w, final_h), interpolation=cv2.INTER_AREA 92 | ) 93 | pred_disp_resized = cv2.resize( 94 | pred_disp, (final_w, final_h), interpolation=cv2.INTER_LINEAR 95 | ) 96 | 97 | if SAVE_IMG: 98 | imageio.imwrite(os.path.join(img_dir, '%05d.png' % i), img_resized) 99 | np.save( 100 | os.path.join(disp_dir, '%05d.npy' % i), 101 | pred_disp_resized.astype(np.float32), 102 | ) 103 | 104 | mono_depths.append(pred_depth) 105 | c2w_mats.append(cam_c2w) 106 | imgs.append(img_resized) 107 | 108 | close_depth, inf_depth = np.percentile(pred_depth, 5), np.percentile( 109 | pred_depth, 95 110 | ) 111 | # print(close_depth, inf_depth) 112 | bounds = np.array([close_depth, inf_depth]) 113 | bounds_mats.append(bounds) 114 | 115 | c2w_mats = np.stack(c2w_mats, 0) 116 | bounds_mats = np.stack(bounds_mats, 0) 117 | 118 | h, w, fx, fy = imgs[0].shape[0], imgs[0].shape[1], K[0, 0], K[1, 1] 119 | 120 | print('h, w ', h, w, fx, fy) 121 | print('bounds_mats ', np.min(bounds_mats), np.max(bounds_mats)) 122 | 123 | ff = (fx + fy) / 2.0 124 | # hwf = np.array([h, w, fx, fy]).reshape([1, 4]) 125 | hwf = np.array([h, w, ff]).reshape([3, 1]) 126 | 127 | poses = c2w_mats[:, :3, :4].transpose([1, 2, 0]) 128 | 129 | poses = np.concatenate( 130 | [poses, np.tile(hwf[..., np.newaxis], [1, 1, poses.shape[-1]])], 1 131 | ) 132 | 133 | # must switch to [-y, x, z] from [x, -y, -z], NOT [r, u, -t] 134 | poses = np.concatenate( 135 | [ 136 | poses[:, 1:2, :], 137 | poses[:, 0:1, :], 138 | -poses[:, 2:3, :], 139 | poses[:, 3:4, :], 140 | poses[:, 4:5, :], 141 | ], 142 | 1, 143 | ) 144 | 145 | save_arr = [] 146 | for i in range((poses.shape[2])): 147 | save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0)) 148 | 149 | np.save(os.path.join(data_dir, 'poses_bounds_cvd.npy'), save_arr) 150 | -------------------------------------------------------------------------------- /ibrnet/data_loaders/data_utils.py: -------------------------------------------------------------------------------- 1 | """utility function definition for data loader.""" 2 | 3 | import math 4 | import numpy as np 5 | 6 | rng = np.random.RandomState(234) 7 | _EPS = np.finfo(float).eps * 4.0 8 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 9 | 10 | 11 | def vector_norm(data, axis=None, out=None): 12 | """Return length, i.e. eucledian norm, of ndarray along axis.""" 13 | data = np.array(data, dtype=np.float64, copy=True) 14 | if out is None: 15 | if data.ndim == 1: 16 | return math.sqrt(np.dot(data, data)) 17 | data *= data 18 | out = np.atleast_1d(np.sum(data, axis=axis)) 19 | np.sqrt(out, out) 20 | return out 21 | else: 22 | data *= data 23 | np.sum(data, axis=axis, out=out) 24 | np.sqrt(out, out) 25 | 26 | 27 | def quaternion_about_axis(angle, axis): 28 | """Return quaternion for rotation about axis.""" 29 | quaternion = np.zeros((4,), dtype=np.float64) 30 | quaternion[:3] = axis[:3] 31 | qlen = vector_norm(quaternion) 32 | if qlen > _EPS: 33 | quaternion *= math.sin(angle / 2.0) / qlen 34 | quaternion[3] = math.cos(angle / 2.0) 35 | return quaternion 36 | 37 | 38 | def quaternion_matrix(quaternion): 39 | """Return homogeneous rotation matrix from quaternion.""" 40 | q = np.array(quaternion[:4], dtype=np.float64, copy=True) 41 | nq = np.dot(q, q) 42 | if nq < _EPS: 43 | return np.identity(4) 44 | q *= math.sqrt(2.0 / nq) 45 | q = np.outer(q, q) 46 | return np.array( 47 | ( 48 | (1.0 - q[1, 1] - q[2, 2], q[0, 1] - q[2, 3], q[0, 2] + q[1, 3], 0.0), 49 | (q[0, 1] + q[2, 3], 1.0 - q[0, 0] - q[2, 2], q[1, 2] - q[0, 3], 0.0), 50 | (q[0, 2] - q[1, 3], q[1, 2] + q[0, 3], 1.0 - q[0, 0] - q[1, 1], 0.0), 51 | (0.0, 0.0, 0.0, 1.0), 52 | ), 53 | dtype=np.float64, 54 | ) 55 | 56 | 57 | def angular_dist_between_2_vectors(vec1, vec2): 58 | vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER) 59 | vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER) 60 | angular_dists = np.arccos( 61 | np.clip(np.sum(vec1_unit * vec2_unit, axis=-1), -1.0, 1.0) 62 | ) 63 | return angular_dists 64 | 65 | 66 | def batched_angular_dist_rot_matrix(r1, r2): 67 | """calculate the angular distance between two rotation matrices (batched).""" 68 | 69 | assert ( 70 | r1.shape[-1] == 3 71 | and r2.shape[-1] == 3 72 | and r1.shape[-2] == 3 73 | and r2.shape[-2] == 3 74 | ) 75 | return np.arccos( 76 | np.clip( 77 | (np.trace(np.matmul(r2.transpose(0, 2, 1), r1), axis1=1, axis2=2) - 1) 78 | / 2.0, 79 | a_min=-1 + TINY_NUMBER, 80 | a_max=1 - TINY_NUMBER, 81 | ) 82 | ) 83 | 84 | 85 | def get_nearest_pose_ids( 86 | tar_pose, 87 | ref_poses, 88 | tar_id=-1, 89 | angular_dist_method='vector', 90 | scene_center=(0, 0, 0), 91 | ): 92 | """Get poses id in nearest neighboorhood manner.""" 93 | num_cams = len(ref_poses) 94 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0) 95 | 96 | if angular_dist_method == 'matrix': 97 | dists = batched_angular_dist_rot_matrix( 98 | batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3] 99 | ) 100 | elif angular_dist_method == 'vector': 101 | tar_cam_locs = batched_tar_pose[:, :3, 3] 102 | ref_cam_locs = ref_poses[:, :3, 3] 103 | scene_center = np.array(scene_center)[None, ...] 104 | tar_vectors = tar_cam_locs - scene_center 105 | ref_vectors = ref_cam_locs - scene_center 106 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors) 107 | elif angular_dist_method == 'dist': 108 | tar_cam_locs = batched_tar_pose[:, :3, 3] 109 | ref_cam_locs = ref_poses[:, :3, 3] 110 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1) 111 | else: 112 | raise NotImplementedError 113 | 114 | if tar_id >= 0: 115 | assert tar_id < num_cams 116 | dists[tar_id] = 1e3 117 | 118 | sorted_ids = np.argsort(dists) 119 | 120 | return sorted_ids 121 | 122 | 123 | def get_interval_pose_ids( 124 | tar_pose, 125 | ref_poses, 126 | tar_id=-1, 127 | angular_dist_method='dist', 128 | interval=2, 129 | scene_center=(0, 0, 0)): 130 | """Get poses id in nearest neighboorhood manner from every 'interval' frames.""" 131 | 132 | original_indices = np.array(range(0, len(ref_poses))) 133 | 134 | ref_poses = ref_poses[::interval] 135 | subsample_indices = original_indices[::interval] 136 | 137 | num_cams = len(ref_poses) 138 | batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0) 139 | 140 | if angular_dist_method == 'matrix': 141 | dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], 142 | ref_poses[:, :3, :3]) 143 | elif angular_dist_method == 'vector': 144 | tar_cam_locs = batched_tar_pose[:, :3, 3] 145 | ref_cam_locs = ref_poses[:, :3, 3] 146 | scene_center = np.array(scene_center)[None, ...] 147 | tar_vectors = tar_cam_locs - scene_center 148 | ref_vectors = ref_cam_locs - scene_center 149 | dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors) 150 | elif angular_dist_method == 'dist': 151 | tar_cam_locs = batched_tar_pose[:, :3, 3] 152 | ref_cam_locs = ref_poses[:, :3, 3] 153 | dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1) 154 | else: 155 | raise NotImplementedError 156 | 157 | if tar_id >= 0: 158 | assert tar_id < num_cams 159 | dists[tar_id] = 1e3 160 | 161 | sorted_ids = np.argsort(dists) 162 | 163 | final_ids = subsample_indices[sorted_ids] 164 | 165 | return final_ids 166 | -------------------------------------------------------------------------------- /ibrnet/projection.py: -------------------------------------------------------------------------------- 1 | """Class definition for perspective projection.""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class Projector: 8 | """Class for performing perspective projection.""" 9 | 10 | def __init__(self, device): 11 | self.device = device 12 | 13 | def inbound(self, pixel_locations, h, w): 14 | """Check if the pixel locations are in valid range.""" 15 | return ( 16 | (pixel_locations[..., 0] <= w - 1.0) 17 | & (pixel_locations[..., 0] >= 0) 18 | & (pixel_locations[..., 1] <= h - 1.0) 19 | & (pixel_locations[..., 1] >= 0) 20 | ) 21 | 22 | def normalize(self, pixel_locations, h, w): 23 | """Normalize pixel locations for grid_sampler function.""" 24 | resize_factor = torch.tensor([w - 1.0, h - 1.0]).to(self.device)[ 25 | None, None, : 26 | ] 27 | normalized_pixel_locations = ( 28 | 2 * pixel_locations / resize_factor - 1.0 29 | ) # [n_views, n_points, 2] 30 | return normalized_pixel_locations 31 | 32 | def compute_projections(self, xyz, train_cameras): 33 | """Project 3D points into views using training camera parameteres.""" 34 | original_shape = xyz.shape[:-1] 35 | xyz = xyz.reshape(original_shape[0], -1, 3) 36 | 37 | num_views = len(train_cameras) 38 | train_intrinsics = train_cameras[:, 2:18].reshape( 39 | -1, 4, 4 40 | ) # [n_views, 4, 4] 41 | train_poses = train_cameras[:, -16:].reshape(-1, 4, 4) # [n_views, 4, 4] 42 | xyz_h = torch.cat( 43 | [xyz, torch.ones_like(xyz[..., :1])], dim=-1 44 | ) # [n_points, 4] 45 | 46 | projections = train_intrinsics.bmm(torch.inverse(train_poses)).bmm( 47 | xyz_h.permute(0, 2, 1) 48 | ) # [n_views, 4, n_points] 49 | 50 | projections = projections.permute(0, 2, 1) # [n_views, n_points, 4] 51 | pixel_locations = projections[..., :2] / torch.clamp( 52 | projections[..., 2:3], min=1e-8 53 | ) # [n_views, n_points, 2] 54 | pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6) 55 | 56 | mask = projections[..., 2] > 0 # a point is invalid if behind the camera 57 | return pixel_locations.reshape( 58 | (num_views,) + original_shape[1:] + (2,) 59 | ), mask.reshape((num_views,) + original_shape[1:]) 60 | 61 | def compute_angle(self, xyz_st, xyz, query_camera, train_cameras): 62 | """Compute difference of viewing angle between rays from source and ones from target view. 63 | 64 | Args: 65 | 66 | xyz_st: reference 3D point location without scene motion 67 | xyz: 3D positions displaced by scene motion at nearby times 68 | query_camera: target view camera parameters 69 | train_imgs: source view images 70 | 71 | Returns: 72 | Difference of viewing angle between rays from source and ones from target 73 | view. 74 | """ 75 | original_shape = xyz.shape[:-1] 76 | xyz_st_ = xyz_st.reshape(xyz_st.shape[0], -1, 3) 77 | xyz_ = xyz.reshape(xyz.shape[0], -1, 3) 78 | 79 | train_poses = train_cameras[:, -16:].reshape(-1, 4, 4) # [n_views, 4, 4] 80 | num_views = len(train_poses) 81 | query_pose = ( 82 | query_camera[-16:].reshape(-1, 4, 4).repeat(num_views, 1, 1) 83 | ) # [n_views, 4, 4] 84 | 85 | ray2tar_pose = F.normalize( 86 | query_pose[:, :3, 3].unsqueeze(1) - xyz_st_, dim=-1 87 | ) 88 | ray2train_pose = F.normalize( 89 | train_poses[:, :3, 3].unsqueeze(1) - xyz_, dim=-1 90 | ) 91 | ray_diff = ray2tar_pose - ray2train_pose 92 | 93 | ray_diff_dot = torch.sum( 94 | ray2tar_pose * ray2train_pose, dim=-1, keepdim=True 95 | ) 96 | ray_diff_direction = F.normalize( 97 | ray_diff, dim=-1 98 | ) # ray_diff / torch.clamp(ray_diff_norm, min=1e-6) 99 | 100 | ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) 101 | return ray_diff.reshape((num_views,) + original_shape[1:] + (4,)) 102 | 103 | def compute_with_motions( 104 | self, xyz_st, xyz, query_camera, train_imgs, train_cameras, featmaps 105 | ): 106 | """Extract 2D feature by projecting 3D points displaced by scene motion. 107 | 108 | Args: 109 | xyz_st: reference point location without scene motion 110 | xyz: 3D point positions displaced by scene motion 111 | query_camera: target view camera parameters 112 | train_imgs: source view images 113 | train_cameras: source view camera parameters 114 | featmaps: source view 2D image feature maps. 115 | 116 | Returns: 117 | rgb_feat_sampled: extracted 2D feature 118 | ray_diff: viewing angle difference between target ray and source ray 119 | mask: valid masks 120 | """ 121 | 122 | assert ( 123 | (train_imgs.shape[0] == 1) 124 | and (train_cameras.shape[0] == 1) 125 | and (query_camera.shape[0] == 1) 126 | ), 'only support batch_size=1 for now' 127 | 128 | xyz_st = xyz_st[None, ...].expand(xyz.shape[0], -1, -1, -1) 129 | 130 | train_imgs = train_imgs.squeeze(0) # [n_views, h, w, 3] 131 | train_cameras = train_cameras.squeeze(0) # [n_views, 34] 132 | query_camera = query_camera.squeeze(0) # [34, ] 133 | 134 | train_imgs = train_imgs.permute(0, 3, 1, 2) # [n_views, 3, h, w] 135 | 136 | h, w = train_cameras[0][:2] 137 | 138 | # compute the projection of the query points to each reference image 139 | pixel_locations, mask_in_front = self.compute_projections( 140 | xyz, train_cameras 141 | ) 142 | 143 | normalized_pixel_locations = self.normalize( 144 | pixel_locations, h, w 145 | ) # [n_views, n_rays, n_samples, 2] 146 | 147 | # rgb sampling 148 | rgbs_sampled = F.grid_sample( 149 | train_imgs, normalized_pixel_locations, align_corners=True 150 | ) 151 | rgbs_sampled_ = rgbs_sampled.permute( 152 | 2, 3, 0, 1 153 | ) # [n_rays, n_samples, n_views, 3] 154 | 155 | # deep feature sampling 156 | feat_sampled = F.grid_sample( 157 | featmaps, normalized_pixel_locations, align_corners=True 158 | ) 159 | feat_sampled = feat_sampled.permute( 160 | 2, 3, 0, 1 161 | ) # [n_rays, n_samples, n_views, d] 162 | rgb_feat_sampled = torch.cat( 163 | [rgbs_sampled_, feat_sampled], dim=-1 164 | ) # [n_rays, n_samples, n_views, d+3] 165 | 166 | inbound = self.inbound(pixel_locations, h, w) 167 | ray_diff = self.compute_angle( 168 | xyz_st, xyz, query_camera, train_cameras 169 | ).detach() 170 | 171 | ray_diff = ray_diff.permute(1, 2, 0, 3) 172 | mask = ( 173 | (inbound * mask_in_front).float().permute(1, 2, 0)[..., None] 174 | ) # [n_rays, n_samples, n_views, 1] 175 | 176 | return rgb_feat_sampled, ray_diff, mask 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is not an officially supported Google product. 2 | 3 | # DynIBaR: Neural Dynamic Image-Based Rendering 4 | 5 | ### [Project Page](https://dynibar.github.io/) 6 | 7 | Implementation for CVPR 2023 paper (best paper honorable mention) 8 | 9 | [DynIBaR: Neural Dynamic Image-Based Rendering, CVPR 2023](https://dynibar.github.io/)
10 | 11 | [Zhengqi Li](https://zhengqili.github.io/)1, [Qianqian Wang](https://www.cs.cornell.edu/~qqw/)1,2, [Forrester Cole](https://people.csail.mit.edu/fcole/)1, [Richard Tucker](https://research.google/people/RichardTucker/)1, [Noah Snavely](https://www.cs.cornell.edu/~snavely/)1 12 |

13 | 1Google Research, 2Cornell Tech, Cornell University \ 14 |
15 | 16 | ## Instructions for installing dependencies 17 | 18 | ### Python Environment 19 | 20 | The following codebase was successfully run with Python 3.8 and CUDA 11.3. We 21 | suggest installing the library in a virtual environment such as Anaconda. 22 | 23 | To install required libraries, run: \ 24 | `conda env create -f enviornment_dynibar.yml` 25 | 26 | To install softmax splatting for preprocessing, clone and install the library 27 | from [here](https://github.com/hperrot/splatting). 28 | 29 | To measure LPIPS, copy "models" folder from 30 | [NSFF](https://github.com/zhengqili/Neural-Scene-Flow-Fields/tree/main/nsff_exp/models), 31 | and put it in the code root directory. 32 | 33 | ## Evaluation on Nvidia Dynamic scene dataset. 34 | 35 | ### Downloading data and pretrained checkpoint 36 | 37 | We include pretrained checkpoints that can be accessed by running: 38 | 39 | ``` 40 | wget https://storage.googleapis.com/gresearch/dynibar/nvidia_checkpoints.zip 41 | unzip nvidia_checkpoints.zip 42 | ``` 43 | 44 | put the unzipped "checkpoints" folder in the code root directory. 45 | 46 | Each scene in the Nvidia dataset can be accessed 47 | [here](https://drive.google.com/drive/folders/1Gv6j_RvDG2WrpqEJWtx73u1tlCZKsPiM?usp=sharing) 48 | 49 | The input data directory should similar to the following format: 50 | xxx/nvidia_long_release/Balloon1 51 | 52 | Run the following command for each scene to obtain reported quantitative results: 53 | 54 | ```bash 55 | # Usage: In txt file, You need to change "rootdir" to your code root directory, 56 | # and "folder_path" to input data directory, and make sure "coarse_dir" points to 57 | # "checkpoints" folder you unzip. 58 | python eval_nvidia.py --config configs_nvidia/eval_balloon1_long.txt 59 | ``` 60 | 61 | Note: It will take ~8 hours to evaluate each scene with 4x Nvidia A100 GPUs. 62 | 63 | ## Training/rendering on monocular videos. 64 | 65 | ### Required inputs and corresponding folders or files: 66 | 67 | We provide a template input data for the NSFF example video, which can 68 | be downloaded 69 | [here](https://drive.google.com/file/d/1t6VLtcdxITFcdm9fi9SSFOiHqgHu9wdP/view?usp=sharing) 70 | 71 | The input data directory should be in the following format: 72 | xxx/release/kid-running/dense/*** 73 | 74 | For your own video, you need to include the following folders to run training. 75 | 76 | * disp: disparity maps from 77 | [dynamic-cvd](https://github.com/google/dynamic-video-depth). Note that you 78 | need to run test.py to save the disparity and camera parameters to the disk. 79 | * images_wxh: resized images at resolution w x h. 80 | * poses_bounds_cvd.npy: camera parameters of input video in LLFF format. 81 | 82 | You can generate the above three items with the following script: 83 | 84 | ```bash 85 | # Usage: data_dir is input video directory path, 86 | # cvd_dir is saved depth directory resulting from running 87 | # "test.py" at https://github.com/google/dynamic-video-depth 88 | python save_monocular_cameras.py \ 89 | --data_dir xxx/release/kid-running \ 90 | --cvd_dir xxx/kid-running_scene_flow_motion_field_epoch_20/epoch0020_test 91 | ``` 92 | 93 | * source_virtual_views_wxh: virtual source views used to improve training 94 | stability and rendering quality (used in monocular video only). Running 95 | the following script to obtain them: 96 | 97 | ```bash 98 | # Usage: data_dir is input video directory path, 99 | # cvd_dir is saved depth direcotry resulting from running 100 | # "test.py" at https://github.com/google/dynamic-video-depth 101 | python render_source_vv.py \ 102 | --data_dir xxx/release/kid-running \ 103 | --cvd_dir xxx/kid-running_scene_flow_motion_field_epoch_20/epoch0020_test 104 | ``` 105 | 106 | * flow_i1, flow_i2, flow_i3: estimated optical flows within temporal window of 107 | length 3. You can follow prior NSFF 108 | [script](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/nsff_scripts/run_flows_video.py) 109 | to run optical flows between the frame i and its nearby frames i+1, i+2, 110 | i+3, and save them in folders "flow_i1", "flow_i2", "flow_i3" respectively. 111 | For example, 00000_fwd.npz in folder "flow_i1" stores forward flow and valid 112 | mask from frame 0 to frame 1, and 00000_bwd.npz stores backward flow and 113 | valid mask from frame 1 to frame 0. 114 | 115 | * static_masks, dynamic_masks: motion masks indicating which region is 116 | stationary or moving. You can perform morphological dilation and erosion operations respectively 117 | to ensure static_masks sufficeintly cover the regions of moving objects, and the regions from dynamic_masks 118 | are within the true regions of moving objects. 119 | (Note: due to dependency reason, we don't release code to generate the masks. Instead you could use [script](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/nsff_scripts/run_flows_video.py#L87) from NSFF to generate coarse masks for your usage) 120 | 121 | ### To train the model: 122 | 123 | ```bash 124 | # Usage: config is config txt file for training video 125 | # make sure "rootdir" is your code root directory, 126 | # "folder_path" is your input data directory path, 127 | # "train_scenes" is your folder name. 128 | # For example, if data is in xxx/release/kid-running/dense/, then "train_scenes" is 129 | # "xxx/release/", "train_scenes" is "kid-running" 130 | python train.py \ 131 | --config configs/train_kid-running.txt 132 | ``` 133 | 134 | Hyperparameters in config txt file you might need to know for training a good model on in-the-wild videos 135 | * rootdir: code root directory, should be in format: YOUR_PATH/dynibar 136 | * folder_path: data root directory, 137 | * N_rand: number of random samples at each iterations. Try to set it as large as possible, typically > 3000 gives good results 138 | * init_decay_epoch: number of epochs to linaerly decay the data-driven depth and optical flow losses. Modify this such that num_video_frames * init_decay_epoch = 30~40K 139 | * max_range, num_source_views: max_range indicates maximum search frame ranges to select source views for static model. num_source_views*2 is number of source views used for static model. 140 | 141 | The tensorboard includes rendering visualization as shown below. 142 | 143 | 144 | 145 | ### To render the model: 146 | 147 | ```bash 148 | # Usage: config is config txt file for training video, 149 | # please make sure expname in txt is the saved folder name in 'out' directory 150 | python render_monocular_bt.py \ 151 | --config configs/test_kid-running.txt 152 | ``` 153 | 154 | ### Contact 155 | 156 | For any questions related to our paper and implementation, 157 | please send email to zhengqili@google.com. 158 | 159 | ## Citation 160 | 161 | ``` 162 | @InProceedings{Li_2023_CVPR, 163 | author = {Li, Zhengqi and Wang, Qianqian and Cole, Forrester and Tucker, Richard and Snavely, Noah}, 164 | title = {DynIBaR: Neural Dynamic Image-Based Rendering}, 165 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 166 | month = {June}, 167 | year = {2023}, 168 | pages = {4273-4284} 169 | } 170 | ``` 171 | -------------------------------------------------------------------------------- /ibrnet/feature_network.py: -------------------------------------------------------------------------------- 1 | """Class definition for 2D feature extractor.""" 2 | 3 | import importlib 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def class_for_name(module_name, class_name): 10 | m = importlib.import_module(module_name) 11 | return getattr(m, class_name) 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 15 | """3x3 convolution with padding.""" 16 | return nn.Conv2d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=dilation, 22 | groups=groups, 23 | bias=False, 24 | dilation=dilation, 25 | padding_mode='reflect', 26 | ) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution layer.""" 31 | return nn.Conv2d( 32 | in_planes, 33 | out_planes, 34 | kernel_size=1, 35 | stride=stride, 36 | bias=False, 37 | padding_mode='reflect', 38 | ) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | """Basic CNN block.""" 43 | expansion = 1 44 | 45 | def __init__( 46 | self, 47 | inplanes, 48 | planes, 49 | stride=1, 50 | downsample=None, 51 | groups=1, 52 | base_width=64, 53 | dilation=1, 54 | norm_layer=None, 55 | ): 56 | super(BasicBlock, self).__init__() 57 | if norm_layer is None: 58 | norm_layer = nn.InstanceNorm2d 59 | 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x): 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | identity = self.downsample(x) 80 | 81 | out += identity 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | """Bottleneck CNN block.""" 89 | 90 | expansion = 4 91 | 92 | def __init__( 93 | self, 94 | inplanes, 95 | planes, 96 | stride=1, 97 | downsample=None, 98 | groups=1, 99 | base_width=64, 100 | dilation=1, 101 | norm_layer=None, 102 | ): 103 | super(Bottleneck, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.InstanceNorm2d 106 | width = int(planes * (base_width / 64.0)) * groups 107 | self.conv1 = conv1x1(inplanes, width) 108 | self.bn1 = norm_layer(width, track_running_stats=False, affine=True) 109 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 110 | self.bn2 = norm_layer(width, track_running_stats=False, affine=True) 111 | self.conv3 = conv1x1(width, planes * self.expansion) 112 | self.bn3 = norm_layer( 113 | planes * self.expansion, track_running_stats=False, affine=True 114 | ) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.downsample = downsample 117 | self.stride = stride 118 | 119 | def forward(self, x): 120 | identity = x 121 | 122 | out = self.conv1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv2(out) 127 | out = self.bn2(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv3(out) 131 | out = self.bn3(out) 132 | 133 | if self.downsample is not None: 134 | identity = self.downsample(x) 135 | 136 | out += identity 137 | out = self.relu(out) 138 | 139 | return out 140 | 141 | 142 | class conv(nn.Module): 143 | """Convolutional layer.""" 144 | 145 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): 146 | super(conv, self).__init__() 147 | self.kernel_size = kernel_size 148 | self.conv = nn.Conv2d( 149 | num_in_layers, 150 | num_out_layers, 151 | kernel_size=kernel_size, 152 | stride=stride, 153 | padding=(self.kernel_size - 1) // 2, 154 | padding_mode='reflect', 155 | ) 156 | self.bn = nn.InstanceNorm2d( 157 | num_out_layers, track_running_stats=False, affine=True 158 | ) 159 | 160 | def forward(self, x): 161 | return F.elu(self.bn(self.conv(x)), inplace=True) 162 | 163 | 164 | class upconv(nn.Module): 165 | """Convolutional layers followed by upsampling.""" 166 | 167 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): 168 | super(upconv, self).__init__() 169 | self.scale = scale 170 | self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) 171 | 172 | def forward(self, x): 173 | x = nn.functional.interpolate( 174 | x, scale_factor=self.scale, align_corners=True, mode='bilinear' 175 | ) 176 | return self.conv(x) 177 | 178 | 179 | class ResNet(nn.Module): 180 | """Main ResNet based feature extractor.""" 181 | def __init__( 182 | self, 183 | encoder='resnet34', 184 | coarse_out_ch=32, 185 | fine_out_ch=32, 186 | norm_layer=None, 187 | coarse_only=False, 188 | ): 189 | super(ResNet, self).__init__() 190 | assert encoder in [ 191 | 'resnet18', 192 | 'resnet34', 193 | 'resnet50', 194 | 'resnet101', 195 | 'resnet152', 196 | ], 'Incorrect encoder type' 197 | if encoder in ['resnet18', 'resnet34']: 198 | filters = [64, 128, 256, 512] 199 | else: 200 | filters = [256, 512, 1024, 2048] 201 | self.coarse_only = coarse_only 202 | if self.coarse_only: 203 | fine_out_ch = 0 204 | self.coarse_out_ch = coarse_out_ch 205 | self.fine_out_ch = fine_out_ch 206 | out_ch = coarse_out_ch + fine_out_ch 207 | 208 | # original 209 | layers = [3, 4, 6, 3] 210 | if norm_layer is None: 211 | # norm_layer = nn.InstanceNorm2d 212 | norm_layer = nn.InstanceNorm2d 213 | self._norm_layer = norm_layer 214 | self.dilation = 1 215 | block = BasicBlock 216 | replace_stride_with_dilation = [False, False, False] 217 | self.inplanes = 64 218 | self.groups = 1 219 | self.base_width = 64 220 | self.conv1 = nn.Conv2d( 221 | 3, 222 | self.inplanes, 223 | kernel_size=7, 224 | stride=2, 225 | padding=3, 226 | bias=False, 227 | padding_mode='reflect', 228 | ) 229 | self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) 230 | self.relu = nn.ReLU(inplace=True) 231 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 232 | self.layer2 = self._make_layer( 233 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 234 | ) 235 | self.layer3 = self._make_layer( 236 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 237 | ) 238 | 239 | # decoder 240 | self.upconv3 = upconv(filters[2], 128, 3, 2) 241 | self.iconv3 = conv(filters[1] + 128, 128, 3, 1) 242 | self.upconv2 = upconv(128, 64, 3, 2) 243 | self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) 244 | 245 | # fine-level conv 246 | self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) 247 | 248 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 249 | norm_layer = self._norm_layer 250 | downsample = None 251 | previous_dilation = self.dilation 252 | if dilate: 253 | self.dilation *= stride 254 | stride = 1 255 | if stride != 1 or self.inplanes != planes * block.expansion: 256 | downsample = nn.Sequential( 257 | conv1x1(self.inplanes, planes * block.expansion, stride), 258 | norm_layer( 259 | planes * block.expansion, track_running_stats=False, affine=True 260 | ), 261 | ) 262 | 263 | layers = [] 264 | layers.append( 265 | block( 266 | self.inplanes, 267 | planes, 268 | stride, 269 | downsample, 270 | self.groups, 271 | self.base_width, 272 | previous_dilation, 273 | norm_layer, 274 | ) 275 | ) 276 | self.inplanes = planes * block.expansion 277 | for _ in range(1, blocks): 278 | layers.append( 279 | block( 280 | self.inplanes, 281 | planes, 282 | groups=self.groups, 283 | base_width=self.base_width, 284 | dilation=self.dilation, 285 | norm_layer=norm_layer, 286 | ) 287 | ) 288 | 289 | return nn.Sequential(*layers) 290 | 291 | def skipconnect(self, x1, x2): 292 | diffY = x2.size()[2] - x1.size()[2] 293 | diffX = x2.size()[3] - x1.size()[3] 294 | 295 | x1 = F.pad( 296 | x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2) 297 | ) 298 | 299 | x = torch.cat([x2, x1], dim=1) 300 | return x 301 | 302 | def forward(self, x): 303 | x = self.relu(self.bn1(self.conv1(x))) 304 | 305 | x1 = self.layer1(x) 306 | x_out = self.out_conv(x1) 307 | 308 | x_coarse = x_out[:, : self.coarse_out_ch, :] 309 | x_fine = x_out[:, -self.fine_out_ch :, :] 310 | 311 | return x_coarse, x_fine 312 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """function definition for config function.""" 2 | 3 | import configargparse 4 | 5 | 6 | def config_parser(): 7 | """Configuration function.""" 8 | parser = configargparse.ArgumentParser() 9 | # general 10 | parser.add_argument('--config', is_config_file=True, help='Config file path') 11 | parser.add_argument( 12 | '--rootdir', 13 | type=str, 14 | help=( 15 | 'The path to the project root directory. Replace this path with' 16 | ' yours!' 17 | ), 18 | ) 19 | parser.add_argument( 20 | '--folder_path', 21 | type=str, 22 | help=( 23 | 'The path to the input training data. Replace this path with yours.' 24 | ), 25 | ) 26 | 27 | parser.add_argument( 28 | '--coarse_dir', 29 | type=str, 30 | help=( 31 | 'The directory of coarse model.' 32 | ), 33 | ) 34 | 35 | parser.add_argument( 36 | '--mask_src_view', 37 | action='store_true', 38 | help=( 39 | 'Using motion segementation to mask src views for rendering static' 40 | ' model' 41 | ), 42 | ) 43 | parser.add_argument( 44 | '--training_height', type=int, default=288, help='Training image height' 45 | ) 46 | parser.add_argument('--expname', type=str, help='Experiment name') 47 | parser.add_argument( 48 | '--distributed', action='store_true', help='Use distributed training' 49 | ) 50 | parser.add_argument( 51 | '--local_rank', type=int, default=0, help='Rank for distributed training' 52 | ) 53 | parser.add_argument( 54 | '-j', 55 | '--workers', 56 | default=16, 57 | type=int, 58 | help='Number of data loading workers (default: 16)', 59 | ) 60 | 61 | parser.add_argument( 62 | '--mask_static', 63 | action='store_true', 64 | help='Using motion mask to mask source views for static model', 65 | ) 66 | 67 | ########## model options ########## 68 | parser.add_argument( 69 | '--N_rand', 70 | type=int, 71 | default=32 * 16, 72 | help='Batch size (number of random rays per gradient step)', 73 | ) 74 | parser.add_argument( 75 | '--sample_mode', 76 | type=str, 77 | default='uniform', 78 | help='How to sample pixels from images for training:uniform|center', 79 | ) 80 | parser.add_argument( 81 | '--lr_multipler', 82 | type=float, 83 | default=1.0, 84 | help='Learning rate ratio for training static component', 85 | ) 86 | parser.add_argument( 87 | '--num_vv', 88 | type=int, 89 | default=3, 90 | help='Number of virtual source views', 91 | ) 92 | parser.add_argument( 93 | '--cycle_factor', 94 | type=float, 95 | default=0.1, 96 | help='Cycle conssitency loss warmup factor', 97 | ) 98 | parser.add_argument( 99 | '--anneal_cycle', 100 | action='store_true', 101 | help='Bootstrap cycle consistency loss', 102 | ) 103 | parser.add_argument( 104 | '--erosion_radius', 105 | type=int, 106 | default=1, 107 | help='Mophorlogical erosion raidus for mask', 108 | ) 109 | parser.add_argument( 110 | '--decay_rate', 111 | type=float, 112 | default=10.0, 113 | help='Decaying rate for data-driven loss', 114 | ) 115 | 116 | ########## dataset options ########## 117 | parser.add_argument( 118 | '--eval_dataset', 119 | type=str, 120 | default='llff_test', 121 | help='The dataset to evaluate', 122 | ) 123 | parser.add_argument( 124 | '--eval_scenes', 125 | nargs='+', 126 | default=[], 127 | help='Optional, specify a subset of scenes from eval_dataset to evaluate', 128 | ) 129 | parser.add_argument( 130 | '--render_idx', type=int, default=-1, help='Frame index for rendering' 131 | ) 132 | parser.add_argument( 133 | '--train_dataset', 134 | type=str, 135 | default='ibrnet_collected', 136 | help=( 137 | 'the training dataset, should either be a single dataset, or multiple' 138 | ' datasets connected with "+", for example,' 139 | ' ibrnet_collected+llff+spaces' 140 | ), 141 | ) 142 | parser.add_argument( 143 | '--train_scenes', 144 | nargs='+', 145 | default=[], 146 | help=( 147 | 'optional, specify a subset of training scenes from training dataset' 148 | ), 149 | ) 150 | 151 | ## others 152 | parser.add_argument( 153 | '--init_decay_epoch', 154 | type=int, 155 | default=150, 156 | help='How many epochs to decay data driven losses', 157 | ) 158 | parser.add_argument( 159 | '--max_range', 160 | type=int, 161 | default=35, 162 | help='Max frame range to sample source views for static model', 163 | ) 164 | 165 | ########## model options ########## 166 | ## ray sampling options 167 | parser.add_argument( 168 | '--chunk_size', 169 | type=int, 170 | default=1024 * 4, 171 | help=( 172 | 'Number of rays processed in parallel, decrease if running out of' 173 | ' memory' 174 | ), 175 | ) 176 | ## model options 177 | parser.add_argument( 178 | '--coarse_feat_dim', 179 | type=int, 180 | default=32, 181 | help='2D feature dimension for coarse level', 182 | ) 183 | parser.add_argument( 184 | '--fine_feat_dim', 185 | type=int, 186 | default=32, 187 | help='2D feature dimension for fine level', 188 | ) 189 | parser.add_argument( 190 | '--num_source_views', 191 | type=int, 192 | default=7, 193 | help=( 194 | 'The number of input source views for each target view used in' 195 | 'static dynibar model' 196 | ), 197 | ) 198 | parser.add_argument( 199 | '--num_basis', 200 | type=int, 201 | default=6, 202 | help='The number of basis for motion trajectory', 203 | ) 204 | parser.add_argument( 205 | '--anti_alias_pooling', 206 | type=int, 207 | default=1, 208 | help='Use anti-alias pooling', 209 | ) 210 | parser.add_argument( 211 | '--mask_rgb', 212 | type=int, 213 | default=1, 214 | help=( 215 | 'Mask RGB features coresponding to black pixel for rendering from' 216 | ' static model' 217 | ), 218 | ) 219 | 220 | ########## checkpoints ########## 221 | parser.add_argument( 222 | '--no_reload', 223 | action='store_true', 224 | help='do not reload weights from saved ckpt', 225 | ) 226 | parser.add_argument( 227 | '--ckpt_path', 228 | type=str, 229 | default='', 230 | help='specific weights npy file to reload for coarse network', 231 | ) 232 | parser.add_argument( 233 | '--no_load_opt', 234 | action='store_true', 235 | help='do not load optimizer when reloading', 236 | ) 237 | parser.add_argument( 238 | '--no_load_scheduler', 239 | action='store_true', 240 | help='do not load scheduler when reloading', 241 | ) 242 | ########### iterations & learning rate options ########## 243 | parser.add_argument( 244 | '--n_iters', type=int, default=300000, help='Num of iterations' 245 | ) 246 | parser.add_argument( 247 | '--lrate_feature', 248 | type=float, 249 | default=1e-3, 250 | help='Learning rate for feature extractor', 251 | ) 252 | parser.add_argument( 253 | '--lrate_mlp', type=float, default=5e-4, help='Learning rate for mlp' 254 | ) 255 | parser.add_argument( 256 | '--lrate_decay_factor', 257 | type=float, 258 | default=0.5, 259 | help='Decay learning rate by a factor every specified number of steps', 260 | ) 261 | parser.add_argument( 262 | '--lrate_decay_steps', 263 | type=int, 264 | default=50000, 265 | help='Decay learning rate by a factor every number of steps', 266 | ) 267 | parser.add_argument( 268 | '--w_cycle', 269 | type=float, 270 | default=0.1, 271 | help='Weight of cycle consistency loss', 272 | ) 273 | parser.add_argument( 274 | '--w_distortion', 275 | type=float, 276 | default=1e-3, 277 | help='Weight of distortion loss', 278 | ) 279 | parser.add_argument( 280 | '--w_entropy', type=float, default=0.0, help='Weight of entropy loss' 281 | ) 282 | parser.add_argument( 283 | '--w_disp', type=float, default=5e-2, help='Weight of disparty loss' 284 | ) 285 | parser.add_argument( 286 | '--w_flow', type=float, default=5e-3, help='Weight of flow loss' 287 | ) 288 | parser.add_argument( 289 | '--w_skew_entropy', 290 | type=float, 291 | default=1e-3, 292 | help='Weight of entropy loss, assuming there is no skewness.', 293 | ) 294 | parser.add_argument( 295 | '--w_reg', type=float, default=0.05, help='Weight of regularization loss' 296 | ) 297 | parser.add_argument( 298 | '--pretrain_path', type=str, default='', help='Pretrained model path' 299 | ) 300 | parser.add_argument( 301 | '--occ_weights_mode', 302 | type=int, 303 | default=0, 304 | help=( 305 | 'Occlusion weight mode during cross-time rendering. 0: mix two models' 306 | ' weights. 1: using weight from dynamic model only 2: using weight' 307 | ' composited from static and dynamic models. ' 308 | ), 309 | ) 310 | 311 | ########## rendering options ########## 312 | parser.add_argument( 313 | '--N_samples', 314 | type=int, 315 | default=64, 316 | help='Number of coarse samples per ray', 317 | ) 318 | parser.add_argument( 319 | '--N_importance', 320 | type=int, 321 | default=64, 322 | help=( 323 | 'Number of fine samples per ray. total number of samples is the sum' 324 | ' of coarse plus fine models' 325 | ), 326 | ) 327 | parser.add_argument( 328 | '--inv_uniform', 329 | action='store_true', 330 | help='If True, uniformly sample in inverse depth space', 331 | ) 332 | parser.add_argument( 333 | '--input_dir', 334 | action='store_true', 335 | help='If True, input global directional with positional encoding', 336 | ) 337 | parser.add_argument( 338 | '--input_xyz', 339 | action='store_true', 340 | help='If True, input global xyz with positional encoding', 341 | ) 342 | parser.add_argument( 343 | '--det', 344 | action='store_true', 345 | help='Deterministic sampling for coarse and fine samples', 346 | ) 347 | parser.add_argument( 348 | '--white_bkgd', 349 | action='store_true', 350 | help='Apply the trick to avoid fitting to white background', 351 | ) 352 | parser.add_argument( 353 | '--render_stride', 354 | type=int, 355 | default=1, 356 | help='Render with large stride for validation to save time', 357 | ) 358 | ########## logging/saving options ########## 359 | parser.add_argument( 360 | '--i_print', type=int, default=100, help='Frequency of terminal printout' 361 | ) 362 | parser.add_argument( 363 | '--i_img', 364 | type=int, 365 | default=1000, 366 | help='Frequency of tensorboard image logging', 367 | ) 368 | parser.add_argument( 369 | '--i_weights', 370 | type=int, 371 | default=10000, 372 | help='Frequency of weight ckpt saving', 373 | ) 374 | 375 | return parser 376 | -------------------------------------------------------------------------------- /render_source_vv.py: -------------------------------------------------------------------------------- 1 | """Rendering virutal source views from video depth, used for monocular video.""" 2 | 3 | import argparse 4 | import glob 5 | import os 6 | 7 | import cv2 8 | import imageio.v2 as imageio 9 | import kornia 10 | import numpy as np 11 | import skimage.morphology 12 | from splatting import splatting_function 13 | import torch 14 | 15 | def render_forward_splat(src_imgs, src_depths, r_cam, t_cam, k_src, k_dst): 16 | '''Point cloud rendering from RGBD images.''' 17 | batch_size = src_imgs.shape[0] 18 | 19 | rot = r_cam 20 | t = t_cam 21 | k_src_inv = k_src.inverse() 22 | 23 | x = np.arange(src_imgs[0].shape[1]) 24 | y = np.arange(src_imgs[0].shape[0]) 25 | coord = np.stack(np.meshgrid(x, y), -1) 26 | coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) 27 | coord = coord.astype(np.float32) 28 | coord = torch.as_tensor(coord, dtype=k_src.dtype, device=k_src.device) 29 | coord = coord[None, ..., None].repeat(batch_size, 1, 1, 1, 1) 30 | 31 | depth = src_depths[:, :, :, None, None] 32 | 33 | # from reference to target viewpoint 34 | pts_3d_ref = depth * k_src_inv[:, None, None, ...] @ coord 35 | pts_3d_tgt = rot[:, None, None, ...] @ pts_3d_ref + t[:, None, None, :, None] 36 | points = k_dst[:, None, None, ...] @ pts_3d_tgt 37 | points = points.squeeze(-1) 38 | 39 | new_z = points[:, :, :, [2]].clone().permute(0, 3, 1, 2) # b,1,h,w 40 | points = points / torch.clamp(points[:, :, :, [2]], 1e-8, None) 41 | 42 | src_ims_ = src_imgs.permute(0, 3, 1, 2) 43 | num_channels = src_ims_.shape[1] 44 | 45 | flow = points - coord.squeeze(-1) 46 | flow = flow.permute(0, 3, 1, 2)[:, :2, ...] 47 | 48 | importance = 1.0 / (new_z) 49 | importance_min = importance.amin((1, 2, 3), keepdim=True) 50 | importance_max = importance.amax((1, 2, 3), keepdim=True) 51 | weights = (importance - importance_min) / ( 52 | importance_max - importance_min + 1e-6 53 | ) * 20 - 10 54 | src_mask_ = torch.ones_like(new_z) 55 | 56 | input_data = torch.cat([src_ims_, (1.0 / (new_z)), src_mask_], 1) 57 | 58 | output_data = splatting_function( 59 | 'softmax', input_data.cuda(), flow.cuda(), weights.detach().cuda() 60 | ) 61 | 62 | warp_feature = output_data[:, 0:num_channels, ...] 63 | warp_disp = output_data[:, num_channels : num_channels + 1, ...] 64 | # warp_mask = output_data[:, num_channels + 1 : num_channels + 2, ...] 65 | 66 | return warp_feature, warp_disp#, warp_mask 67 | 68 | def render_wander_path(c2w, hwf, bd_scale, max_disp_=50, xyz=[1, 0, 1]): 69 | """Render nearby virtual source views with displacement in x and z direciton.""" 70 | num_frames = 60 71 | max_disp = max_disp_ * bd_scale 72 | max_trans = ( 73 | max_disp / hwf[2][0] 74 | ) 75 | output_poses = [] 76 | 77 | for i in range(num_frames): 78 | 79 | x_trans = max_trans * np.cos( 80 | 2.0 * np.pi * float(i) / float(num_frames) 81 | ) * xyz[0] 82 | y_trans = max_trans * np.sin( 83 | 2.0 * np.pi * float(i) / float(num_frames) 84 | ) * xyz[1] 85 | z_trans = max_trans * np.cos( 86 | 2.0 * np.pi * float(i) / float(num_frames) 87 | ) * xyz[2] 88 | 89 | i_pose = np.concatenate( 90 | [ 91 | np.concatenate( 92 | [ 93 | np.eye(3), 94 | np.array([x_trans, y_trans, z_trans])[:, np.newaxis], 95 | ], 96 | axis=1, 97 | ), 98 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :], 99 | ], 100 | axis=0, 101 | ) 102 | 103 | i_pose = np.linalg.inv( 104 | i_pose 105 | ) # torch.tensor(np.linalg.inv(i_pose)).float() 106 | 107 | ref_pose = np.concatenate( 108 | [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0 109 | ) 110 | 111 | render_pose = np.dot(ref_pose, i_pose) 112 | 113 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1)) 114 | 115 | return np.array(output_poses + output_poses), num_frames 116 | 117 | 118 | def sobel_fg_alpha(disp, mode='sobel', beta=10.0): 119 | """Create depth boundary mask.""" 120 | sobel_grad = kornia.filters.spatial_gradient( 121 | disp, mode=mode, normalized=False 122 | ) 123 | sobel_mag = torch.sqrt( 124 | sobel_grad[:, :, 0, ...] ** 2 + sobel_grad[:, :, 1, ...] ** 2 125 | ) 126 | alpha = torch.exp(-1.0 * beta * sobel_mag).detach() 127 | 128 | return alpha 129 | 130 | 131 | FINAL_H = 288 132 | USE_DPT = True 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser() 136 | # parser.add_argument("--scene_name", type=str, 137 | # help='Scene name') # 'kid-running' 138 | parser.add_argument("--data_dir", type=str, 139 | help='data directory') # '/home/zhengqili/filestore/NSFF/nerf_data/release' 140 | parser.add_argument("--cvd_dir", type=str, 141 | help='video depth directory') # '/home/zhengqili/filestore/dynamic-video-DPT/monocular-results/kid-runningscene_flow_motion_field_shutterstock_epoch_15/epoch0015_test' 142 | 143 | args = parser.parse_args() 144 | 145 | data_path = os.path.join( 146 | args.data_dir, 'dense' 147 | ) 148 | 149 | pt_out_list = sorted( 150 | glob.glob( 151 | os.path.join( 152 | args.cvd_dir, 153 | '*.npz', 154 | ) 155 | ) 156 | ) 157 | 158 | try: 159 | original_img_path = os.path.join(data_path, 'images', '00000.png') 160 | o_img = imageio.imread(original_img_path) 161 | except: 162 | original_img_path = os.path.join(data_path, 'images', '00000.jpg') 163 | o_img = imageio.imread(original_img_path) 164 | 165 | o_ar = float(o_img.shape[1]) / float(o_img.shape[0]) 166 | 167 | final_w, final_h = int(round(FINAL_H * o_ar)), int(FINAL_H) 168 | 169 | save_dir = os.path.join( 170 | data_path, 'source_virtual_views_%dx%d' % (final_w, final_h) 171 | ) 172 | os.makedirs(save_dir, exist_ok=True) 173 | 174 | Ks = [] 175 | mono_depths = [] 176 | c2w_mats = [] 177 | imgs = [] 178 | bounds_mats = [] 179 | points_cloud = [] 180 | 181 | for i in range(0, len(pt_out_list)): 182 | pt_out_path = pt_out_list[i] 183 | out_name = pt_out_path.split('/')[-1] 184 | pt_data = np.load(pt_out_path) 185 | pred_depth = pt_data['depth'][0, 0, ...] 186 | cam_c2w = pt_data['cam_c2w'][0] 187 | img = pt_data['img_1'][0].transpose(1, 2, 0) 188 | 189 | c2w_mats.append(cam_c2w) 190 | bounds_mats.append(np.percentile(pred_depth, 5)) 191 | K = pt_data['K'][0, 0, 0, ...].transpose() 192 | K[0, :] *= final_w / img.shape[1] 193 | K[1, :] *= final_h / img.shape[0] 194 | 195 | h, w, fx, fy = final_h, final_w, K[0, 0], K[1, 1] 196 | ff = (fx + fy) / 2.0 197 | # hwf = np.array([h, w, fx, fy]).reshape([1, 4]) 198 | hwf = np.array([h, w, ff]).reshape([3, 1]) 199 | 200 | c2w_mats = np.stack(c2w_mats, 0) 201 | bounds_mats = np.stack(bounds_mats, 0) 202 | 203 | bd_scale = bounds_mats.min() * 0.75 204 | 205 | poses = c2w_mats[:, :3, :4].transpose([1, 2, 0]) 206 | 207 | # must switch to [-y, x, z] from [x, -y, -z], NOT [r, u, -t] 208 | poses = np.concatenate( 209 | [poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:4, :]], 210 | 1, 211 | ) 212 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 213 | 214 | num_samples = 4 215 | vv_poses_final = np.zeros((poses.shape[0], num_samples * 2, 3, 4)) 216 | 217 | for ii in range(poses.shape[0]): 218 | print(ii) 219 | virtural_poses_0, num_render_0 = render_wander_path( 220 | poses[ii], hwf, bd_scale, 56 * 1.5, 221 | xyz=[0., 1., 1.] # y, x, z 222 | ) 223 | virtural_poses_1, num_render_1 = render_wander_path( 224 | poses[ii], hwf, bd_scale, 48 * 1.5, 225 | xyz=[0.5, 1., 0.] 226 | ) 227 | # this is for fixed viewpoint! 228 | start_idx = np.random.randint(0, num_render_0 // num_samples) 229 | 230 | vv_poses_final[ii, :num_samples, ...] = virtural_poses_0[ 231 | 5 : -1 : int(num_render_0 // num_samples) 232 | ][:num_samples, :3, :4] 233 | vv_poses_final[ii, num_samples:, ...] = virtural_poses_1[ 234 | 15 : -1 : int(num_render_1 // num_samples) 235 | ][:num_samples, :3, :4] 236 | 237 | np.save( 238 | os.path.join(data_path, 'source_vv_poses.npy'), 239 | np.moveaxis(vv_poses_final, 0, -1).astype(np.float32), 240 | ) 241 | 242 | # switch back 243 | c2w_mats_vsv = np.concatenate( 244 | [ 245 | vv_poses_final[..., 1:2], 246 | vv_poses_final[..., 0:1], 247 | -vv_poses_final[..., 2:3], 248 | vv_poses_final[..., 3:4], 249 | ], 250 | -1, 251 | ) 252 | 253 | for i in range(0, len(pt_out_list)): 254 | save_sub_dir = os.path.join(save_dir, '%05d' % i) 255 | print(save_sub_dir) 256 | os.makedirs(save_sub_dir, exist_ok=True) 257 | pt_out_path = pt_out_list[i] 258 | 259 | out_name = pt_out_path.split('/')[-1] 260 | pt_data = np.load(pt_out_path) 261 | 262 | K = pt_data['K'][0, 0, 0, ...].transpose() 263 | img = pt_data['img_1'][0].transpose(1, 2, 0) 264 | cam_ref2w = pt_data['cam_c2w'][0] 265 | pred_depth = pt_data['depth'][0, 0, ...] 266 | pred_disp = 1.0 / pred_depth 267 | 268 | K[0, :] *= final_w / img.shape[1] 269 | K[1, :] *= final_h / img.shape[0] 270 | 271 | print('K ', K) 272 | assert abs(K[0, 0] - K[1, 1]) / abs(K[0, 0] + K[1, 1]) < 0.005 273 | 274 | pred_depth_ = cv2.resize( 275 | pred_depth, (final_w, final_h), interpolation=cv2.INTER_NEAREST 276 | ) 277 | 278 | img = cv2.resize(img, (final_w, final_h), interpolation=cv2.INTER_AREA) 279 | pred_disp = cv2.resize( 280 | pred_disp, (final_w, final_h), interpolation=cv2.INTER_LINEAR 281 | ) 282 | 283 | mode = 'sobel' 284 | beta = 0.5 285 | pred_depth = 1.0 / torch.from_numpy(pred_disp[None, None, ...]) 286 | pred_depth = pred_depth / 10.0 287 | cur_alpha = sobel_fg_alpha(pred_depth, mode, beta=beta)[ 288 | 0, 0, ..., None 289 | ].numpy() 290 | 291 | for k in range(num_samples * 2): 292 | # render source view into target viewpoint 293 | rgba_pt = torch.from_numpy( 294 | np.concatenate( 295 | [np.array(img * 255.0), cur_alpha], axis=-1 296 | ) 297 | )[None].float() 298 | disp_pt = torch.from_numpy(np.array(pred_disp))[ 299 | None 300 | ].float() 301 | cam_tgt2w = np.eye(4) 302 | cam_tgt2w[:3, :4] = c2w_mats_vsv[i, k] 303 | T_ref2tgt = np.dot(np.linalg.inv(cam_tgt2w), cam_ref2w) 304 | 305 | fwd_rot = torch.from_numpy(T_ref2tgt[:3, :3])[None].float() 306 | fwd_t = torch.from_numpy(T_ref2tgt[:3, 3])[None].float() # * metric_scale 307 | k_ref = torch.from_numpy(np.array(K))[None].float() 308 | 309 | render_rgba, render_depth = render_forward_splat( 310 | rgba_pt, 1.0 / disp_pt, fwd_rot, fwd_t, k_src=k_ref, k_dst=k_ref 311 | ) 312 | 313 | render_rgb = np.clip( 314 | render_rgba[0, :3, ...].cpu().numpy().transpose(1, 2, 0) / 255.0, 315 | 0.0, 316 | 1.0, 317 | ) 318 | mask = np.clip( 319 | render_rgba[0, 3:4, ...].cpu().numpy().transpose(1, 2, 0), 0.0, 1.0 320 | ) 321 | mask = skimage.morphology.erosion( 322 | mask[..., 0] > 0.5, skimage.morphology.disk(1) 323 | ) 324 | 325 | render_rgb_masked = render_rgb * mask[..., None] 326 | h, w = render_rgb_masked.shape[:2] 327 | imageio.imsave( 328 | os.path.join(save_sub_dir, '%02d.png' % k), 329 | np.uint8(255 * np.clip(render_rgb_masked, 0.0, 1.0)), 330 | ) 331 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /ibrnet/sample_ray.py: -------------------------------------------------------------------------------- 1 | """Utility class for sampling data corresponding to rays from images.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from kornia import create_meshgrid 7 | 8 | rng = np.random.RandomState(234) 9 | 10 | 11 | def parse_camera(params): 12 | H = params[:, 0] 13 | W = params[:, 1] 14 | intrinsics = params[:, 2:18].reshape((-1, 4, 4)) 15 | c2w = params[:, 18:34].reshape((-1, 4, 4)) 16 | return W, H, intrinsics, c2w 17 | 18 | 19 | class RaySamplerSingleImage(object): 20 | """Sampling data corresponding to the rays from a target view. 21 | 22 | This class stores and returns following items at sampled pixel locations 23 | for training Dynibar: 24 | ray_o: ray origin at target view 25 | ray_d: ray direction at target view 26 | depth_range: scene depth bounds at target view 27 | camera: reference time camera parameters 28 | render_camera: rendered target view camera parameters 29 | anchor_camera: camera parameters for input view at nearby cross time 30 | rgb: image at reference time 31 | src_rgbs: source view images w.r.t reference time for dynamic model 32 | src_cameras: source view camera parameters w.r.t reference time for 33 | dynamic model 34 | anchor_src_rgbs: source view images w.r.t nearby cross time for dynamic 35 | model. 36 | anchor_src_cameras: source view camera parameters w.r.t 37 | nearby cross time for dynamic model. 38 | static_src_rgbs: source view images for static model 39 | static_src_cameras: source view camera parameters for static model 40 | static_src_masks: dynamic masks of source views for static model 41 | disp: disparity map 42 | motion_mask: dynamic mask 43 | static_mask: static masks 44 | uv_grid: 2D pixel coorindate in image space 45 | flows: observed 2D optical flows 46 | masks: optical flow vailid masks 47 | """ 48 | 49 | def __init__(self, data, device, resize_factor=1, render_stride=1): 50 | super().__init__() 51 | self.render_stride = render_stride 52 | self.rgb = data['rgb'] if 'rgb' in data.keys() else None 53 | self.disp = data['disp'] if 'disp' in data.keys() else None 54 | 55 | self.motion_mask = ( 56 | data['motion_mask'] if 'motion_mask' in data.keys() else None 57 | ) 58 | 59 | self.static_mask = ( 60 | data['static_mask'] if 'static_mask' in data.keys() else None 61 | ) 62 | 63 | self.flows = data['flows'].squeeze(0) if 'flows' in data.keys() else None 64 | self.masks = data['masks'].squeeze(0) if 'masks' in data.keys() else None 65 | 66 | self.camera = data['camera'] 67 | self.render_camera = ( 68 | data['render_camera'] if 'render_camera' in data.keys() else None 69 | ) 70 | 71 | self.anchor_camera = ( 72 | data['anchor_camera'] if 'anchor_camera' in data.keys() else None 73 | ) 74 | self.rgb_path = data['rgb_path'] 75 | self.depth_range = data['depth_range'] 76 | self.device = device 77 | W, H, self.intrinsics, self.c2w_mat = parse_camera(self.camera) 78 | 79 | self.batch_size = len(self.camera) 80 | 81 | self.H = int(H[0]) 82 | self.W = int(W[0]) 83 | self.uv_grid = create_meshgrid( 84 | self.H, self.W, normalized_coordinates=False 85 | )[0].to( 86 | self.device 87 | ) # (H, W, 2) 88 | 89 | self.rays_o, self.rays_d = self.get_rays_single_image( 90 | self.H, self.W, self.intrinsics, self.c2w_mat 91 | ) 92 | 93 | if self.rgb is not None: 94 | self.rgb = self.rgb.reshape(-1, 3) 95 | 96 | if self.disp is not None: 97 | self.disp = self.disp.reshape(-1, 1) 98 | 99 | if self.motion_mask is not None: 100 | self.motion_mask = self.motion_mask.reshape(-1, 1) 101 | 102 | if self.static_mask is not None: 103 | self.static_mask = self.static_mask.reshape(-1, 1) 104 | 105 | if self.flows is not None: 106 | self.flows = self.flows.reshape(self.flows.shape[0], -1, 2) 107 | self.masks = self.masks.reshape(self.masks.shape[0], -1, 1) 108 | 109 | self.uv_grid = self.uv_grid.reshape(-1, 2) 110 | 111 | if 'src_rgbs' in data.keys(): 112 | self.src_rgbs = data['src_rgbs'] 113 | else: 114 | self.src_rgbs = None 115 | 116 | if 'src_cameras' in data.keys(): 117 | self.src_cameras = data['src_cameras'] 118 | else: 119 | self.src_cameras = None 120 | 121 | self.anchor_src_rgbs = ( 122 | data['anchor_src_rgbs'] if 'anchor_src_rgbs' in data.keys() else None 123 | ) 124 | self.anchor_src_cameras = ( 125 | data['anchor_src_cameras'] 126 | if 'anchor_src_cameras' in data.keys() 127 | else None 128 | ) 129 | 130 | self.static_src_rgbs = ( 131 | data['static_src_rgbs'] if 'static_src_rgbs' in data.keys() else None 132 | ) 133 | self.static_src_cameras = ( 134 | data['static_src_cameras'] 135 | if 'static_src_cameras' in data.keys() 136 | else None 137 | ) 138 | self.static_src_masks = ( 139 | data['static_src_masks'] if 'static_src_masks' in data.keys() else None 140 | ) 141 | 142 | 143 | def get_rays_single_image(self, H, W, intrinsics, c2w): 144 | """Return ray parameters (origin, direction) from a target view.""" 145 | u, v = np.meshgrid( 146 | np.arange(W)[:: self.render_stride], np.arange(H)[:: self.render_stride] 147 | ) 148 | u = u.reshape(-1).astype(dtype=np.float32) 149 | v = v.reshape(-1).astype(dtype=np.float32) 150 | pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W) 151 | pixels = torch.from_numpy(pixels) 152 | batched_pixels = pixels.unsqueeze(0).repeat(self.batch_size, 1, 1) 153 | 154 | rays_d = ( 155 | c2w[:, :3, :3] 156 | .bmm(torch.inverse(intrinsics[:, :3, :3])) 157 | .bmm(batched_pixels) 158 | ).transpose(1, 2) 159 | rays_d = rays_d.reshape(-1, 3) 160 | rays_o = ( 161 | c2w[:, :3, 3].unsqueeze(1).repeat(1, rays_d.shape[0], 1).reshape(-1, 3) 162 | ) # B x HW x 3 163 | return rays_o, rays_d 164 | 165 | def get_all(self): 166 | """Return all camera and ray information from a target view.""" 167 | ret = { 168 | 'ray_o': self.rays_o.to(self.device), 169 | 'ray_d': self.rays_d.to(self.device), 170 | 'depth_range': self.depth_range.to(self.device), 171 | 'camera': self.camera.to(self.device), 172 | 'render_camera': ( 173 | self.render_camera.to(self.device) 174 | if self.render_camera is not None 175 | else None 176 | ), 177 | 'anchor_camera': ( 178 | self.anchor_camera.to(self.device) 179 | if self.anchor_camera is not None 180 | else None 181 | ), 182 | 'rgb': self.rgb.to(self.device) if self.rgb is not None else None, 183 | 'src_rgbs': ( 184 | self.src_rgbs.to(self.device) if self.src_rgbs is not None else None 185 | ), 186 | 'src_cameras': ( 187 | self.src_cameras.to(self.device) 188 | if self.src_cameras is not None 189 | else None 190 | ), 191 | 'anchor_src_rgbs': ( 192 | self.anchor_src_rgbs.to(self.device) 193 | if self.anchor_src_rgbs is not None 194 | else None 195 | ), 196 | 'anchor_src_cameras': ( 197 | self.anchor_src_cameras.to(self.device) 198 | if self.anchor_src_cameras is not None 199 | else None 200 | ), 201 | 'static_src_rgbs': ( 202 | self.static_src_rgbs.to(self.device) 203 | if self.static_src_rgbs is not None 204 | else None 205 | ), 206 | 'static_src_cameras': ( 207 | self.static_src_cameras.to(self.device) 208 | if self.static_src_cameras is not None 209 | else None 210 | ), 211 | 'static_src_masks': ( 212 | self.static_src_masks.to(self.device) 213 | if self.static_src_masks is not None 214 | else None 215 | ), 216 | 'disp': ( 217 | self.disp.to(self.device).squeeze() 218 | if self.disp is not None 219 | else None 220 | ), 221 | 'motion_mask': ( 222 | self.motion_mask.to(self.device).squeeze() 223 | if self.motion_mask is not None 224 | else None 225 | ), 226 | 'static_mask': ( 227 | self.static_mask.to(self.device).squeeze() 228 | if self.static_mask is not None 229 | else None 230 | ), 231 | 'uv_grid': self.uv_grid.to(self.device), 232 | 'flows': self.flows.to(self.device) if self.flows is not None else None, 233 | 'masks': self.masks.to(self.device) if self.masks is not None else None, 234 | } 235 | return ret 236 | 237 | def sample_random_pixel(self, N_rand, sample_mode, center_ratio=0.8): 238 | """Sample pixel randomly from the target view.""" 239 | if sample_mode == 'center': 240 | border_H = int(self.H * (1 - center_ratio) / 2.0) 241 | border_W = int(self.W * (1 - center_ratio) / 2.0) 242 | 243 | # pixel coordinates 244 | u, v = np.meshgrid( 245 | np.arange(border_H, self.H - border_H), 246 | np.arange(border_W, self.W - border_W), 247 | ) 248 | u = u.reshape(-1) 249 | v = v.reshape(-1) 250 | 251 | select_inds = rng.choice(u.shape[0], size=(N_rand,), replace=False) 252 | select_inds = v[select_inds] + self.W * u[select_inds] 253 | 254 | elif sample_mode == 'uniform': 255 | # Random from one image 256 | select_inds = rng.choice(self.H*self.W, size=(N_rand,), replace=False) 257 | else: 258 | raise NotImplementedError 259 | 260 | return select_inds 261 | 262 | def random_sample(self, N_rand, sample_mode, center_ratio=0.8): 263 | """Randomly sample pixel and pixel data from the target view.""" 264 | select_inds = self.sample_random_pixel(N_rand, sample_mode, center_ratio) 265 | 266 | rays_o = self.rays_o[select_inds] 267 | rays_d = self.rays_d[select_inds] 268 | 269 | if self.rgb is not None: 270 | rgb = self.rgb[select_inds] 271 | disp = self.disp[select_inds].squeeze() 272 | motion_mask = self.motion_mask[select_inds].squeeze() 273 | static_mask = self.static_mask[select_inds].squeeze() 274 | 275 | flows = self.flows[:, select_inds, :] 276 | masks = self.masks[:, select_inds, :] 277 | 278 | uv_grid = self.uv_grid[select_inds] 279 | 280 | else: 281 | raise NotImplementedError 282 | 283 | ret = { 284 | 'ray_o': rays_o.to(self.device), 285 | 'ray_d': rays_d.to(self.device), 286 | 'camera': self.camera.to(self.device), 287 | 'anchor_camera': self.anchor_camera.to(self.device), 288 | 'depth_range': self.depth_range.to(self.device), 289 | 'rgb': rgb.to(self.device) if rgb is not None else None, 290 | 'disp': disp.to(self.device), 291 | 'motion_mask': motion_mask.to(self.device), 292 | 'static_mask': static_mask.to(self.device), 293 | 'uv_grid': uv_grid.to(self.device), 294 | 'flows': flows.to(self.device), 295 | 'masks': masks.to(self.device), 296 | 'src_rgbs': ( 297 | self.src_rgbs.to(self.device) if self.src_rgbs is not None else None 298 | ), 299 | 'src_cameras': ( 300 | self.src_cameras.to(self.device) 301 | if self.src_cameras is not None 302 | else None 303 | ), 304 | 'static_src_rgbs': ( 305 | self.static_src_rgbs.to(self.device) 306 | if self.static_src_rgbs is not None 307 | else None 308 | ), 309 | 'static_src_cameras': ( 310 | self.static_src_cameras.to(self.device) 311 | if self.static_src_cameras is not None 312 | else None 313 | ), 314 | 'static_src_masks': ( 315 | self.static_src_masks.to(self.device) 316 | if self.static_src_masks is not None 317 | else None 318 | ), 319 | 'anchor_src_rgbs': ( 320 | self.anchor_src_rgbs.to(self.device) 321 | if self.anchor_src_rgbs is not None 322 | else None 323 | ), 324 | 'anchor_src_cameras': ( 325 | self.anchor_src_cameras.to(self.device) 326 | if self.anchor_src_cameras is not None 327 | else None 328 | ), 329 | 'selected_inds': select_inds, 330 | } 331 | return ret 332 | -------------------------------------------------------------------------------- /render_monocular_bt.py: -------------------------------------------------------------------------------- 1 | """Script to render novel views from pretrained model.""" 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | import imageio.v2 as imageio 7 | from config import config_parser 8 | from ibrnet.sample_ray import RaySamplerSingleImage 9 | from ibrnet.render_image import render_single_image_mono 10 | from ibrnet.model import DynibarMono 11 | from ibrnet.projection import Projector 12 | from ibrnet.data_loaders.data_utils import get_nearest_pose_ids 13 | from ibrnet.data_loaders.data_utils import get_interval_pose_ids 14 | from ibrnet.data_loaders.llff_data_utils import load_mono_data 15 | from ibrnet.data_loaders.llff_data_utils import batch_parse_llff_poses 16 | from ibrnet.data_loaders.llff_data_utils import batch_parse_vv_poses 17 | import time 18 | import os 19 | import numpy as np 20 | import cv2 21 | 22 | 23 | class DynamicVideoDataset(Dataset): 24 | """Class for defining monocular video data. 25 | 26 | Attributes: 27 | folder_path: root path 28 | num_source_views: number of source views to sample 29 | mask_src_view: using mask to mask moving objects 30 | render_idx: rendering frame index 31 | max_range: max sampling frame range 32 | render_rgb_files: rendering RGB file path 33 | render_intrinsics: rendering camera intrinsics 34 | render_poses: rendering camera poses 35 | render_depth_range: rendering depth bounds 36 | h: image height 37 | w: image width 38 | train_intrinsics: training camera intrinisc 39 | train_poses: training camera poses 40 | train_rgb_files: training RGB path 41 | num_frames: number of video frames 42 | src_vv_c2w_mats: virtual views camera matrix 43 | """ 44 | 45 | def __init__(self, args, scenes, **kwargs): 46 | self.folder_path = ( 47 | args.folder_path 48 | ) 49 | self.num_source_views = args.num_source_views 50 | self.mask_src_view = args.mask_src_view 51 | self.render_idx = args.render_idx 52 | self.max_range = args.max_range 53 | self.num_vv = args.num_vv 54 | print('num_source_views ', self.num_source_views) 55 | print('loading {} for rendering'.format(scenes)) 56 | assert len(scenes) == 1 57 | 58 | scene = scenes[0] 59 | # for i, scene in enumerate(scenes): 60 | scene_path = os.path.join(self.folder_path, scene, 'dense') 61 | _, poses, src_vv_poses, bds, render_poses, _, rgb_files, _ = ( 62 | load_mono_data( 63 | scene_path, 64 | height=args.training_height, 65 | render_idx=self.render_idx, 66 | load_imgs=False, 67 | ) 68 | ) 69 | near_depth = np.min(bds) 70 | 71 | if np.max(bds) < 10: 72 | far_depth = min(50, np.max(bds) + 15.0) 73 | else: 74 | far_depth = min(50, max(20, np.max(bds))) 75 | 76 | self.num_frames = len(rgb_files) 77 | 78 | intrinsics, c2w_mats = batch_parse_llff_poses(poses) 79 | h, w = poses[0][:2, -1] 80 | render_intrinsics, render_c2w_mats = batch_parse_llff_poses(render_poses) 81 | self.src_vv_c2w_mats = batch_parse_vv_poses(src_vv_poses) 82 | 83 | self.train_intrinsics = intrinsics 84 | self.train_poses = c2w_mats 85 | self.train_rgb_files = rgb_files 86 | 87 | self.render_intrinsics = render_intrinsics 88 | self.render_poses = render_c2w_mats 89 | self.render_depth_range = [[near_depth, far_depth]] * self.num_frames 90 | self.h = [int(h)] * self.num_frames 91 | self.w = [int(w)] * self.num_frames 92 | 93 | def __len__(self): 94 | return len(self.render_poses) 95 | 96 | def __getitem__(self, idx): 97 | render_pose = self.render_poses[idx] 98 | intrinsics = self.render_intrinsics[idx] 99 | depth_range = self.render_depth_range[idx] 100 | 101 | train_rgb_files = self.train_rgb_files 102 | train_poses = self.train_poses 103 | train_intrinsics = self.train_intrinsics 104 | 105 | rgb_file = train_rgb_files[idx] 106 | rgb = imageio.imread(rgb_file).astype(np.float32) / 255.0 107 | 108 | h, w = self.h[idx], self.w[idx] 109 | camera = np.concatenate( 110 | ([h, w], intrinsics.flatten(), render_pose.flatten()) 111 | ).astype(np.float32) 112 | 113 | nearest_pose_ids = np.sort( 114 | [self.render_idx + offset for offset in [1, 2, 3, 0, -1, -2, -3]] 115 | ) 116 | sp_pose_ids = get_nearest_pose_ids( 117 | render_pose, train_poses, tar_id=-1, angular_dist_method='dist' 118 | ) 119 | 120 | static_pose_ids = [] 121 | frame_interval = args.max_range // self.num_source_views 122 | interval_pose_ids = get_interval_pose_ids( 123 | render_pose, 124 | train_poses, 125 | tar_id=-1, 126 | angular_dist_method='dist', 127 | interval=frame_interval, 128 | ) 129 | 130 | for sp_pose_id in interval_pose_ids: 131 | if len(static_pose_ids) >= (self.num_source_views * 2 + 1): 132 | break 133 | 134 | if np.abs(sp_pose_id - self.render_idx) > ( 135 | self.max_range + self.num_source_views * 0.5 136 | ): 137 | continue 138 | 139 | static_pose_ids.append(sp_pose_id) 140 | 141 | static_pose_set = set(static_pose_ids) 142 | 143 | # if there is no sufficient src imgs, naively choose the closest images 144 | for sp_pose_id in sp_pose_ids[::5]: 145 | if len(static_pose_ids) >= (self.num_source_views * 2 + 1): 146 | break 147 | 148 | if sp_pose_id in static_pose_set: 149 | continue 150 | 151 | static_pose_ids.append(sp_pose_id) 152 | 153 | static_pose_ids = np.sort(static_pose_ids) 154 | 155 | assert len(static_pose_ids) == (self.num_source_views * 2 + 1) 156 | 157 | src_rgbs = [] 158 | src_cameras = [] 159 | for src_idx in nearest_pose_ids: 160 | src_rgb = ( 161 | imageio.imread(train_rgb_files[src_idx]).astype(np.float32) / 255.0 162 | ) 163 | train_pose = train_poses[src_idx] 164 | train_intrinsics_ = train_intrinsics[src_idx] 165 | src_rgbs.append(src_rgb) 166 | img_size = src_rgb.shape[:2] 167 | src_camera = np.concatenate( 168 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten()) 169 | ).astype(np.float32) 170 | 171 | src_cameras.append(src_camera) 172 | 173 | # load src virtual views 174 | vv_pose_ids = get_nearest_pose_ids( 175 | render_pose, 176 | self.src_vv_c2w_mats[self.render_idx], 177 | tar_id=-1, 178 | angular_dist_method='dist', 179 | ) 180 | 181 | # load virtual source views 182 | num_vv = self.num_vv 183 | for virtual_idx in vv_pose_ids[:num_vv]: 184 | src_vv_path = os.path.join( 185 | '/'.join( 186 | rgb_file.replace('images', 'source_virtual_views').split('/')[:-1] 187 | ), 188 | '%05d' % self.render_idx, 189 | '%02d.png' % virtual_idx, 190 | ) 191 | src_rgb = imageio.imread(src_vv_path).astype(np.float32) / 255.0 192 | src_rgbs.append(src_rgb) 193 | img_size = src_rgb.shape[:2] 194 | 195 | src_camera = np.concatenate(( 196 | list(img_size), 197 | intrinsics.flatten(), 198 | self.src_vv_c2w_mats[self.render_idx, virtual_idx].flatten(), 199 | )).astype(np.float32) 200 | 201 | src_cameras.append(src_camera) 202 | 203 | src_rgbs = np.stack(src_rgbs, axis=0) 204 | src_cameras = np.stack(src_cameras, axis=0) 205 | 206 | static_src_rgbs = [] 207 | static_src_cameras = [] 208 | # load src rgb for static view 209 | for st_near_id in static_pose_ids: 210 | src_rgb = ( 211 | imageio.imread(train_rgb_files[st_near_id]).astype(np.float32) / 255.0 212 | ) 213 | train_pose = train_poses[st_near_id] 214 | train_intrinsics_ = train_intrinsics[st_near_id] 215 | 216 | if self.mask_src_view: 217 | st_mask_path = os.path.join( 218 | '/'.join(rgb_file.split('/')[:-2]), 219 | 'dynamic_masks', 220 | '%d.png' % st_near_id, 221 | ) 222 | st_mask = imageio.imread(st_mask_path).astype(np.float32) / 255.0 223 | st_mask = cv2.resize( 224 | st_mask, 225 | (src_rgb.shape[1], src_rgb.shape[0]), 226 | interpolation=cv2.INTER_NEAREST, 227 | ) 228 | 229 | if len(st_mask.shape) == 2: 230 | st_mask = st_mask[..., None] 231 | 232 | src_rgb = src_rgb * st_mask 233 | 234 | static_src_rgbs.append(src_rgb) 235 | img_size = src_rgb.shape[:2] 236 | src_camera = np.concatenate( 237 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten()) 238 | ).astype(np.float32) 239 | 240 | static_src_cameras.append(src_camera) 241 | 242 | static_src_rgbs = np.stack(static_src_rgbs, axis=0) 243 | static_src_cameras = np.stack(static_src_cameras, axis=0) 244 | 245 | depth_range = torch.tensor([depth_range[0] * 0.9, depth_range[1] * 1.5]) 246 | 247 | return { 248 | 'camera': torch.from_numpy(camera), 249 | 'rgb_path': '', 250 | 'rgb': torch.from_numpy(rgb), 251 | 'src_rgbs': torch.from_numpy(src_rgbs[..., :3]).float(), 252 | 'src_cameras': torch.from_numpy(src_cameras).float(), 253 | 'static_src_rgbs': torch.from_numpy(static_src_rgbs[..., :3]).float(), 254 | 'static_src_cameras': torch.from_numpy(static_src_cameras).float(), 255 | 'depth_range': depth_range, 256 | 'ref_time': float(self.render_idx / float(self.num_frames)), 257 | 'id': self.render_idx, 258 | 'nearest_pose_ids': nearest_pose_ids 259 | } 260 | 261 | if __name__ == '__main__': 262 | parser = config_parser() 263 | args = parser.parse_args() 264 | args.distributed = False 265 | 266 | test_dataset = DynamicVideoDataset(args, scenes=args.eval_scenes) 267 | args.num_frames = test_dataset.num_frames 268 | 269 | # Create ibrnet model 270 | model = DynibarMono(args) 271 | eval_dataset_name = args.eval_dataset 272 | extra_out_dir = '{}/{}/{}'.format( 273 | eval_dataset_name, args.expname, str(args.render_idx) 274 | ) 275 | print('saving results to {}...'.format(extra_out_dir)) 276 | os.makedirs(extra_out_dir, exist_ok=True) 277 | 278 | projector = Projector(device='cuda:0') 279 | 280 | assert len(args.eval_scenes) == 1, 'only accept single scene' 281 | scene_name = args.eval_scenes[0] 282 | out_scene_dir = os.path.join( 283 | extra_out_dir, '{}_{:06d}'.format(scene_name, model.start_step), 'videos' 284 | ) 285 | print('saving results to {}'.format(out_scene_dir)) 286 | 287 | os.makedirs(out_scene_dir, exist_ok=True) 288 | os.makedirs(os.path.join(out_scene_dir, 'rgb_out'), exist_ok=True) 289 | 290 | save_prefix = scene_name 291 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 292 | total_num = len(test_loader) 293 | out_frames = [] 294 | full_frames = [] 295 | crop_ratio = 0.03 296 | 297 | for i, data in enumerate(test_loader): 298 | idx = int(data['id'].item()) 299 | start = time.time() 300 | ref_time_embedding = data['ref_time'].cuda() 301 | ref_frame_idx = int(data['id'].item()) 302 | ref_time_offset = [ 303 | int(near_idx - ref_frame_idx) 304 | for near_idx in data['nearest_pose_ids'].squeeze().tolist() 305 | ] 306 | 307 | model.switch_to_eval() 308 | with torch.no_grad(): 309 | ray_sampler = RaySamplerSingleImage(data, device='cuda:0') 310 | ray_batch = ray_sampler.get_all() 311 | 312 | cb_featmaps_1, cb_featmaps_2 = model.feature_net( 313 | ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2) 314 | ) 315 | ref_featmaps = cb_featmaps_1 # [0:NUM_DYNAMIC_SRC_VIEWS] 316 | 317 | static_src_rgbs = ( 318 | ray_batch['static_src_rgbs'].squeeze(0).permute(0, 3, 1, 2) 319 | ) 320 | static_featmaps, _ = model.feature_net_st(static_src_rgbs) 321 | 322 | ret = render_single_image_mono( 323 | frame_idx=(ref_frame_idx, None), 324 | time_embedding=(ref_time_embedding, None), 325 | time_offset=(ref_time_offset, None), 326 | ray_sampler=ray_sampler, 327 | ray_batch=ray_batch, 328 | model=model, 329 | projector=projector, 330 | chunk_size=args.chunk_size, 331 | det=True, 332 | N_samples=args.N_samples, 333 | args=args, 334 | inv_uniform=args.inv_uniform, 335 | N_importance=args.N_importance, 336 | white_bkgd=args.white_bkgd, 337 | featmaps=(ref_featmaps, None, static_featmaps), 338 | is_train=False, 339 | num_vv=args.num_vv 340 | ) 341 | 342 | coarse_pred_rgb = ret['outputs_coarse_ref']['rgb'].detach().cpu() 343 | coarse_pred_rgb_st = ret['outputs_coarse_ref']['rgb_static'].detach().cpu() 344 | coarse_pred_rgb_rgb = ret['outputs_coarse_ref']['rgb_dy'].detach().cpu() 345 | 346 | coarse_pred_rgb = ( 347 | 255 * np.clip(coarse_pred_rgb.numpy(), a_min=0, a_max=1.0) 348 | ).astype(np.uint8) 349 | 350 | h, w = coarse_pred_rgb.shape[:2] 351 | crop_h = int(h * crop_ratio) 352 | crop_w = int(w * crop_ratio) 353 | 354 | coarse_pred_rgb = coarse_pred_rgb[crop_h:h-crop_h, crop_w:w-crop_w, ...] 355 | 356 | gt_rgb = data['rgb'][0, crop_h:h-crop_h, crop_w:w-crop_w, ...] 357 | gt_rgb = (255 * np.clip(gt_rgb.numpy(), a_min=0, a_max=1.)).astype(np.uint8) 358 | 359 | full_rgb = np.concatenate([gt_rgb, coarse_pred_rgb], axis=1) 360 | 361 | full_frames.append(coarse_pred_rgb) 362 | 363 | imageio.imwrite(os.path.join(out_scene_dir, 'rgb_out', '{}.png'.format(i)), 364 | coarse_pred_rgb) 365 | 366 | print('frame {} completed, {}'.format(i, time.time() - start)) 367 | -------------------------------------------------------------------------------- /ibrnet/render_image.py: -------------------------------------------------------------------------------- 1 | """Functions for rendering a target view.""" 2 | 3 | from collections import OrderedDict 4 | from ibrnet.render_ray import render_rays_mono 5 | from ibrnet.render_ray import render_rays_mv 6 | import torch 7 | 8 | 9 | def render_single_image_nvi( 10 | frame_idx, 11 | time_embedding, 12 | time_offset, 13 | ray_sampler, 14 | ray_batch, 15 | model, 16 | projector, 17 | chunk_size, 18 | N_samples, 19 | args, 20 | inv_uniform=False, 21 | N_importance=0, 22 | det=False, 23 | white_bkgd=False, 24 | render_stride=1, 25 | coarse_featmaps=None, 26 | fine_featmaps=None, 27 | is_train=True, 28 | ): 29 | """Render a target view for Nvidia dataset. 30 | 31 | Args: 32 | frame_idx: video frame index 33 | time_embedding: input time embedding 34 | time_offset: offset w.r.t reference time 35 | ray_sampler: target view ray sampler 36 | ray_batch: batch of ray information 37 | model: dynibar model 38 | projector: perspective projection module 39 | chunk_size: processing chunk size 40 | N_samples: number of coarse samples along the ray 41 | args: additional input arguments 42 | inv_uniform: use disparity-based sampling or not 43 | N_importance: number of fine samples along the ray 44 | det: deterministic sampling 45 | white_bkgd: whether background is present 46 | render_stride: pixel stride when rendering images 47 | coarse_featmaps: coarse-stage 2D feature map 48 | fine_featmaps: fine-stage 2D feature map 49 | is_train: is training or not 50 | 51 | Returns: 52 | outputs_fine_anchor: rendered fine images at target view from contents at 53 | nearby time 54 | outputs_fine_ref: rendered fine images at target view from contents at 55 | target time 56 | outputs_coarse_ref: rendered coarse images at target view from contents at 57 | target time 58 | """ 59 | 60 | all_ret = OrderedDict([ 61 | ('outputs_fine_anchor', OrderedDict()), 62 | ('outputs_fine_ref', OrderedDict()), 63 | ('outputs_coarse_ref', OrderedDict()), 64 | ]) 65 | 66 | N_rays = ray_batch['ray_o'].shape[0] 67 | 68 | for i in range(0, N_rays, chunk_size): 69 | chunk = OrderedDict() 70 | for k in ray_batch: 71 | if ray_batch[k] is None: 72 | chunk[k] = None 73 | elif k in [ 74 | 'camera', 75 | 'depth_range', 76 | 'src_rgbs', 77 | 'src_cameras', 78 | 'anchor_src_rgbs', 79 | 'anchor_src_cameras', 80 | 'static_src_rgbs', 81 | 'static_src_cameras', 82 | ]: 83 | chunk[k] = ray_batch[k] 84 | elif len(ray_batch[k].shape) == 3: # flow and mask 85 | chunk[k] = ray_batch[k][:, i : i + chunk_size, ...] 86 | elif ray_batch[k] is not None: 87 | chunk[k] = ray_batch[k][i : i + chunk_size] 88 | else: 89 | chunk[k] = None 90 | 91 | ret = render_rays_mv( 92 | frame_idx=frame_idx, 93 | time_embedding=time_embedding, 94 | time_offset=time_offset, 95 | ray_batch=chunk, 96 | model=model, 97 | coarse_featmaps=coarse_featmaps, 98 | fine_featmaps=fine_featmaps, 99 | projector=projector, 100 | N_samples=N_samples, 101 | args=args, 102 | inv_uniform=inv_uniform, 103 | N_importance=N_importance, 104 | raw_noise_std=0.0, 105 | det=det, 106 | white_bkgd=white_bkgd, 107 | is_train=is_train, 108 | ) 109 | 110 | # handle both coarse and fine outputs 111 | # cache chunk results on cpu 112 | if i == 0: 113 | for k in ret['outputs_coarse_ref']: 114 | all_ret['outputs_coarse_ref'][k] = [] 115 | 116 | for k in ret['outputs_fine_ref']: 117 | all_ret['outputs_fine_ref'][k] = [] 118 | 119 | if is_train: 120 | for k in ret['outputs_fine_anchor']: 121 | all_ret['outputs_fine_anchor'][k] = [] 122 | 123 | for k in ret['outputs_coarse_ref']: 124 | all_ret['outputs_coarse_ref'][k].append( 125 | ret['outputs_coarse_ref'][k].cpu() 126 | ) 127 | 128 | for k in ret['outputs_fine_ref']: 129 | all_ret['outputs_fine_ref'][k].append(ret['outputs_fine_ref'][k].cpu()) 130 | 131 | if is_train: 132 | for k in ret['outputs_fine_anchor']: 133 | all_ret['outputs_fine_anchor'][k].append( 134 | ret['outputs_fine_anchor'][k].cpu() 135 | ) 136 | 137 | rgb_strided = torch.ones(ray_sampler.H, ray_sampler.W, 3)[ 138 | ::render_stride, ::render_stride, : 139 | ] 140 | # merge chunk results and reshape 141 | for k in all_ret['outputs_coarse_ref']: 142 | if k == 'random_sigma': 143 | continue 144 | 145 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 4: 146 | continue 147 | 148 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 3: 149 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=1).reshape(( 150 | all_ret['outputs_coarse_ref'][k][0].shape[0], 151 | rgb_strided.shape[0], 152 | rgb_strided.shape[1], 153 | -1, 154 | )) 155 | else: 156 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=0).reshape( 157 | (rgb_strided.shape[0], rgb_strided.shape[1], -1) 158 | ) 159 | all_ret['outputs_coarse_ref'][k] = tmp.squeeze() 160 | 161 | all_ret['outputs_coarse_ref']['rgb'][ 162 | all_ret['outputs_coarse_ref']['mask'] == 0 163 | ] = 0.0 164 | 165 | # merge chunk results and reshape 166 | for k in all_ret['outputs_fine_ref']: 167 | if k == 'random_sigma': 168 | continue 169 | 170 | if len(all_ret['outputs_fine_ref'][k][0].shape) == 4: 171 | continue 172 | 173 | if len(all_ret['outputs_fine_ref'][k][0].shape) == 3: 174 | tmp = torch.cat(all_ret['outputs_fine_ref'][k], dim=1).reshape(( 175 | all_ret['outputs_fine_ref'][k][0].shape[0], 176 | rgb_strided.shape[0], 177 | rgb_strided.shape[1], 178 | -1, 179 | )) 180 | else: 181 | tmp = torch.cat(all_ret['outputs_fine_ref'][k], dim=0).reshape( 182 | (rgb_strided.shape[0], rgb_strided.shape[1], -1) 183 | ) 184 | all_ret['outputs_fine_ref'][k] = tmp.squeeze() 185 | 186 | all_ret['outputs_fine_ref']['rgb'][ 187 | all_ret['outputs_fine_ref']['mask'] == 0 188 | ] = 0.0 189 | 190 | # merge chunk results and reshape 191 | if is_train: 192 | for k in all_ret['outputs_fine_anchor']: 193 | if k == 'random_sigma': 194 | continue 195 | 196 | if len(all_ret['outputs_fine_anchor'][k][0].shape) == 4: 197 | continue 198 | 199 | if len(all_ret['outputs_fine_anchor'][k][0].shape) == 3: 200 | tmp = torch.cat(all_ret['outputs_fine_anchor'][k], dim=1).reshape(( 201 | all_ret['outputs_fine_anchor'][k][0].shape[0], 202 | rgb_strided.shape[0], 203 | rgb_strided.shape[1], 204 | -1, 205 | )) 206 | else: 207 | tmp = torch.cat(all_ret['outputs_fine_anchor'][k], dim=0).reshape( 208 | (rgb_strided.shape[0], rgb_strided.shape[1], -1) 209 | ) 210 | all_ret['outputs_fine_anchor'][k] = tmp.squeeze() 211 | 212 | all_ret['outputs_fine_anchor']['rgb'][ 213 | all_ret['outputs_fine_anchor']['mask'] == 0 214 | ] = 0.0 215 | 216 | all_ret['outputs_fine'] = None 217 | return all_ret 218 | 219 | 220 | def render_single_image_mono( 221 | frame_idx, 222 | time_embedding, 223 | time_offset, 224 | ray_sampler, 225 | ray_batch, 226 | model, 227 | projector, 228 | chunk_size, 229 | N_samples, 230 | args, 231 | inv_uniform=False, 232 | N_importance=0, 233 | det=False, 234 | white_bkgd=False, 235 | render_stride=1, 236 | featmaps=None, 237 | is_train=True, 238 | num_vv=2, 239 | ): 240 | """Render a target view for Monocular video. 241 | 242 | Args: 243 | frame_idx: video frame index 244 | time_embedding: input time embedding 245 | time_offset: offset w.r.t reference time 246 | ray_sampler: target view ray sampler 247 | ray_batch: batch of ray information 248 | model: dynibar model 249 | projector: perspective projection module 250 | chunk_size: processing chunk size 251 | N_samples: number of coarse samples along the ray 252 | args: additional input arguments 253 | inv_uniform: use disparity-based sampling or not 254 | N_importance: number of fine samples along the ray 255 | det: deterministic sampling 256 | white_bkgd: whether background is present 257 | render_stride: pixel stride when rendering images 258 | featmaps: coarse-stage 2D feature map 259 | is_train: is training or not 260 | num_vv: number of virtual source views used 261 | 262 | Returns: 263 | outputs_coarse_ref: rendered images at target view from combined contents at 264 | target time, coarse model 265 | outputs_coarse_st: rendered images at target view from static 266 | contents at target time, coarse model 267 | outputs_coarse_anchor: cross-rendered images at target view 268 | from combined contents at nearby time, coarse model 269 | 270 | """ 271 | 272 | all_ret = OrderedDict([ 273 | ('outputs_coarse_ref', OrderedDict()), 274 | ('outputs_coarse_st', OrderedDict()), 275 | ('outputs_coarse_anchor', OrderedDict()), 276 | ]) 277 | 278 | N_rays = ray_batch['ray_o'].shape[0] 279 | 280 | for i in range(0, N_rays, chunk_size): 281 | chunk = OrderedDict() 282 | for k in ray_batch: 283 | if ray_batch[k] is None: 284 | chunk[k] = None 285 | elif k in [ 286 | 'camera', 287 | 'anchor_camera', 288 | 'depth_range', 289 | 'src_rgbs', 290 | 'src_cameras', 291 | 'anchor_src_rgbs', 292 | 'anchor_src_cameras', 293 | 'static_src_rgbs', 294 | 'static_src_cameras', 295 | ]: 296 | chunk[k] = ray_batch[k] 297 | elif len(ray_batch[k].shape) == 3: # flow and mask 298 | chunk[k] = ray_batch[k][:, i : i + chunk_size, ...] 299 | elif ray_batch[k] is not None: 300 | chunk[k] = ray_batch[k][i : i + chunk_size] 301 | else: 302 | chunk[k] = None 303 | 304 | ret = render_rays_mono( 305 | frame_idx=frame_idx, 306 | time_embedding=time_embedding, 307 | time_offset=time_offset, 308 | ray_batch=chunk, 309 | model=model, 310 | featmaps=featmaps, 311 | projector=projector, 312 | N_samples=N_samples, 313 | args=args, 314 | inv_uniform=inv_uniform, 315 | N_importance=N_importance, 316 | raw_noise_std=0.0, 317 | det=det, 318 | white_bkgd=white_bkgd, 319 | is_train=is_train, 320 | num_vv=num_vv, 321 | ) 322 | 323 | # handle both coarse and fine outputs 324 | # cache chunk results on cpu 325 | if i == 0: 326 | for k in ret['outputs_coarse_ref']: 327 | all_ret['outputs_coarse_ref'][k] = [] 328 | 329 | for k in ret['outputs_coarse_st']: 330 | all_ret['outputs_coarse_st'][k] = [] 331 | 332 | if is_train: 333 | for k in ret['outputs_coarse_anchor']: 334 | all_ret['outputs_coarse_anchor'][k] = [] 335 | 336 | if ret['outputs_fine'] is None: 337 | all_ret['outputs_fine'] = None 338 | else: 339 | for k in ret['outputs_fine']: 340 | all_ret['outputs_fine'][k] = [] 341 | 342 | for k in ret['outputs_coarse_ref']: 343 | all_ret['outputs_coarse_ref'][k].append( 344 | ret['outputs_coarse_ref'][k].cpu() 345 | ) 346 | 347 | for k in ret['outputs_coarse_st']: 348 | all_ret['outputs_coarse_st'][k].append(ret['outputs_coarse_st'][k].cpu()) 349 | 350 | if is_train: 351 | for k in ret['outputs_coarse_anchor']: 352 | all_ret['outputs_coarse_anchor'][k].append( 353 | ret['outputs_coarse_anchor'][k].cpu() 354 | ) 355 | 356 | if ret['outputs_fine'] is not None: 357 | for k in ret['outputs_fine']: 358 | all_ret['outputs_fine'][k].append(ret['outputs_fine'][k].cpu()) 359 | 360 | rgb_strided = torch.ones(ray_sampler.H, ray_sampler.W, 3)[ 361 | ::render_stride, ::render_stride, : 362 | ] 363 | # merge chunk results and reshape 364 | for k in all_ret['outputs_coarse_ref']: 365 | if k == 'random_sigma': 366 | continue 367 | 368 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 4: 369 | continue 370 | 371 | if len(all_ret['outputs_coarse_ref'][k][0].shape) == 3: 372 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=1).reshape(( 373 | all_ret['outputs_coarse_ref'][k][0].shape[0], 374 | rgb_strided.shape[0], 375 | rgb_strided.shape[1], 376 | -1, 377 | )) 378 | else: 379 | tmp = torch.cat(all_ret['outputs_coarse_ref'][k], dim=0).reshape( 380 | (rgb_strided.shape[0], rgb_strided.shape[1], -1) 381 | ) 382 | all_ret['outputs_coarse_ref'][k] = tmp.squeeze() 383 | 384 | all_ret['outputs_coarse_ref']['rgb'][ 385 | all_ret['outputs_coarse_ref']['mask'] == 0 386 | ] = 0.0 387 | 388 | # merge chunk results and reshape 389 | for k in all_ret['outputs_coarse_st']: 390 | if k == 'random_sigma': 391 | continue 392 | 393 | if len(all_ret['outputs_coarse_st'][k][0].shape) == 4: 394 | continue 395 | 396 | if len(all_ret['outputs_coarse_st'][k][0].shape) == 3: 397 | tmp = torch.cat(all_ret['outputs_coarse_st'][k], dim=1).reshape(( 398 | all_ret['outputs_coarse_st'][k][0].shape[0], 399 | rgb_strided.shape[0], 400 | rgb_strided.shape[1], 401 | -1, 402 | )) 403 | else: 404 | tmp = torch.cat(all_ret['outputs_coarse_st'][k], dim=0).reshape( 405 | (rgb_strided.shape[0], rgb_strided.shape[1], -1) 406 | ) 407 | all_ret['outputs_coarse_st'][k] = tmp.squeeze() 408 | 409 | all_ret['outputs_coarse_st']['rgb'][ 410 | all_ret['outputs_coarse_st']['mask'] == 0 411 | ] = 0.0 412 | 413 | # merge chunk results and reshape 414 | if is_train: 415 | for k in all_ret['outputs_coarse_anchor']: 416 | if k == 'random_sigma': 417 | continue 418 | 419 | if len(all_ret['outputs_coarse_anchor'][k][0].shape) == 4: 420 | continue 421 | 422 | if len(all_ret['outputs_coarse_anchor'][k][0].shape) == 3: 423 | tmp = torch.cat(all_ret['outputs_coarse_anchor'][k], dim=1).reshape(( 424 | all_ret['outputs_coarse_anchor'][k][0].shape[0], 425 | rgb_strided.shape[0], 426 | rgb_strided.shape[1], 427 | -1, 428 | )) 429 | else: 430 | tmp = torch.cat(all_ret['outputs_coarse_anchor'][k], dim=0).reshape( 431 | (rgb_strided.shape[0], rgb_strided.shape[1], -1) 432 | ) 433 | all_ret['outputs_coarse_anchor'][k] = tmp.squeeze() 434 | 435 | all_ret['outputs_coarse_anchor']['rgb'][ 436 | all_ret['outputs_coarse_anchor']['mask'] == 0 437 | ] = 0.0 438 | 439 | return all_ret 440 | -------------------------------------------------------------------------------- /ibrnet/data_loaders/llff_data_utils.py: -------------------------------------------------------------------------------- 1 | """Forward-Facing data loading code. 2 | 3 | Modify from IBRNet 4 | github.com/googleinterns/IBRNet/blob/master/ibrnet/data_loaders/llff_data_utils.py 5 | """ 6 | 7 | import os 8 | 9 | import cv2 10 | import imageio 11 | import numpy as np 12 | 13 | 14 | def parse_llff_pose(pose): 15 | """convert llff format pose to 4x4 matrix of intrinsics and extrinsics.""" 16 | 17 | h, w, f = pose[:3, -1] 18 | c2w = pose[:3, :4] 19 | c2w_4x4 = np.eye(4) 20 | c2w_4x4[:3] = c2w 21 | c2w_4x4[:, 1:3] *= -1 22 | intrinsics = np.array( 23 | [[f, 0, w / 2.0, 0], [0, f, h / 2.0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] 24 | ) 25 | return intrinsics, c2w_4x4 26 | 27 | 28 | def batch_parse_llff_poses(poses): 29 | """Parse LLFF data format to opencv/colmap format.""" 30 | all_intrinsics = [] 31 | all_c2w_mats = [] 32 | for pose in poses: 33 | intrinsics, c2w_mat = parse_llff_pose(pose) 34 | all_intrinsics.append(intrinsics) 35 | all_c2w_mats.append(c2w_mat) 36 | all_intrinsics = np.stack(all_intrinsics) 37 | all_c2w_mats = np.stack(all_c2w_mats) 38 | return all_intrinsics, all_c2w_mats 39 | 40 | 41 | def batch_parse_vv_poses(poses): 42 | """Parse virtural views pose used for monocular video training.""" 43 | all_c2w_mats = [] 44 | for pose in poses: 45 | t_c2w_mats = [] 46 | for p in pose: 47 | intrinsics, c2w_mat = parse_llff_pose(p) 48 | t_c2w_mats.append(c2w_mat) 49 | t_c2w_mats = np.stack(t_c2w_mats) 50 | all_c2w_mats.append(t_c2w_mats) 51 | 52 | all_c2w_mats = np.stack(all_c2w_mats) 53 | 54 | return all_c2w_mats 55 | 56 | 57 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 58 | """Function for loading LLFF data.""" 59 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds_cvd.npy')) 60 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 61 | bds = poses_arr[:, -2:].transpose([1, 0]) 62 | 63 | img0 = [ 64 | os.path.join(basedir, 'images', f) 65 | for f in sorted(os.listdir(os.path.join(basedir, 'images'))) 66 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') 67 | ][0] 68 | sh = imageio.imread(img0).shape 69 | 70 | sfx = '' 71 | 72 | if factor is not None and factor != 1: 73 | sfx = '_{}'.format(factor) 74 | elif height is not None: 75 | factor = sh[0] / float(height) 76 | width = int(round(sh[1] / factor)) 77 | sfx = '_{}x{}'.format(width, height) 78 | elif width is not None: 79 | factor = sh[1] / float(width) 80 | height = int(round(sh[0] / factor)) 81 | sfx = '_{}x{}'.format(width, height) 82 | else: 83 | factor = 1 84 | 85 | imgdir = os.path.join(basedir, 'images' + sfx) 86 | print('imgdir ', imgdir, ' factor ', factor) 87 | 88 | if not os.path.exists(imgdir): 89 | print(imgdir, 'does not exist, returning') 90 | return 91 | 92 | imgfiles = [ 93 | os.path.join(imgdir, f) 94 | for f in sorted(os.listdir(imgdir)) 95 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') 96 | ] 97 | 98 | if poses.shape[-1] != len(imgfiles): 99 | print( 100 | '{}: Mismatch between imgs {} and poses {} !!!!'.format( 101 | basedir, len(imgfiles), poses.shape[-1] 102 | ) 103 | ) 104 | raise NotImplementedError 105 | 106 | sh = imageio.imread(imgfiles[0]).shape 107 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 108 | poses[2, 4, :] = poses[2, 4, :] # * 1. / factor 109 | 110 | def imread(f): 111 | if f.endswith('png'): 112 | return imageio.imread(f, ignoregamma=True) 113 | else: 114 | return imageio.imread(f) 115 | 116 | if not load_imgs: 117 | imgs = None 118 | else: 119 | imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles] 120 | imgs = np.stack(imgs, -1) 121 | print('Loaded image data', imgs.shape, poses[:, -1, 0]) 122 | 123 | return poses, bds, imgs, imgfiles 124 | 125 | 126 | def normalize(x): 127 | return x / np.linalg.norm(x) 128 | 129 | 130 | def viewmatrix(z, up, pos): 131 | vec2 = normalize(z) 132 | vec1_avg = up 133 | vec0 = normalize(np.cross(vec1_avg, vec2)) 134 | vec1 = normalize(np.cross(vec2, vec0)) 135 | m = np.stack([vec0, vec1, vec2, pos], 1) 136 | return m 137 | 138 | 139 | def ptstocam(pts, c2w): 140 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] 141 | return tt 142 | 143 | 144 | def poses_avg(poses): 145 | hwf = poses[0, :3, -1:] 146 | 147 | center = poses[:, :3, 3].mean(0) 148 | vec2 = normalize(poses[:, :3, 2].sum(0)) 149 | up = poses[:, :3, 1].sum(0) 150 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 151 | 152 | return c2w 153 | 154 | 155 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 156 | """Render a spiral path.""" 157 | 158 | render_poses = [] 159 | rads = np.array(list(rads) + [1.0]) 160 | hwf = c2w[:, 4:5] 161 | 162 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 163 | c = np.dot( 164 | c2w[:3, :4], 165 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) 166 | * rads, 167 | ) 168 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 169 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 170 | return render_poses 171 | 172 | 173 | def recenter_poses(poses): 174 | """Recenter camera poses into centroid.""" 175 | poses_ = poses + 0 176 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) 177 | c2w = poses_avg(poses) 178 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 179 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 180 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 181 | 182 | poses = np.linalg.inv(c2w) @ poses 183 | poses_[:, :3, :4] = poses[:, :3, :4] 184 | poses = poses_ 185 | return poses 186 | 187 | 188 | def recenter_poses_mono(poses, src_vv_poses): 189 | """Recenter virutal view camera poses into centroid.""" 190 | hwf = poses[:, :, 4:5] 191 | poses_ = poses + 0 192 | bottom = np.reshape([0, 0, 0, 1.], [1, 4]) 193 | c2w = poses_avg(poses) 194 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 195 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 196 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 197 | 198 | poses = np.linalg.inv(c2w) @ poses 199 | poses_[:, :3, :4] = poses[:, :3, :4] 200 | poses = poses_ 201 | 202 | src_output_poses = np.zeros(( 203 | src_vv_poses.shape[1], 204 | src_vv_poses.shape[0], 205 | src_vv_poses.shape[2], 206 | src_vv_poses.shape[3] + 1, 207 | )) 208 | for i in range(src_vv_poses.shape[1]): 209 | src_vv_poses_ = np.concatenate([src_vv_poses[:, i, :3, :4], bottom], -2) 210 | src_vv_poses_ = np.linalg.inv(c2w) @ src_vv_poses_ 211 | src_output_poses[i, ...] = np.concatenate([src_vv_poses_[:, :3, :], hwf], 2) 212 | 213 | return poses, np.moveaxis(src_output_poses, 1, 0) 214 | 215 | 216 | def load_llff_data( 217 | basedir, 218 | height, 219 | num_avg_imgs, 220 | factor=8, 221 | render_idx=8, 222 | recenter=True, 223 | bd_factor=0.75, 224 | spherify=False, 225 | load_imgs=True, 226 | ): 227 | """Load LLFF forward-facing data. 228 | 229 | Args: 230 | basedir: base directory 231 | height: training image height 232 | factor: resize factor 233 | render_idx: rendering frame index from the video 234 | recenter: recentor camera poses 235 | bd_factor: scale factor for bounds 236 | spherify: spherify the camera poses 237 | load_imgs: load images from the disk 238 | 239 | Returns: 240 | images: video frames 241 | poses: corresponding camera parameters 242 | bds: bounds 243 | render_poses: rendering camera poses 244 | i_test: test index 245 | imgfiles: list of image path 246 | scale: scene scale 247 | """ 248 | out = _load_data( 249 | basedir, factor=None, load_imgs=load_imgs, height=height 250 | ) 251 | 252 | if out is None: 253 | return 254 | else: 255 | poses, bds, imgs, imgfiles = out 256 | 257 | # Correct rotation matrix ordering and move variable dim to axis 0 258 | poses = np.concatenate( 259 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1 260 | ) 261 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 262 | if imgs is not None: 263 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 264 | images = imgs 265 | images = images.astype(np.float32) 266 | else: 267 | images = None 268 | 269 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 270 | 271 | # Rescale if bd_factor is provided 272 | scale = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor) 273 | 274 | poses[:, :3, 3] *= scale 275 | bds *= scale 276 | 277 | if recenter: 278 | poses = recenter_poses(poses) 279 | 280 | spiral = True 281 | if spiral: 282 | print('================= render_path_spiral ==========================') 283 | c2w = poses_avg(poses[0:num_avg_imgs]) 284 | ## Get spiral 285 | # Get average pose 286 | up = normalize(poses[:, :3, 1].sum(0)) 287 | 288 | # Find a reasonable "focus depth" for this dataset 289 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 2.0 290 | dt = 0.75 291 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 292 | focal = mean_dz * 1.5 293 | 294 | # Get radii for spiral path 295 | # shrink_factor = 0.8 296 | zdelta = close_depth * 0.2 297 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 298 | rads = np.percentile(np.abs(tt), 80, 0) 299 | c2w_path = c2w 300 | n_views = 120 301 | n_rots = 2 302 | 303 | # Generate poses for spiral path 304 | render_poses = render_path_spiral( 305 | c2w_path, up, rads, focal, zdelta, zrate=0.5, rots=n_rots, N=n_views 306 | ) 307 | else: 308 | raise NotImplementedError 309 | 310 | render_poses = np.array(render_poses).astype(np.float32) 311 | 312 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1) 313 | i_test = np.argmin(dists) 314 | poses = poses.astype(np.float32) 315 | 316 | print('bds ', bds.min(), bds.max()) 317 | 318 | return images, poses, bds, render_poses, i_test, imgfiles, scale 319 | 320 | 321 | def load_mono_data( 322 | basedir, 323 | height=288, 324 | factor=8, 325 | render_idx=-1, 326 | recenter=True, 327 | bd_factor=0.75, 328 | spherify=False, 329 | load_imgs=True, 330 | ): 331 | """Load monocular video data. 332 | 333 | Args: 334 | basedir: base directory 335 | height: training image height 336 | factor: resize factor 337 | render_idx: rendering frame index from the video 338 | recenter: recentor camera poses 339 | bd_factor: scale factor for bounds 340 | spherify: spherify the camera poses 341 | load_imgs: load images from the disk 342 | 343 | Returns: 344 | images: video frames 345 | poses: corresponding camera parameters 346 | src_vv_poses: virtual view camera poses 347 | bds: bounds 348 | render_poses: rendering camera poses 349 | i_test: test index 350 | imgfiles: list of image path 351 | scale: scene scale 352 | """ 353 | out = _load_data(basedir, factor=None, load_imgs=load_imgs, height=height) 354 | 355 | src_vv_poses = np.load(os.path.join(basedir, 'source_vv_poses.npy')) 356 | 357 | if out is None: 358 | return 359 | else: 360 | poses, bds, imgs, imgfiles = out 361 | 362 | # Correct rotation matrix ordering and move variable dim to axis 0 363 | poses = np.concatenate( 364 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1 365 | ) 366 | src_vv_poses = np.concatenate( 367 | [ 368 | src_vv_poses[:, :, 1:2, :], 369 | -src_vv_poses[:, :, 0:1, :], 370 | src_vv_poses[:, :, 2:, :], 371 | ], 372 | 2, 373 | ) 374 | 375 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 376 | src_vv_poses = np.moveaxis(src_vv_poses, -1, 0).astype(np.float32) 377 | 378 | if imgs is not None: 379 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 380 | images = imgs 381 | images = images.astype(np.float32) 382 | else: 383 | images = None 384 | 385 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 386 | 387 | # Rescale if bd_factor is provided 388 | scale = 1. if bd_factor is None else 1. / (bds.min() * bd_factor) 389 | 390 | poses[:, :3, 3] *= scale 391 | src_vv_poses[..., :3, 3] *= scale 392 | 393 | bds *= scale 394 | 395 | if recenter: 396 | poses, src_vv_poses = recenter_poses_mono(poses, src_vv_poses) 397 | 398 | if render_idx >= 0: 399 | render_poses = render_wander_path(poses[render_idx]) 400 | else: 401 | render_poses = render_stabilization_path(poses, k_size=45) 402 | 403 | render_poses = np.array(render_poses).astype(np.float32) 404 | 405 | i_test = [] 406 | poses = poses.astype(np.float32) 407 | 408 | print('bds ', bds.min(), bds.max()) 409 | 410 | return images, poses, src_vv_poses, bds, render_poses, i_test, imgfiles, scale 411 | 412 | 413 | def render_wander_path(c2w): 414 | """Rendering circular path.""" 415 | hwf = c2w[:, 4:5] 416 | num_frames = 50 417 | max_disp = 48.0 418 | 419 | max_trans = max_disp / hwf[2][0] 420 | output_poses = [] 421 | 422 | for i in range(num_frames): 423 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) 424 | y_trans = 0.#max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2. 425 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2. 426 | 427 | i_pose = np.concatenate( 428 | [ 429 | np.concatenate( 430 | [ 431 | np.eye(3), 432 | np.array([x_trans, y_trans, z_trans])[:, np.newaxis], 433 | ], 434 | axis=1, 435 | ), 436 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :], 437 | ], 438 | axis=0, 439 | ) 440 | 441 | i_pose = np.linalg.inv(i_pose) 442 | 443 | ref_pose = np.concatenate( 444 | [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0 445 | ) 446 | 447 | render_pose = np.dot(ref_pose, i_pose) 448 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1)) 449 | 450 | return output_poses 451 | 452 | 453 | def render_stabilization_path(poses, k_size): 454 | """Rendering stablizaed camera path.""" 455 | 456 | hwf = poses[0, :, 4:5] 457 | num_frames = poses.shape[0] 458 | output_poses = [] 459 | 460 | input_poses = [] 461 | 462 | for i in range(num_frames): 463 | input_poses.append( 464 | np.concatenate( 465 | [poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], axis=-1 466 | ) 467 | ) 468 | 469 | input_poses = np.array(input_poses) 470 | 471 | gaussian_kernel = cv2.getGaussianKernel( 472 | ksize=k_size, sigma=-1 473 | ) 474 | output_r1 = cv2.filter2D(input_poses[:, :, 0], -1, gaussian_kernel) 475 | output_r2 = cv2.filter2D(input_poses[:, :, 1], -1, gaussian_kernel) 476 | 477 | output_r1 = output_r1 / np.linalg.norm(output_r1, axis=-1, keepdims=True) 478 | output_r2 = output_r2 / np.linalg.norm(output_r2, axis=-1, keepdims=True) 479 | 480 | output_t = cv2.filter2D(input_poses[:, :, 2], -1, gaussian_kernel) 481 | 482 | for i in range(num_frames): 483 | output_r3 = np.cross(output_r1[i], output_r2[i]) 484 | 485 | render_pose = np.concatenate( 486 | [ 487 | output_r1[i, :, None], 488 | output_r2[i, :, None], 489 | output_r3[:, None], 490 | output_t[i, :, None], 491 | ], 492 | axis=-1, 493 | ) 494 | 495 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1)) 496 | 497 | return output_poses 498 | -------------------------------------------------------------------------------- /ibrnet/data_loaders/monocular.py: -------------------------------------------------------------------------------- 1 | """Dataloader class for training monocular videos.""" 2 | 3 | 4 | import os 5 | import cv2 6 | from ibrnet.data_loaders.data_utils import get_nearest_pose_ids 7 | from ibrnet.data_loaders.llff_data_utils import batch_parse_llff_poses 8 | from ibrnet.data_loaders.llff_data_utils import batch_parse_vv_poses 9 | from ibrnet.data_loaders.llff_data_utils import load_mono_data 10 | import imageio 11 | import numpy as np 12 | import skimage.morphology 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | 17 | class MonocularDataset(Dataset): 18 | """This class loads data from monocular video. 19 | 20 | Each returned item in the dataset has 21 | id: reference frame index 22 | anchor_id: nearby frame index for cross time rendering 23 | num_frames: number of video frames 24 | ref_time: normalized reference time index 25 | anchor_time: normalized nearby cross-time index 26 | nearest_pose_ids: source view index w.r.t reference time 27 | anchor_nearest_pose_ids: source view index w.r.t nearby time 28 | rgb: [H, W, 3], image at reference time 29 | disp: [H, W], disparity at reference time 30 | motion_mask: [H, W], dynamic mask at reference time 31 | static_mask: [H, W], static mask at reference time 32 | flows: [6, H, W, 2] optical flows from reference time 33 | masks: [6, H, W] optical flow valid masks from reference time 34 | camera: [34] camera parameters at reference time 35 | anchor_camera: [34] camera parameters at nearby cross-time 36 | rgb_path: RGB file path name 37 | src_rgbs: [..., H, W, 3] source views RGB images for dynamic model 38 | src_cameras: [..., 34] source view camera parameters for dynamic model 39 | static_src_rgbs: [..., H, W, 3] srouce view images for static model 40 | static_src_cameras: [..., 34] source view camera parameters for static model 41 | anchor_src_rgbs: [..., H, W, 3] cross-time view images for dynamic model 42 | anchor_src_cameras: [..., 34] cross-time source view camera parameters for 43 | dynamic model 44 | depth_range: [2] scene near and far bounds 45 | """ 46 | 47 | def __init__(self, args, mode, scenes=(), random_crop=True, **kwargs): 48 | assert len(scenes) == 1 49 | self.folder_path = args.folder_path 50 | self.num_vv = args.num_vv 51 | self.args = args 52 | self.mask_src_view = args.mask_src_view 53 | self.num_frames_sample = args.num_source_views 54 | self.erosion_radius = args.erosion_radius 55 | self.random_crop = random_crop 56 | 57 | self.max_range = args.max_range 58 | 59 | scene = scenes[0] 60 | self.scene_path = os.path.join(self.folder_path, scene, 'dense') 61 | _, poses, src_vv_poses, bds, _, _, rgb_files, scale = ( 62 | load_mono_data( 63 | self.scene_path, height=args.training_height, load_imgs=False 64 | ) 65 | ) 66 | near_depth = np.min(bds) 67 | 68 | # make sure far scenes to be at least 15 69 | # so that static model is able to model view-dependent effect. 70 | if np.max(bds) < 10: 71 | far_depth = min(20, np.max(bds) + 15.0) 72 | else: 73 | far_depth = min(50, max(20, np.max(bds))) 74 | 75 | print('============= FINAL NEAR FAR', near_depth, far_depth) 76 | 77 | intrinsics, c2w_mats = batch_parse_llff_poses(poses) 78 | self.src_vv_c2w_mats = batch_parse_vv_poses(src_vv_poses) 79 | self.num_frames = len(rgb_files) 80 | assert self.num_frames == poses.shape[0] 81 | i_train = np.arange(self.num_frames) 82 | i_render = i_train 83 | self.scale = scale 84 | 85 | num_render = len(i_render) 86 | self.train_rgb_files = rgb_files 87 | self.train_intrinsics = intrinsics 88 | self.train_poses = c2w_mats 89 | self.train_depth_range = [[near_depth, far_depth]] * num_render 90 | 91 | def read_optical_flow(self, basedir, img_i, start_frame, fwd, interval): 92 | flow_dir = os.path.join(basedir, 'flow_i%d' % interval) 93 | 94 | if fwd: 95 | fwd_flow_path = os.path.join( 96 | flow_dir, '%05d_fwd.npz' % (start_frame + img_i) 97 | ) 98 | fwd_data = np.load(fwd_flow_path) # , (w, h)) 99 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask'] 100 | fwd_mask = np.float32(fwd_mask) 101 | 102 | return fwd_flow, fwd_mask 103 | else: 104 | bwd_flow_path = os.path.join( 105 | flow_dir, '%05d_bwd.npz' % (start_frame + img_i) 106 | ) 107 | 108 | bwd_data = np.load(bwd_flow_path) # , (w, h)) 109 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask'] 110 | bwd_mask = np.float32(bwd_mask) 111 | 112 | return bwd_flow, bwd_mask 113 | 114 | def __len__(self): 115 | return self.num_frames 116 | 117 | def set_epoch(self, epoch): 118 | self.current_epoch = epoch 119 | 120 | def load_src_view( 121 | self, rgb_file, pose, intrinsics, st_mask_path=None 122 | ): 123 | """Load RGB and camera data from each source views id.""" 124 | 125 | src_rgb = imageio.imread(rgb_file).astype(np.float32) / 255.0 126 | img_size = src_rgb.shape[:2] 127 | src_camera = np.concatenate( 128 | (list(img_size), intrinsics.flatten(), pose.flatten()) 129 | ).astype(np.float32) 130 | 131 | if st_mask_path: 132 | st_mask = imageio.imread(st_mask_path).astype(np.float32) / 255.0 133 | st_mask = cv2.resize( 134 | st_mask, 135 | (src_rgb.shape[1], src_rgb.shape[0]), 136 | interpolation=cv2.INTER_NEAREST, 137 | ) 138 | 139 | if len(st_mask.shape) == 2: 140 | st_mask = st_mask[..., None] 141 | 142 | src_rgb = src_rgb * st_mask 143 | 144 | return src_rgb, src_camera 145 | 146 | def __getitem__(self, idx): 147 | # skip first and last 3 frames 148 | idx = np.random.randint(3, self.num_frames - 3) 149 | rgb_file = self.train_rgb_files[idx] 150 | 151 | render_pose = self.train_poses[idx] 152 | intrinsics = self.train_intrinsics[idx] 153 | depth_range = self.train_depth_range[idx] 154 | 155 | rgb, camera = self.load_src_view(rgb_file, render_pose, intrinsics) 156 | img_size = rgb.shape[:2] 157 | 158 | # load mono-depth 159 | disp_path = os.path.join( 160 | self.scene_path, 'disp', rgb_file.split('/')[-1][:-4] + '.npy' 161 | ) 162 | disp = np.load(disp_path) / self.scale 163 | 164 | # load motion mask 165 | mask_path = os.path.join( 166 | '/'.join(rgb_file.split('/')[:-2]), 'dynamic_masks', '%d.png' % idx 167 | ) 168 | motion_mask = 1.0 - imageio.imread(mask_path).astype(np.float32) / 255.0 169 | 170 | static_mask_path = os.path.join( 171 | '/'.join(rgb_file.split('/')[:-2]), 'static_masks', '%d.png' % idx 172 | ) 173 | static_mask = ( 174 | 1.0 - imageio.imread(static_mask_path).astype(np.float32) / 255.0 175 | ) 176 | 177 | static_mask = cv2.resize( 178 | static_mask, 179 | (disp.shape[1], disp.shape[0]), 180 | interpolation=cv2.INTER_NEAREST, 181 | ) 182 | # ensure input dynamic and static mask to have same height before 183 | # running morphological erosion 184 | motion_mask = cv2.resize( 185 | motion_mask, 186 | (int(round(288.0 * disp.shape[1] / disp.shape[0])), 288), 187 | interpolation=cv2.INTER_NEAREST, 188 | ) 189 | 190 | if len(motion_mask.shape) == 2: 191 | motion_mask = motion_mask[..., None] 192 | 193 | motion_mask = skimage.morphology.erosion( 194 | motion_mask[..., 0] > 1e-3, skimage.morphology.disk(self.erosion_radius) 195 | ) 196 | 197 | motion_mask = cv2.resize( 198 | np.float32(motion_mask), 199 | (disp.shape[1], disp.shape[0]), 200 | interpolation=cv2.INTER_NEAREST, 201 | ) 202 | 203 | motion_mask = np.float32(motion_mask) 204 | static_mask = np.float32(static_mask > 1e-3) 205 | 206 | assert disp.shape[0:2] == img_size 207 | assert motion_mask.shape[0:2] == img_size 208 | assert static_mask.shape[0:2] == img_size 209 | 210 | # train_set_id = self.render_train_set_ids[idx] 211 | train_rgb_files = self.train_rgb_files 212 | train_poses = self.train_poses 213 | train_intrinsics = self.train_intrinsics 214 | 215 | # view selection based on time interval 216 | nearest_pose_ids = [idx + offset for offset in [1, 2, 3, -1, -2, -3]] 217 | max_step = min(3, self.current_epoch // (self.args.init_decay_epoch) + 1) 218 | # select a nearby time index for cross time rendering 219 | anchor_pool = [i for i in range(1, max_step + 1)] + [ 220 | -i for i in range(1, max_step + 1) 221 | ] 222 | anchor_idx = idx + anchor_pool[np.random.choice(len(anchor_pool))] 223 | anchor_nearest_pose_ids = [] 224 | 225 | anchor_camera = np.concatenate(( 226 | list(img_size), 227 | self.train_intrinsics[anchor_idx].flatten(), 228 | self.train_poses[anchor_idx].flatten(), 229 | )).astype(np.float32) 230 | 231 | for offset in [3, 2, 1, 0, -1, -2, -3]: 232 | if ( 233 | (anchor_idx + offset) < 0 234 | or (anchor_idx + offset) >= len(train_rgb_files) 235 | or (anchor_idx + offset) == idx 236 | ): 237 | continue 238 | anchor_nearest_pose_ids.append((anchor_idx + offset)) 239 | 240 | # occasionally include render image for anchor time index 241 | if np.random.choice([0, 1], p=[1.0 - 0.005, 0.005]): 242 | anchor_nearest_pose_ids.append(idx) 243 | 244 | anchor_nearest_pose_ids = np.sort(anchor_nearest_pose_ids) 245 | 246 | flows, masks = [], [] 247 | 248 | # load optical flow 249 | for ii in range(len(nearest_pose_ids)): 250 | offset = nearest_pose_ids[ii] - idx 251 | flow, mask = self.read_optical_flow( 252 | self.scene_path, 253 | idx, 254 | start_frame=0, 255 | fwd=True if offset > 0 else False, 256 | interval=np.abs(offset), 257 | ) 258 | 259 | flows.append(flow) 260 | masks.append(mask) 261 | 262 | flows = np.stack(flows) 263 | masks = np.stack(masks) 264 | 265 | assert flows.shape[1:3] == img_size 266 | assert masks.shape[1:3] == img_size 267 | 268 | # load src rgb for ref view 269 | sp_pose_ids = get_nearest_pose_ids( 270 | render_pose, 271 | train_poses, 272 | tar_id=idx, 273 | angular_dist_method='dist', 274 | ) 275 | 276 | static_pose_ids = [] 277 | 278 | max_interval = self.max_range // self.num_frames_sample 279 | interval = np.random.randint(max(2, max_interval - 2), max_interval + 1) 280 | 281 | for ii in range(-self.num_frames_sample, self.num_frames_sample): 282 | rand_j = np.random.randint(1, interval + 1) 283 | static_pose_id = idx + interval * ii + rand_j 284 | 285 | if 0 <= static_pose_id < self.num_frames and static_pose_id != idx: 286 | static_pose_ids.append(static_pose_id) 287 | 288 | static_pose_set = set(static_pose_ids) 289 | # if there are no enough image, add nearest images w.r.t camera poses 290 | # choose stride of 5 so that views are not very close to each other. 291 | for sp_pose_id in sp_pose_ids[::5]: 292 | if len(static_pose_ids) >= (self.num_frames_sample * 2): 293 | break 294 | 295 | if sp_pose_id not in static_pose_set: 296 | static_pose_ids.append(sp_pose_id) 297 | 298 | static_pose_ids = np.sort(static_pose_ids) 299 | 300 | src_rgbs = [] 301 | src_cameras = [] 302 | 303 | for near_id in nearest_pose_ids: 304 | src_rgb, src_camera = self.load_src_view( 305 | train_rgb_files[near_id], 306 | train_poses[near_id], 307 | train_intrinsics[near_id], 308 | ) 309 | src_rgbs.append(src_rgb) 310 | src_cameras.append(src_camera) 311 | 312 | # load src virtual views 313 | for virtual_idx in np.random.choice( 314 | list(range(0, 8)), size=self.num_vv, replace=False 315 | ): 316 | src_vv_path = os.path.join( 317 | '/'.join( 318 | rgb_file.replace('images', 'source_virtual_views').split('/')[:-1] 319 | ), 320 | '%05d' % idx, 321 | '%02d.png' % virtual_idx, 322 | ) 323 | src_rgb, src_camera = self.load_src_view( 324 | src_vv_path, 325 | self.src_vv_c2w_mats[idx, virtual_idx], 326 | intrinsics, 327 | ) 328 | src_rgbs.append(src_rgb) 329 | src_cameras.append(src_camera) 330 | 331 | src_rgbs = np.stack(src_rgbs, axis=0) 332 | src_cameras = np.stack(src_cameras, axis=0) 333 | 334 | static_src_rgbs = [] 335 | static_src_cameras = [] 336 | 337 | # load src rgb for static view 338 | for st_near_id in static_pose_ids: 339 | st_mask_path = None 340 | 341 | if self.mask_src_view: 342 | st_mask_path = os.path.join( 343 | '/'.join(rgb_file.split('/')[:-2]), 344 | 'dynamic_masks', 345 | '%d.png' % st_near_id, 346 | ) 347 | 348 | src_rgb, src_camera = self.load_src_view( 349 | train_rgb_files[st_near_id], 350 | train_poses[st_near_id], 351 | train_intrinsics[st_near_id], 352 | st_mask_path=st_mask_path, 353 | ) 354 | 355 | static_src_rgbs.append(src_rgb) 356 | static_src_cameras.append(src_camera) 357 | 358 | static_src_rgbs = np.stack(static_src_rgbs, axis=0) 359 | static_src_cameras = np.stack(static_src_cameras, axis=0) 360 | 361 | # load src rgb for anchor view 362 | anchor_src_rgbs = [] 363 | anchor_src_cameras = [] 364 | 365 | for near_id in anchor_nearest_pose_ids: 366 | src_rgb, src_camera = self.load_src_view( 367 | train_rgb_files[near_id], 368 | train_poses[near_id], 369 | train_intrinsics[near_id], 370 | ) 371 | anchor_src_rgbs.append(src_rgb) 372 | anchor_src_cameras.append(src_camera) 373 | 374 | # load anchor src virtual views 375 | for virtual_idx in np.random.choice( 376 | list(range(0, 8)), size=self.num_vv, replace=False 377 | ): 378 | src_vv_path = os.path.join( 379 | '/'.join( 380 | rgb_file.replace('images', 'source_virtual_views').split('/')[:-1] 381 | ), 382 | '%05d' % anchor_idx, 383 | '%02d.png' % virtual_idx, 384 | ) 385 | src_rgb, src_camera = self.load_src_view( 386 | src_vv_path, 387 | self.src_vv_c2w_mats[anchor_idx, virtual_idx], 388 | intrinsics, 389 | ) 390 | anchor_src_rgbs.append(src_rgb) 391 | anchor_src_cameras.append(src_camera) 392 | 393 | anchor_src_rgbs = np.stack(anchor_src_rgbs, axis=0) 394 | anchor_src_cameras = np.stack(anchor_src_cameras, axis=0) 395 | 396 | depth_range = torch.tensor( 397 | [depth_range[0] * 0.9, depth_range[1] * 1.5] 398 | ).float() 399 | 400 | return { 401 | 'id': idx, 402 | 'anchor_id': anchor_idx, 403 | 'num_frames': self.num_frames, 404 | 'ref_time': float(idx / float(self.num_frames)), 405 | 'anchor_time': float(anchor_idx / float(self.num_frames)), 406 | 'nearest_pose_ids': torch.from_numpy(np.array(nearest_pose_ids)), 407 | 'anchor_nearest_pose_ids': torch.from_numpy( 408 | np.array(anchor_nearest_pose_ids) 409 | ), 410 | 'rgb': torch.from_numpy(rgb[..., 0:3]).float(), 411 | 'disp': torch.from_numpy(disp).float(), 412 | 'motion_mask': torch.from_numpy(motion_mask).float(), 413 | 'static_mask': torch.from_numpy(static_mask).float(), 414 | 'flows': torch.from_numpy(flows).float(), 415 | 'masks': torch.from_numpy(masks).float(), 416 | 'camera': torch.from_numpy(camera).float(), 417 | 'anchor_camera': torch.from_numpy(anchor_camera).float(), 418 | 'rgb_path': rgb_file, 419 | 'src_rgbs': torch.from_numpy(src_rgbs[..., :3]).float(), 420 | 'src_cameras': torch.from_numpy(src_cameras).float(), 421 | 'static_src_rgbs': torch.from_numpy(static_src_rgbs[..., :3]).float(), 422 | 'static_src_cameras': torch.from_numpy(static_src_cameras).float(), 423 | 'anchor_src_rgbs': torch.from_numpy(anchor_src_rgbs[..., :3]).float(), 424 | 'anchor_src_cameras': torch.from_numpy(anchor_src_cameras).float(), 425 | 'depth_range': depth_range, 426 | } 427 | -------------------------------------------------------------------------------- /ibrnet/model.py: -------------------------------------------------------------------------------- 1 | """Main Dynibar model class definition.""" 2 | 3 | 4 | import os 5 | from ibrnet.feature_network import ResNet 6 | from ibrnet.mlp_network import DynibarDynamic 7 | from ibrnet.mlp_network import DynibarStatic 8 | from ibrnet.mlp_network import MotionMLP 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def de_parallel(model): 14 | """convert distributed parallel model to single model.""" 15 | return model.module if hasattr(model, 'module') else model 16 | 17 | 18 | def init_dct_basis(num_basis, num_frames): 19 | """Initialize motion basis with DCT coefficient.""" 20 | T = num_frames 21 | K = num_basis 22 | dct_basis = torch.zeros([T, K]) 23 | 24 | for t in range(T): 25 | for k in range(1, K + 1): 26 | dct_basis[t, k - 1] = np.sqrt(2.0 / T) * np.cos( 27 | np.pi / (2.0 * T) * (2 * t + 1) * k 28 | ) 29 | 30 | return dct_basis 31 | 32 | 33 | class DynibarFF(object): 34 | """Dynibar model for forward-facing benchmark.""" 35 | 36 | def __init__(self, args, load_opt=True, load_scheduler=True): 37 | self.args = args 38 | self.device = torch.device('cuda:{}'.format(args.local_rank)) 39 | # create coarse DynIBaR models 40 | self.net_coarse_st = DynibarStatic( 41 | args, 42 | in_feat_ch=self.args.coarse_feat_dim, 43 | n_samples=self.args.N_samples, 44 | ).to(self.device) 45 | self.net_coarse_dy = DynibarDynamic( 46 | args, 47 | in_feat_ch=self.args.coarse_feat_dim, 48 | n_samples=self.args.N_samples, 49 | ).to(self.device) 50 | 51 | # create fine DynIBaR models 52 | self.net_fine_st = DynibarStatic( 53 | args, 54 | in_feat_ch=self.args.fine_feat_dim, 55 | n_samples=self.args.N_samples + self.args.N_importance, 56 | ).to(self.device) 57 | self.net_fine_dy = DynibarDynamic( 58 | args, 59 | in_feat_ch=self.args.fine_feat_dim, 60 | n_samples=self.args.N_samples + self.args.N_importance, 61 | ).to(self.device) 62 | 63 | # create coarse feature extraction network 64 | self.feature_net = ResNet( 65 | coarse_out_ch=self.args.coarse_feat_dim, 66 | fine_out_ch=self.args.fine_feat_dim, 67 | coarse_only=False, 68 | ).to(self.device) 69 | 70 | # create fine feature extraction network 71 | self.feature_net_fine = ResNet( 72 | coarse_out_ch=self.args.coarse_feat_dim, 73 | fine_out_ch=self.args.fine_feat_dim, 74 | coarse_only=False, 75 | ).to(self.device) 76 | 77 | # Motion trajectory models with MLPs 78 | self.motion_mlp = ( 79 | MotionMLP(num_basis=args.num_basis).float().to(self.device) 80 | ) 81 | self.motion_mlp_fine = ( 82 | MotionMLP(num_basis=args.num_basis).float().to(self.device) 83 | ) 84 | 85 | # Motion basis 86 | dct_basis = init_dct_basis(args.num_basis, args.num_frames) 87 | self.trajectory_basis = ( 88 | torch.nn.parameter.Parameter(dct_basis) 89 | .float() 90 | .to(self.device) 91 | .detach() 92 | .requires_grad_(True) 93 | ) 94 | self.trajectory_basis_fine = ( 95 | torch.nn.parameter.Parameter(dct_basis) 96 | .float() 97 | .to(self.device) 98 | .detach() 99 | .requires_grad_(True) 100 | ) 101 | 102 | self.load_coarse_from_ckpt(args.coarse_dir) 103 | 104 | out_folder = os.path.join(args.rootdir, 'checkpoints/fine', args.expname) 105 | 106 | self.optimizer = torch.optim.Adam([ 107 | { 108 | 'params': self.net_fine_st.parameters(), 109 | 'lr': args.lrate_mlp * args.lr_multipler, 110 | }, 111 | {'params': self.net_fine_dy.parameters(), 'lr': args.lrate_mlp}, 112 | { 113 | 'params': self.feature_net_fine.parameters(), 114 | 'lr': args.lrate_feature, 115 | }, 116 | {'params': self.motion_mlp_fine.parameters(), 'lr': args.lrate_mlp}, 117 | {'params': self.trajectory_basis_fine, 'lr': args.lrate_mlp * 0.25}, 118 | ]) 119 | 120 | self.scheduler = torch.optim.lr_scheduler.StepLR( 121 | self.optimizer, 122 | step_size=args.lrate_decay_steps, 123 | gamma=args.lrate_decay_factor, 124 | ) 125 | 126 | self.start_step = self.load_fine_from_ckpt( 127 | out_folder, load_opt=True, load_scheduler=True 128 | ) 129 | 130 | device_ids = list(range(torch.cuda.device_count())) 131 | 132 | # convert single model to 133 | # multi-GPU distributed mode for coarse networks 134 | self.net_coarse_st = torch.nn.DataParallel( 135 | self.net_coarse_st, device_ids=device_ids 136 | ) 137 | self.net_coarse_dy = torch.nn.DataParallel( 138 | self.net_coarse_dy, device_ids=device_ids 139 | ) 140 | self.feature_net = torch.nn.DataParallel( 141 | self.feature_net, device_ids=device_ids 142 | ) 143 | self.motion_mlp = torch.nn.DataParallel( 144 | self.motion_mlp, device_ids=device_ids 145 | ) 146 | # convert single model to 147 | # multi-GPU distributed mode for fine networks 148 | self.net_fine_st = torch.nn.DataParallel( 149 | self.net_fine_st, device_ids=device_ids 150 | ) 151 | self.net_fine_dy = torch.nn.DataParallel( 152 | self.net_fine_dy, device_ids=device_ids 153 | ) 154 | self.feature_net_fine = torch.nn.DataParallel( 155 | self.feature_net_fine, device_ids=device_ids 156 | ) 157 | self.motion_mlp_fine = torch.nn.DataParallel( 158 | self.motion_mlp_fine, device_ids=device_ids 159 | ) 160 | 161 | def switch_to_eval(self): 162 | """Switch to evaluation model.""" 163 | self.net_fine_st.eval() 164 | self.net_fine_dy.eval() 165 | 166 | self.feature_net_fine.eval() 167 | self.motion_mlp_fine.eval() 168 | 169 | def switch_to_train(self): 170 | """Switch to training model.""" 171 | self.net_fine_st.train() 172 | self.net_fine_dy.train() 173 | 174 | self.feature_net_fine.train() 175 | self.motion_mlp_fine.train() 176 | 177 | def save_model(self, filename, global_step): 178 | """De-parallel and save current model to local disk.""" 179 | to_save = { 180 | 'optimizer': self.optimizer.state_dict(), 181 | 'scheduler': self.scheduler.state_dict(), 182 | 'net_fine_st': de_parallel(self.net_fine_st).state_dict(), 183 | 'net_fine_dy': de_parallel(self.net_fine_dy).state_dict(), 184 | 'feature_net_fine': de_parallel(self.feature_net_fine).state_dict(), 185 | 'motion_mlp_fine': de_parallel(self.motion_mlp_fine).state_dict(), 186 | 'traj_basis_fine': self.trajectory_basis_fine, 187 | 'global_step': int(global_step), 188 | } 189 | 190 | torch.save(to_save, filename) 191 | 192 | def load_coarse_model(self, filename): 193 | """Load coarse stage dynibar model.""" 194 | if self.args.distributed: 195 | to_load = torch.load( 196 | filename, map_location='cuda:{}'.format(self.args.local_rank) 197 | ) 198 | else: 199 | to_load = torch.load(filename) 200 | 201 | self.net_coarse_st.load_state_dict(to_load['net_coarse_st']) 202 | self.net_coarse_dy.load_state_dict(to_load['net_coarse_dy']) 203 | 204 | self.feature_net.load_state_dict(to_load['feature_net']) 205 | 206 | self.motion_mlp.load_state_dict(to_load['motion_mlp']) 207 | self.trajectory_basis = to_load['traj_basis'] 208 | 209 | return to_load['global_step'] 210 | 211 | def load_fine_model(self, filename, load_opt=True, load_scheduler=True): 212 | """Load fine stage dynibar model.""" 213 | if self.args.distributed: 214 | to_load = torch.load( 215 | filename, map_location='cuda:{}'.format(self.args.local_rank) 216 | ) 217 | else: 218 | to_load = torch.load(filename) 219 | 220 | if load_opt: 221 | self.optimizer.load_state_dict(to_load['optimizer']) 222 | if load_scheduler: 223 | self.scheduler.load_state_dict(to_load['scheduler']) 224 | 225 | self.net_fine_st.load_state_dict(to_load['net_fine_st']) 226 | self.net_fine_dy.load_state_dict(to_load['net_fine_dy']) 227 | 228 | self.feature_net_fine.load_state_dict(to_load['feature_net_fine']) 229 | 230 | self.motion_mlp_fine.load_state_dict(to_load['motion_mlp_fine']) 231 | self.trajectory_basis_fine = to_load['traj_basis_fine'] 232 | 233 | return to_load['global_step'] 234 | 235 | def load_coarse_from_ckpt( 236 | self, 237 | out_folder 238 | ): 239 | """Load coarse model from existing checkpoints and return the current step.""" 240 | 241 | # all existing ckpts 242 | ckpts = [] 243 | if os.path.exists(out_folder): 244 | ckpts = [ 245 | os.path.join(out_folder, f) 246 | for f in sorted(os.listdir(out_folder)) 247 | if f.endswith('.pth') 248 | ] 249 | 250 | fpath = ckpts[-1] 251 | num_steps = self.load_coarse_model(fpath) 252 | 253 | step = num_steps 254 | print('Reloading from {}, starting at step={}'.format(fpath, step)) 255 | 256 | return step 257 | 258 | def load_fine_from_ckpt( 259 | self, 260 | out_folder, 261 | load_opt=True, 262 | load_scheduler=True 263 | ): 264 | """Load fine model from existing checkpoints and return the current step.""" 265 | 266 | # all existing ckpts 267 | ckpts = [] 268 | if os.path.exists(out_folder): 269 | ckpts = [ 270 | os.path.join(out_folder, f) 271 | for f in sorted(os.listdir(out_folder)) 272 | if f.endswith('.pth') 273 | ] 274 | 275 | if self.args.ckpt_path is not None: 276 | if os.path.isfile(self.args.ckpt_path): # load the specified ckpt 277 | ckpts = [self.args.ckpt_path] 278 | 279 | if len(ckpts) > 0 and not self.args.no_reload: 280 | fpath = ckpts[-1] 281 | num_steps = self.load_fine_model(fpath, load_opt, load_scheduler) 282 | step = num_steps 283 | print('Reloading from {}, starting at step={}'.format(fpath, step)) 284 | else: 285 | print('No ckpts found, training from scratch...') 286 | step = 0 287 | 288 | return step 289 | 290 | 291 | class DynibarMono(object): 292 | """Main Dynibar model for monocular video.""" 293 | 294 | def __init__(self, args): 295 | self.args = args 296 | self.device = torch.device('cuda:{}'.format(args.local_rank)) 297 | # create Dynibar models for monocular videos 298 | self.net_coarse_st = DynibarStatic( 299 | args, 300 | in_feat_ch=self.args.coarse_feat_dim, 301 | n_samples=self.args.N_samples, 302 | ).to(self.device) 303 | self.net_coarse_dy = DynibarDynamic( 304 | args, 305 | in_feat_ch=self.args.coarse_feat_dim, 306 | n_samples=self.args.N_samples, 307 | shift=5.0, 308 | ).to(self.device) 309 | 310 | self.net_fine = None 311 | 312 | # create feature extraction network used for dynamic model. 313 | self.feature_net = ResNet( 314 | coarse_out_ch=self.args.coarse_feat_dim, 315 | fine_out_ch=self.args.fine_feat_dim, 316 | coarse_only=False, 317 | ).to(self.device) 318 | 319 | # create feature extraction network used for static model. 320 | self.feature_net_st = ResNet( 321 | coarse_out_ch=self.args.coarse_feat_dim, 322 | fine_out_ch=self.args.fine_feat_dim, 323 | coarse_only=False, 324 | ).to(self.device) 325 | 326 | # Motion trajectory model with MLP. 327 | self.motion_mlp = ( 328 | MotionMLP(num_basis=args.num_basis).float().to(self.device) 329 | ) 330 | 331 | # basis 332 | dct_basis = init_dct_basis(args.num_basis, args.num_frames) 333 | self.trajectory_basis = ( 334 | torch.nn.parameter.Parameter(dct_basis) 335 | .float() 336 | .to(self.device) 337 | .detach() 338 | .requires_grad_(True) 339 | ) 340 | 341 | self.optimizer = torch.optim.Adam([ 342 | {'params': self.net_coarse_st.parameters(), 'lr': args.lrate_mlp * 0.5}, 343 | { 344 | 'params': self.feature_net_st.parameters(), 345 | 'lr': args.lrate_feature * 0.5, 346 | }, 347 | {'params': self.net_coarse_dy.parameters(), 'lr': args.lrate_mlp}, 348 | {'params': self.feature_net.parameters(), 'lr': args.lrate_feature}, 349 | {'params': self.motion_mlp.parameters(), 'lr': args.lrate_mlp}, 350 | {'params': self.trajectory_basis, 'lr': args.lrate_mlp * 0.25}, 351 | ]) 352 | 353 | print( 354 | 'lrate_decay_steps ', 355 | args.lrate_decay_steps, 356 | ' lrate_decay_factor ', 357 | args.lrate_decay_factor, 358 | ) 359 | 360 | self.scheduler = torch.optim.lr_scheduler.StepLR( 361 | self.optimizer, 362 | step_size=args.lrate_decay_steps, 363 | gamma=args.lrate_decay_factor, 364 | ) 365 | 366 | out_folder = os.path.join(args.rootdir, 'out', args.expname) 367 | 368 | self.start_step = 0 369 | 370 | if args.pretrain_path == '': 371 | self.start_step = self.load_from_ckpt( 372 | out_folder, load_opt=True, load_scheduler=True 373 | ) 374 | 375 | else: 376 | self.start_step = self.load_from_ckpt( 377 | args.pretrain_path, load_opt=True, load_scheduler=True 378 | ) 379 | 380 | device_ids = list(range(torch.cuda.device_count())) 381 | 382 | self.net_coarse_st = torch.nn.DataParallel( 383 | self.net_coarse_st, device_ids=device_ids 384 | ) 385 | self.net_coarse_dy = torch.nn.DataParallel( 386 | self.net_coarse_dy, device_ids=device_ids 387 | ) 388 | self.feature_net = torch.nn.DataParallel( 389 | self.feature_net, device_ids=device_ids 390 | ) 391 | self.feature_net_st = torch.nn.DataParallel( 392 | self.feature_net_st, device_ids=device_ids 393 | ) 394 | 395 | self.motion_mlp = torch.nn.DataParallel( 396 | self.motion_mlp, device_ids=device_ids 397 | ) 398 | 399 | def switch_to_eval(self): 400 | """Switch models to evaluation mode.""" 401 | self.net_coarse_st.eval() 402 | self.net_coarse_dy.eval() 403 | 404 | self.feature_net.eval() 405 | self.feature_net_st.eval() 406 | self.motion_mlp.eval() 407 | 408 | if self.net_fine is not None: 409 | self.net_fine.eval() 410 | 411 | def switch_to_train(self): 412 | """Switch models to training mode.""" 413 | 414 | self.net_coarse_st.train() 415 | self.net_coarse_dy.train() 416 | 417 | self.feature_net.train() 418 | self.motion_mlp.train() 419 | self.feature_net_st.train() 420 | 421 | if self.net_fine is not None: 422 | self.net_fine.train() 423 | 424 | def save_model(self, filename, global_step): 425 | """Save Dynibar monocular model.""" 426 | to_save = { 427 | 'optimizer': self.optimizer.state_dict(), 428 | 'scheduler': self.scheduler.state_dict(), 429 | 'net_coarse_st': de_parallel(self.net_coarse_st).state_dict(), 430 | 'net_coarse_dy': de_parallel(self.net_coarse_dy).state_dict(), 431 | 'feature_net': de_parallel(self.feature_net).state_dict(), 432 | 'feature_net_st': de_parallel(self.feature_net_st).state_dict(), 433 | 'motion_mlp': de_parallel(self.motion_mlp).state_dict(), 434 | 'traj_basis': self.trajectory_basis, 435 | 'global_step': int(global_step), 436 | } 437 | 438 | if self.net_fine is not None: 439 | to_save['net_fine'] = de_parallel(self.net_fine).state_dict() 440 | 441 | torch.save(to_save, filename) 442 | 443 | def load_model(self, filename, load_opt=True, load_scheduler=True): 444 | """Load Dynibar monocular model.""" 445 | if self.args.distributed: 446 | to_load = torch.load( 447 | filename, map_location='cuda:{}'.format(self.args.local_rank) 448 | ) 449 | else: 450 | to_load = torch.load(filename) 451 | 452 | if load_opt: 453 | self.optimizer.load_state_dict(to_load['optimizer']) 454 | if load_scheduler: 455 | self.scheduler.load_state_dict(to_load['scheduler']) 456 | 457 | self.net_coarse_st.load_state_dict(to_load['net_coarse_st']) 458 | self.net_coarse_dy.load_state_dict(to_load['net_coarse_dy']) 459 | 460 | self.feature_net.load_state_dict(to_load['feature_net']) 461 | self.feature_net_st.load_state_dict(to_load['feature_net_st']) 462 | 463 | self.motion_mlp.load_state_dict(to_load['motion_mlp']) 464 | self.trajectory_basis = to_load['traj_basis'] 465 | 466 | return to_load['global_step'] 467 | 468 | def load_from_ckpt( 469 | self, 470 | out_folder, 471 | load_opt=True, 472 | load_scheduler=True, 473 | ): 474 | """Load coarse model from existing checkpoints and return the current step.""" 475 | 476 | # all existing ckpts 477 | ckpts = [] 478 | if os.path.exists(out_folder): 479 | ckpts = [ 480 | os.path.join(out_folder, f) 481 | for f in sorted(os.listdir(out_folder)) 482 | if f.endswith('latest.pth') 483 | ] 484 | 485 | if self.args.ckpt_path is not None: 486 | if os.path.isfile(self.args.ckpt_path): # load the specified ckpt 487 | ckpts = [self.args.ckpt_path] 488 | 489 | if len(ckpts) > 0 and not self.args.no_reload: 490 | fpath = ckpts[-1] 491 | num_steps = self.load_model(fpath, True, True) 492 | print('=========== num_steps ', num_steps) 493 | 494 | step = num_steps 495 | print('Reloading from {}, starting at step={}'.format(fpath, step)) 496 | else: 497 | print('No ckpts found, training from scratch...') 498 | step = 0 499 | 500 | return step 501 | 502 | -------------------------------------------------------------------------------- /eval_nvidia.py: -------------------------------------------------------------------------------- 1 | """Evaluation script for the Nvidia Benchmark.""" 2 | 3 | import collections 4 | import math 5 | import os 6 | import time 7 | from config import config_parser 8 | import cv2 9 | from ibrnet.data_loaders.llff_data_utils import batch_parse_llff_poses 10 | from ibrnet.data_loaders.llff_data_utils import load_llff_data 11 | from ibrnet.model import DynibarFF 12 | from ibrnet.projection import Projector 13 | from ibrnet.render_image import render_single_image_nvi 14 | from ibrnet.sample_ray import RaySamplerSingleImage 15 | import imageio 16 | import models 17 | import numpy as np 18 | import skimage.metrics 19 | import torch 20 | from torch.utils.data import DataLoader 21 | from torch.utils.data import Dataset 22 | 23 | 24 | class DynamicVideoDataset(Dataset): 25 | """This class loads data from Nvidia benchmarks, including camera scene and image information from source views.""" 26 | 27 | def __init__(self, render_idx, args, scenes, **kwargs): 28 | self.folder_path = args.folder_path 29 | self.render_idx = render_idx 30 | self.mask_static = args.mask_static 31 | 32 | print('loading {} for rendering'.format(scenes)) 33 | assert len(scenes) == 1 34 | 35 | scene = scenes[0] 36 | self.scene_path = os.path.join( 37 | self.folder_path, scene, 'dense' 38 | ) 39 | _, poses, bds, _, i_test, rgb_files, _ = load_llff_data( 40 | self.scene_path, 41 | height=288, 42 | num_avg_imgs=12, 43 | render_idx=self.render_idx, 44 | load_imgs=False, 45 | ) 46 | near_depth = np.min(bds) 47 | # Adding 15 to ensure we cover far scene contents 48 | far_depth = np.max(bds) + 15.0 49 | self.num_frames = len(rgb_files) 50 | 51 | intrinsics, c2w_mats = batch_parse_llff_poses(poses) 52 | h, w = poses[0][:2, -1] 53 | render_intrinsics, render_c2w_mats = ( 54 | intrinsics, 55 | c2w_mats, 56 | ) 57 | 58 | self.train_intrinsics = intrinsics 59 | self.train_poses = c2w_mats 60 | self.train_rgb_files = rgb_files 61 | self.render_intrinsics = render_intrinsics 62 | 63 | self.render_poses = render_c2w_mats 64 | self.render_depth_range = [[near_depth, far_depth]] * self.num_frames 65 | self.h = [int(h)] * self.num_frames 66 | self.w = [int(w)] * self.num_frames 67 | 68 | def __len__(self): 69 | return 12 # number of viewpoints 70 | 71 | def __getitem__(self, idx): 72 | render_pose = self.render_poses[idx] 73 | intrinsics = self.render_intrinsics[idx] 74 | depth_range = self.render_depth_range[idx] 75 | 76 | train_rgb_files = self.train_rgb_files 77 | train_poses = self.train_poses 78 | train_intrinsics = self.train_intrinsics 79 | 80 | h, w = self.h[idx], self.w[idx] 81 | camera = np.concatenate( 82 | ([h, w], intrinsics.flatten(), render_pose.flatten()) 83 | ).astype(np.float32) 84 | 85 | gt_img_path = os.path.join( 86 | self.scene_path, 87 | 'mv_images', 88 | '%05d' % self.render_idx, 89 | 'cam%02d.jpg' % (idx + 1), 90 | ) 91 | 92 | nearest_pose_ids = np.sort( 93 | [self.render_idx + offset for offset in [1, 2, 3, 0, -1, -2, -3]] 94 | ) 95 | # 12 is number of viewpoints we sample from input cameras 96 | num_imgs_per_cycle = 12 97 | 98 | # Get camera viewpoint that is closet to target view using index for benchmark 99 | # Since benchamrk has fixed viewpoint in a round-robin manner 100 | static_pose_ids = np.array(list(range(0, train_poses.shape[0]))) 101 | static_id_dict = collections.defaultdict(list) 102 | for static_pose_id in static_pose_ids: 103 | # do not include image with the same viewpoint 104 | if ( 105 | static_pose_id % num_imgs_per_cycle 106 | == self.render_idx % num_imgs_per_cycle 107 | ): 108 | continue 109 | 110 | static_id_dict[static_pose_id % num_imgs_per_cycle].append(static_pose_id) 111 | 112 | static_pose_ids = [] 113 | for key in static_id_dict: 114 | min_idx = np.argmin( 115 | np.abs(np.array(static_id_dict[key]) - self.render_idx) 116 | ) 117 | static_pose_ids.append(static_id_dict[key][min_idx]) 118 | 119 | static_pose_ids = np.sort(static_pose_ids) 120 | 121 | src_rgbs = [] 122 | src_cameras = [] 123 | for src_idx in nearest_pose_ids: 124 | src_rgb = ( 125 | imageio.v2.imread(train_rgb_files[src_idx]).astype(np.float32) / 255.0 126 | ) 127 | train_pose = train_poses[src_idx] 128 | train_intrinsics_ = train_intrinsics[src_idx] 129 | src_rgbs.append(src_rgb) 130 | img_size = src_rgb.shape[:2] 131 | src_camera = np.concatenate( 132 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten()) 133 | ).astype(np.float32) 134 | 135 | src_cameras.append(src_camera) 136 | 137 | src_rgbs = np.stack(src_rgbs, axis=0) 138 | src_cameras = np.stack(src_cameras, axis=0) 139 | 140 | static_src_rgbs = [] 141 | static_src_cameras = [] 142 | static_src_masks = [] 143 | 144 | # load src rgb for static view 145 | for st_near_id in static_pose_ids: 146 | src_rgb = ( 147 | imageio.v2.imread(train_rgb_files[st_near_id]).astype(np.float32) 148 | / 255.0 149 | ) 150 | train_pose = train_poses[st_near_id] 151 | train_intrinsics_ = train_intrinsics[st_near_id] 152 | 153 | static_src_rgbs.append(src_rgb) 154 | 155 | # load coarse mask 156 | if self.mask_static and 3 <= st_near_id < self.num_frames - 3: 157 | st_mask_path = os.path.join( 158 | '/'.join(train_rgb_files[st_near_id].split('/')[:-2]), 159 | 'coarse_masks', 160 | '%05d.png' % st_near_id, 161 | ) 162 | st_mask = imageio.v2.imread(st_mask_path).astype(np.float32) / 255.0 163 | st_mask = cv2.resize( 164 | st_mask, 165 | (src_rgb.shape[1], src_rgb.shape[0]), 166 | interpolation=cv2.INTER_NEAREST, 167 | ) 168 | else: 169 | st_mask = np.ones_like(src_rgb[..., 0]) 170 | 171 | static_src_masks.append(st_mask) 172 | 173 | img_size = src_rgb.shape[:2] 174 | src_camera = np.concatenate( 175 | (list(img_size), train_intrinsics_.flatten(), train_pose.flatten()) 176 | ).astype(np.float32) 177 | 178 | static_src_cameras.append(src_camera) 179 | 180 | static_src_rgbs = np.stack(static_src_rgbs, axis=0) 181 | static_src_cameras = np.stack(static_src_cameras, axis=0) 182 | static_src_masks = np.stack(static_src_masks, axis=0) 183 | 184 | depth_range = torch.tensor([depth_range[0] * 0.9, depth_range[1] * 1.5]) 185 | 186 | return { 187 | 'camera': torch.from_numpy(camera), 188 | 'rgb_path': gt_img_path, 189 | 'src_rgbs': torch.from_numpy(src_rgbs[..., :3]).float(), 190 | 'src_cameras': torch.from_numpy(src_cameras).float(), 191 | 'static_src_rgbs': torch.from_numpy(static_src_rgbs[..., :3]).float(), 192 | 'static_src_cameras': torch.from_numpy(static_src_cameras).float(), 193 | 'static_src_masks': torch.from_numpy(static_src_masks).float(), 194 | 'depth_range': depth_range, 195 | 'ref_time': float(self.render_idx / float(self.num_frames)), 196 | 'id': self.render_idx, 197 | 'nearest_pose_ids': nearest_pose_ids, 198 | } 199 | 200 | 201 | def calculate_psnr(img1, img2, mask): 202 | """Compute PSNR between two images. 203 | 204 | Args: 205 | img1: image 1 206 | img2: image 2 207 | mask: mask indicating which region is valid. 208 | 209 | Returns: 210 | PSNR: PSNR error 211 | """ 212 | 213 | # img1 and img2 have range [0, 1] 214 | img1 = img1.astype(np.float64) 215 | img2 = img2.astype(np.float64) 216 | mask = mask.astype(np.float64) 217 | 218 | num_valid = np.sum(mask) + 1e-8 219 | 220 | mse = np.sum((img1 - img2) ** 2 * mask) / num_valid 221 | 222 | if mse == 0: 223 | return 0 # float('inf') 224 | 225 | return 10 * math.log10(1.0 / mse) 226 | 227 | 228 | def calculate_ssim(img1, img2, mask): 229 | """Compute SSIM between two images. 230 | 231 | Args: 232 | img1: image 1 233 | img2: image 2 234 | mask: mask indicating which region is valid. 235 | 236 | Returns: 237 | PSNR: PSNR error 238 | """ 239 | if img1.shape != img2.shape: 240 | raise ValueError('Input images must have the same dimensions.') 241 | 242 | _, ssim_map = skimage.metrics.structural_similarity( 243 | img1, img2, multichannel=True, full=True 244 | ) 245 | num_valid = np.sum(mask) + 1e-8 246 | 247 | return np.sum(ssim_map * mask) / num_valid 248 | 249 | 250 | def im2tensor(image, cent=1.0, factor=1.0 / 2.0): 251 | """Convert image to Pytorch tensor. 252 | 253 | Args: 254 | image: input image 255 | cent: shift 256 | factor: scale 257 | 258 | Returns: 259 | Pytorch tensor 260 | """ 261 | return torch.Tensor( 262 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 263 | ) 264 | 265 | 266 | if __name__ == '__main__': 267 | parser = config_parser() 268 | args = parser.parse_args() 269 | args.distributed = False 270 | # Construct a dataset to get number of frames for evaluation 271 | test_dataset = DynamicVideoDataset(0, args, scenes=args.eval_scenes) 272 | args.num_frames = test_dataset.num_frames 273 | print('args.num_frames ', args.num_frames) 274 | # Create ibrnet model 275 | model = DynibarFF(args, load_scheduler=False, load_opt=False) 276 | eval_dataset_name = args.eval_dataset 277 | # extra_out_dir = '{}/{}'.format(eval_dataset_name, args.expname) 278 | # print('saving results to {}...'.format(extra_out_dir)) 279 | # os.makedirs(extra_out_dir, exist_ok=True) 280 | 281 | projector = Projector(device='cuda:0') 282 | 283 | assert len(args.eval_scenes) == 1, 'only accept single scene' 284 | scene_name = args.eval_scenes[0] 285 | # out_scene_dir = os.path.join(extra_out_dir, 'renderings') 286 | # print('saving results to {}'.format(out_scene_dir)) 287 | # os.makedirs(out_scene_dir, exist_ok=True) 288 | 289 | lpips_model = models.PerceptualLoss( 290 | model='net-lin', net='alex', use_gpu=True, version=0.1 291 | ) 292 | 293 | psnr_list = [] 294 | ssim_list = [] 295 | lpips_list = [] 296 | 297 | dy_psnr_list = [] 298 | dy_ssim_list = [] 299 | dy_lpips_list = [] 300 | 301 | st_psnr_list = [] 302 | st_ssim_list = [] 303 | st_lpips_list = [] 304 | 305 | for img_i in range(3, args.num_frames - 3): 306 | test_dataset = DynamicVideoDataset(img_i, args, scenes=args.eval_scenes) 307 | save_prefix = scene_name 308 | test_loader = DataLoader( 309 | test_dataset, batch_size=1, num_workers=12, shuffle=False 310 | ) 311 | total_num = len(test_loader) 312 | out_frames = [] 313 | 314 | for i, data in enumerate(test_loader): 315 | print('img_i ', img_i, i) 316 | 317 | if img_i % 12 == i: 318 | continue 319 | 320 | # idx = int(data['id'].item()) 321 | start = time.time() 322 | 323 | ref_time_embedding = data['ref_time'].cuda() 324 | ref_frame_idx = int(data['id'].item()) 325 | ref_time_offset = [ 326 | int(near_idx - ref_frame_idx) 327 | for near_idx in data['nearest_pose_ids'].squeeze().tolist() 328 | ] 329 | 330 | model.switch_to_eval() 331 | with torch.no_grad(): 332 | ray_sampler = RaySamplerSingleImage(data, device='cuda:0') 333 | ray_batch = ray_sampler.get_all() 334 | 335 | cb_featmaps_1, cb_featmaps_2 = model.feature_net( 336 | ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2) 337 | ) 338 | ref_featmaps = cb_featmaps_1 339 | 340 | static_src_rgbs = ( 341 | ray_batch['static_src_rgbs'].squeeze(0).permute(0, 3, 1, 2) 342 | ) 343 | _, static_featmaps = model.feature_net(static_src_rgbs) 344 | 345 | cb_featmaps_1_fine, _ = model.feature_net_fine( 346 | ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2) 347 | ) 348 | ref_featmaps_fine = cb_featmaps_1_fine 349 | 350 | if args.mask_static: 351 | static_src_rgbs_ = ( 352 | static_src_rgbs 353 | * ray_batch['static_src_masks'].squeeze(0)[:, None, ...] 354 | ) 355 | else: 356 | static_src_rgbs_ = static_src_rgbs 357 | 358 | _, static_featmaps_fine = model.feature_net_fine(static_src_rgbs_) 359 | 360 | ret = render_single_image_nvi( 361 | frame_idx=(ref_frame_idx, None), 362 | time_embedding=(ref_time_embedding, None), 363 | time_offset=(ref_time_offset, None), 364 | ray_sampler=ray_sampler, 365 | ray_batch=ray_batch, 366 | model=model, 367 | projector=projector, 368 | chunk_size=args.chunk_size, 369 | det=True, 370 | N_samples=args.N_samples, 371 | args=args, 372 | inv_uniform=args.inv_uniform, 373 | N_importance=args.N_importance, 374 | white_bkgd=args.white_bkgd, 375 | coarse_featmaps=(ref_featmaps, None, static_featmaps), 376 | fine_featmaps=(ref_featmaps_fine, None, static_featmaps_fine), 377 | is_train=False, 378 | ) 379 | 380 | fine_pred_rgb = ret['outputs_fine_ref']['rgb'].detach().cpu().numpy() 381 | fine_pred_depth = ret['outputs_fine_ref']['depth'].detach().cpu().numpy() 382 | 383 | valid_mask = np.float32( 384 | np.sum(fine_pred_rgb, axis=-1, keepdims=True) > 1e-3 385 | ) 386 | valid_mask = np.tile(valid_mask, (1, 1, 3)) 387 | gt_img = cv2.imread(data['rgb_path'][0])[:, :, ::-1] 388 | gt_img = cv2.resize( 389 | gt_img, 390 | dsize=(fine_pred_rgb.shape[1], fine_pred_rgb.shape[0]), 391 | interpolation=cv2.INTER_AREA, 392 | ) 393 | gt_img = np.float32(gt_img) / 255 394 | 395 | gt_img = gt_img * valid_mask 396 | fine_pred_rgb = fine_pred_rgb * valid_mask 397 | 398 | dynamic_mask = valid_mask 399 | ssim = calculate_ssim(gt_img, fine_pred_rgb, dynamic_mask) 400 | psnr = calculate_psnr(gt_img, fine_pred_rgb, dynamic_mask) 401 | 402 | gt_img_0 = im2tensor(gt_img).cuda() 403 | fine_pred_rgb_0 = im2tensor(fine_pred_rgb).cuda() 404 | dynamic_mask_0 = torch.Tensor( 405 | dynamic_mask[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 406 | ) 407 | 408 | lpips = lpips_model.forward( 409 | gt_img_0, fine_pred_rgb_0, dynamic_mask_0 410 | ).item() 411 | print(psnr, ssim, lpips) 412 | psnr_list.append(psnr) 413 | ssim_list.append(ssim) 414 | lpips_list.append(lpips) 415 | 416 | dynamic_mask_path = os.path.join( 417 | test_dataset.scene_path, 418 | 'mv_masks', 419 | '%05d' % img_i, 420 | 'cam%02d.png' % (i + 1), 421 | ) 422 | 423 | dynamic_mask = np.float32(cv2.imread(dynamic_mask_path) > 1e-3) # /255. 424 | dynamic_mask = cv2.resize( 425 | dynamic_mask, 426 | dsize=(gt_img.shape[1], gt_img.shape[0]), 427 | interpolation=cv2.INTER_NEAREST, 428 | ) 429 | 430 | dynamic_mask_0 = torch.Tensor( 431 | dynamic_mask[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 432 | ) 433 | dynamic_ssim = calculate_ssim(gt_img, fine_pred_rgb, dynamic_mask) 434 | dynamic_psnr = calculate_psnr(gt_img, fine_pred_rgb, dynamic_mask) 435 | dynamic_lpips = lpips_model.forward( 436 | gt_img_0, fine_pred_rgb_0, dynamic_mask_0 437 | ).item() 438 | print(dynamic_psnr, dynamic_ssim, dynamic_lpips) 439 | 440 | dy_psnr_list.append(dynamic_psnr) 441 | dy_ssim_list.append(dynamic_ssim) 442 | dy_lpips_list.append(dynamic_lpips) 443 | 444 | static_mask = 1 - dynamic_mask 445 | static_mask_0 = torch.Tensor( 446 | static_mask[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 447 | ) 448 | static_ssim = calculate_ssim(gt_img, fine_pred_rgb, static_mask) 449 | static_psnr = calculate_psnr(gt_img, fine_pred_rgb, static_mask) 450 | static_lpips = lpips_model.forward( 451 | gt_img_0, fine_pred_rgb_0, static_mask_0 452 | ).item() 453 | print(static_psnr, static_ssim, static_lpips) 454 | 455 | st_psnr_list.append(static_psnr) 456 | st_ssim_list.append(static_ssim) 457 | st_lpips_list.append(static_lpips) 458 | 459 | print('MOVING PSNR ', np.mean(np.array(psnr_list))) 460 | print('MOVING SSIM ', np.mean(np.array(ssim_list))) 461 | print('MOVING LPIPS ', np.mean(np.array(lpips_list))) 462 | 463 | print('MOVING DYNAMIC PSNR ', np.mean(np.array(dy_psnr_list))) 464 | print('MOVING DYNAMIC SSIM ', np.mean(np.array(dy_ssim_list))) 465 | print('MOVING DYNAMIC LPIPS ', np.mean(np.array(dy_lpips_list))) 466 | 467 | print('MOVING Static PSNR ', np.mean(np.array(st_psnr_list))) 468 | print('MOVING Static SSIM ', np.mean(np.array(st_ssim_list))) 469 | print('MOVING Static LPIPS ', np.mean(np.array(st_lpips_list))) 470 | 471 | print('AVG PSNR ', np.mean(np.array(psnr_list))) 472 | print('AVG SSIM ', np.mean(np.array(ssim_list))) 473 | print('AVG LPIPS ', np.mean(np.array(lpips_list))) 474 | 475 | print('AVG DYNAMIC PSNR ', np.mean(np.array(dy_psnr_list))) 476 | print('AVG DYNAMIC SSIM ', np.mean(np.array(dy_ssim_list))) 477 | print('AVG DYNAMIC LPIPS ', np.mean(np.array(dy_lpips_list))) 478 | 479 | print('AVG Static PSNR ', np.mean(np.array(st_psnr_list))) 480 | print('AVG Static SSIM ', np.mean(np.array(st_ssim_list))) 481 | print('AVG Static LPIPS ', np.mean(np.array(st_lpips_list))) 482 | -------------------------------------------------------------------------------- /ibrnet/mlp_network.py: -------------------------------------------------------------------------------- 1 | """Class definition for MLP Network.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | torch._C._jit_set_profiling_executor(False) 10 | torch._C._jit_set_profiling_mode(False) 11 | 12 | 13 | class ScaledDotProductAttention(nn.Module): 14 | """Dot-Product Attention Layer.""" 15 | 16 | def __init__(self, temperature, attn_dropout=0.1): 17 | super().__init__() 18 | self.temperature = temperature 19 | 20 | def forward(self, q, k, v, mask=None): 21 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 22 | 23 | if mask is not None: 24 | attn = attn.masked_fill(mask == 0, -1e9) 25 | # attn = attn * mask 26 | 27 | attn = F.softmax(attn, dim=-1) 28 | # attn = self.dropout(F.softmax(attn, dim=-1)) 29 | output = torch.matmul(attn, v) 30 | 31 | return output, attn 32 | 33 | 34 | class PositionwiseFeedForward(nn.Module): 35 | """A two-feed-forward-layer module.""" 36 | 37 | def __init__(self, d_in, d_hid, dropout=0.1): 38 | super().__init__() 39 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 40 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 41 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 42 | # self.dropout = nn.Dropout(dropout) 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | x = self.w_2(F.relu(self.w_1(x))) 48 | # x = self.dropout(x) 49 | x += residual 50 | 51 | x = self.layer_norm(x) 52 | 53 | return x 54 | 55 | 56 | class MultiHeadAttention(nn.Module): 57 | """Multi-Head Attention module.""" 58 | 59 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 60 | super().__init__() 61 | 62 | self.n_head = n_head 63 | self.d_k = d_k 64 | self.d_v = d_v 65 | 66 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 67 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 68 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 69 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 70 | 71 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5) 72 | 73 | # self.dropout = nn.Dropout(dropout) 74 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 75 | 76 | def forward(self, q, k, v, mask=None): 77 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 78 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 79 | 80 | residual = q 81 | 82 | # Pass through the pre-attention projection: b x lq x (n*dv) 83 | # Separate different heads: b x lq x n x dv 84 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 85 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 86 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 87 | 88 | # Transpose for attention dot product: b x n x lq x dv 89 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 90 | 91 | if mask is not None: 92 | mask = mask.unsqueeze(1) # For head axis broadcasting. 93 | 94 | q, attn = self.attention(q, k, v, mask=mask) 95 | 96 | # Transpose to move the head dimension back: b x lq x n x dv 97 | # Combine the last two dimensions to concatenate all the heads together 98 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 99 | q = self.fc(q) 100 | q += residual 101 | 102 | q = self.layer_norm(q) 103 | 104 | return q, attn 105 | 106 | 107 | def weights_init(m): 108 | """Default initialization of linear layers.""" 109 | if isinstance(m, nn.Linear): 110 | nn.init.kaiming_normal_(m.weight.data) 111 | if m.bias is not None: 112 | nn.init.zeros_(m.bias.data) 113 | 114 | 115 | @torch.jit.script 116 | def fused_mean_variance(x, weight): 117 | mean = torch.sum(x * weight, dim=2, keepdim=True) 118 | var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True) 119 | return mean, var 120 | 121 | 122 | @torch.jit.script 123 | def epipolar_fused_mean_variance(x, weight): 124 | mean = torch.sum(x * weight, dim=1, keepdim=True) 125 | var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) 126 | return mean, var 127 | 128 | 129 | class DynibarDynamic(nn.Module): 130 | """Dynibar time-varying dynamic model.""" 131 | 132 | def __init__(self, args, in_feat_ch=32, n_samples=64, shift=0.0, **kwargs): 133 | super(DynibarDynamic, self).__init__() 134 | self.args = args 135 | self.anti_alias_pooling = False # args.anti_alias_pooling 136 | self.input_dir = args.input_dir 137 | self.input_xyz = args.input_xyz 138 | 139 | if self.anti_alias_pooling: 140 | self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) 141 | 142 | activation_func = nn.ELU(inplace=True) 143 | self.shift = shift 144 | t_num_freqs = 10 145 | self.t_embed = PeriodicEmbed( 146 | max_freq=t_num_freqs, N_freq=t_num_freqs, linspace=False 147 | ).float() 148 | dir_num_freqs = 4 149 | self.dir_embed = PeriodicEmbed( 150 | max_freq=dir_num_freqs, N_freq=dir_num_freqs, linspace=False 151 | ).float() 152 | 153 | pts_num_freqs = 5 154 | self.pts_embed = PeriodicEmbed( 155 | max_freq=pts_num_freqs, N_freq=pts_num_freqs, linspace=False 156 | ).float() 157 | 158 | self.n_samples = n_samples 159 | self.ray_dir_fc = nn.Sequential( 160 | nn.Linear(t_num_freqs * 2 + 1, 256), 161 | activation_func, 162 | nn.Linear(256, in_feat_ch + 3), 163 | activation_func, 164 | ) 165 | 166 | self.base_fc = nn.Sequential( 167 | nn.Linear((in_feat_ch + 3) * 3, 256), 168 | activation_func, 169 | nn.Linear(256, 128), 170 | activation_func, 171 | ) 172 | 173 | self.vis_fc = nn.Sequential( 174 | nn.Linear(128, 128), 175 | activation_func, 176 | nn.Linear(128, 128 + 1), 177 | activation_func, 178 | ) 179 | 180 | self.vis_fc2 = nn.Sequential( 181 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1), nn.Sigmoid() 182 | ) 183 | 184 | self.geometry_fc = nn.Sequential( 185 | nn.Linear(128 * 2 + 1, 256), 186 | activation_func, 187 | nn.Linear(256, 128), 188 | activation_func, 189 | ) 190 | 191 | self.ray_attention = MultiHeadAttention(4, 128, 32, 32) 192 | 193 | num_c_xyz = (pts_num_freqs * 2 + 1) * 3 194 | 195 | self.ref_pts_fc = nn.Sequential( 196 | nn.Linear(num_c_xyz + 128, 256), 197 | activation_func, 198 | nn.Linear(256, 128), 199 | activation_func, 200 | ) 201 | 202 | self.out_geometry_fc = nn.Sequential( 203 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1) 204 | ) 205 | 206 | if self.input_dir: 207 | self.rgb_fc = nn.Sequential( 208 | nn.Linear(128 + (dir_num_freqs * 2 + 1) * 3, 128), 209 | activation_func, 210 | nn.Linear(128, 64), 211 | activation_func, 212 | nn.Linear(64, 3), 213 | nn.Sigmoid(), 214 | ) 215 | else: 216 | raise NotImplementedError 217 | 218 | self.pos_encoding = self.posenc(d_hid=128, n_samples=self.n_samples) 219 | 220 | def posenc(self, d_hid, n_samples): 221 | def get_position_angle_vec(position): 222 | return [ 223 | position / np.power(10000, 2 * (hid_j // 2) / d_hid) 224 | for hid_j in range(d_hid) 225 | ] 226 | 227 | sinusoid_table = np.array( 228 | [get_position_angle_vec(pos_i) for pos_i in range(n_samples)] 229 | ) 230 | 231 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 232 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 233 | sinusoid_table = torch.from_numpy(sinusoid_table).float().unsqueeze(0) 234 | return sinusoid_table 235 | 236 | def forward( 237 | self, pts_xyz, rgb_feat, glb_ray_dir, ray_diff, time_diff, mask, time 238 | ): 239 | num_views = rgb_feat.shape[2] 240 | time_pe = ( 241 | self.t_embed(time)[..., None, :].repeat(1, 1, num_views, 1).float() 242 | ) 243 | 244 | direction_feat = self.ray_dir_fc(time_pe) 245 | 246 | # rgb_in = rgb_feat[..., :3] 247 | rgb_feat = rgb_feat + direction_feat 248 | 249 | if self.anti_alias_pooling: 250 | _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1) 251 | exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1)) 252 | weight = ( 253 | exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0] 254 | ) * mask 255 | weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8) 256 | else: 257 | weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) 258 | 259 | # compute mean and variance across different views for each point 260 | mean, var = fused_mean_variance( 261 | rgb_feat, weight 262 | ) # [n_rays, n_samples, 1, n_feat] 263 | globalfeat = torch.cat( 264 | [mean, var], dim=-1 265 | ) # [n_rays, n_samples, 1, 2*n_feat] 266 | 267 | x = torch.cat( 268 | [globalfeat.expand(-1, -1, num_views, -1), rgb_feat], dim=-1 269 | ) # [n_rays, n_samples, n_views, 3*n_feat] 270 | x = self.base_fc(x) 271 | 272 | x_vis = self.vis_fc(x * weight) 273 | x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1) 274 | vis = F.sigmoid(vis) * mask 275 | x = x + x_res 276 | vis = self.vis_fc2(x * vis) * mask 277 | weight = vis / (torch.sum(vis, dim=2, keepdim=True) + 1e-8) 278 | 279 | mean, var = fused_mean_variance(x, weight) 280 | globalfeat = torch.cat( 281 | [mean.squeeze(2), var.squeeze(2), weight.mean(dim=2)], dim=-1 282 | ) # [n_rays, n_samples, 32*2+1] 283 | globalfeat = self.geometry_fc(globalfeat) # [n_rays, n_samples, 16] 284 | num_valid_obs = torch.sum(mask, dim=2) 285 | 286 | globalfeat = globalfeat + self.pos_encoding.to(globalfeat.device) 287 | globalfeat, _ = self.ray_attention( 288 | globalfeat, globalfeat, globalfeat, mask=(num_valid_obs > 1).float() 289 | ) # [n_rays, n_samples, 16] 290 | 291 | pts_xyz_pe = self.pts_embed(pts_xyz) 292 | globalfeat = self.ref_pts_fc(torch.cat([globalfeat, pts_xyz_pe], dim=-1)) 293 | 294 | sigma = ( 295 | self.out_geometry_fc(globalfeat) - self.shift 296 | ) # [n_rays, n_samples, 1] 297 | sigma_out = sigma.masked_fill( 298 | num_valid_obs < 1, -1e9 299 | ) # set the sigma of invalid point to zero 300 | 301 | if self.input_dir: 302 | glb_ray_dir_pe = self.dir_embed(glb_ray_dir).float() 303 | h = torch.cat( 304 | [ 305 | globalfeat, 306 | glb_ray_dir_pe[:, None, :].repeat(1, globalfeat.shape[1], 1), 307 | ], 308 | dim=-1, 309 | ) 310 | else: 311 | h = globalfeat 312 | 313 | rgb_out = self.rgb_fc(h) 314 | rgb_out = rgb_out.masked_fill(torch.sum(mask.repeat(1, 1, 1, 3), 2) == 0, 0) 315 | out = torch.cat([rgb_out, sigma_out], dim=-1) 316 | return out 317 | 318 | 319 | class DynibarStatic(nn.Module): 320 | """Dynibar time-invariant static model.""" 321 | 322 | def __init__(self, args, in_feat_ch=32, n_samples=64, **kwargs): 323 | super(DynibarStatic, self).__init__() 324 | self.args = args 325 | self.anti_alias_pooling = args.anti_alias_pooling # CHECK DISCREPENCY 326 | self.mask_rgb = args.mask_rgb 327 | self.input_dir = args.input_dir 328 | self.input_xyz = args.input_xyz 329 | 330 | if self.anti_alias_pooling: 331 | self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) 332 | 333 | activation_func = nn.ELU(inplace=True) 334 | 335 | ray_num_freqs = 5 336 | self.ray_embed = PeriodicEmbed( 337 | max_freq=ray_num_freqs, N_freq=ray_num_freqs, linspace=False 338 | ) 339 | pts_num_freqs = 5 340 | self.pts_embed = PeriodicEmbed( 341 | max_freq=pts_num_freqs, N_freq=pts_num_freqs, linspace=False 342 | ) 343 | 344 | num_c_xyz = (pts_num_freqs * 2 + 1) * 3 345 | num_c_ray = (ray_num_freqs * 2 + 1) * 6 346 | 347 | self.n_samples = n_samples 348 | 349 | self.ray_dir_fc = nn.Sequential( 350 | nn.Linear(4 + num_c_xyz + num_c_ray, 256), 351 | activation_func, 352 | nn.Linear(256, in_feat_ch + 3), 353 | ) 354 | 355 | self.ref_feature_fc = nn.Sequential(nn.Linear(num_c_ray, in_feat_ch + 3)) 356 | 357 | self.base_fc = nn.Sequential( 358 | nn.Linear((in_feat_ch + 3) * 6, 256), 359 | activation_func, 360 | nn.Linear(256, 128), 361 | activation_func, 362 | ) 363 | 364 | self.vis_fc = nn.Sequential( 365 | nn.Linear(128, 128), 366 | activation_func, 367 | nn.Linear(128, 128 + 1), 368 | activation_func, 369 | ) 370 | 371 | self.vis_fc2 = nn.Sequential( 372 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1), nn.Sigmoid() 373 | ) 374 | 375 | self.geometry_fc = nn.Sequential( 376 | nn.Linear(128 * 2 + 1, 256), 377 | activation_func, 378 | nn.Linear(256, 128), 379 | activation_func, 380 | ) 381 | 382 | self.ray_attention = MultiHeadAttention(4, 128, 32, 32) 383 | self.out_geometry_fc = nn.Sequential( 384 | nn.Linear(128, 128), activation_func, nn.Linear(128, 1) 385 | ) 386 | 387 | if self.input_dir: 388 | self.rgb_fc = nn.Sequential( 389 | nn.Linear(128 * 2 + 1 + 4, 128), 390 | activation_func, 391 | nn.Linear(128, 64), 392 | activation_func, 393 | nn.Linear(64, 1), 394 | ) 395 | 396 | else: 397 | self.rgb_fc = nn.Sequential( 398 | nn.Linear(32 + 1, 32), 399 | activation_func, 400 | nn.Linear(32, 16), 401 | activation_func, 402 | nn.Linear(16, 1), 403 | ) 404 | 405 | self.pos_encoding = self.posenc(d_hid=128, n_samples=self.n_samples) 406 | 407 | def posenc(self, d_hid, n_samples): 408 | def get_position_angle_vec(position): 409 | return [ 410 | position / np.power(10000, 2 * (hid_j // 2) / d_hid) 411 | for hid_j in range(d_hid) 412 | ] 413 | 414 | sinusoid_table = np.array( 415 | [get_position_angle_vec(pos_i) for pos_i in range(n_samples)] 416 | ) 417 | 418 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 419 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 420 | sinusoid_table = torch.from_numpy(sinusoid_table).float().unsqueeze(0) 421 | return sinusoid_table 422 | 423 | def forward( 424 | self, 425 | pts, 426 | ref_rays_coords, 427 | src_rays_coords, 428 | rgb_feat, 429 | glb_ray_dir, 430 | ray_diff, 431 | mask, 432 | ): 433 | num_views = rgb_feat.shape[2] 434 | ref_rays_pe = self.ray_embed(ref_rays_coords) 435 | src_rays_pe = self.ray_embed(src_rays_coords) 436 | pts_pe = self.pts_embed(pts) 437 | 438 | ref_features = ref_rays_pe[:, None, None, :].expand( 439 | -1, src_rays_pe.shape[1], src_rays_pe.shape[2], -1 440 | ) 441 | src_features = torch.cat( 442 | [ 443 | pts_pe.unsqueeze(2).expand(-1, -1, src_rays_pe.shape[2], -1), 444 | src_rays_pe, 445 | ], 446 | dim=-1, 447 | ) 448 | 449 | src_feat = self.ray_dir_fc(torch.cat([src_features, ray_diff], dim=-1)) 450 | ref_feat = self.ref_feature_fc(ref_features) 451 | 452 | rgb_in = rgb_feat[..., :3] 453 | 454 | if self.mask_rgb: 455 | rgb_in_sum = torch.sum(rgb_in, dim=-1, keepdim=True) 456 | rgb_mask = (rgb_in_sum > 1e-3).float().detach() 457 | mask = mask * rgb_mask 458 | 459 | rgb_feat = torch.cat([rgb_feat, src_feat * ref_feat], dim=-1) 460 | 461 | if self.anti_alias_pooling: 462 | _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1) 463 | exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1)) 464 | weight = ( 465 | exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0] 466 | ) * mask 467 | weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8) 468 | else: 469 | weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) 470 | 471 | # compute mean and variance across different views for each point 472 | mean, var = fused_mean_variance( 473 | rgb_feat, weight 474 | ) # [n_rays, n_samples, 1, n_feat] 475 | globalfeat = torch.cat( 476 | [mean, var], dim=-1 477 | ) # [n_rays, n_samples, 1, 2*n_feat] 478 | 479 | x = torch.cat( 480 | [globalfeat.expand(-1, -1, num_views, -1), rgb_feat], dim=-1 481 | ) # [n_rays, n_samples, n_views, 3*n_feat] 482 | 483 | x = self.base_fc(x) 484 | 485 | x_vis = self.vis_fc(x * weight) 486 | x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1) 487 | vis = F.sigmoid(vis) * mask 488 | x = x + x_res 489 | vis = self.vis_fc2(x * vis) * mask 490 | weight = vis / (torch.sum(vis, dim=2, keepdim=True) + 1e-8) 491 | 492 | mean, var = fused_mean_variance(x, weight) 493 | globalfeat = torch.cat( 494 | [mean.squeeze(2), var.squeeze(2), weight.mean(dim=2)], dim=-1 495 | ) # [n_rays, n_samples, 32*2+1] 496 | globalfeat = self.geometry_fc(globalfeat) # [n_rays, n_samples, 16] 497 | num_valid_obs = torch.sum(mask, dim=2) 498 | 499 | # globalfeat = globalfeat #+ self.pos_encoding.to(globalfeat.device) 500 | globalfeat, _ = self.ray_attention( 501 | globalfeat, globalfeat, globalfeat, mask=(num_valid_obs > 1).float() 502 | ) # [n_rays, n_samples, 16] 503 | sigma = self.out_geometry_fc(globalfeat) # [n_rays, n_samples, 1] 504 | sigma_out = sigma.masked_fill( 505 | num_valid_obs < 1, -1e9 506 | ) # set the sigma of invalid point to zero 507 | 508 | if self.input_dir: 509 | x = torch.cat( 510 | [ 511 | globalfeat[:, :, None, :].expand(-1, -1, x.shape[2], -1), 512 | x, 513 | vis, 514 | ray_diff, 515 | ], 516 | dim=-1, 517 | ) 518 | else: 519 | x = torch.cat([globalfeat, vis], dim=-1) 520 | 521 | x = self.rgb_fc(x) 522 | 523 | x = x.masked_fill(mask == 0, -1e9) 524 | blending_weights_valid = F.softmax(x, dim=2) # color blending 525 | rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2) 526 | out = torch.cat([rgb_out, sigma_out], dim=-1) 527 | return out 528 | 529 | 530 | class PeriodicEmbed(nn.Module): 531 | """Fourier Position encoding module.""" 532 | 533 | def __init__(self, max_freq, N_freq, linspace=True): 534 | """Init function for position encoding. 535 | 536 | Args: 537 | max_freq: max frequency band 538 | N_freq: number of frequency 539 | linspace: linearly spacing or not 540 | """ 541 | super().__init__() 542 | self.embed_functions = [torch.cos, torch.sin] 543 | if linspace: 544 | self.freqs = torch.linspace(1, max_freq + 1, steps=N_freq) 545 | else: 546 | exps = torch.linspace(0, N_freq - 1, steps=N_freq) 547 | self.freqs = 2**exps 548 | 549 | def forward(self, x): 550 | output = [x] 551 | for f in self.embed_functions: 552 | for freq in self.freqs: 553 | output.append(f(freq * x)) 554 | 555 | return torch.cat(output, -1) 556 | 557 | 558 | class MotionMLP(nn.Module): 559 | """Motion trajectory MLP module.""" 560 | 561 | def __init__( 562 | self, 563 | num_basis=4, 564 | D=8, 565 | W=256, 566 | input_ch=4, 567 | num_freqs=16, 568 | skips=[4], 569 | sf_mag_div=1.0, 570 | ): 571 | """Init function for motion MLP. 572 | 573 | Args: 574 | num_basis: number motion basis 575 | D: MLP layers 576 | W: feature dimention of MLP layers 577 | input_ch: input number of channels 578 | num_freqs: number of rquency for position encoding 579 | skips: where to inject skip connection 580 | sf_mag_div: motion scaling factor 581 | """ 582 | super(MotionMLP, self).__init__() 583 | self.D = D 584 | self.W = W 585 | self.input_ch = int(input_ch + input_ch * num_freqs * 2) 586 | self.skips = skips 587 | self.sf_mag_div = sf_mag_div 588 | 589 | self.xyzt_embed = PeriodicEmbed(max_freq=num_freqs, N_freq=num_freqs) 590 | 591 | self.pts_linears = nn.ModuleList( 592 | [nn.Linear(self.input_ch, W)] 593 | + [ 594 | nn.Linear(W, W) 595 | if i not in self.skips 596 | else nn.Linear(W + self.input_ch, W) 597 | for i in range(D - 1) 598 | ] 599 | ) 600 | 601 | self.coeff_linear = nn.Linear(W, num_basis * 3) 602 | self.coeff_linear.weight.data.fill_(0.0) 603 | self.coeff_linear.bias.data.fill_(0.0) 604 | 605 | def forward(self, x): 606 | input_pts = self.xyzt_embed(x) 607 | 608 | h = input_pts 609 | for i, l in enumerate(self.pts_linears): 610 | h = self.pts_linears[i](h) 611 | h = F.relu(h) 612 | if i in self.skips: 613 | h = torch.cat([input_pts, h], -1) 614 | 615 | # sf = nn.functional.tanh(self.sf_linear(h)) 616 | pred_coeff = self.coeff_linear(h) 617 | 618 | return pred_coeff / self.sf_mag_div 619 | --------------------------------------------------------------------------------