├── data ├── __init__.py ├── get_datasets.py ├── nerf.py ├── llff.py ├── dtu.py ├── klevr.py └── scannet.py ├── model ├── __init__.py ├── UNet.py ├── geo_reasoner.py └── self_attn_renderer.py ├── utils ├── __init__.py ├── optimizer.py ├── loss.py ├── metrics.py ├── depth_map.py ├── options.py ├── klevr_utils.py ├── rendering.py ├── depth_loss.py └── scannet_utils.py ├── img └── cvpr_poster_final_5120.png ├── configs ├── lists │ ├── replica_test_split.txt │ ├── scannet_test_split.txt │ ├── replica_train_split.txt │ └── scannet_train_split.txt ├── replica.txt └── scannet.txt ├── requirements.txt └── README.md /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/cvpr_poster_final_5120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimChou-ntu/GSNeRF/HEAD/img/cvpr_poster_final_5120.png -------------------------------------------------------------------------------- /configs/lists/replica_test_split.txt: -------------------------------------------------------------------------------- 1 | replica/office_4/Sequence_1/black_320 2 | replica/office_4/Sequence_2/black_320 3 | replica/room_2/Sequence_1/black_320 4 | replica/room_2/Sequence_2/black_320 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning 2 | inplace_abn 3 | imageio 4 | pillow 5 | scikit-image 6 | opencv-python 7 | ConfigArgParse 8 | lpips 9 | kornia 10 | ipdb 11 | scikit-learn 12 | pandas 13 | natsort 14 | segmentation-models-pytorch 15 | wandb -------------------------------------------------------------------------------- /configs/lists/scannet_test_split.txt: -------------------------------------------------------------------------------- 1 | scannet/scene0063_00/black_320 2 | scannet/scene0067_00/black_320 3 | scannet/scene0071_00/black_320 4 | scannet/scene0074_00/black_320 5 | scannet/scene0079_00/black_320 6 | scannet/scene0086_00/black_320 7 | scannet/scene0200_00/black_320 8 | scannet/scene0211_00/black_320 9 | scannet/scene0226_00/black_320 10 | scannet/scene0376_02/black_320 11 | -------------------------------------------------------------------------------- /configs/lists/replica_train_split.txt: -------------------------------------------------------------------------------- 1 | replica/office_0/Sequence_1/black_320 2 | replica/office_0/Sequence_2/black_320 3 | replica/office_1/Sequence_1/black_320 4 | replica/office_1/Sequence_2/black_320 5 | replica/office_2/Sequence_1/black_320 6 | replica/office_2/Sequence_2/black_320 7 | replica/office_3/Sequence_1/black_320 8 | replica/office_3/Sequence_2/black_320 9 | replica/room_0/Sequence_1/black_320 10 | replica/room_0/Sequence_2/black_320 11 | replica/room_1/Sequence_1/black_320 12 | replica/room_1/Sequence_2/black_320 13 | -------------------------------------------------------------------------------- /configs/replica.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | expname = replica 3 | logdir = ./logs_replica/ 4 | nb_views = 8 5 | 6 | ### number of class + 1 7 | nb_class = 20 8 | ignore_label = 19 9 | 10 | ## model 11 | using_semantic_global_tokens = 1 12 | only_using_semantic_global_tokens = 0 13 | use_depth_refine_net = False 14 | 15 | ## dataset 16 | dataset_name = replica 17 | replica_path = "/mnt/sdb/timothy/Desktop/2023Spring/Semantic-Ray/data/replica" 18 | scene = None 19 | val_set_list = "configs/lists/replica_val_split.txt" 20 | 21 | ### TESTING 22 | chunk = 4096 ### Reduce it to save memory 23 | 24 | ### TRAINING 25 | ### num_steps = 250000 26 | num_steps = 300000 27 | lrate = 0.0005 28 | logger = wandb 29 | batch_size = 1024 30 | two_stage_training_steps = 0 31 | cross_entropy_weight = 0.5 32 | background_weight = 0.8 33 | use_batch_semantic_feature = True 34 | feat_net = smp_UNet -------------------------------------------------------------------------------- /configs/scannet.txt: -------------------------------------------------------------------------------- 1 | ### INPUT 2 | expname = scannet 3 | logdir = ./logs_scannet/ 4 | nb_views = 8 5 | 6 | ### number of class + 1 7 | nb_class = 21 8 | ignore_label = 20 9 | 10 | ## model 11 | using_semantic_global_tokens = 1 12 | only_using_semantic_global_tokens = 0 13 | use_depth_refine_net = False 14 | feat_net = smp_UNet 15 | 16 | ## dataset 17 | dataset_name = scannet 18 | scannet_path = "/mnt/sdb/timothy/Desktop/2023Spring/Semantic-Ray/data/scannet" 19 | scene = None 20 | val_set_list = "configs/lists/scannet_test_split.txt" 21 | 22 | ### TESTING 23 | chunk = 4096 ### Reduce it to save memory 24 | 25 | ### TRAINING 26 | ### num_steps = 250000 27 | num_steps = 300000 28 | lrate = 0.0005 29 | logger = wandb 30 | batch_size = 1024 31 | two_stage_training_steps = 0 32 | cross_entropy_weight = 0.5 33 | background_weight = 1.0 34 | use_batch_semantic_feature = True -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.optim import SGD, Adam 3 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR 4 | 5 | def get_optimizer(hparams, models): 6 | eps = 1e-5 7 | parameters = [] 8 | for model in models: 9 | parameters += list(model.parameters()) 10 | if hparams.optimizer == 'sgd': 11 | optimizer = SGD(parameters, lr=hparams.lrate, 12 | weight_decay=1e-5) 13 | elif hparams.optimizer == 'adam': 14 | optimizer = Adam(parameters, lr=hparams.lrate, 15 | betas=(0.9, 0.999), 16 | weight_decay=1e-5) 17 | else: 18 | raise ValueError('optimizer not recognized!') 19 | 20 | return optimizer 21 | 22 | def get_scheduler(hparams, optimizer): 23 | eps = 1e-5 24 | # if hparams.lr_scheduler == 'steplr': 25 | # scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, 26 | # gamma=hparams.decay_gamma) 27 | # elif hparams.lr_scheduler == 'cosine': 28 | # scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_steps, eta_min=eps) 29 | 30 | # else: 31 | # raise ValueError('scheduler not recognized!') 32 | num_steps = hparams.num_steps 33 | if hparams.ddp: 34 | num_steps /= 8 35 | return CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=eps) -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | ### from semray 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Loss: 8 | def __init__(self, keys): 9 | """ 10 | keys are used in multi-gpu model, DummyLoss in train_tools.py 11 | :param keys: the output keys of the dict 12 | """ 13 | self.keys = keys 14 | 15 | def __call__(self, data_pr, data_gt, **kwargs): 16 | pass 17 | 18 | class SemanticLoss(Loss): 19 | def __init__(self, nb_class, ignore_label, weight=None): 20 | super().__init__(['loss_semantic']) 21 | self.nb_class = nb_class 22 | self.ignore_label = ignore_label 23 | self.weight = weight 24 | 25 | def __call__(self, data_pr, data_gt, **kwargs): 26 | def compute_loss(label_pr, label_gt): 27 | label_pr = label_pr.reshape(-1, self.nb_class) 28 | label_gt = label_gt.reshape(-1).long() 29 | valid_mask = (label_gt != self.ignore_label) 30 | label_pr = label_pr[valid_mask] 31 | label_gt = label_gt[valid_mask] 32 | if self.weight != None: 33 | self.weight = self.weight.to(label_pr.device) 34 | return nn.functional.cross_entropy(label_pr, label_gt, reduction='mean', weight=self.weight).unsqueeze(0) 35 | else: 36 | return nn.functional.cross_entropy(label_pr, label_gt, reduction='mean').unsqueeze(0) 37 | 38 | loss = compute_loss(data_pr, data_gt) 39 | 40 | return loss -------------------------------------------------------------------------------- /configs/lists/scannet_train_split.txt: -------------------------------------------------------------------------------- 1 | scannet/scene0000_00/black_320 2 | scannet/scene0001_00/black_320 3 | scannet/scene0002_00/black_320 4 | scannet/scene0003_00/black_320 5 | scannet/scene0004_00/black_320 6 | scannet/scene0005_00/black_320 7 | scannet/scene0006_00/black_320 8 | scannet/scene0007_00/black_320 9 | scannet/scene0008_00/black_320 10 | scannet/scene0009_00/black_320 11 | scannet/scene0010_00/black_320 12 | scannet/scene0011_00/black_320 13 | scannet/scene0012_00/black_320 14 | scannet/scene0013_00/black_320 15 | scannet/scene0014_00/black_320 16 | scannet/scene0015_00/black_320 17 | scannet/scene0016_00/black_320 18 | scannet/scene0017_00/black_320 19 | scannet/scene0018_00/black_320 20 | scannet/scene0019_00/black_320 21 | scannet/scene0020_00/black_320 22 | scannet/scene0021_00/black_320 23 | scannet/scene0022_00/black_320 24 | scannet/scene0023_00/black_320 25 | scannet/scene0024_00/black_320 26 | scannet/scene0025_00/black_320 27 | scannet/scene0026_00/black_320 28 | scannet/scene0027_00/black_320 29 | scannet/scene0028_00/black_320 30 | scannet/scene0029_00/black_320 31 | scannet/scene0030_00/black_320 32 | scannet/scene0031_00/black_320 33 | scannet/scene0032_00/black_320 34 | scannet/scene0033_00/black_320 35 | scannet/scene0034_00/black_320 36 | scannet/scene0035_00/black_320 37 | scannet/scene0036_00/black_320 38 | scannet/scene0037_00/black_320 39 | scannet/scene0038_00/black_320 40 | scannet/scene0039_00/black_320 41 | scannet/scene0040_00/black_320 42 | scannet/scene0041_00/black_320 43 | scannet/scene0042_00/black_320 44 | scannet/scene0043_00/black_320 45 | scannet/scene0044_00/black_320 46 | scannet/scene0045_00/black_320 47 | scannet/scene0046_00/black_320 48 | scannet/scene0047_00/black_320 49 | scannet/scene0048_00/black_320 50 | scannet/scene0049_00/black_320 51 | scannet/scene0050_00/black_320 52 | scannet/scene0051_00/black_320 53 | scannet/scene0052_00/black_320 54 | scannet/scene0053_00/black_320 55 | scannet/scene0054_00/black_320 56 | scannet/scene0055_00/black_320 57 | scannet/scene0056_00/black_320 58 | scannet/scene0057_00/black_320 59 | scannet/scene0058_00/black_320 60 | scannet/scene0059_00/black_320 61 | scannet/scene0060_00/black_320 -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def nanmean(data, **args): 6 | # This makes it ignore the first 'background' class 7 | return np.ma.masked_array(data, np.isnan(data)).mean(**args) 8 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean() 9 | 10 | 11 | def calculate_segmentation_metrics(true_labels, predicted_labels, number_classes, ignore_label=-1): 12 | ''' 13 | return: 14 | miou: the miou of all classes (without nan classes / not existing classes) 15 | valid_miou: the miou of valid classes (without ignore_class and not existing classes) 16 | class_average_accuracy: the average accuracy of all classes 17 | total_accuracy: the accuracy of all classes 18 | ious: per class iou 19 | ''' 20 | 21 | np.seterr(divide='ignore', invalid='ignore') 22 | if (true_labels == ignore_label).all(): 23 | return [0]*5 24 | 25 | true_labels = true_labels.flatten().cpu().numpy() 26 | predicted_labels = predicted_labels.flatten().cpu().numpy() 27 | valid_pix_ids = true_labels!=ignore_label 28 | predicted_labels = predicted_labels[valid_pix_ids] 29 | true_labels = true_labels[valid_pix_ids] 30 | 31 | conf_mat = confusion_matrix(true_labels, predicted_labels, labels=list(range(number_classes))) 32 | norm_conf_mat = np.transpose( 33 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1)) 34 | 35 | missing_class_mask = np.isnan(norm_conf_mat.sum(1)) # missing class will have NaN at corresponding class 36 | exsiting_class_mask = ~ missing_class_mask 37 | 38 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat)) 39 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat)) 40 | ious = np.zeros(number_classes) 41 | for class_id in range(number_classes): 42 | ious[class_id] = (conf_mat[class_id, class_id] / ( 43 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) - 44 | conf_mat[class_id, class_id])) 45 | miou = nanmean(ious) 46 | miou_valid_class = np.mean(ious[exsiting_class_mask]) 47 | return miou, miou_valid_class, total_accuracy, class_average_accuracy, ious 48 | 49 | 50 | # From https://github.com/Harry-Zhi/semantic_nerf/blob/a0113bb08dc6499187c7c48c3f784c2764b8abf1/SSR/training/training_utils.py 51 | class IoU(): 52 | 53 | def __init__(self, ignore_label=-1, num_classes=20): 54 | self.ignore_label = ignore_label 55 | self.num_classes = num_classes 56 | 57 | def __call__(self, true_labels, predicted_labels): 58 | np.seterr(divide='ignore', invalid='ignore') 59 | true_labels = true_labels.long().detach().cpu().numpy() 60 | predicted_labels = predicted_labels.long().detach().cpu().numpy() 61 | 62 | if self.ignore_label != -1: 63 | valid_pix_ids = true_labels != self.ignore_label 64 | else: 65 | valid_pix_ids = np.ones_like(true_labels, dtype=bool) 66 | 67 | num_classes = self.num_classes 68 | predicted_labels = predicted_labels[valid_pix_ids] 69 | true_labels = true_labels[valid_pix_ids] 70 | 71 | conf_mat = confusion_matrix( 72 | true_labels, predicted_labels, labels=list(range(num_classes))) 73 | norm_conf_mat = np.transpose(np.transpose( 74 | conf_mat) / conf_mat.astype(float).sum(axis=1)) 75 | 76 | # missing class will have NaN at corresponding class 77 | missing_class_mask = np.isnan(norm_conf_mat.sum(1)) 78 | exsiting_class_mask = ~ missing_class_mask 79 | 80 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat)) 81 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat)) 82 | ious = np.zeros(num_classes) 83 | for class_id in range(num_classes): 84 | ious[class_id] = (conf_mat[class_id, class_id] / ( 85 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) - 86 | conf_mat[class_id, class_id])) 87 | miou = np.mean(ious[exsiting_class_mask]) 88 | if np.isnan(miou): 89 | miou = 0. 90 | total_accuracy = 0. 91 | class_average_accuracy = 0. 92 | output = { 93 | 'miou': torch.tensor([miou], dtype=torch.float32), 94 | 'total_accuracy': torch.tensor([total_accuracy], dtype=torch.float32), 95 | 'class_average_accuracy': torch.tensor([class_average_accuracy], dtype=torch.float32) 96 | } 97 | return output -------------------------------------------------------------------------------- /utils/depth_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | # This depth map seems to be slow and inaccurate 4 | def dense_map(Pts, n, m, grid): 5 | ng = 2 * grid + 1 6 | 7 | mX = 100000*torch.ones((m,n)).to(Pts.dtype).to(Pts.device) 8 | mY = 100000*torch.ones((m,n)).to(Pts.dtype).to(Pts.device) 9 | mD = torch.zeros((m,n)).to(Pts.device) 10 | y_ = Pts[1].to(torch.int32) 11 | x_ = Pts[0].to(torch.int32) 12 | int_x = torch.round(Pts[0]).to(Pts.device) 13 | int_y = torch.round(Pts[1]).to(Pts.device) 14 | mX[y_,x_] = Pts[0] - int_x 15 | mY[y_,x_] = Pts[1] - int_y 16 | mD[y_,x_] = Pts[2] 17 | 18 | KmX = torch.zeros((ng, ng, m - ng, n - ng)).to(Pts.device) 19 | KmY = torch.zeros((ng, ng, m - ng, n - ng)).to(Pts.device) 20 | KmD = torch.zeros((ng, ng, m - ng, n - ng)).to(Pts.device) 21 | 22 | for i in range(ng): 23 | for j in range(ng): 24 | KmX[i,j] = mX[i : (m - ng + i), j : (n - ng + j)] - grid - 1 +i 25 | KmY[i,j] = mY[i : (m - ng + i), j : (n - ng + j)] - grid - 1 +i 26 | KmD[i,j] = mD[i : (m - ng + i), j : (n - ng + j)] 27 | S = torch.zeros_like(KmD[0,0]).to(Pts.device) 28 | Y = torch.zeros_like(KmD[0,0]).to(Pts.device) 29 | 30 | for i in range(ng): 31 | for j in range(ng): 32 | s = 1/torch.sqrt(KmX[i,j] * KmX[i,j] + KmY[i,j] * KmY[i,j]) 33 | Y = Y + s * KmD[i,j] 34 | S = S + s 35 | del s 36 | 37 | S[S == 0] = 1 38 | # out = torch.zeros((m,n)).to(Pts.device) 39 | # set to the far range of the dataset, TODO: change this to be more general 40 | out = torch.ones((m,n)).to(Pts.device)*10.0 41 | # incase Y and S goes too big and becomes inf, we set inf to the biggest value of Y/S respectively 42 | Y = torch.nan_to_num(Y, posinf=Y[Y!=torch.inf].max().item()) 43 | S = torch.nan_to_num(S, posinf=S[S!=torch.inf].max().item()) 44 | out[grid + 1 : -grid, grid + 1 : -grid] = Y/S 45 | out = F.pad(out[grid + 1 : -grid, grid + 1 : -grid].unsqueeze(0), (grid+1, grid, grid+1, grid), mode='replicate').squeeze(0) 46 | if torch.isnan(out).any(): 47 | print('Nan in depth map') 48 | del mX, mY, mD, KmX, KmY, KmD, S, Y, y_, x_, int_x, int_y 49 | return out 50 | 51 | 52 | def get_target_view_depth(source_depths, source_intrinsics, source_c2ws, target_intrinsics, target_w2c, img_wh, grid_size): 53 | ''' 54 | source_depth: [N, H, W] 55 | source_intrinsics: [N, 3, 3] 56 | source_c2ws: [N, 4, 4] 57 | target_intrinsics: [3, 3] 58 | target_w2c: [4, 4] 59 | img_wh: [2] 60 | grid_size: int 61 | return: depth map [H, W] 62 | ''' 63 | W, H = img_wh 64 | N = source_depths.shape[0] 65 | points = [] 66 | 67 | ys, xs = torch.meshgrid( 68 | torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W), indexing="ij" 69 | ) # pytorch's meshgrid has indexing='ij' 70 | ys, xs = ys.reshape(-1).to(source_intrinsics.device), xs.reshape(-1).to(source_intrinsics.device) 71 | 72 | for num in range(N): 73 | # Might need to change this to be more general (too small or too big value are not good) 74 | mask = source_depths[num] > 0 75 | 76 | dirs = torch.stack( 77 | [ 78 | (xs - source_intrinsics[num][0, 2]) / source_intrinsics[num][0, 0], 79 | (ys - source_intrinsics[num][1, 2]) / source_intrinsics[num][1, 1], 80 | torch.ones_like(xs), 81 | ], 82 | -1, 83 | ) 84 | rays_dir = ( 85 | dirs @ source_c2ws[num][:3, :3].t() 86 | ) 87 | rays_orig = source_c2ws[num][:3, -1].clone().reshape(1, 3).expand(rays_dir.shape[0], -1) 88 | rays_orig = rays_orig.reshape(H,W,-1)[mask] 89 | rays_depth = source_depths[num].reshape(H,W,-1)[mask] 90 | rays_dir = rays_dir.reshape(H,W,-1)[mask] 91 | ray_pts = rays_orig + rays_depth * rays_dir 92 | points.append(ray_pts.reshape(-1,3)) 93 | 94 | del rays_orig, rays_depth, rays_dir, ray_pts, dirs, mask 95 | 96 | points = torch.cat(points,0).reshape(-1,3) 97 | 98 | R = target_w2c[:3, :3] # (3, 3) 99 | T = target_w2c[:3, 3:] # (3, 1) 100 | ray_pts_transformed = torch.matmul(points, R.t()) + T.reshape(1, 3) 101 | 102 | ray_pts_ndc = ray_pts_transformed @ target_intrinsics.t() 103 | ndc = ray_pts_ndc[:, :2] / ray_pts_ndc[:, -1:] 104 | # ray_pts_ndc[:, 0] = ray_pts_ndc[:, 0] / ray_pts_ndc[:, 2] 105 | # ray_pts_ndc[:, 1] = ray_pts_ndc[:, 1] / ray_pts_ndc[:, 2] 106 | mask = (ndc[:, 0] >= 0) & (ndc[:, 0] <= W-1) & (ndc[:, 1] >= 0) & (ndc[:, 1] <= H-1) 107 | # when doing scannet dataset this is not necessary, cause the ndc depth more than 2 108 | # mask = mask & (ray_pts_transformed[:, 2] > 2) 109 | points_2d = ndc[mask, 0:2] 110 | 111 | lidarOnImage = torch.cat((points_2d, ray_pts_transformed[mask,2].reshape(-1,1)), 1) 112 | if torch.isnan(lidarOnImage).any(): 113 | print("lidarOnImage has nan") 114 | depth_map = dense_map(lidarOnImage.t(), W, H, grid_size) 115 | # del ray_pts_ndc, target_intrinsics, ray_pts_transformed 116 | # del ys, xs, ray_pts_transformed, points, points_2d, ray_pts_ndc, mask, lidarOnImage, R, T 117 | # depth_map = torch.ones_like(source_depths[0]).to("cuda") 118 | 119 | return depth_map 120 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def config_parser(): 4 | parser = configargparse.ArgumentParser() 5 | parser.add_argument("--config", is_config_file=True, help="Config file path") 6 | 7 | # Task options 8 | parser.add_argument("--segmentation", action="store_true", help="Use segmentation mask for training") 9 | parser.add_argument("--nb_class", type=int, default=21, help="Number of classes for segmentation") 10 | parser.add_argument("--ignore_label", type=int, default=20, help="Ignore label for segmentation") 11 | 12 | # Datasets options 13 | parser.add_argument("--dataset_name", type=str, default="llff", choices=["llff", "nerf", "dtu", "klevr", "scannet", "replica"],) 14 | parser.add_argument("--llff_path", type=str, help="Path to llff dataset") 15 | parser.add_argument("--llff_test_path", type=str, help="Path to llff dataset") 16 | parser.add_argument("--dtu_path", type=str, help="Path to dtu dataset") 17 | parser.add_argument("--dtu_pre_path", type=str, help="Path to preprocessed dtu dataset") 18 | parser.add_argument("--nerf_path", type=str, help="Path to nerf dataset") 19 | parser.add_argument("--ams_path", type=str, help="Path to ams dataset") 20 | parser.add_argument("--ibrnet1_path", type=str, help="Path to ibrnet1 dataset") 21 | parser.add_argument("--ibrnet2_path", type=str, help="Path to ibrnet2 dataset") 22 | parser.add_argument("--klevr_path", type=str, help="Path to klevr dataset") 23 | parser.add_argument("--scannet_path", type=str, help="Path to scannet dataset") 24 | parser.add_argument("--replica_path", type=str, help="Path to replica dataset") 25 | 26 | # for scannet dataset 27 | parser.add_argument("--val_set_list", type=str, help="Path to scannet val dataset list") 28 | 29 | # Training options 30 | parser.add_argument("--batch_size", type=int, default=512) 31 | parser.add_argument("--num_steps", type=int, default=200000) 32 | parser.add_argument("--nb_views", type=int, default=3) 33 | parser.add_argument("--lrate", type=float, default=5e-4, help="Learning rate") 34 | parser.add_argument("--warmup_steps", type=int, default=500, help="Gradually warm-up learning rate in optimizer") 35 | parser.add_argument("--scene", type=str, default="None", help="Scene for fine-tuning") 36 | parser.add_argument("--cross_entropy_weight", type=float, default=0.1, help="Weight for cross entropy loss") 37 | parser.add_argument("--optimizer", type=str, default="adam", help="select optimizer: adam / sgd") 38 | parser.add_argument("--background_weight", type=float, default=1, help="Weight for background class in cross entropy loss") 39 | parser.add_argument("--two_stage_training_steps", type=int, default=60000, help="Use two stage training, indicating how many steps for first stage") 40 | parser.add_argument("--self_supervised_depth_loss", action="store_true", help="Use self supervised depth loss") 41 | 42 | # Rendering options 43 | parser.add_argument("--chunk", type=int, default=4096, help="Number of rays rendered in parallel") 44 | parser.add_argument("--nb_coarse", type=int, default=96, help="Number of coarse samples per ray") 45 | parser.add_argument("--nb_fine", type=int, default=32, help="Number of additional fine samples per ray",) 46 | 47 | # Other options 48 | parser.add_argument("--expname", type=str, help="Experiment name") 49 | parser.add_argument("--logger", type=str, default="wandb", choices=["wandb", "tensorboard", "none"]) 50 | parser.add_argument("--logdir", type=str, default="./logs/", help="Where to store ckpts and logs") 51 | parser.add_argument("--eval", action="store_true", help="Render and evaluate the test set") 52 | parser.add_argument("--use_depth", action="store_true", help="Use ground truth low-res depth maps in rendering process") 53 | parser.add_argument("--seed", type=int, default=123, help="Random seed") 54 | parser.add_argument("--val_save_img_type", default=["target"], action="append", help="choices=[target, depth, source], Save target comparison images or depth maps or source images") 55 | parser.add_argument("--target_depth_estimation", action="store_true", help="Use target depth estimation in rendering process") 56 | parser.add_argument("--use_depth_refine_net", action="store_true", help="Use depth refine net before rendering process") 57 | parser.add_argument("--using_semantic_global_tokens", type=int, default=0, help="Use only semantic global tokens in rendering process. 0: not use, 1: use") 58 | parser.add_argument("--only_using_semantic_global_tokens", type=int, default=0, help="Use only semantic global tokens in rendering process. 0: not use, 1: use") 59 | parser.add_argument("--use_batch_semantic_feature", action="store_true", help="Use batch semantic feature in rendering process") 60 | parser.add_argument("--ddp", action="store_true", help="Use distributed data parallel") 61 | parser.add_argument("--feat_net", type=str, default="UNet", choices=["UNet", "smp_UNet"], help="FeatureNet used in depth estimation") 62 | # resume options 63 | parser.add_argument("--ckpt_path", type=str, default=None, help="Path to a checkpoint to resume training") 64 | parser.add_argument("--finetune", action="store_true", help="Finetune the model with a checkpoint") 65 | parser.add_argument("--fintune_scene", type=str, default="None", help="Scene for fine-tuning") 66 | return parser.parse_args() 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GSNeRF: Enhancing 3D Scene Understanding with Generalizable Semantic Neural Radiance Fields
2 | > Zi-Ting Chou, Sheng-Yu Huang, I-Jieh Liu, Yu-Chiang Wang
3 | > [Project Page](https://timchou-ntu.github.io/gsnerf/) | [Paper](https://arxiv.org/abs/2403.03608) 4 | 5 | This repository contains a official PyTorch Lightning implementation of our paper, GSNeRF (CVPR 2024). 6 | 7 |
8 | 9 |
10 | 11 | ## Installation 12 | 13 | #### Tested on NVIDIA GeForce RTX 3090 GPUs with cuda 11.7, PyTorch 2.0.1 and PyTorch Lightning 2.0.4 14 | 15 | To install the dependencies, in addition to PyTorch, run: 16 | 17 | ``` 18 | git clone --recursive https://github.com/TimChou-ntu/GSNeRF.git 19 | conda create -n gsnerf python=3.9 20 | conda activate gsnerf 21 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117 22 | cd GSNeRF 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Evaluation and Training 27 | Following [Semantic Nerf](https://github.com/Harry-Zhi/semantic_nerf) and [Semantic-Ray](https://github.com/liuff19/Semantic-Ray), we conduct experiment on [ScanNet](#scannet-real-world-indoor-scene-dataset) and [Replica](#replica-synthetic-indoor-scene-dataset) respectively. 28 | 29 | Download `scannet` from [here](https://github.com/ScanNet/ScanNet) and set its path as `scannet_path` in the [scannet.txt](./configs/scannet.txt) file. 30 | 31 | Download `Replica` from [here](https://www.dropbox.com/sh/9yu1elddll00sdl/AAC-rSJdLX0C6HhKXGKMOIija?dl=0) and set its path as `replica_path` in the [replica.txt](configs/replica.txt) file. (Thanks [Semantic Nerf](https://github.com/Harry-Zhi/semantic_nerf) for rendering 2D image and semantic map.) 32 | 33 | Organize the data in the following structure: 34 | ``` 35 | ├── data 36 | │ ├── scannet 37 | │ │ ├── scene0000_00 38 | │ │ │ ├── color 39 | │ │ │ │ ├── 0.jpg 40 | │ │ │ │ ├── ... 41 | │ │ │ ├── depth 42 | │ │ │ │ ├── 0.png 43 | │ │ │ │ ├── ... 44 | │ │ │ ├── label-filt 45 | │ │ │ │ ├── 0.png 46 | │ │ │ │ ├── ... 47 | │ │ │ ├── pose 48 | │ │ │ │ ├── 0.txt 49 | │ │ │ │ ├── ... 50 | │ │ │ ├── intrinsic 51 | │ │ │ │ ├── extrinsic_color.txt 52 | │ │ │ │ ├── intrinsic_color.txt 53 | │ │ │ │ ├── ... 54 | │ │ │ ├── ... 55 | │ │ ├── ... 56 | │ │ ├── scannetv2-labels.combined.tsv 57 | | | 58 | │ ├── replica 59 | │ │ ├── office_0 60 | │ │ │ ├── Sequence_1 61 | │ │ │ │ ├── depth 62 | | │ │ │ │ ├── depth_0.png 63 | | │ │ │ │ ├── ... 64 | │ │ │ │ ├── rgb 65 | | │ │ │ │ ├── rgb_0.png 66 | | │ │ │ │ ├── ... 67 | │ │ │ │ ├── semantic_class 68 | | │ │ │ │ ├── semantic_class_0.png 69 | | │ │ │ │ ├── ... 70 | │ │ │ │ ├── traj_w_c.txt 71 | │ │ ├── ... 72 | │ │ ├── semantic_info 73 | ``` 74 | ## ScanNet (real-world indoor scene) Dataset 75 | 76 | For training a generalizable model, set the number of source views to 8 (nb_views = 8) in the [scannet.txt](./configs/scannet.txt) file and run the following command: 77 | 78 | ``` 79 | python train.py --config configs/scannet.txt --segmentation --logger wandb --target_depth_estimation 80 | ``` 81 | 82 | For evaluation on a novel scene, run the following command: (replace [ckpt path] with your trained checkpoint path.) 83 | 84 | ``` 85 | python train.py --config configs/scannet.txt --segmentation --logger none --target_depth_estimation --ckpt_path [ckpt path] --eval 86 | ``` 87 | 88 | ## Replica (Synthetic indoor scene) Dataset 89 | 90 | For training a generalizable model, set the number of source views to 8 (nb_views = 8) in the [replica.txt](./configs/replica.txt) file and run the following command: 91 | 92 | ``` 93 | python train.py --config configs/replica.txt --segmentation --logger wandb --target_depth_estimation 94 | ``` 95 | 96 | For evaluation on a novel scene, run the following command: (replace [ckpt path] with your trained checkpoint path.) 97 | 98 | ``` 99 | python train.py --config configs/replica.txt --segmentation --logger none --target_depth_estimation --ckpt_path [ckpt path] --eval 100 | ``` 101 | 102 | ### Self-supervised depth model 103 | Simply add --self_supervised_depth_loss at the end of command. 104 | 105 | ### Contact 106 | You can contact the author through email: A88551212@gmail.com 107 | 108 | ## Citing 109 | If you find our work useful, please consider citing: 110 | ```BibTeX 111 | @inproceedings{Chou2024gsnerf, 112 | author = {Zi‑Ting Chou* and Sheng‑Yu Huang* and I‑Jieh Liu and Yu‑Chiang Frank Wang}, 113 | title = {GSNeRF: Generalizable Semantic Neural Radiance Fields with Enhanced 3D Scene Understanding}, 114 | booktitle = CVPR, 115 | year = {2024}, 116 | arxiv = {2403.03608}, 117 | } 118 | ``` 119 | 120 | ### Acknowledgement 121 | 122 | Some portions of the code were derived from [GeoNeRF](https://github.com/idiap/GeoNeRF). 123 | 124 | Additionally, the well-structured codebases of [nerf_pl](https://github.com/kwea123/nerf_pl), [nesf](https://nesf3d.github.io/), and [RC-MVSNet](https://github.com/Boese0601/RC-MVSNet) were extremely helpful during the experiment. Shout out to them for their contributions. -------------------------------------------------------------------------------- /utils/klevr_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import transform 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch 5 | import json 6 | import os 7 | 8 | def blender_quat2rot(quaternion): 9 | """Convert quaternion to rotation matrix. 10 | Equivalent to, but support batched case: 11 | ```python 12 | rot3x3 = mathutils.Quaternion(quaternion).to_matrix() 13 | ``` 14 | Args: 15 | quaternion: 16 | Returns: 17 | rotation matrix 18 | """ 19 | 20 | # Note: Blender first cast to double values for numerical precision while 21 | # we're using float32. 22 | q = np.sqrt(2) * quaternion 23 | 24 | q0 = q[..., 0] 25 | q1 = q[..., 1] 26 | q2 = q[..., 2] 27 | q3 = q[..., 3] 28 | 29 | qda = q0 * q1 30 | qdb = q0 * q2 31 | qdc = q0 * q3 32 | qaa = q1 * q1 33 | qab = q1 * q2 34 | qac = q1 * q3 35 | qbb = q2 * q2 36 | qbc = q2 * q3 37 | qcc = q3 * q3 38 | 39 | # Note: idx are inverted as blender and numpy convensions do not 40 | # match (x, y) -> (y, x) 41 | rotation = np.empty((*quaternion.shape[:-1], 3, 3), dtype=np.float32) 42 | rotation[..., 0, 0] = 1.0 - qbb - qcc 43 | rotation[..., 1, 0] = qdc + qab 44 | rotation[..., 2, 0] = -qdb + qac 45 | 46 | rotation[..., 0, 1] = -qdc + qab 47 | rotation[..., 1, 1] = 1.0 - qaa - qcc 48 | rotation[..., 2, 1] = qda + qbc 49 | 50 | rotation[..., 0, 2] = qdb + qac 51 | rotation[..., 1, 2] = -qda + qbc 52 | rotation[..., 2, 2] = 1.0 - qaa - qbb 53 | return rotation 54 | 55 | def make_transform_matrix(positions,rotations,): 56 | """Create the 4x4 transformation matrix. 57 | Note: This function uses numpy. 58 | Args: 59 | positions: Translation applied after the rotation. 60 | Last column of the transformation matrix 61 | rotations: Rotation. Top-left 3x3 matrix of the transformation matrix. 62 | Returns: 63 | transformation_matrix: 64 | """ 65 | # Create the 4x4 transformation matrix 66 | rot_pos = np.broadcast_to(np.eye(4), (*positions.shape[:-1], 4, 4)).copy() 67 | rot_pos[..., :3, :3] = rotations 68 | # Note: Blender and numpy use different convensions for the translation 69 | rot_pos[..., :3, 3] = positions 70 | return rot_pos 71 | 72 | def from_position_and_quaternion(positions, quaternions, use_unreal_axes): 73 | if use_unreal_axes: 74 | rotations = transform.Rotation.from_quat(quaternions).as_matrix() 75 | else: 76 | # Rotation matrix that rotates from world to object coordinates. 77 | # Warning: Rotations should be given in blender convensions as 78 | # scipy.transform uses different convensions. 79 | rotations = blender_quat2rot(quaternions) 80 | px2world_transform = make_transform_matrix(positions=positions,rotations=rotations) 81 | return px2world_transform 82 | 83 | def scale_rays(all_rays_o, all_rays_d, scene_boundaries, img_wh): 84 | """Rescale scene boundaries. 85 | rays_o: (len(image_paths)*h*w, 3) 86 | rays_d: (len(image_paths)*h*w, 3) 87 | scene_boundaries: np.array(2 ,3), [min, max] 88 | img_wh: (2) 89 | """ 90 | # Rescale (x, y, z) from [min, max] -> [-1, 1] 91 | # all_rays_o = all_rays_o.reshape(-1, img_wh[0], img_wh[1], 3) # (len(image_paths), h, w, 3)) 92 | # all_rays_d = all_rays_d.reshape(-1, img_wh[0], img_wh[1], 3) 93 | assert all_rays_o.shape[-1] == 3, "all_rays_o should be (chunk, 3)" 94 | assert all_rays_d.shape[-1] == 3, "all_rays_d should be (chunk, 3)" 95 | old_min = torch.from_numpy(scene_boundaries[0]).to(all_rays_o.dtype).to(all_rays_o.device) 96 | old_max = torch.from_numpy(scene_boundaries[1]).to(all_rays_o.dtype).to(all_rays_o.device) 97 | new_min = torch.tensor([-1,-1,-1]).to(all_rays_o.dtype).to(all_rays_o.device) 98 | new_max = torch.tensor([1,1,1]).to(all_rays_o.dtype).to(all_rays_o.device) 99 | # scale = max(scene_boundaries[1] - scene_boundaries[0])/2 100 | # all_rays_o = (all_rays_o - torch.mean(all_rays_o, dim=-1, keepdim=True)) / scale 101 | # This is from jax3d.interp, kind of weird but true 102 | all_rays_o = ((new_min - new_max) / (old_min - old_max))*all_rays_o + (old_min * new_max - new_min * old_max) / (old_min - old_max) 103 | 104 | # We also need to rescale the camera direction by bbox.size. 105 | # The direction can be though of a ray from a point in space (the camera 106 | # origin) to another point in space (say the red light on the lego 107 | # bulldozer). When we scale the scene in a certain way, this direction 108 | # also needs to be scaled in the same way. 109 | all_rays_d = all_rays_d * 2 / (old_max - old_min) 110 | # (re)-normalize the rays 111 | all_rays_d = all_rays_d / torch.linalg.norm(all_rays_d, dim=-1, keepdims=True) 112 | return all_rays_o.reshape(-1, 3), all_rays_d.reshape(-1, 3) 113 | 114 | 115 | def calculate_near_and_far(rays_o, rays_d, bbox_min=[-1.,-1.,-1.], bbox_max=[1.,1.,1.]): 116 | ''' 117 | rays_o, (len(self.split_ids)*h*w, 3) 118 | rays_d, (len(self.split_ids)*h*w, 3) 119 | bbox_min=[-1,-1,-1], 120 | bbox_max=[1,1,1] 121 | ''' 122 | # map all shape to same (len(self.split_ids)*h*w, 3, 2) 123 | corners = torch.stack((torch.tensor(bbox_min),torch.tensor(bbox_max)), dim=-1).to(rays_o.dtype).to(rays_o.device) 124 | corners = corners.unsqueeze(0).repeat(rays_o.shape[0],1,1) # (len(self.split_ids)*h*w, 3, 2) 125 | corners -= torch.unsqueeze(rays_o, -1).repeat(1,1,2) 126 | intersections = (corners / (torch.unsqueeze(rays_d, -1).repeat(1,1,2))) 127 | 128 | min_intersections = torch.amax(torch.amin(intersections, dim=-1), dim=-1, keepdim=True) 129 | max_intersections = torch.amin(torch.amax(intersections, dim=-1), dim=-1, keepdim=True) 130 | epsilon = 1e-1*torch.ones_like(min_intersections) 131 | near = torch.maximum(epsilon, min_intersections) 132 | # tmp = near 133 | near = torch.where((near > max_intersections), epsilon, near) 134 | far = torch.where(near < max_intersections, max_intersections, near+epsilon) 135 | 136 | return near, far -------------------------------------------------------------------------------- /model/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from inplace_abn import InPlaceABN 5 | import segmentation_models_pytorch as smp 6 | 7 | class DoubleConv(nn.Module): 8 | """(convolution => [BN] => ReLU) * 2""" 9 | 10 | def __init__(self, in_channels, out_channels, mid_channels=None): 11 | super().__init__() 12 | if not mid_channels: 13 | mid_channels = out_channels 14 | self.double_conv = nn.Sequential( 15 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 16 | InPlaceABN(mid_channels), 17 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 18 | InPlaceABN(out_channels), 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | 25 | class Down(nn.Module): 26 | """Downscaling with maxpool then double conv""" 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | # nn.MaxPool2d(2), 32 | nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2, bias=False), 33 | InPlaceABN(out_channels), 34 | DoubleConv(out_channels, out_channels) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.maxpool_conv(x) 39 | 40 | 41 | class Up(nn.Module): 42 | """Upscaling then double conv""" 43 | 44 | def __init__(self, in_channels, out_channels, bilinear=True): 45 | super().__init__() 46 | self.bilinear = bilinear 47 | # if bilinear, use the normal convolutions to reduce the number of channels 48 | if bilinear: 49 | self.up = F.interpolate 50 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 51 | else: 52 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 53 | self.conv = DoubleConv(in_channels, out_channels) 54 | 55 | def forward(self, x1, x2): 56 | if self.bilinear: 57 | x1 = self.up(x1, scale_factor=2, mode="bilinear", align_corners=True) 58 | else: 59 | x1 = self.up(x1) 60 | 61 | x = torch.cat([x2, x1], dim=1) 62 | return self.conv(x) 63 | 64 | 65 | class OutConv(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(OutConv, self).__init__() 68 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 69 | 70 | def forward(self, x): 71 | return self.conv(x) 72 | 73 | 74 | class UNet(nn.Module): 75 | def __init__(self, n_channels, n_classes, bilinear=False): 76 | super(UNet, self).__init__() 77 | self.n_channels = n_channels 78 | self.n_classes = n_classes 79 | self.bilinear = bilinear 80 | 81 | self.inc = (DoubleConv(n_channels, 8)) 82 | self.down1 = (Down(8, 16)) 83 | self.down2 = (Down(16, 32)) 84 | self.down3 = (Down(32, 64)) 85 | # TODO: down four times might be too much 86 | self.down4 = (Down(64, 128 )) 87 | self.up1 = (Up(128, 64 , bilinear)) 88 | self.up2 = (Up(64, 32 , bilinear)) 89 | self.up3 = (Up(32, 16 , bilinear)) 90 | self.up4 = (Up(16, 8, bilinear)) 91 | self.outc1 = (OutConv(8, n_classes)) 92 | self.outc2 = (OutConv(n_classes, n_classes)) 93 | 94 | 95 | 96 | self.toplayer = nn.Conv2d(32, 32, 1) 97 | self.lat1 = nn.Conv2d(16, 32, 1) 98 | self.lat0 = nn.Conv2d(8, 32, 1) 99 | 100 | # to reduce channel size of the outputs from FPN 101 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) 102 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) 103 | 104 | def _upsample_add(self, x, y): 105 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y 106 | 107 | def forward(self, x): 108 | x1 = self.inc(x) 109 | x2 = self.down1(x1) 110 | x3 = self.down2(x2) 111 | x4 = self.down3(x3) 112 | x5 = self.down4(x4) 113 | x = self.up1(x5, x4) 114 | x = self.up2(x, x3) 115 | x = self.up3(x, x2) 116 | x = self.up4(x, x1) 117 | feature = self.outc1(x) 118 | logits = self.outc2(feature) 119 | 120 | # original FeatureNet used in depth estimation 121 | feat2 = self.toplayer(x3) # (B, 32, H//4, W//4) 122 | feat1 = self._upsample_add(feat2, self.lat1(x2)) # (B, 32, H//2, W//2) 123 | feat0 = self._upsample_add(feat1, self.lat0(x1)) # (B, 32, H, W) 124 | 125 | # reduce output channels 126 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) 127 | feat0 = self.smooth0(feat0) # (B, 8, H, W) 128 | 129 | output = {"level_0": feat0, "level_1": feat1, "level_2": feat2, 'logits': logits, 'feature': feature} 130 | 131 | return output 132 | 133 | class smp_UNet(nn.Module): 134 | def __init__(self, n_channels, n_classes, bilinear=False): 135 | super(smp_UNet, self).__init__() 136 | self.n_channels = n_channels 137 | self.n_classes = n_classes 138 | 139 | self.model = smp.Unet( 140 | encoder_name="timm-mobilenetv3_small_minimal_100", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 141 | in_channels=n_channels, # model input channels (1 for grayscale images, 3 for RGB, etc.) 142 | classes=self.n_classes, # model output channels (number of classes in your dataset) 143 | encoder_depth=4, 144 | decoder_channels=(128, 64, 64, 32), 145 | ) 146 | del self.model.encoder.model.blocks[4:] 147 | self.toplayer = nn.Conv2d(16, 32, 1) 148 | self.lat1 = nn.Conv2d(16, 32, 1) 149 | self.lat0 = nn.Conv2d(3, 32, 1) 150 | 151 | # to reduce channel size of the outputs from FPN 152 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) 153 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) 154 | 155 | self.projection = nn.Sequential( 156 | nn.Conv2d(self.n_classes, self.n_classes, 1), 157 | ) 158 | 159 | def _upsample_add(self, x, y): 160 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y 161 | 162 | def forward(self, x): 163 | feat = self.model.encoder(x) 164 | 165 | feature = self.model.decoder(*feat) 166 | feature = self.model.segmentation_head(feature) 167 | 168 | logits = self.projection(feature) 169 | 170 | x3 = feat[2] 171 | x2 = feat[1] 172 | x1 = feat[0] 173 | # original FeatureNet used in depth estimation 174 | feat2 = self.toplayer(x3) # (B, 32, H//4, W//4) 175 | feat1 = self._upsample_add(feat2, self.lat1(x2)) # (B, 32, H//2, W//2) 176 | feat0 = self._upsample_add(feat1, self.lat0(x1)) # (B, 32, H, W) 177 | 178 | # reduce output channels 179 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) 180 | feat0 = self.smooth0(feat0) # (B, 8, H, W) 181 | 182 | output = {"level_0": feat0, "level_1": feat1, "level_2": feat2, 'logits': logits, 'feature': feature} 183 | 184 | return output -------------------------------------------------------------------------------- /utils/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils.utils import normal_vect, interpolate_3D, interpolate_2D 5 | 6 | 7 | class Embedder: 8 | def __init__(self, **kwargs): 9 | self.kwargs = kwargs 10 | self.create_embedding_fn() 11 | 12 | def create_embedding_fn(self): 13 | embed_fns = [] 14 | 15 | if self.kwargs["include_input"]: 16 | embed_fns.append(lambda x: x) 17 | 18 | max_freq = self.kwargs["max_freq_log2"] 19 | N_freqs = self.kwargs["num_freqs"] 20 | 21 | if self.kwargs["log_sampling"]: 22 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) 25 | self.freq_bands = freq_bands.reshape(1, -1, 1) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs["periodic_fns"]: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | 31 | self.embed_fns = embed_fns 32 | 33 | def embed(self, inputs): 34 | repeat = inputs.dim() - 1 35 | self.freq_bands = self.freq_bands.to(inputs.device) 36 | inputs_scaled = ( 37 | inputs.unsqueeze(-2) * self.freq_bands.view(*[1] * repeat, -1, 1) 38 | ).reshape(*inputs.shape[:-1], -1) 39 | inputs_scaled = torch.cat( 40 | (inputs, torch.sin(inputs_scaled), torch.cos(inputs_scaled)), dim=-1 41 | ) 42 | return inputs_scaled 43 | 44 | 45 | def get_embedder(multires=4): 46 | 47 | embed_kwargs = { 48 | "include_input": True, 49 | "max_freq_log2": multires - 1, 50 | "num_freqs": multires, 51 | "log_sampling": True, 52 | "periodic_fns": [torch.sin, torch.cos], 53 | } 54 | 55 | embedder_obj = Embedder(**embed_kwargs) 56 | embed = lambda x, eo=embedder_obj: eo.embed(x) 57 | return embed 58 | 59 | 60 | def sigma2weights(sigma): 61 | alpha = 1.0 - torch.exp(-sigma) 62 | T = torch.cumprod( 63 | torch.cat( 64 | [torch.ones(alpha.shape[0], 1).to(alpha.device), 1.0 - alpha + 1e-10], -1 65 | ), 66 | -1, 67 | )[:, :-1] 68 | weights = alpha * T 69 | 70 | return weights 71 | 72 | 73 | def volume_rendering(rgb_sigma, pts_depth): 74 | rgb = rgb_sigma[..., :3] 75 | weights = sigma2weights(rgb_sigma[..., 3]) 76 | 77 | rendered_rgb = torch.sum(weights[..., None] * rgb, -2) 78 | rendered_depth = torch.sum(weights * pts_depth, -1) 79 | 80 | return rendered_rgb, rendered_depth 81 | 82 | 83 | def get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit): 84 | nb_rays = rays_pts.shape[0] 85 | ## Unit vectors from source cameras to the points on the ray 86 | dirs = normal_vect(rays_pts.unsqueeze(2) - c2ws[:, :3, 3][None, None]) 87 | ## Cosine of the angle between two directions 88 | angle_cos = torch.sum( 89 | dirs * rays_dir_unit.reshape(nb_rays, 1, 1, 3), dim=-1, keepdim=True 90 | ) 91 | # Cosine to Sine and approximating it as the angle (angle << 1 => sin(angle) = angle) 92 | angle = (1 - (angle_cos**2)).abs().sqrt() 93 | 94 | return angle 95 | 96 | 97 | def interpolate_pts_feats(imgs, feats_fpn, semantic_feat, feats_vol, rays_pts_ndc, padding_mode='border', use_batch_semantic_features=False): 98 | nb_views = feats_fpn.shape[1] 99 | interpolated_feats = [] 100 | 101 | for i in range(nb_views): 102 | ray_feats_0 = interpolate_3D(feats_vol[f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i], padding_mode=padding_mode) 103 | ray_feats_1 = interpolate_3D( 104 | feats_vol[f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i], padding_mode=padding_mode 105 | ) 106 | ray_feats_2 = interpolate_3D( 107 | feats_vol[f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i], padding_mode=padding_mode 108 | ) 109 | ray_feats_fpn, ray_colors, ray_semantic_feats_fpn, ray_masks = interpolate_2D( 110 | feats_fpn[:, i], imgs[:, i], semantic_feat[:, i], rays_pts_ndc[f"level_0"][:, :, i], padding_mode=padding_mode, use_batch_semantic_features=use_batch_semantic_features 111 | ) 112 | # When using only one point per ray, all features except ray masks are 2D (N_rays, C), while ray masks are 3D (N_rays, C, 1), so we need to squeeze it 113 | if ray_colors.dim() == 2 and ray_masks.dim() == 3: 114 | ray_masks = ray_masks.squeeze(-1) 115 | 116 | interpolated_feats.append( 117 | torch.cat( 118 | [ 119 | ray_feats_0, # (N_rays, N_samples, 8) 120 | ray_feats_1, # (N_rays, N_samples, 8) 121 | ray_feats_2, # (N_rays, N_samples, 8) 122 | ray_feats_fpn, # (N_rays, N_samples, 8) 123 | ray_colors, # (N_rays, N_samples, 3) 124 | ray_semantic_feats_fpn, # (N_rays, N_samples, nb_classes(21)) / if use_batch_semantic_features, (N_rays, N_samples, 9*nb_classes(21)) 125 | ray_masks, # (N_rays, N_samples, 1) 126 | ], 127 | dim=-1, 128 | ) 129 | ) 130 | interpolated_feats = torch.stack(interpolated_feats, dim=2) 131 | if torch.isnan(interpolated_feats).any(): 132 | print("interpolated_feats has nan values") 133 | return interpolated_feats 134 | 135 | 136 | def get_occ_masks(depth_map_norm, rays_pts_ndc, visibility_thr=0.2): 137 | nb_views = depth_map_norm["level_0"].shape[1] 138 | z_diff = [] 139 | for i in range(nb_views): 140 | ## Interpolate depth maps corresponding to each sample point 141 | # [1 H W 3] (x,y,z) 142 | grid = rays_pts_ndc[f"level_0"][None, :, :, i, :2] * 2 - 1.0 143 | rays_depths = F.grid_sample( 144 | depth_map_norm["level_0"][:, i : i + 1], 145 | grid, 146 | align_corners=True, 147 | mode="bilinear", 148 | padding_mode="border", 149 | )[0, 0] 150 | z_diff.append(rays_pts_ndc["level_0"][:, :, i, 2] - rays_depths) 151 | z_diff = torch.stack(z_diff, dim=2) 152 | 153 | occ_masks = z_diff.unsqueeze(-1) < visibility_thr 154 | 155 | return occ_masks 156 | 157 | 158 | def render_rays( 159 | c2ws, 160 | rays_pts, 161 | rays_pts_ndc, 162 | pts_depth, 163 | rays_dir, 164 | feats_vol, 165 | feats_fpn, 166 | imgs, 167 | depth_map_norm, 168 | renderer_net, 169 | middle_pts_mask, 170 | semantic_feat=None, 171 | use_batch_semantic_features=False, 172 | ): 173 | ## The angles between the ray and source camera vectors 174 | rays_dir_unit = rays_dir / torch.norm(rays_dir, dim=-1, keepdim=True) 175 | angles = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit) 176 | 177 | ## Positional encoding 178 | embedded_angles = get_embedder()(angles) 179 | 180 | ## Interpolate all features for sample points (N_rays, N_samples, source_view_num, 8+8+8+8+3+21+1) 181 | pts_feat = interpolate_pts_feats(imgs, feats_fpn, semantic_feat, feats_vol, rays_pts_ndc, use_batch_semantic_features=use_batch_semantic_features) 182 | 183 | ## Getting Occlusion Masks based on predicted depths (N_rays, N_samples, source_view_num, 1) 184 | occ_masks = get_occ_masks(depth_map_norm, rays_pts_ndc) 185 | 186 | ## rendering sigma and RGB values 187 | rgb_sigma, rendered_semantic = renderer_net(embedded_angles, pts_feat, occ_masks, middle_pts_mask) 188 | 189 | rendered_rgb, rendered_depth = volume_rendering(rgb_sigma, pts_depth) 190 | 191 | if torch.isnan(rendered_semantic).sum() > 0: 192 | print("NaN in rendered_semantic") 193 | 194 | return rendered_rgb, rendered_semantic, rendered_depth 195 | -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | import torch 23 | from torch.utils.data import ConcatDataset, WeightedRandomSampler 24 | import numpy as np 25 | 26 | from data.llff import LLFF_Dataset 27 | from data.dtu import DTU_Dataset 28 | from data.nerf import NeRF_Dataset 29 | from data.klevr import KlevrDataset 30 | from data.scannet import RendererDataset 31 | 32 | def get_training_dataset(args, downsample=1.0): 33 | 34 | if args.dataset_name == "klevr": 35 | train_dataset = KlevrDataset( 36 | root_dir=args.klevr_path, 37 | split="train", 38 | max_len=-1, 39 | downSample=downsample, 40 | nb_views=args.nb_views, 41 | get_semantic=args.segmentation, 42 | ) 43 | train_sampler = None 44 | return train_dataset, train_sampler 45 | 46 | elif args.dataset_name == "scannet": 47 | if args.finetune: 48 | cfg = { 49 | 'val_database_name': args.fintune_scene, 50 | 'min_wn': args.nb_views, 'max_wn': args.nb_views+1, 51 | 'type2sample_weights': {'scannet_single': 1}, 52 | 'train_database_types': ['scannet_single'], 53 | } 54 | else: 55 | cfg = {'min_wn': args.nb_views, 'max_wn': args.nb_views+1} 56 | train_dataset = RendererDataset( 57 | root_dir=args.scannet_path, 58 | is_train=True, 59 | cfg=cfg, 60 | ) 61 | train_sampler = None 62 | return train_dataset, train_sampler 63 | elif args.dataset_name == "replica": 64 | cfg = {"resolution_type": "hr", "type2sample_weights": {"replica": 1}, "train_database_types": ['replica'], 'min_wn': args.nb_views, 'max_wn': args.nb_views+1} 65 | train_dataset = RendererDataset( 66 | cfg=cfg, 67 | root_dir=args.replica_path, 68 | is_train=True 69 | ) 70 | train_sampler = None 71 | return train_dataset, train_sampler 72 | 73 | train_datasets = [ 74 | DTU_Dataset( 75 | original_root_dir=args.dtu_path, 76 | preprocessed_root_dir=args.dtu_pre_path, 77 | split="train", 78 | max_len=-1, 79 | downSample=downsample, 80 | nb_views=args.nb_views, 81 | ), 82 | LLFF_Dataset( 83 | root_dir=args.ibrnet1_path, 84 | split="train", 85 | max_len=-1, 86 | downSample=downsample, 87 | nb_views=args.nb_views, 88 | imgs_folder_name="images", 89 | ), 90 | LLFF_Dataset( 91 | root_dir=args.ibrnet2_path, 92 | split="train", 93 | max_len=-1, 94 | downSample=downsample, 95 | nb_views=args.nb_views, 96 | imgs_folder_name="images", 97 | ), 98 | LLFF_Dataset( 99 | root_dir=args.llff_path, 100 | split="train", 101 | max_len=-1, 102 | downSample=downsample, 103 | nb_views=args.nb_views, 104 | imgs_folder_name="images_4", 105 | ), 106 | ] 107 | weights = [0.5, 0.22, 0.12, 0.16] 108 | 109 | train_weights_samples = [] 110 | for dataset, weight in zip(train_datasets, weights): 111 | num_samples = len(dataset) 112 | weight_each_sample = weight / num_samples 113 | train_weights_samples.extend([weight_each_sample] * num_samples) 114 | 115 | train_dataset = ConcatDataset(train_datasets) 116 | train_weights = torch.from_numpy(np.array(train_weights_samples)) 117 | train_sampler = WeightedRandomSampler(train_weights, len(train_weights)) 118 | 119 | return train_dataset, train_sampler 120 | 121 | 122 | def get_finetuning_dataset(args, downsample=1.0): 123 | if args.dataset_name == "dtu": 124 | train_dataset = DTU_Dataset( 125 | original_root_dir=args.dtu_path, 126 | preprocessed_root_dir=args.dtu_pre_path, 127 | split="train", 128 | max_len=-1, 129 | downSample=downsample, 130 | nb_views=args.nb_views, 131 | scene=args.scene, 132 | ) 133 | elif args.dataset_name == "llff": 134 | train_dataset = LLFF_Dataset( 135 | root_dir=args.llff_path, 136 | split="train", 137 | max_len=-1, 138 | downSample=downsample, 139 | nb_views=args.nb_views, 140 | scene=args.scene, 141 | imgs_folder_name="images_4", 142 | ) 143 | elif args.dataset_name == "nerf": 144 | train_dataset = NeRF_Dataset( 145 | root_dir=args.nerf_path, 146 | split="train", 147 | max_len=-1, 148 | downSample=downsample, 149 | nb_views=args.nb_views, 150 | scene=args.scene, 151 | ) 152 | 153 | train_sampler = None 154 | 155 | return train_dataset, train_sampler 156 | 157 | 158 | def get_validation_dataset(args, downsample=1.0): 159 | if args.scene == "None": 160 | max_len = 10 161 | else: 162 | max_len = -1 163 | 164 | if args.dataset_name == "dtu": 165 | val_dataset = DTU_Dataset( 166 | original_root_dir=args.dtu_path, 167 | preprocessed_root_dir=args.dtu_pre_path, 168 | split="val", 169 | max_len=max_len, 170 | downSample=downsample, 171 | nb_views=args.nb_views, 172 | scene=args.scene, 173 | ) 174 | elif args.dataset_name == "llff": 175 | val_dataset = LLFF_Dataset( 176 | root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path, 177 | split="val", 178 | max_len=max_len, 179 | downSample=downsample, 180 | nb_views=args.nb_views, 181 | scene=args.scene, 182 | imgs_folder_name="images_4", 183 | ) 184 | elif args.dataset_name == "nerf": 185 | val_dataset = NeRF_Dataset( 186 | root_dir=args.nerf_path, 187 | split="val", 188 | max_len=max_len, 189 | downSample=downsample, 190 | nb_views=args.nb_views, 191 | scene=args.scene, 192 | ) 193 | elif args.dataset_name == "klevr": 194 | val_dataset = KlevrDataset( 195 | root_dir=args.klevr_path, 196 | split="val", 197 | max_len=max_len, 198 | downSample=downsample, 199 | nb_views=args.nb_views, 200 | get_semantic=args.segmentation, 201 | ) 202 | elif args.dataset_name == "scannet": 203 | val_set_list, val_set_names = [], [] 204 | if args.finetune: 205 | val_cfg = {'val_database_name': args.fintune_scene,'min_wn': args.nb_views, 'max_wn': args.nb_views+1} 206 | val_set = RendererDataset(cfg=val_cfg, is_train=False, root_dir=args.scannet_path) 207 | return val_set 208 | if isinstance(args.val_set_list, str): 209 | val_scenes = np.loadtxt(args.val_set_list, dtype=str).tolist() 210 | for name in val_scenes: 211 | val_cfg = {'val_database_name': name,'min_wn': args.nb_views, 'max_wn': args.nb_views+1} 212 | val_set = RendererDataset(cfg=val_cfg, is_train=False, root_dir=args.scannet_path) 213 | val_set_list.append(val_set) 214 | val_set_names.append(name) 215 | print(f'{name} val set len {len(val_set)}') 216 | # print("only one scene") 217 | # break 218 | val_dataset = ConcatDataset(val_set_list) 219 | elif args.dataset_name == "replica": 220 | val_set_list, val_set_names = [], [] 221 | if isinstance(args.val_set_list, str): 222 | val_scenes = np.loadtxt(args.val_set_list, dtype=str).tolist() 223 | for name in val_scenes: 224 | val_cfg = {'val_database_name': name,'min_wn': args.nb_views, 'max_wn': args.nb_views+1} 225 | val_set = RendererDataset(cfg=val_cfg, is_train=False, root_dir=args.replica_path) 226 | val_set_list.append(val_set) 227 | val_set_names.append(name) 228 | print(f'{name} val set len {len(val_set)}') 229 | val_dataset = ConcatDataset(val_set_list) 230 | else: 231 | raise NotImplementedError 232 | 233 | return val_dataset 234 | -------------------------------------------------------------------------------- /data/nerf.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | from torch.utils.data import Dataset 48 | from torchvision import transforms as T 49 | 50 | import os 51 | import json 52 | import numpy as np 53 | from PIL import Image 54 | 55 | from utils.utils import get_nearest_pose_ids 56 | 57 | class NeRF_Dataset(Dataset): 58 | def __init__( 59 | self, 60 | root_dir, 61 | split, 62 | nb_views, 63 | downSample=1.0, 64 | max_len=-1, 65 | scene="None", 66 | ): 67 | self.root_dir = root_dir 68 | self.split = split 69 | self.nb_views = nb_views 70 | self.scene = scene 71 | 72 | self.downsample = downSample 73 | self.max_len = max_len 74 | 75 | self.img_wh = (int(800 * self.downsample), int(800 * self.downsample)) 76 | 77 | self.define_transforms() 78 | self.blender2opencv = np.array( 79 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 80 | ) 81 | 82 | self.build_metas() 83 | 84 | def define_transforms(self): 85 | self.transform = T.ToTensor() 86 | 87 | self.src_transform = T.Compose( 88 | [ 89 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 90 | ] 91 | ) 92 | 93 | def build_metas(self): 94 | self.meta = {} 95 | with open( 96 | os.path.join(self.root_dir, self.scene, "transforms_train.json"), "r" 97 | ) as f: 98 | self.meta["train"] = json.load(f) 99 | 100 | with open( 101 | os.path.join(self.root_dir, self.scene, "transforms_test.json"), "r" 102 | ) as f: 103 | self.meta["val"] = json.load(f) 104 | 105 | w, h = self.img_wh 106 | 107 | # original focal length 108 | focal = 0.5 * 800 / np.tan(0.5 * self.meta["train"]["camera_angle_x"]) 109 | 110 | # modify focal length to match size self.img_wh 111 | focal *= self.img_wh[0] / 800 112 | 113 | self.near_far = np.array([2.0, 6.0]) 114 | 115 | self.image_paths = {"train": [], "val": []} 116 | self.c2ws = {"train": [], "val": []} 117 | self.w2cs = {"train": [], "val": []} 118 | self.intrinsics = {"train": [], "val": []} 119 | 120 | for frame in self.meta["train"]["frames"]: 121 | self.image_paths["train"].append( 122 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png") 123 | ) 124 | 125 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv 126 | w2c = np.linalg.inv(c2w) 127 | self.c2ws["train"].append(c2w) 128 | self.w2cs["train"].append(w2c) 129 | 130 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]) 131 | self.intrinsics["train"].append(intrinsic.copy()) 132 | 133 | self.c2ws["train"] = np.stack(self.c2ws["train"], axis=0) 134 | self.w2cs["train"] = np.stack(self.w2cs["train"], axis=0) 135 | self.intrinsics["train"] = np.stack(self.intrinsics["train"], axis=0) 136 | 137 | for frame in self.meta["val"]["frames"]: 138 | self.image_paths["val"].append( 139 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png") 140 | ) 141 | 142 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv 143 | w2c = np.linalg.inv(c2w) 144 | self.c2ws["val"].append(c2w) 145 | self.w2cs["val"].append(w2c) 146 | 147 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]) 148 | self.intrinsics["val"].append(intrinsic.copy()) 149 | 150 | self.c2ws["val"] = np.stack(self.c2ws["val"], axis=0) 151 | self.w2cs["val"] = np.stack(self.w2cs["val"], axis=0) 152 | self.intrinsics["val"] = np.stack(self.intrinsics["val"], axis=0) 153 | 154 | def __len__(self): 155 | return len(self.image_paths[self.split]) if self.max_len <= 0 else self.max_len 156 | 157 | def __getitem__(self, idx): 158 | target_frame = self.meta[self.split]["frames"][idx] 159 | c2w = np.array(target_frame["transform_matrix"]) @ self.blender2opencv 160 | w2c = np.linalg.inv(c2w) 161 | 162 | if self.split == "train": 163 | src_views = get_nearest_pose_ids( 164 | c2w, 165 | ref_poses=self.c2ws["train"], 166 | num_select=self.nb_views + 1, 167 | angular_dist_method="dist", 168 | )[1:] 169 | else: 170 | src_views = get_nearest_pose_ids( 171 | c2w, 172 | ref_poses=self.c2ws["train"], 173 | num_select=self.nb_views, 174 | angular_dist_method="dist", 175 | ) 176 | 177 | imgs, depths, depths_h, depths_aug = [], [], [], [] 178 | intrinsics, w2cs, c2ws, near_fars = [], [], [], [] 179 | affine_mats, affine_mats_inv = [], [] 180 | 181 | w, h = self.img_wh 182 | 183 | for vid in src_views: 184 | img_filename = self.image_paths["train"][vid] 185 | img = Image.open(img_filename) 186 | if img.size != (w, h): 187 | img = img.resize((w, h), Image.BICUBIC) 188 | 189 | img = self.transform(img) 190 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB 191 | imgs.append(self.src_transform(img)) 192 | 193 | intrinsic = self.intrinsics["train"][vid] 194 | intrinsics.append(intrinsic) 195 | 196 | w2c = self.w2cs["train"][vid] 197 | w2cs.append(w2c) 198 | c2ws.append(self.c2ws["train"][vid]) 199 | 200 | aff = [] 201 | aff_inv = [] 202 | for l in range(3): 203 | proj_mat_l = np.eye(4) 204 | intrinsic_temp = intrinsic.copy() 205 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l) 206 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4] 207 | aff.append(proj_mat_l.copy()) 208 | aff_inv.append(np.linalg.inv(proj_mat_l)) 209 | aff = np.stack(aff, axis=-1) 210 | aff_inv = np.stack(aff_inv, axis=-1) 211 | 212 | affine_mats.append(aff) 213 | affine_mats_inv.append(aff_inv) 214 | 215 | near_fars.append(self.near_far) 216 | 217 | depths_h.append(np.zeros([h, w])) 218 | depths.append(np.zeros([h // 4, w // 4])) 219 | depths_aug.append(np.zeros([h // 4, w // 4])) 220 | 221 | ## Adding target data 222 | img_filename = self.image_paths[self.split][idx] 223 | img = Image.open(img_filename) 224 | if img.size != (w, h): 225 | img = img.resize((w, h), Image.BICUBIC) 226 | 227 | img = self.transform(img) # (4, h, w) 228 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB 229 | imgs.append(self.src_transform(img)) 230 | 231 | intrinsic = self.intrinsics[self.split][idx] 232 | intrinsics.append(intrinsic) 233 | 234 | w2c = self.w2cs[self.split][idx] 235 | w2cs.append(w2c) 236 | c2ws.append(self.c2ws[self.split][idx]) 237 | 238 | near_fars.append(self.near_far) 239 | 240 | depths_h.append(np.zeros([h, w])) 241 | depths.append(np.zeros([h // 4, w // 4])) 242 | depths_aug.append(np.zeros([h // 4, w // 4])) 243 | 244 | ## Stacking 245 | imgs = np.stack(imgs) 246 | depths = np.stack(depths) 247 | depths_h = np.stack(depths_h) 248 | depths_aug = np.stack(depths_aug) 249 | affine_mats = np.stack(affine_mats) 250 | affine_mats_inv = np.stack(affine_mats_inv) 251 | intrinsics = np.stack(intrinsics) 252 | w2cs = np.stack(w2cs) 253 | c2ws = np.stack(c2ws) 254 | near_fars = np.stack(near_fars) 255 | 256 | closest_idxs = [] 257 | for pose in c2ws[:-1]: 258 | closest_idxs.append( 259 | get_nearest_pose_ids( 260 | pose, ref_poses=c2ws[:-1], num_select=5, angular_dist_method="dist" 261 | ) 262 | ) 263 | closest_idxs = np.stack(closest_idxs, axis=0) 264 | 265 | sample = {} 266 | sample["images"] = imgs 267 | sample["depths"] = depths 268 | sample["depths_h"] = depths_h 269 | sample["depths_aug"] = depths_aug 270 | sample["w2cs"] = w2cs.astype("float32") 271 | sample["c2ws"] = c2ws.astype("float32") 272 | sample["near_fars"] = near_fars 273 | sample["affine_mats"] = affine_mats 274 | sample["affine_mats_inv"] = affine_mats_inv 275 | sample["intrinsics"] = intrinsics.astype("float32") 276 | sample["closest_idxs"] = closest_idxs 277 | 278 | return sample 279 | -------------------------------------------------------------------------------- /data/llff.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | from torch.utils.data import Dataset 48 | from torchvision import transforms as T 49 | 50 | import os 51 | import glob 52 | import numpy as np 53 | from PIL import Image 54 | 55 | from utils.utils import get_nearest_pose_ids 56 | 57 | def normalize(v): 58 | return v / np.linalg.norm(v) 59 | 60 | 61 | def average_poses(poses): 62 | # 1. Compute the center 63 | center = poses[..., 3].mean(0) # (3) 64 | 65 | # 2. Compute the z axis 66 | z = normalize(poses[..., 2].mean(0)) # (3) 67 | 68 | # 3. Compute axis y' (no need to normalize as it's not the final output) 69 | y_ = poses[..., 1].mean(0) # (3) 70 | 71 | # 4. Compute the x axis 72 | x = normalize(np.cross(y_, z)) # (3) 73 | 74 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 75 | y = np.cross(z, x) # (3) 76 | 77 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 78 | 79 | return pose_avg 80 | 81 | 82 | def center_poses(poses, blender2opencv): 83 | pose_avg = average_poses(poses) # (3, 4) 84 | pose_avg_homo = np.eye(4) 85 | 86 | # convert to homogeneous coordinate for faster computation 87 | # by simply adding 0, 0, 0, 1 as the last row 88 | pose_avg_homo[:3] = pose_avg 89 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 90 | 91 | # (N_images, 4, 4) homogeneous coordinate 92 | poses_homo = np.concatenate([poses, last_row], 1) 93 | 94 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 95 | poses_centered = poses_centered @ blender2opencv 96 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 97 | 98 | return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv 99 | 100 | 101 | class LLFF_Dataset(Dataset): 102 | def __init__( 103 | self, 104 | root_dir, 105 | split, 106 | nb_views, 107 | downSample=1.0, 108 | max_len=-1, 109 | scene="None", 110 | imgs_folder_name="images", 111 | ): 112 | self.root_dir = root_dir 113 | self.split = split 114 | self.nb_views = nb_views 115 | self.scene = scene 116 | self.imgs_folder_name = imgs_folder_name 117 | 118 | self.downsample = downSample 119 | self.max_len = max_len 120 | self.img_wh = (int(960 * self.downsample), int(720 * self.downsample)) 121 | 122 | self.define_transforms() 123 | self.blender2opencv = np.array( 124 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 125 | ) 126 | 127 | self.build_metas() 128 | 129 | def define_transforms(self): 130 | self.transform = T.Compose( 131 | [ 132 | T.ToTensor(), 133 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 134 | ] 135 | ) 136 | 137 | def build_metas(self): 138 | if self.scene != "None": 139 | self.scans = [ 140 | os.path.basename(scan_dir) 141 | for scan_dir in sorted( 142 | glob.glob(os.path.join(self.root_dir, self.scene)) 143 | ) 144 | ] 145 | else: 146 | self.scans = [ 147 | os.path.basename(scan_dir) 148 | for scan_dir in sorted(glob.glob(os.path.join(self.root_dir, "*"))) 149 | ] 150 | 151 | self.meta = [] 152 | self.image_paths = {} 153 | self.near_far = {} 154 | self.id_list = {} 155 | self.closest_idxs = {} 156 | self.c2ws = {} 157 | self.w2cs = {} 158 | self.intrinsics = {} 159 | self.affine_mats = {} 160 | self.affine_mats_inv = {} 161 | for scan in self.scans: 162 | self.image_paths[scan] = sorted( 163 | glob.glob(os.path.join(self.root_dir, scan, self.imgs_folder_name, "*")) 164 | ) 165 | poses_bounds = np.load( 166 | os.path.join(self.root_dir, scan, "poses_bounds.npy") 167 | ) # (N_images, 17) 168 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 169 | bounds = poses_bounds[:, -2:] # (N_images, 2) 170 | 171 | # Step 1: rescale focal length according to training resolution 172 | H, W, focal = poses[0, :, -1] # original intrinsics, same for all images 173 | 174 | focal = [focal * self.img_wh[0] / W, focal * self.img_wh[1] / H] 175 | 176 | # Step 2: correct poses 177 | poses = np.concatenate( 178 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1 179 | ) 180 | poses, _ = center_poses(poses, self.blender2opencv) 181 | # poses = poses @ self.blender2opencv 182 | 183 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 184 | near_original = bounds.min() 185 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 186 | bounds /= scale_factor 187 | poses[..., 3] /= scale_factor 188 | 189 | self.near_far[scan] = bounds.astype('float32') 190 | 191 | num_viewpoint = len(self.image_paths[scan]) 192 | val_ids = [idx for idx in range(0, num_viewpoint, 8)] 193 | w, h = self.img_wh 194 | 195 | self.id_list[scan] = [] 196 | self.closest_idxs[scan] = [] 197 | self.c2ws[scan] = [] 198 | self.w2cs[scan] = [] 199 | self.intrinsics[scan] = [] 200 | self.affine_mats[scan] = [] 201 | self.affine_mats_inv[scan] = [] 202 | for idx in range(num_viewpoint): 203 | if ( 204 | (self.split == "val" and idx in val_ids) 205 | or ( 206 | self.split == "train" 207 | and self.scene != "None" 208 | and idx not in val_ids 209 | ) 210 | or (self.split == "train" and self.scene == "None") 211 | ): 212 | self.meta.append({"scan": scan, "target_idx": idx}) 213 | 214 | view_ids = get_nearest_pose_ids( 215 | poses[idx, :, :], 216 | ref_poses=poses[..., :], 217 | num_select=self.nb_views + 1, 218 | angular_dist_method="dist", 219 | ) 220 | 221 | self.id_list[scan].append(view_ids) 222 | 223 | closest_idxs = [] 224 | source_views = view_ids[1:] 225 | for vid in source_views: 226 | closest_idxs.append( 227 | get_nearest_pose_ids( 228 | poses[vid, :, :], 229 | ref_poses=poses[source_views], 230 | num_select=5, 231 | angular_dist_method="dist", 232 | ) 233 | ) 234 | self.closest_idxs[scan].append(np.stack(closest_idxs, axis=0)) 235 | 236 | c2w = np.eye(4).astype('float32') 237 | c2w[:3] = poses[idx] 238 | w2c = np.linalg.inv(c2w) 239 | self.c2ws[scan].append(c2w) 240 | self.w2cs[scan].append(w2c) 241 | 242 | intrinsic = np.array([[focal[0], 0, w / 2], [0, focal[1], h / 2], [0, 0, 1]]).astype('float32') 243 | self.intrinsics[scan].append(intrinsic) 244 | 245 | def __len__(self): 246 | return len(self.meta) if self.max_len <= 0 else self.max_len 247 | 248 | def __getitem__(self, idx): 249 | if self.split == "train" and self.scene == "None": 250 | noisy_factor = float(np.random.choice([1.0, 0.75, 0.5], 1)) 251 | close_views = int(np.random.choice([3, 4, 5], 1)) 252 | else: 253 | noisy_factor = 1.0 254 | close_views = 5 255 | 256 | scan = self.meta[idx]["scan"] 257 | target_idx = self.meta[idx]["target_idx"] 258 | 259 | view_ids = self.id_list[scan][target_idx] 260 | target_view = view_ids[0] 261 | src_views = view_ids[1:] 262 | view_ids = [vid for vid in src_views] + [target_view] 263 | 264 | closest_idxs = self.closest_idxs[scan][target_idx][:, :close_views] 265 | 266 | imgs, depths, depths_h, depths_aug = [], [], [], [] 267 | intrinsics, w2cs, c2ws, near_fars = [], [], [], [] 268 | affine_mats, affine_mats_inv = [], [] 269 | 270 | w, h = self.img_wh 271 | w, h = int(w * noisy_factor), int(h * noisy_factor) 272 | 273 | for vid in view_ids: 274 | img_filename = self.image_paths[scan][vid] 275 | img = Image.open(img_filename).convert("RGB") 276 | if img.size != (w, h): 277 | img = img.resize((w, h), Image.BICUBIC) 278 | img = self.transform(img) 279 | imgs.append(img) 280 | 281 | intrinsic = self.intrinsics[scan][vid].copy() 282 | intrinsic[:2] = intrinsic[:2] * noisy_factor 283 | intrinsics.append(intrinsic) 284 | 285 | w2c = self.w2cs[scan][vid] 286 | w2cs.append(w2c) 287 | c2ws.append(self.c2ws[scan][vid]) 288 | 289 | aff = [] 290 | aff_inv = [] 291 | for l in range(3): 292 | proj_mat_l = np.eye(4) 293 | intrinsic_temp = intrinsic.copy() 294 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l) 295 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4] 296 | aff.append(proj_mat_l.copy()) 297 | aff_inv.append(np.linalg.inv(proj_mat_l)) 298 | aff = np.stack(aff, axis=-1) 299 | aff_inv = np.stack(aff_inv, axis=-1) 300 | 301 | affine_mats.append(aff) 302 | affine_mats_inv.append(aff_inv) 303 | 304 | near_fars.append(self.near_far[scan][vid]) 305 | 306 | depths_h.append(np.zeros([h, w])) 307 | depths.append(np.zeros([h // 4, w // 4])) 308 | depths_aug.append(np.zeros([h // 4, w // 4])) 309 | 310 | imgs = np.stack(imgs) 311 | depths = np.stack(depths) 312 | depths_h = np.stack(depths_h) 313 | depths_aug = np.stack(depths_aug) 314 | affine_mats = np.stack(affine_mats) 315 | affine_mats_inv = np.stack(affine_mats_inv) 316 | intrinsics = np.stack(intrinsics) 317 | w2cs = np.stack(w2cs) 318 | c2ws = np.stack(c2ws) 319 | near_fars = np.stack(near_fars) 320 | 321 | sample = {} 322 | sample["images"] = imgs 323 | sample["depths"] = depths 324 | sample["depths_h"] = depths_h 325 | sample["depths_aug"] = depths_aug 326 | sample["w2cs"] = w2cs 327 | sample["c2ws"] = c2ws 328 | sample["near_fars"] = near_fars 329 | sample["affine_mats"] = affine_mats 330 | sample["affine_mats_inv"] = affine_mats_inv 331 | sample["intrinsics"] = intrinsics 332 | sample["closest_idxs"] = closest_idxs 333 | 334 | return sample 335 | -------------------------------------------------------------------------------- /data/dtu.py: -------------------------------------------------------------------------------- 1 | # GeoNeRF is a generalizable NeRF model that renders novel views 2 | # without requiring per-scene optimization. This software is the 3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with 4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin, 5 | # and Francois Fleuret. 6 | 7 | # Copyright (c) 2022 ams International AG 8 | 9 | # This file is part of GeoNeRF. 10 | # GeoNeRF is free software: you can redistribute it and/or modify 11 | # it under the terms of the GNU General Public License version 3 as 12 | # published by the Free Software Foundation. 13 | 14 | # GeoNeRF is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | 19 | # You should have received a copy of the GNU General Public License 20 | # along with GeoNeRF. If not, see . 21 | 22 | # This file incorporates work covered by the following copyright and 23 | # permission notice: 24 | 25 | # MIT License 26 | 27 | # Copyright (c) 2021 apchenstu 28 | 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | from torch.utils.data import Dataset 48 | from torchvision import transforms as T 49 | 50 | import os 51 | import cv2 52 | import numpy as np 53 | from PIL import Image 54 | 55 | from utils.utils import read_pfm, get_nearest_pose_ids 56 | 57 | class DTU_Dataset(Dataset): 58 | def __init__( 59 | self, 60 | original_root_dir, 61 | preprocessed_root_dir, 62 | split, 63 | nb_views, 64 | downSample=1.0, 65 | max_len=-1, 66 | scene="None", 67 | ): 68 | self.original_root_dir = original_root_dir 69 | self.preprocessed_root_dir = preprocessed_root_dir 70 | self.split = split 71 | self.scene = scene 72 | 73 | self.downSample = downSample 74 | self.scale_factor = 1.0 / 200 75 | self.interval_scale = 1.06 76 | self.max_len = max_len 77 | self.nb_views = nb_views 78 | 79 | self.build_metas() 80 | self.build_proj_mats() 81 | self.define_transforms() 82 | 83 | def define_transforms(self): 84 | self.transform = T.Compose( 85 | [ 86 | T.ToTensor(), 87 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 88 | ] 89 | ) 90 | 91 | def build_metas(self): 92 | self.metas = [] 93 | with open(f"configs/lists/dtu_{self.split}_all.txt") as f: 94 | self.scans = [line.rstrip() for line in f.readlines()] 95 | if self.scene != "None": 96 | self.scans = [self.scene] 97 | 98 | # light conditions 2-5 for training 99 | # light condition 3 for testing (the brightest?) 100 | light_idxs = ( 101 | [3] if "train" != self.split or self.scene != "None" else range(2, 5) 102 | ) 103 | 104 | self.id_list = [] 105 | 106 | if self.split == "train": 107 | if self.scene == "None": 108 | pair_file = f"configs/lists/dtu_pairs.txt" 109 | else: 110 | pair_file = f"configs/lists/dtu_pairs_ft.txt" 111 | else: 112 | pair_file = f"configs/lists/dtu_pairs_val.txt" 113 | 114 | for scan in self.scans: 115 | with open(pair_file) as f: 116 | num_viewpoint = int(f.readline()) 117 | for _ in range(num_viewpoint): 118 | ref_view = int(f.readline().rstrip()) 119 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 120 | for light_idx in light_idxs: 121 | self.metas += [(scan, light_idx, ref_view, src_views)] 122 | self.id_list.append([ref_view] + src_views) 123 | 124 | self.id_list = np.unique(self.id_list) 125 | self.build_remap() 126 | 127 | def build_proj_mats(self): 128 | near_fars, intrinsics, world2cams, cam2worlds = [], [], [], [] 129 | for vid in self.id_list: 130 | proj_mat_filename = os.path.join( 131 | self.preprocessed_root_dir, f"Cameras/train/{vid:08d}_cam.txt" 132 | ) 133 | intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename) 134 | intrinsic[:2] *= 4 135 | extrinsic[:3, 3] *= self.scale_factor 136 | 137 | intrinsic[:2] = intrinsic[:2] * self.downSample 138 | intrinsics += [intrinsic.copy()] 139 | 140 | near_fars += [near_far] 141 | world2cams += [extrinsic] 142 | cam2worlds += [np.linalg.inv(extrinsic)] 143 | 144 | self.near_fars, self.intrinsics = np.stack(near_fars), np.stack(intrinsics) 145 | self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds) 146 | 147 | def read_cam_file(self, filename): 148 | with open(filename) as f: 149 | lines = [line.rstrip() for line in f.readlines()] 150 | # extrinsics: line [1,5), 4x4 matrix 151 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ") 152 | extrinsics = extrinsics.reshape((4, 4)) 153 | # intrinsics: line [7-10), 3x3 matrix 154 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ") 155 | intrinsics = intrinsics.reshape((3, 3)) 156 | # depth_min & depth_interval: line 11 157 | depth_min, depth_interval = lines[11].split() 158 | depth_min = float(depth_min) * self.scale_factor 159 | depth_max = depth_min + float(depth_interval) * 192 * self.interval_scale * self.scale_factor 160 | 161 | intrinsics[0, 2] = intrinsics[0, 2] + 80.0 / 4.0 162 | intrinsics[1, 2] = intrinsics[1, 2] + 44.0 / 4.0 163 | intrinsics[:2] = intrinsics[:2] 164 | 165 | return intrinsics, extrinsics, [depth_min, depth_max] 166 | 167 | def read_depth(self, filename, far_bound, noisy_factor=1.0): 168 | depth_h = self.scale_factor * np.array( 169 | read_pfm(filename)[0], dtype=np.float32 170 | ) 171 | depth_h = cv2.resize( 172 | depth_h, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST 173 | ) 174 | 175 | depth_h = cv2.resize( 176 | depth_h, 177 | None, 178 | fx=self.downSample * noisy_factor, 179 | fy=self.downSample * noisy_factor, 180 | interpolation=cv2.INTER_NEAREST, 181 | ) 182 | 183 | ## Exclude points beyond the bounds 184 | depth_h[depth_h > far_bound * 0.95] = 0.0 185 | 186 | depth = {} 187 | for l in range(3): 188 | depth[f"level_{l}"] = cv2.resize( 189 | depth_h, 190 | None, 191 | fx=1.0 / (2**l), 192 | fy=1.0 / (2**l), 193 | interpolation=cv2.INTER_NEAREST, 194 | ) 195 | 196 | if self.split == "train": 197 | cutout = np.ones_like(depth[f"level_2"]) 198 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1)) 199 | h1 = int( 200 | np.random.randint( 201 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1 202 | ) 203 | ) 204 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1)) 205 | w1 = int( 206 | np.random.randint( 207 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1 208 | ) 209 | ) 210 | cutout[h0:h1, w0:w1] = 0 211 | depth_aug = depth[f"level_2"] * cutout 212 | else: 213 | depth_aug = depth[f"level_2"].copy() 214 | 215 | return depth, depth_h, depth_aug 216 | 217 | def build_remap(self): 218 | self.remap = np.zeros(np.max(self.id_list) + 1).astype("int") 219 | for i, item in enumerate(self.id_list): 220 | self.remap[item] = i 221 | 222 | def __len__(self): 223 | return len(self.metas) if self.max_len <= 0 else self.max_len 224 | 225 | def __getitem__(self, idx): 226 | if self.split == "train" and self.scene == "None": 227 | noisy_factor = float(np.random.choice([1.0, 0.5], 1)) 228 | close_views = int(np.random.choice([3, 4, 5], 1)) 229 | else: 230 | noisy_factor = 1.0 231 | close_views = 5 232 | 233 | scan, light_idx, target_view, src_views = self.metas[idx] 234 | view_ids = src_views[:self.nb_views] + [target_view] 235 | 236 | affine_mats, affine_mats_inv = [], [] 237 | imgs, depths_h, depths_aug = [], [], [] 238 | depths = {"level_0": [], "level_1": [], "level_2": []} 239 | intrinsics, w2cs, c2ws, near_fars = [], [], [], [] 240 | 241 | for vid in view_ids: 242 | # Note that the id in image file names is from 1 to 49 (not 0~48) 243 | img_filename = os.path.join( 244 | self.original_root_dir, 245 | f"Rectified/{scan}/rect_{vid + 1:03d}_{light_idx}_r5000.png", 246 | ) 247 | depth_filename = os.path.join( 248 | self.preprocessed_root_dir, f"Depths/{scan}/depth_map_{vid:04d}.pfm" 249 | ) 250 | img = Image.open(img_filename) 251 | img_wh = np.round( 252 | np.array(img.size) / 2.0 * self.downSample * noisy_factor 253 | ).astype("int") 254 | img = img.resize(img_wh, Image.BICUBIC) 255 | img = self.transform(img) 256 | imgs += [img] 257 | 258 | index_mat = self.remap[vid] 259 | 260 | intrinsic = self.intrinsics[index_mat].copy() 261 | intrinsic[:2] = intrinsic[:2] * noisy_factor 262 | intrinsics.append(intrinsic) 263 | 264 | w2c = self.world2cams[index_mat] 265 | w2cs.append(w2c) 266 | c2ws.append(self.cam2worlds[index_mat]) 267 | 268 | aff = [] 269 | aff_inv = [] 270 | for l in range(3): 271 | proj_mat_l = np.eye(4) 272 | intrinsic_temp = intrinsic.copy() 273 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l) 274 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4] 275 | aff.append(proj_mat_l.copy()) 276 | aff_inv.append(np.linalg.inv(proj_mat_l)) 277 | aff = np.stack(aff, axis=-1) 278 | aff_inv = np.stack(aff_inv, axis=-1) 279 | 280 | affine_mats.append(aff) 281 | affine_mats_inv.append(aff_inv) 282 | 283 | near_far = self.near_fars[index_mat] 284 | 285 | depth, depth_h, depth_aug = self.read_depth( 286 | depth_filename, near_far[1], noisy_factor 287 | ) 288 | 289 | depths["level_0"].append(depth["level_0"]) 290 | depths["level_1"].append(depth["level_1"]) 291 | depths["level_2"].append(depth["level_2"]) 292 | depths_h.append(depth_h) 293 | depths_aug.append(depth_aug) 294 | 295 | near_fars.append(near_far) 296 | 297 | imgs = np.stack(imgs) 298 | depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug) 299 | depths["level_0"] = np.stack(depths["level_0"]) 300 | depths["level_1"] = np.stack(depths["level_1"]) 301 | depths["level_2"] = np.stack(depths["level_2"]) 302 | affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv) 303 | intrinsics = np.stack(intrinsics) 304 | w2cs = np.stack(w2cs) 305 | c2ws = np.stack(c2ws) 306 | near_fars = np.stack(near_fars) 307 | 308 | closest_idxs = [] 309 | for pose in c2ws[:-1]: 310 | closest_idxs.append( 311 | get_nearest_pose_ids( 312 | pose, 313 | ref_poses=c2ws[:-1], 314 | num_select=close_views, 315 | angular_dist_method="dist", 316 | ) 317 | ) 318 | closest_idxs = np.stack(closest_idxs, axis=0) 319 | 320 | sample = {} 321 | sample["images"] = imgs 322 | sample["depths"] = depths 323 | sample["depths_h"] = depths_h 324 | sample["depths_aug"] = depths_aug 325 | sample["w2cs"] = w2cs 326 | sample["c2ws"] = c2ws 327 | sample["near_fars"] = near_fars 328 | sample["intrinsics"] = intrinsics 329 | sample["affine_mats"] = affine_mats 330 | sample["affine_mats_inv"] = affine_mats_inv 331 | sample["closest_idxs"] = closest_idxs 332 | 333 | return sample 334 | -------------------------------------------------------------------------------- /data/klevr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import numpy as np 5 | import os 6 | from glob import glob as glob 7 | from PIL import Image 8 | from torchvision import transforms as T 9 | 10 | from utils.klevr_utils import from_position_and_quaternion, scale_rays, calculate_near_and_far 11 | from utils.utils import read_pfm, get_nearest_pose_ids, get_rays 12 | 13 | # Nesf Klevr 14 | class KlevrDataset(Dataset): 15 | def __init__( 16 | self, 17 | root_dir, 18 | nb_views, 19 | split='train', 20 | get_rgb=True, 21 | get_semantic=False, 22 | max_len=-1, 23 | scene=None, 24 | downSample=1.0 25 | ): 26 | 27 | # super().__init__() 28 | self.root_dir = root_dir 29 | self.get_rgb = get_rgb 30 | self.get_semantic = get_semantic 31 | self.nb_views = nb_views 32 | self.max_len = max_len 33 | self.scene = scene 34 | self.downSample = downSample 35 | self.blender2opencv = np.array( 36 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 37 | ) 38 | 39 | # This is hard coded for Klevr as all scans are in the same range (according to the metadata) 40 | self.scene_boundaries = np.array([[-3.1,-3.1,-0.1],[3.1,3.1,3.1]]) 41 | if split == 'train': 42 | self.split = split 43 | elif split =='val': 44 | self.split = 'test' 45 | else: 46 | raise KeyError("only train/val split works") 47 | 48 | self.define_transforms() 49 | self.read_meta() 50 | self.white_back = True 51 | self.buid_proj_mats() 52 | 53 | def define_transforms(self): 54 | # this normalize is for imagenet pretrained resnet for CasMVSNet pretrained weights 55 | self.transform = T.Compose( 56 | [ 57 | T.ToTensor(), 58 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 59 | ] 60 | ) 61 | 62 | def read_meta(self): 63 | self.metas = [] 64 | # for now, we only use the first 10 scans to train, 11~20 scans to test 65 | if self.split == 'train': 66 | self.scans = sorted(glob(os.path.join(self.root_dir, '*')))[:100] 67 | else: 68 | self.scans = sorted(glob(os.path.join(self.root_dir, '*')))[100:120] 69 | 70 | # remap the scan_idx to the scan name 71 | self.scan_idx_to_name = {} 72 | 73 | # record ids that being used of each scan 74 | self.id_list = [] 75 | 76 | # read the pair file, here use the same pair list as dtu dataset; 77 | # could be 6x more pairs since each klevr scene have 300 views, dtu have 50 views 78 | if self.split == "train": 79 | if self.scene == "None": 80 | pair_file = f"configs/lists/dtu_pairs.txt" 81 | else: 82 | pair_file = f"configs/lists/dtu_pairs_ft.txt" 83 | else: 84 | pair_file = f"configs/lists/dtu_pairs_val.txt" 85 | 86 | for scan_idx, meta_filename in enumerate(self.scans): 87 | with open(pair_file) as f: 88 | scan = meta_filename.split('/')[-1] 89 | self.scan_idx_to_name[scan_idx] = scan 90 | num_viewpoint = int(f.readline()) 91 | for _ in range(num_viewpoint): 92 | ref_view = int(f.readline().rstrip()) 93 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 94 | self.metas += [(scan_idx, ref_view, src_views[:self.nb_views])] 95 | self.id_list.append([ref_view] + src_views) 96 | 97 | self.id_list = np.unique(self.id_list) 98 | self.build_remap() 99 | 100 | def build_remap(self): 101 | self.remap = np.zeros(np.max(self.id_list) + 1).astype("int") 102 | for i, item in enumerate(self.id_list): 103 | self.remap[item] = i 104 | 105 | def buid_proj_mats(self): 106 | # maybe not calculate near_far now. Do it when creating the rays 107 | self.near_fars, self.intrinsics, self.world2cams, self.cam2worlds = None, {}, {}, {} 108 | for scan_idx, meta_fileprefix in enumerate(self.scans): 109 | meta_filename = os.path.join(meta_fileprefix, 'metadata.json') 110 | intrinsic, world2cam, cam2world = self.read_cam_file(meta_filename, scan_idx) 111 | self.intrinsics[scan_idx], self.world2cams[scan_idx], self.cam2worlds[scan_idx] = np.array(intrinsic), np.array(world2cam), np.array(cam2world) 112 | 113 | def read_cam_file(self, filename, scan_idx): 114 | ''' 115 | read the metadata file and return the near/far, intrinsic, world2cam, cam2world 116 | filename(str): the metadata file 117 | scan_idx(int): the index of the scan 118 | 119 | return: 120 | intrinsic: the intrinsic of the scan [N,3,3] 121 | world2cam: the world2cam of the scan [N,4,4] 122 | cam2world: the cam2world of the scan [N,4,4] 123 | ''' 124 | intrinsic, world2cam, cam2world = [], [], [] 125 | with open(filename, 'r') as f: 126 | meta = json.load(f) 127 | w, h = meta['metadata']['width'], meta['metadata']['height'] 128 | focal = meta['camera']['focal_length']*w/meta['camera']['sensor_width'] 129 | 130 | camera_positions = np.array(meta['camera']['positions']) 131 | camera_quaternions = np.array(meta['camera']['quaternions']) 132 | # calculate camera pose of each frame that will be used (in this scan idx) 133 | for frame_id in self.id_list: 134 | c2w = from_position_and_quaternion(camera_positions[frame_id], camera_quaternions[frame_id], False).tolist() 135 | # not sure 136 | c2w = np.array(c2w) @ self.blender2opencv 137 | cam2world += [c2w.tolist()] 138 | world2cam += [np.linalg.inv(c2w).tolist()] 139 | intrinsic += [[[focal, 0, w/2], [0, focal, h/2], [0, 0, 1]]] 140 | 141 | return intrinsic, world2cam, cam2world 142 | 143 | def read_depth(self, filename, far_bound, noisy_factor=1.0): 144 | # read depth image, currently not using it 145 | depth_h = Image.open(filename) 146 | depth_wh = np.round(np.array(depth_h.size) * self.downSample * noisy_factor).astype(np.int32) 147 | 148 | # originally NeRF use Image.Resampling.LANCZOS, not sure if BICUBIC is better 149 | depth_h = depth_h.resize(depth_wh, Image.BILINEAR) 150 | 151 | ## Exclude points beyond the bounds 152 | depth_h_filter = np.array(depth_h) 153 | depth_h_filter[depth_h_filter > far_bound * 0.95] = 0.0 154 | 155 | 156 | depth = {} 157 | for l in range(3): 158 | depth_wh = np.round(np.array(depth_h.size) / (2**l)).astype(np.int32) 159 | depth[f"level_{l}"] = np.array(depth_h.resize(depth_wh, Image.BILINEAR)) 160 | depth[f"level_{l}"][depth[f"level_{l}"] > far_bound * 0.95] = 0.0 161 | 162 | if self.split == "train": 163 | cutout = np.ones_like(depth[f"level_2"]) 164 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1)) 165 | h1 = int( 166 | np.random.randint( 167 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1 168 | ) 169 | ) 170 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1)) 171 | w1 = int( 172 | np.random.randint( 173 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1 174 | ) 175 | ) 176 | cutout[h0:h1, w0:w1] = 0 177 | depth_aug = depth[f"level_2"] * cutout 178 | else: 179 | depth_aug = depth[f"level_2"].copy() 180 | 181 | return depth, depth_h_filter, depth_aug 182 | 183 | 184 | def __len__(self): 185 | return len(self.metas) if self.max_len <= 0 else self.max_len 186 | 187 | def __getitem__(self, idx): 188 | # haven't used the depth image and noisy factor yet 189 | if self.split == "train" and self.scene == "None": 190 | noisy_factor = float(np.random.choice([1.0, 0.5], 1)) 191 | close_views = int(np.random.choice([3, 4, 5], 1)) 192 | else: 193 | noisy_factor = 1.0 194 | close_views = 5 195 | 196 | scan_idx, ref_id, src_ids = self.metas[idx] 197 | 198 | # notice that the ref_id is in the last position 199 | view_ids = src_ids + [ref_id] 200 | 201 | affine_mats, affine_mats_inv = [], [] 202 | imgs, depths_h, depths_aug = [], [], [] 203 | depths = {"level_0": [], "level_1": [], "level_2": []} 204 | semantics = [] 205 | intrinsics, w2cs, c2ws = [], [], [] 206 | 207 | # intrinsic now every frame has its own intrinsic, but actually it is the same for all frames in a scan 208 | # # every scan has only one intrinsic, here actually is focal 209 | # intrinsic = self.intrinsics[scan_idx] 210 | 211 | for vid in view_ids: 212 | img_filename = os.path.join(self.root_dir, self.scan_idx_to_name[scan_idx],f'rgba_{vid:05d}.png') 213 | depth_filename = os.path.join(self.root_dir, self.scan_idx_to_name[scan_idx],f'depth_{vid:05d}.tiff') 214 | 215 | img = Image.open(img_filename) 216 | img_wh = np.round(np.array(img.size)*self.downSample).astype(np.int32) 217 | 218 | # originally NeRF use Image.Resampling.LANCZOS, not sure if BICUBIC is better 219 | img = img.resize(img_wh, Image.LANCZOS) 220 | # discard the alpha channel, only use rgb. Maybe need "valid_mask = img[-1]>0" 221 | img = self.transform(np.array(img)[:,:,:3]) 222 | imgs += [img] 223 | 224 | # semantic part 225 | if self.get_semantic: 226 | semantic_filename = os.path.join(self.root_dir, self.scan_idx_to_name[scan_idx],f'segmentation_{vid:05d}.png') 227 | semantic = Image.open(semantic_filename) 228 | semantic = semantic.resize(img_wh, Image.LANCZOS) 229 | semantic = torch.from_numpy(np.array(semantic)).long() #(h,w) 230 | semantics += [semantic] 231 | 232 | index = self.remap[vid] 233 | # # debug 234 | # print("vid: ", vid, "index: ", index, "scan_idx: ", scan_idx) 235 | # print("self.remap[scan_idx]: ", self.remap[scan_idx]) 236 | # print("self.cam2worlds[scan_idx]: ", np.array(self.cam2worlds[scan_idx]).shape) 237 | # raise Exception("debug") 238 | c2ws.append(self.cam2worlds[scan_idx][index]) 239 | 240 | w2c = self.world2cams[scan_idx][index] 241 | w2cs.append(w2c) 242 | 243 | intrinsic = self.intrinsics[scan_idx][index] 244 | intrinsics.append(intrinsic) 245 | 246 | aff = [] 247 | aff_inv = [] 248 | # if using the depth image, there should be for l in range(3) 249 | for l in range(3): 250 | proj_mat_l = np.eye(4) 251 | intrinsic_temp = intrinsic.copy() 252 | intrinsic_temp[:2] = intrinsic_temp[:2]/(2**l) 253 | proj_mat_l[:3,:4] = intrinsic_temp @ w2c[:3,:4] 254 | aff.append(proj_mat_l) 255 | aff_inv.append(np.linalg.inv(proj_mat_l)) 256 | 257 | aff = np.stack(aff, axis=-1) 258 | aff_inv = np.stack(aff_inv, axis=-1) 259 | 260 | affine_mats.append(aff) 261 | affine_mats_inv.append(aff_inv) 262 | # currently hardcode far bound to be 17 263 | depth, depth_h, depth_aug = self.read_depth(depth_filename, far_bound=17, noisy_factor=1) 264 | 265 | depths["level_0"].append(depth["level_0"]) 266 | depths["level_1"].append(depth["level_1"]) 267 | depths["level_2"].append(depth["level_2"]) 268 | depths_h.append(depth_h) 269 | depths_aug.append(depth_aug) 270 | 271 | 272 | imgs = np.stack(imgs) 273 | semantics = np.stack(semantics) 274 | affine_mats = np.stack(affine_mats) 275 | affine_mats_inv = np.stack(affine_mats_inv) 276 | intrinsics = np.stack(intrinsics) 277 | w2cs = np.stack(w2cs) 278 | c2ws = np.stack(c2ws) 279 | depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug) 280 | depths["level_0"] = np.stack(depths["level_0"]) 281 | depths["level_1"] = np.stack(depths["level_1"]) 282 | depths["level_2"] = np.stack(depths["level_2"]) 283 | 284 | 285 | close_idxs = [] 286 | for pose in c2ws[:-1]: 287 | close_idxs.append( 288 | get_nearest_pose_ids( 289 | pose, 290 | c2ws[:-1], 291 | close_views, 292 | angular_dist_method="dist" 293 | ) 294 | ) 295 | close_idxs = np.stack(close_idxs, axis=0) 296 | self.near_fars = [] 297 | for i in range(imgs.shape[0]): 298 | rays_orig, rays_dir, _ = get_rays( 299 | H=imgs.shape[2], 300 | W=imgs.shape[3], 301 | # hard code to cuda 302 | intrinsics_target=torch.tensor(intrinsics[i].astype(imgs.dtype)), 303 | c2w_target=torch.tensor(c2ws[i].astype(imgs.dtype)), 304 | # train=False with chunk=-1, will return all rays of the image 305 | train=False, 306 | ) 307 | near, far = calculate_near_and_far(rays_orig, rays_dir, bbox_min=[-3.1, -3.1, -0.1], bbox_max=[3.1, 3.1, 3.1]) 308 | near = near.min().item() 309 | far = far.max().item() 310 | self.near_fars.append([near, far]) 311 | 312 | sample = {} 313 | sample['images'] = imgs 314 | if self.get_semantic: 315 | sample['semantics'] = semantics 316 | sample['w2cs'] = w2cs.astype(imgs.dtype) 317 | sample['c2ws'] = c2ws.astype(imgs.dtype) 318 | sample['intrinsics'] = intrinsics.astype(imgs.dtype) 319 | sample['affine_mats'] = affine_mats.astype(imgs.dtype) 320 | sample['affine_mats_inv'] = affine_mats_inv.astype(imgs.dtype) 321 | sample['closest_idxs'] = close_idxs 322 | # depth aug seems to be a must to give, but if set to None doesn't matter (the use_depth is False) 323 | sample['depths_aug'] = depths_aug 324 | sample['depths_h'] = depths_h 325 | # depth should be a dict, but if set to None doesn't matter (the use_depth is False) (original code still use depth loss) 326 | sample['depths'] = depths 327 | # near_fars is now just using constant value 328 | sample['near_fars'] = np.array(self.near_fars).astype(imgs.dtype) 329 | return sample -------------------------------------------------------------------------------- /model/geo_reasoner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | from model.UNet import UNet, smp_UNet 7 | from utils.utils import homo_warp 8 | from inplace_abn import InPlaceABN 9 | 10 | 11 | def get_depth_values(current_depth, n_depths, depth_interval): 12 | depth_min = torch.clamp_min(current_depth - n_depths / 2 * depth_interval, 1e-3) 13 | depth_values = ( 14 | depth_min 15 | + depth_interval 16 | * torch.arange( 17 | 0, n_depths, device=current_depth.device, dtype=current_depth.dtype 18 | )[None, :, None, None] 19 | ) 20 | return depth_values 21 | 22 | 23 | class ConvBnReLU(nn.Module): 24 | def __init__( 25 | self, 26 | in_channels, 27 | out_channels, 28 | kernel_size=3, 29 | stride=1, 30 | pad=1, 31 | norm_act=InPlaceABN, 32 | ): 33 | super(ConvBnReLU, self).__init__() 34 | self.conv = nn.Conv2d( 35 | in_channels, 36 | out_channels, 37 | kernel_size, 38 | stride=stride, 39 | padding=pad, 40 | bias=False, 41 | ) 42 | self.bn = norm_act(out_channels) 43 | 44 | def forward(self, x): 45 | return self.bn(self.conv(x)) 46 | 47 | 48 | class ConvBnReLU3D(nn.Module): 49 | def __init__( 50 | self, 51 | in_channels, 52 | out_channels, 53 | kernel_size=3, 54 | stride=1, 55 | pad=1, 56 | norm_act=InPlaceABN, 57 | ): 58 | super(ConvBnReLU3D, self).__init__() 59 | self.conv = nn.Conv3d( 60 | in_channels, 61 | out_channels, 62 | kernel_size, 63 | stride=stride, 64 | padding=pad, 65 | bias=False, 66 | ) 67 | self.bn = norm_act(out_channels) 68 | 69 | def forward(self, x): 70 | return self.bn(self.conv(x)) 71 | 72 | 73 | class FeatureNet(nn.Module): 74 | def __init__(self, norm_act=InPlaceABN): 75 | super(FeatureNet, self).__init__() 76 | 77 | self.conv0 = nn.Sequential( 78 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), 79 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act), 80 | ) 81 | 82 | self.conv1 = nn.Sequential( 83 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), 84 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act), 85 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act), 86 | ) 87 | 88 | self.conv2 = nn.Sequential( 89 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), 90 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act), 91 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act), 92 | ) 93 | 94 | self.toplayer = nn.Conv2d(32, 32, 1) 95 | self.lat1 = nn.Conv2d(16, 32, 1) 96 | self.lat0 = nn.Conv2d(8, 32, 1) 97 | 98 | # to reduce channel size of the outputs from FPN 99 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) 100 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) 101 | 102 | def _upsample_add(self, x, y): 103 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y 104 | 105 | def forward(self, x, dummy=None): 106 | # x: (B, 3, H, W) 107 | conv0 = self.conv0(x) # (B, 8, H, W) 108 | conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) 109 | conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) 110 | feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) 111 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) 112 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) 113 | 114 | # reduce output channels 115 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) 116 | feat0 = self.smooth0(feat0) # (B, 8, H, W) 117 | 118 | feats = {"level_0": feat0, "level_1": feat1, "level_2": feat2} 119 | 120 | return feats 121 | 122 | 123 | class CostRegNet(nn.Module): 124 | def __init__(self, in_channels, norm_act=InPlaceABN): 125 | super(CostRegNet, self).__init__() 126 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 127 | 128 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 129 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 130 | 131 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 132 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 133 | 134 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) 135 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 136 | 137 | self.conv7 = nn.Sequential( 138 | nn.ConvTranspose3d( 139 | 64, 32, 3, padding=1, output_padding=1, stride=2, bias=False 140 | ), 141 | norm_act(32), 142 | ) 143 | 144 | self.conv9 = nn.Sequential( 145 | nn.ConvTranspose3d( 146 | 32, 16, 3, padding=1, output_padding=1, stride=2, bias=False 147 | ), 148 | norm_act(16), 149 | ) 150 | 151 | self.conv11 = nn.Sequential( 152 | nn.ConvTranspose3d( 153 | 16, 8, 3, padding=1, output_padding=1, stride=2, bias=False 154 | ), 155 | norm_act(8), 156 | ) 157 | 158 | self.br1 = ConvBnReLU3D(8, 8, norm_act=norm_act) 159 | self.br2 = ConvBnReLU3D(8, 8, norm_act=norm_act) 160 | 161 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1) 162 | 163 | def forward(self, x): 164 | if x.shape[-2] % 8 != 0 or x.shape[-1] % 8 != 0: 165 | pad_h = 8 * (x.shape[-2] // 8 + 1) - x.shape[-2] 166 | pad_w = 8 * (x.shape[-1] // 8 + 1) - x.shape[-1] 167 | x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0) 168 | else: 169 | pad_h = 0 170 | pad_w = 0 171 | 172 | conv0 = self.conv0(x) 173 | conv2 = self.conv2(self.conv1(conv0)) 174 | conv4 = self.conv4(self.conv3(conv2)) 175 | 176 | x = self.conv6(self.conv5(conv4)) 177 | x = conv4 + self.conv7(x) 178 | del conv4 179 | x = conv2 + self.conv9(x) 180 | del conv2 181 | x = conv0 + self.conv11(x) 182 | del conv0 183 | #################### 184 | # x1 = self.br1(x) 185 | # with torch.enable_grad(): 186 | # x2 = self.br2(x) 187 | x1 = self.br1(x) 188 | x2 = self.br2(x) 189 | #################### 190 | p = self.prob(x1) 191 | 192 | if pad_h > 0 or pad_w > 0: 193 | x2 = x2[..., :-pad_h, :-pad_w] 194 | p = p[..., :-pad_h, :-pad_w] 195 | 196 | return x2, p 197 | 198 | 199 | class CasMVSNet(nn.Module): 200 | def __init__(self, num_groups=8, norm_act=InPlaceABN, levels=3, use_depth=False, nb_class=1, feat_net=None): 201 | super(CasMVSNet, self).__init__() 202 | self.levels = levels # 3 depth levels 203 | self.n_depths = [8, 32, 48] 204 | self.interval_ratios = [1, 2, 4] 205 | self.use_depth = use_depth 206 | 207 | self.G = num_groups # number of groups in groupwise correlation 208 | # self.feature = FeatureNet() 209 | # change to UNet 210 | if feat_net is None: 211 | raise ValueError("feat_net must be specified") 212 | elif feat_net == 'UNet': 213 | self.feature = UNet(3,nb_class) 214 | elif feat_net == 'smp_UNet': 215 | self.feature = smp_UNet(3,nb_class) 216 | 217 | for l in range(self.levels): 218 | if l == self.levels - 1 and self.use_depth: 219 | cost_reg_l = CostRegNet(self.G + 1, norm_act) 220 | else: 221 | cost_reg_l = CostRegNet(self.G, norm_act) 222 | 223 | setattr(self, f"cost_reg_{l}", cost_reg_l) 224 | 225 | def build_cost_volumes(self, feats, affine_mats, affine_mats_inv, depth_values, idx, spikes): 226 | B, V, C, H, W = feats.shape 227 | D = depth_values.shape[1] 228 | 229 | ref_feats, src_feats = feats[:, idx[0]], feats[:, idx[1:]] 230 | src_feats = src_feats.permute(1, 0, 2, 3, 4) # (V-1, B, C, h, w) 231 | 232 | affine_mats_inv = affine_mats_inv[:, idx[0]] 233 | affine_mats = affine_mats[:, idx[1:]] 234 | affine_mats = affine_mats.permute(1, 0, 2, 3) # (V-1, B, 3, 4) 235 | 236 | ref_volume = ref_feats.unsqueeze(2).repeat(1, 1, D, 1, 1) # (B, C, D, h, w) 237 | 238 | ref_volume = ref_volume.view(B, self.G, C // self.G, *ref_volume.shape[-3:]) 239 | volume_sum = 0 240 | 241 | for i in range(len(idx) - 1): 242 | proj_mat = (affine_mats[i].double() @ affine_mats_inv.double()).float()[ 243 | :, :3 244 | ] # shape (1,3,4) 245 | warped_volume, grid = homo_warp(src_feats[i], proj_mat, depth_values) 246 | 247 | warped_volume = warped_volume.view_as(ref_volume) 248 | volume_sum = volume_sum + warped_volume # (B, G, C//G, D, h, w) 249 | if torch.isnan(volume_sum).sum()>0: 250 | print("nan in volume_sum") 251 | 252 | volume = (volume_sum * ref_volume).mean(dim=2) / (V - 1) 253 | 254 | if spikes is None: 255 | output = volume 256 | else: 257 | output = torch.cat([volume, spikes], dim=1) 258 | 259 | return output 260 | 261 | def create_neural_volume( 262 | self, 263 | feats, 264 | affine_mats, 265 | affine_mats_inv, 266 | idx, 267 | init_depth_min, 268 | depth_interval, 269 | gt_depths, 270 | ): 271 | if feats["level_0"].shape[-1] >= 800: 272 | hres_input = True 273 | else: 274 | hres_input = False 275 | 276 | B, V = affine_mats.shape[:2] 277 | 278 | v_feat = {} 279 | depth_maps = {} 280 | depth_values = {} 281 | for l in reversed(range(self.levels)): # (2, 1, 0) 282 | feats_l = feats[f"level_{l}"] # (B*V, C, h, w) 283 | feats_l = feats_l.view(B, V, *feats_l.shape[1:]) # (B, V, C, h, w) 284 | h, w = feats_l.shape[-2:] 285 | depth_interval_l = depth_interval * self.interval_ratios[l] 286 | D = self.n_depths[l] 287 | if l == self.levels - 1: # coarsest level 288 | depth_values_l = init_depth_min + depth_interval_l * torch.arange( 289 | 0, D, device=feats_l.device, dtype=feats_l.dtype 290 | ) # (D) 291 | depth_values_l = depth_values_l[None, :, None, None].expand( 292 | -1, -1, h, w 293 | ) 294 | 295 | if self.use_depth: 296 | gt_mask = gt_depths > 0 297 | sp_idx_float = ( 298 | gt_mask * (gt_depths - init_depth_min) / (depth_interval_l) 299 | )[:, :, None] 300 | spikes = ( 301 | torch.arange(D).view(1, 1, -1, 1, 1).to(gt_mask.device) 302 | == sp_idx_float.floor().long() 303 | ) * (1 - sp_idx_float.frac()) 304 | spikes = spikes + ( 305 | torch.arange(D).view(1, 1, -1, 1, 1).to(gt_mask.device) 306 | == sp_idx_float.ceil().long() 307 | ) * (sp_idx_float.frac()) 308 | spikes = (spikes * gt_mask[:, :, None]).float() 309 | else: 310 | depth_lm1 = depth_l.detach() # the depth of previous level 311 | depth_lm1 = F.interpolate( 312 | depth_lm1, scale_factor=2, mode="bilinear", align_corners=True 313 | ) # (B, 1, h, w) 314 | depth_values_l = get_depth_values(depth_lm1, D, depth_interval_l) 315 | 316 | affine_mats_l = affine_mats[..., l] 317 | affine_mats_inv_l = affine_mats_inv[..., l] 318 | 319 | if l == self.levels - 1 and self.use_depth: 320 | spikes_ = spikes 321 | else: 322 | spikes_ = None 323 | 324 | if hres_input: 325 | v_feat_l = checkpoint( 326 | self.build_cost_volumes, 327 | feats_l, 328 | affine_mats_l, 329 | affine_mats_inv_l, 330 | depth_values_l, 331 | idx, 332 | spikes_, 333 | preserve_rng_state=False, 334 | ) 335 | else: 336 | v_feat_l = self.build_cost_volumes( 337 | feats_l, 338 | affine_mats_l, 339 | affine_mats_inv_l, 340 | depth_values_l, 341 | idx, 342 | spikes_, 343 | ) 344 | 345 | cost_reg_l = getattr(self, f"cost_reg_{l}") 346 | v_feat_l_, depth_prob = cost_reg_l(v_feat_l) # (B, 1, D, h, w) 347 | 348 | depth_l = (F.softmax(depth_prob, dim=2) * depth_values_l[:, None]).sum( 349 | dim=2 350 | ) 351 | # v_feat_l have 8 nan values, go debug build_cost_volumes 352 | if torch.isnan(v_feat_l_).sum()>0: 353 | print("nan in v_feat_l_") 354 | v_feat[f"level_{l}"] = v_feat_l_ 355 | depth_maps[f"level_{l}"] = depth_l 356 | depth_values[f"level_{l}"] = depth_values_l 357 | 358 | return v_feat, depth_maps, depth_values 359 | 360 | def forward( 361 | self, imgs, affine_mats, affine_mats_inv, near_far, closest_idxs, gt_depths=None 362 | ): 363 | B, V, _, H, W = imgs.shape 364 | 365 | ## Feature Pyramid 366 | feats = self.feature( 367 | imgs.reshape(B * V, 3, H, W) 368 | ) # (B*V, 8, H, W), (B*V, 16, H//2, W//2), (B*V, 32, H//4, W//4) 369 | feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:]) 370 | semantic_logits = feats[f'logits'].reshape(B, V, *feats[f'logits'].shape[1:]) 371 | semantic_feature = feats[f'feature'].reshape(B, V, *feats[f'feature'].shape[1:]) 372 | 373 | feats_vol = {"level_0": [], "level_1": [], "level_2": []} 374 | depth_map = {"level_0": [], "level_1": [], "level_2": []} 375 | depth_values = {"level_0": [], "level_1": [], "level_2": []} 376 | ## Create cost volumes for each view 377 | for i in range(0, V): 378 | permuted_idx = closest_idxs[0, i].clone().detach().to(feats['level_0'].device) 379 | # if near_far.sum() == 0: 380 | # init_depth_min = 1 381 | # depth_interval = 0.1 382 | # else: 383 | init_depth_min = near_far[0, i, 0] 384 | depth_interval = ( 385 | (near_far[0, i, 1] - near_far[0, i, 0]) 386 | / self.n_depths[-1] 387 | / self.interval_ratios[-1] 388 | ) 389 | 390 | v_feat, d_map, d_values = self.create_neural_volume( 391 | feats, 392 | affine_mats, 393 | affine_mats_inv, 394 | idx=permuted_idx, 395 | init_depth_min=init_depth_min, 396 | depth_interval=depth_interval, 397 | gt_depths=gt_depths[:, i : i + 1], 398 | ) 399 | 400 | for l in range(3): 401 | feats_vol[f"level_{l}"].append(v_feat[f"level_{l}"]) 402 | depth_map[f"level_{l}"].append(d_map[f"level_{l}"]) 403 | depth_values[f"level_{l}"].append(d_values[f"level_{l}"]) 404 | 405 | for l in range(3): 406 | feats_vol[f"level_{l}"] = torch.stack(feats_vol[f"level_{l}"], dim=1) 407 | depth_map[f"level_{l}"] = torch.cat(depth_map[f"level_{l}"], dim=1) 408 | depth_values[f"level_{l}"] = torch.stack(depth_values[f"level_{l}"], dim=1) 409 | 410 | return feats_vol, feats_fpn, depth_map, depth_values, semantic_logits, semantic_feature 411 | -------------------------------------------------------------------------------- /model/self_attn_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | def weights_init(m): 8 | if isinstance(m, nn.Linear): 9 | stdv = 1.0 / math.sqrt(m.weight.size(1)) 10 | m.weight.data.uniform_(-stdv, stdv) 11 | if m.bias is not None: 12 | m.bias.data.uniform_(stdv, stdv) 13 | 14 | 15 | def masked_softmax(x, mask, **kwargs): 16 | x_masked = x.masked_fill(mask == 0, -float("inf")) 17 | 18 | return torch.softmax(x_masked, **kwargs) 19 | 20 | 21 | ## Auto-encoder network 22 | class ConvAutoEncoder(nn.Module): 23 | def __init__(self, num_ch, S): 24 | super(ConvAutoEncoder, self).__init__() 25 | 26 | # Encoder 27 | self.conv1 = nn.Sequential( 28 | nn.Conv1d(num_ch, num_ch * 2, 3, stride=1, padding=1), 29 | # nn.LayerNorm(S, elementwise_affine=False), 30 | nn.ELU(alpha=1.0, inplace=True), 31 | nn.MaxPool1d(2), 32 | ) 33 | self.conv2 = nn.Sequential( 34 | nn.Conv1d(num_ch * 2, num_ch * 4, 3, stride=1, padding=1), 35 | # nn.LayerNorm(S // 2, elementwise_affine=False), 36 | nn.ELU(alpha=1.0, inplace=True), 37 | nn.MaxPool1d(2), 38 | ) 39 | self.conv3 = nn.Sequential( 40 | nn.Conv1d(num_ch * 4, num_ch * 4, 3, stride=1, padding=1), 41 | # nn.LayerNorm(S // 4, elementwise_affine=False), 42 | nn.ELU(alpha=1.0, inplace=True), 43 | nn.MaxPool1d(2), 44 | ) 45 | 46 | # Decoder 47 | self.t_conv1 = nn.Sequential( 48 | nn.ConvTranspose1d(num_ch * 4, num_ch * 4, 4, stride=2, padding=1), 49 | # nn.LayerNorm(S // 4, elementwise_affine=False), 50 | nn.ELU(alpha=1.0, inplace=True), 51 | ) 52 | self.t_conv2 = nn.Sequential( 53 | nn.ConvTranspose1d(num_ch * 8, num_ch * 2, 4, stride=2, padding=1), 54 | # nn.LayerNorm(S // 2, elementwise_affine=False), 55 | nn.ELU(alpha=1.0, inplace=True), 56 | ) 57 | self.t_conv3 = nn.Sequential( 58 | nn.ConvTranspose1d(num_ch * 4, num_ch, 4, stride=2, padding=1), 59 | # nn.LayerNorm(S, elementwise_affine=False), 60 | nn.ELU(alpha=1.0, inplace=True), 61 | ) 62 | # Output 63 | self.conv_out = nn.Sequential( 64 | nn.Conv1d(num_ch * 2, num_ch, 3, stride=1, padding=1), 65 | # nn.LayerNorm(S, elementwise_affine=False), 66 | nn.ELU(alpha=1.0, inplace=True), 67 | ) 68 | 69 | def forward(self, x): 70 | input = x 71 | x = self.conv1(x) 72 | conv1_out = x 73 | x = self.conv2(x) 74 | conv2_out = x 75 | x = self.conv3(x) 76 | 77 | x = self.t_conv1(x) 78 | x = self.t_conv2(torch.cat([x, conv2_out], dim=1)) 79 | x = self.t_conv3(torch.cat([x, conv1_out], dim=1)) 80 | 81 | x = self.conv_out(torch.cat([x, input], dim=1)) 82 | 83 | return x 84 | 85 | 86 | class ScaledDotProductAttention(nn.Module): 87 | def __init__(self, temperature, attn_dropout=0.1): 88 | super().__init__() 89 | self.temperature = temperature 90 | 91 | def forward(self, q, k, v, mask=None): 92 | 93 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 94 | 95 | if mask is not None: 96 | attn = masked_softmax(attn, mask, dim=-1) 97 | else: 98 | attn = F.softmax(attn, dim=-1) 99 | 100 | output = torch.matmul(attn, v) 101 | 102 | return output, attn 103 | 104 | 105 | class PositionwiseFeedForward(nn.Module): 106 | def __init__(self, d_in, d_hid, dropout=0.1): 107 | super().__init__() 108 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 109 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 110 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 111 | 112 | def forward(self, x): 113 | 114 | residual = x 115 | 116 | x = self.w_2(F.relu(self.w_1(x))) 117 | x += residual 118 | 119 | x = self.layer_norm(x) 120 | 121 | return x 122 | 123 | 124 | class MultiHeadAttention(nn.Module): 125 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 126 | super().__init__() 127 | 128 | self.n_head = n_head 129 | self.d_k = d_k 130 | self.d_v = d_v 131 | 132 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 133 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 134 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 135 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 136 | 137 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5) 138 | 139 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 140 | 141 | def forward(self, q, k, v, mask=None): 142 | 143 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 144 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 145 | 146 | residual = q 147 | 148 | # Pass through the pre-attention projection: b x lq x (n*dv) 149 | # Separate different heads: b x lq x n x dv 150 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 151 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 152 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 153 | 154 | # Transpose for attention dot product: b x n x lq x dv 155 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 156 | 157 | if mask is not None: 158 | mask = mask.transpose(1, 2).unsqueeze(1) # For head axis broadcasting. 159 | 160 | q, attn = self.attention(q, k, v, mask=mask) 161 | 162 | # Transpose to move the head dimension back: b x lq x n x dv 163 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 164 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 165 | q = self.fc(q) 166 | q += residual 167 | 168 | q = self.layer_norm(q) 169 | 170 | return q, attn 171 | 172 | 173 | class EncoderLayer(nn.Module): 174 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0): 175 | super(EncoderLayer, self).__init__() 176 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 177 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 178 | 179 | def forward(self, enc_input, slf_attn_mask=None): 180 | enc_output, enc_slf_attn = self.slf_attn( 181 | enc_input, enc_input, enc_input, mask=slf_attn_mask 182 | ) 183 | enc_output = self.pos_ffn(enc_output) 184 | return enc_output, enc_slf_attn 185 | 186 | 187 | class Renderer(nn.Module): 188 | def __init__(self, nb_samples_per_ray, nb_view=8, nb_class=21, 189 | using_semantic_global_tokens=True, only_using_semantic_global_tokens=True, use_batch_semantic_feature=False): 190 | super(Renderer, self).__init__() 191 | 192 | self.nb_view = nb_view 193 | self.nb_class = nb_class 194 | self.using_semantic_global_tokens = using_semantic_global_tokens 195 | self.only_using_semantic_global_tokens = only_using_semantic_global_tokens 196 | self.dim = 32 197 | 198 | if use_batch_semantic_feature: 199 | self.nb_class = self.nb_class * 9 200 | else: 201 | self.nb_class = self.nb_class 202 | 203 | self.semantic_token_gen = nn.Linear(1 + self.nb_class, self.dim) 204 | 205 | self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim) 206 | 207 | ## Self-Attention Settings 208 | d_inner = self.dim 209 | n_head = 4 210 | d_k = self.dim // n_head 211 | d_v = self.dim // n_head 212 | num_layers = 4 213 | self.attn_layers = nn.ModuleList( 214 | [ 215 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v) 216 | for i in range(num_layers) 217 | ] 218 | ) 219 | 220 | self.semantic_attn_layers = nn.ModuleList( 221 | [ 222 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v) 223 | for i in range(num_layers) 224 | ] 225 | ) 226 | 227 | # +1 because we add the mean and variance of input features as global features 228 | if using_semantic_global_tokens and only_using_semantic_global_tokens: 229 | self.semantic_dim = self.dim 230 | ## Processing the mean and variance of semantic features 231 | self.semantic_var_mean_fc1 = nn.Linear(self.nb_class*2, self.dim) 232 | self.semantic_var_mean_fc2 = nn.Linear(self.dim, self.dim) 233 | elif using_semantic_global_tokens: 234 | self.semantic_dim = self.dim * (nb_view + 1) 235 | ## Processing the mean and variance of semantic features 236 | self.semantic_var_mean_fc1 = nn.Linear(self.nb_class*2, self.dim) 237 | self.semantic_var_mean_fc2 = nn.Linear(self.dim, self.dim) 238 | else: 239 | self.semantic_dim = self.dim * nb_view 240 | 241 | self.semantic_fc1 = nn.Linear(self.semantic_dim, self.semantic_dim) 242 | self.semantic_fc2 = nn.Linear(self.semantic_dim, self.semantic_dim // 2) 243 | self.semantic_fc3 = nn.Linear(self.semantic_dim // 2, nb_class) 244 | 245 | ## Processing the mean and variance of input features 246 | self.var_mean_fc1 = nn.Linear(16, self.dim) 247 | self.var_mean_fc2 = nn.Linear(self.dim, self.dim) 248 | 249 | 250 | ## Setting mask of var_mean always enabled 251 | self.var_mean_mask = torch.tensor([1]) 252 | self.var_mean_mask.requires_grad = False 253 | 254 | ## For aggregating data along ray samples 255 | self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray) 256 | 257 | self.sigma_fc1 = nn.Linear(self.dim, self.dim) 258 | self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2) 259 | self.sigma_fc3 = nn.Linear(self.dim // 2, 1) 260 | 261 | 262 | self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim) 263 | self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2) 264 | self.rgb_fc3 = nn.Linear(self.dim // 2, 1) 265 | 266 | ## Initialization 267 | self.sigma_fc1.apply(weights_init) 268 | self.sigma_fc2.apply(weights_init) 269 | self.sigma_fc3.apply(weights_init) 270 | self.rgb_fc1.apply(weights_init) 271 | self.rgb_fc2.apply(weights_init) 272 | self.rgb_fc3.apply(weights_init) 273 | 274 | def forward(self, viewdirs, feat, occ_masks, middle_pts_mask): 275 | ## Viewing samples regardless of batch or ray 276 | N, S, V = feat.shape[:3] 277 | feat = feat.view(-1, *feat.shape[2:]) 278 | v_feat = feat[..., :24] 279 | s_feat = feat[..., 24 : 24 + 8] 280 | colors = feat[..., 24 + 8 : 24 + 8 + 3] 281 | semantic_feat = feat[..., 24 + 8 + 3 : -1] 282 | vis_mask = feat[..., -1:].detach() 283 | 284 | occ_masks = occ_masks.view(-1, *occ_masks.shape[2:]) 285 | viewdirs = viewdirs.view(-1, *viewdirs.shape[2:]) 286 | 287 | ## Mean and variance of 2D features provide view-independent tokens 288 | var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True) 289 | var_mean = torch.cat(var_mean, dim=-1) 290 | var_mean = F.elu(self.var_mean_fc1(var_mean)) 291 | var_mean = F.elu(self.var_mean_fc2(var_mean)) 292 | 293 | ## Converting the input features to tokens (view-dependent) before self-attention 294 | tokens = F.elu( 295 | self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1)) 296 | ) 297 | tokens = torch.cat([tokens, var_mean], dim=1) 298 | 299 | # by adding middle_pts_mask, we can only take the predicted depth's points into account 300 | 301 | if self.using_semantic_global_tokens: 302 | semantic_var_mean = torch.var_mean(semantic_feat[middle_pts_mask.view(-1)], dim=1, unbiased=False, keepdim=True) 303 | semantic_var_mean = torch.cat(semantic_var_mean, dim=-1) 304 | semantic_var_mean = F.elu(self.semantic_var_mean_fc1(semantic_var_mean)) 305 | semantic_var_mean = F.elu(self.semantic_var_mean_fc2(semantic_var_mean)) 306 | # (N_rays, 1, views, feat_dim) 307 | semantic_tokens = F.elu( 308 | self.semantic_token_gen(torch.cat([semantic_feat[middle_pts_mask.view(-1)], vis_mask[middle_pts_mask.view(-1)]], dim=-1)) 309 | ) 310 | 311 | if self.using_semantic_global_tokens: 312 | semantic_tokens = torch.cat([semantic_tokens, semantic_var_mean], dim=1) 313 | 314 | ## Adding a new channel to mask for var_mean 315 | vis_mask = torch.cat( 316 | [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1).to(vis_mask.device)], dim=1 317 | ) 318 | ## If a point is not visible by any source view, force its masks to enabled 319 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1) 320 | 321 | ## Taking occ_masks into account, but remembering if there were any visibility before that 322 | mask_cloned = vis_mask.clone() 323 | vis_mask[:, :-1] *= occ_masks 324 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1) 325 | masks = vis_mask * mask_cloned 326 | 327 | ## Performing self-attention 328 | for layer in self.attn_layers: 329 | tokens, _ = layer(tokens, masks) 330 | 331 | for semantic_layer in self.semantic_attn_layers: 332 | if self.using_semantic_global_tokens: 333 | # mask has shape (N_rays*N_points, nb_views+1, 1), because of the var_mean_mask, semantic not using that 334 | semantic_tokens, _ = semantic_layer(semantic_tokens, masks[middle_pts_mask.view(-1)]) 335 | else: 336 | semantic_tokens, _ = semantic_layer(semantic_tokens, masks[middle_pts_mask.view(-1)][:, :-1]) 337 | 338 | ## Predicting sigma with an Auto-Encoder and MLP 339 | sigma_tokens = tokens[:, -1:] 340 | sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2) 341 | sigma_tokens = self.auto_enc(sigma_tokens) 342 | sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim) 343 | 344 | sigma_tokens_ = F.elu(self.sigma_fc1(sigma_tokens)) 345 | sigma_tokens_ = F.elu(self.sigma_fc2(sigma_tokens_)) 346 | # sigma shape (N_rays*N_points, 1) 347 | sigma = torch.relu(self.sigma_fc3(sigma_tokens_[:, 0])) 348 | 349 | if self.using_semantic_global_tokens and self.only_using_semantic_global_tokens: 350 | semantic_global_tokens = semantic_tokens[:, -1:] 351 | elif self.using_semantic_global_tokens: 352 | semantic_global_tokens = semantic_tokens.reshape(-1, self.semantic_dim) 353 | else: 354 | semantic_global_tokens = semantic_tokens.reshape(-1, self.semantic_dim) 355 | semantic_tokens_ = F.elu(self.semantic_fc1(semantic_global_tokens)) 356 | semantic_tokens_ = F.elu(self.semantic_fc2(semantic_tokens_)) 357 | semantic_tokens_ = torch.relu(self.semantic_fc3(semantic_tokens_)) 358 | 359 | semantic = semantic_tokens_.reshape(N, -1).unsqueeze(1) 360 | 361 | ## Concatenating positional encodings and predicting RGB weights 362 | rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1) 363 | rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens)) 364 | rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens)) 365 | rgb_w = self.rgb_fc3(rgb_tokens) 366 | rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1) 367 | 368 | rgb = (colors * rgb_w).sum(1) 369 | 370 | outputs = torch.cat([rgb, sigma], -1) 371 | outputs = outputs.reshape(N, S, -1) 372 | 373 | return outputs, semantic 374 | 375 | 376 | class Semantic_predictor(nn.Module): 377 | def __init__(self, nb_view=6, nb_class=0): 378 | super(Semantic_predictor, self).__init__() 379 | self.nb_class = nb_class 380 | self.dim = 32 381 | # self.attn_token_gen = nn.Linear(24 + 1 + self.nb_class, self.dim) 382 | self.attn_token_gen = nn.Linear(1 + self.nb_class, self.dim) 383 | self.semantic_dim = self.dim * nb_view 384 | 385 | # Self-Attention Settings, This attention is cross-view attention for a point, which represent a pixel in target view 386 | d_inner = self.dim 387 | n_head = 4 388 | d_k = self.dim // n_head 389 | d_v = self.dim // n_head 390 | num_layers = 4 391 | self.attn_layers = nn.ModuleList( 392 | [ 393 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v) 394 | for i in range(num_layers) 395 | ] 396 | ) 397 | self.semantic_fc1 = nn.Linear(self.semantic_dim, self.semantic_dim) 398 | self.semantic_fc2 = nn.Linear(self.semantic_dim, self.semantic_dim // 2) 399 | self.semantic_fc3 = nn.Linear(self.semantic_dim // 2, nb_class) 400 | 401 | def forward(self, feat, occ_masks): 402 | if feat.dim() == 3: 403 | feat = feat.unsqueeze(1) 404 | if occ_masks.dim() == 3: 405 | occ_masks = occ_masks.unsqueeze(1) 406 | N, S, V, C = feat.shape # (num_rays, num_samples, num_views, feat_dim), S should be 1 here 407 | 408 | feat = feat.view(-1, *feat.shape[2:]) # (num_rays * num_samples, num_views, feat_dim) 409 | v_feat = feat[..., :24] 410 | s_feat = feat[..., 24 : 24 + self.nb_class] 411 | colors = feat[..., 24 + self.nb_class : -1] 412 | vis_mask = feat[..., -1:].detach() 413 | 414 | occ_masks = occ_masks.view(-1, *occ_masks.shape[2:]) 415 | 416 | tokens = F.elu( 417 | # self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1)) 418 | self.attn_token_gen(torch.cat([vis_mask, s_feat], dim=-1)) 419 | ) 420 | 421 | ## If a point is not visible by any source view, force its masks to enabled 422 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1) 423 | 424 | ## Taking occ_masks into account, but remembering if there were any visibility before that 425 | mask_cloned = vis_mask.clone() 426 | vis_mask *= occ_masks 427 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1) 428 | masks = vis_mask * mask_cloned 429 | 430 | ## Performing self-attention on source view features, 431 | for layer in self.attn_layers: 432 | tokens, _ = layer(tokens, masks) 433 | # tokens, _ = layer(tokens, vis_mask) 434 | 435 | ## Predicting semantic with MLP 436 | ## tokens shape: (N*S, V, dim), S = 1 437 | tokens = tokens.reshape(N, V*self.dim) 438 | semantic_tokens_ = F.elu(self.semantic_fc1(tokens)) 439 | semantic_tokens_ = F.elu(self.semantic_fc2(semantic_tokens_)) 440 | semantic_tokens_ = torch.relu(self.semantic_fc3(semantic_tokens_)) 441 | 442 | semantic = semantic_tokens_.reshape(N, S, -1) 443 | 444 | return semantic -------------------------------------------------------------------------------- /utils/depth_loss.py: -------------------------------------------------------------------------------- 1 | # reference: RC-MVSNet 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def inverse_warping(img, left_cam, right_cam, depth): 9 | # img: [batch_size, height, width, channels] 10 | 11 | # cameras (K, R, t) 12 | # print('left_cam: {}'.format(left_cam.shape)) 13 | R_left = left_cam[:, 0:1, 0:3, 0:3] # [B, 1, 3, 3] 14 | R_right = right_cam[:, 0:1, 0:3, 0:3] # [B, 1, 3, 3] 15 | t_left = left_cam[:, 0:1, 0:3, 3:4] # [B, 1, 3, 1] 16 | t_right = right_cam[:, 0:1, 0:3, 3:4] # [B, 1, 3, 1] 17 | K_left = left_cam[:, 1:2, 0:3, 0:3] # [B, 1, 3, 3] 18 | K_right = right_cam[:, 1:2, 0:3, 0:3] # [B, 1, 3, 3] 19 | 20 | K_left = K_left.squeeze(1) # [B, 3, 3] 21 | K_left_inv = torch.inverse(K_left) # [B, 3, 3] 22 | R_left_trans = R_left.squeeze(1).permute(0, 2, 1) # [B, 3, 3] 23 | R_right_trans = R_right.squeeze(1).permute(0, 2, 1) # [B, 3, 3] 24 | 25 | R_left = R_left.squeeze(1) 26 | t_left = t_left.squeeze(1) 27 | R_right = R_right.squeeze(1) 28 | t_right = t_right.squeeze(1) 29 | 30 | ## estimate egomotion by inverse composing R1,R2 and t1,t2 31 | R_rel = torch.matmul(R_right, R_left_trans) # [B, 3, 3] 32 | t_rel = t_right - torch.matmul(R_rel, t_left) # [B, 3, 1] 33 | ## now convert R and t to transform mat, as in SFMlearner 34 | batch_size = R_left.shape[0] 35 | # filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).to(device).reshape(1, 1, 4) # [1, 1, 4] 36 | filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).cuda().reshape(1, 1, 4) # [1, 1, 4] 37 | filler = filler.repeat(batch_size, 1, 1) # [B, 1, 4] 38 | transform_mat = torch.cat([R_rel, t_rel], dim=2) # [B, 3, 4] 39 | transform_mat = torch.cat([transform_mat.float(), filler.float()], dim=1) # [B, 4, 4] 40 | # print(img.shape) 41 | batch_size, img_height, img_width, _ = img.shape 42 | # print(depth.shape) 43 | # print('depth: {}'.format(depth.shape)) 44 | depth = depth.reshape(batch_size, 1, img_height * img_width) # [batch_size, 1, height * width] 45 | 46 | grid = _meshgrid_abs(img_height, img_width) # [3, height * width] 47 | grid = grid.unsqueeze(0).repeat(batch_size, 1, 1) # [batch_size, 3, height * width] 48 | cam_coords = _pixel2cam(depth, grid, K_left_inv) # [batch_size, 3, height * width] 49 | # ones = torch.ones([batch_size, 1, img_height * img_width], device=device) # [batch_size, 1, height * width] 50 | ones = torch.ones([batch_size, 1, img_height * img_width]).cuda() # [batch_size, 1, height * width] 51 | cam_coords_hom = torch.cat([cam_coords, ones], dim=1) # [batch_size, 4, height * width] 52 | 53 | # Get projection matrix for target camera frame to source pixel frame 54 | # hom_filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).to(device).reshape(1, 1, 4) # [1, 1, 4] 55 | hom_filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).cuda().reshape(1, 1, 4) # [1, 1, 4] 56 | hom_filler = hom_filler.repeat(batch_size, 1, 1) # [B, 1, 4] 57 | intrinsic_mat_hom = torch.cat([K_left.float(), torch.zeros([batch_size, 3, 1]).cuda()], dim=2) # [B, 3, 4] 58 | intrinsic_mat_hom = torch.cat([intrinsic_mat_hom, hom_filler], dim=1) # [B, 4, 4] 59 | proj_target_cam_to_source_pixel = torch.matmul(intrinsic_mat_hom, transform_mat) # [B, 4, 4] 60 | source_pixel_coords = _cam2pixel(cam_coords_hom, proj_target_cam_to_source_pixel) # [batch_size, 2, height * width] 61 | source_pixel_coords = source_pixel_coords.reshape(batch_size, 2, img_height, img_width) # [batch_size, 2, height, width] 62 | source_pixel_coords = source_pixel_coords.permute(0, 2, 3, 1) # [batch_size, height, width, 2] 63 | warped_right, mask = _spatial_transformer(img, source_pixel_coords) 64 | return warped_right, mask 65 | 66 | 67 | def _meshgrid_abs(height, width): 68 | """Meshgrid in the absolute coordinates.""" 69 | x_t = torch.matmul( 70 | torch.ones([height, 1]), 71 | torch.linspace(-1.0, 1.0, width).unsqueeze(1).permute(1, 0) 72 | ) # [height, width] 73 | y_t = torch.matmul( 74 | torch.linspace(-1.0, 1.0, height).unsqueeze(1), 75 | torch.ones([1, width]) 76 | ) 77 | x_t = (x_t + 1.0) * 0.5 * (width - 1) 78 | y_t = (y_t + 1.0) * 0.5 * (height - 1) 79 | x_t_flat = x_t.reshape(1, -1) 80 | y_t_flat = y_t.reshape(1, -1) 81 | ones = torch.ones_like(x_t_flat) 82 | grid = torch.cat([x_t_flat, y_t_flat, ones], dim=0) # [3, height * width] 83 | # return grid.to(device) 84 | return grid.cuda() 85 | 86 | 87 | def _pixel2cam(depth, pixel_coords, intrinsic_mat_inv): 88 | """Transform coordinates in the pixel frame to the camera frame.""" 89 | cam_coords = torch.matmul(intrinsic_mat_inv.float(), pixel_coords.float()) * depth.float() 90 | return cam_coords 91 | 92 | 93 | def _cam2pixel(cam_coords, proj_c2p): 94 | """Transform coordinates in the camera frame to the pixel frame.""" 95 | pcoords = torch.matmul(proj_c2p, cam_coords) # [batch_size, 4, height * width] 96 | x = pcoords[:, 0:1, :] # [batch_size, 1, height * width] 97 | y = pcoords[:, 1:2, :] # [batch_size, 1, height * width] 98 | z = pcoords[:, 2:3, :] # [batch_size, 1, height * width] 99 | x_norm = x / (z + 1e-10) 100 | y_norm = y / (z + 1e-10) 101 | pixel_coords = torch.cat([x_norm, y_norm], dim=1) 102 | return pixel_coords # [batch_size, 2, height * width] 103 | 104 | 105 | def _spatial_transformer(img, coords): 106 | """A wrapper over binlinear_sampler(), taking absolute coords as input.""" 107 | # img: [B, H, W, C] 108 | img_height = img.shape[1] 109 | img_width = img.shape[2] 110 | px = coords[:, :, :, :1] # [batch_size, height, width, 1] 111 | py = coords[:, :, :, 1:] # [batch_size, height, width, 1] 112 | # Normalize coordinates to [-1, 1] to send to _bilinear_sampler. 113 | px = px / (img_width - 1) * 2.0 - 1.0 # [batch_size, height, width, 1] 114 | py = py / (img_height - 1) * 2.0 - 1.0 # [batch_size, height, width, 1] 115 | output_img, mask = _bilinear_sample(img, px, py) 116 | return output_img, mask 117 | 118 | 119 | def _bilinear_sample(im, x, y, name='bilinear_sampler'): 120 | """Perform bilinear sampling on im given list of x, y coordinates. 121 | Implements the differentiable sampling mechanism with bilinear kernel 122 | in https://arxiv.org/abs/1506.02025. 123 | x,y are tensors specifying normalized coordinates [-1, 1] to be sampled on im. 124 | For example, (-1, -1) in (x, y) corresponds to pixel location (0, 0) in im, 125 | and (1, 1) in (x, y) corresponds to the bottom right pixel in im. 126 | Args: 127 | im: Batch of images with shape [B, h, w, channels]. 128 | x: Tensor of normalized x coordinates in [-1, 1], with shape [B, h, w, 1]. 129 | y: Tensor of normalized y coordinates in [-1, 1], with shape [B, h, w, 1]. 130 | name: Name scope for ops. 131 | Returns: 132 | Sampled image with shape [B, h, w, channels]. 133 | Principled mask with shape [B, h, w, 1], dtype:float32. A value of 1.0 134 | in the mask indicates that the corresponding coordinate in the sampled 135 | image is valid. 136 | """ 137 | x = x.reshape(-1) # [batch_size * height * width] 138 | y = y.reshape(-1) # [batch_size * height * width] 139 | 140 | # Constants. 141 | batch_size, height, width, channels = im.shape 142 | 143 | x, y = x.float(), y.float() 144 | max_y = int(height - 1) 145 | max_x = int(width - 1) 146 | 147 | # Scale indices from [-1, 1] to [0, width - 1] or [0, height - 1]. 148 | x = (x + 1.0) * (width - 1.0) / 2.0 149 | y = (y + 1.0) * (height - 1.0) / 2.0 150 | 151 | # Compute the coordinates of the 4 pixels to sample from. 152 | x0 = torch.floor(x).int() 153 | x1 = x0 + 1 154 | y0 = torch.floor(y).int() 155 | y1 = y0 + 1 156 | 157 | mask = (x0 >= 0) & (x1 <= max_x) & (y0 >= 0) & (y0 <= max_y) 158 | mask = mask.float() 159 | 160 | x0 = torch.clamp(x0, 0, max_x) 161 | x1 = torch.clamp(x1, 0, max_x) 162 | y0 = torch.clamp(y0, 0, max_y) 163 | y1 = torch.clamp(y1, 0, max_y) 164 | dim2 = width 165 | dim1 = width * height 166 | 167 | # Create base index. 168 | base = torch.arange(batch_size) * dim1 169 | base = base.reshape(-1, 1) 170 | base = base.repeat(1, height * width) 171 | base = base.reshape(-1) # [batch_size * height * width] 172 | # base = base.long().to(device) 173 | base = base.long().cuda() 174 | 175 | base_y0 = base + y0.long() * dim2 176 | base_y1 = base + y1.long() * dim2 177 | idx_a = base_y0 + x0.long() 178 | idx_b = base_y1 + x0.long() 179 | idx_c = base_y0 + x1.long() 180 | idx_d = base_y1 + x1.long() 181 | 182 | # Use indices to lookup pixels in the flat image and restore channels dim. 183 | im_flat = im.reshape(-1, channels).float() # [batch_size * height * width, channels] 184 | # pixel_a = tf.gather(im_flat, idx_a) 185 | # pixel_b = tf.gather(im_flat, idx_b) 186 | # pixel_c = tf.gather(im_flat, idx_c) 187 | # pixel_d = tf.gather(im_flat, idx_d) 188 | pixel_a = im_flat[idx_a] 189 | pixel_b = im_flat[idx_b] 190 | pixel_c = im_flat[idx_c] 191 | pixel_d = im_flat[idx_d] 192 | 193 | wa = (x1.float() - x) * (y1.float() - y) 194 | wb = (x1.float() - x) * (1.0 - (y1.float() - y)) 195 | wc = (1.0 - (x1.float() - x)) * (y1.float() - y) 196 | wd = (1.0 - (x1.float() - x)) * (1.0 - (y1.float() - y)) 197 | wa, wb, wc, wd = wa.unsqueeze(1), wb.unsqueeze(1), wc.unsqueeze(1), wd.unsqueeze(1) 198 | 199 | output = wa * pixel_a + wb * pixel_b + wc * pixel_c + wd * pixel_d 200 | output = output.reshape(batch_size, height, width, channels) 201 | mask = mask.reshape(batch_size, height, width, 1) 202 | return output, mask 203 | 204 | 205 | class SSIM(nn.Module): 206 | """Layer to compute the SSIM loss between a pair of images 207 | """ 208 | def __init__(self): 209 | super(SSIM, self).__init__() 210 | self.mu_x_pool = nn.AvgPool2d(3, 1) 211 | self.mu_y_pool = nn.AvgPool2d(3, 1) 212 | self.sig_x_pool = nn.AvgPool2d(3, 1) 213 | self.sig_y_pool = nn.AvgPool2d(3, 1) 214 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 215 | self.mask_pool = nn.AvgPool2d(3, 1) 216 | # self.refl = nn.ReflectionPad2d(1) 217 | 218 | self.C1 = 0.01 ** 2 219 | self.C2 = 0.03 ** 2 220 | 221 | def forward(self, x, y, mask): 222 | # print('mask: {}'.format(mask.shape)) 223 | # print('x: {}'.format(x.shape)) 224 | # print('y: {}'.format(y.shape)) 225 | x = x.permute(0, 3, 1, 2) # [B, H, W, C] --> [B, C, H, W] 226 | y = y.permute(0, 3, 1, 2) 227 | mask = mask.permute(0, 3, 1, 2) 228 | 229 | # x = self.refl(x) 230 | # y = self.refl(y) 231 | mu_x = self.mu_x_pool(x) 232 | mu_y = self.mu_y_pool(y) 233 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 234 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 235 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 236 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 237 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 238 | SSIM_mask = self.mask_pool(mask) 239 | output = SSIM_mask * torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 240 | return output.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C] 241 | 242 | 243 | def gradient_x(img): 244 | return img[:, :, :-1, :] - img[:, :, 1:, :] 245 | 246 | def gradient_y(img): 247 | return img[:, :-1, :, :] - img[:, 1:, :, :] 248 | 249 | def gradient(pred): 250 | D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :] 251 | D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :] 252 | return D_dx, D_dy 253 | 254 | 255 | def depth_smoothness(depth, img,lambda_wt=1): 256 | """Computes image-aware depth smoothness loss.""" 257 | # print('depth: {} img: {}'.format(depth.shape, img.shape)) 258 | depth_dx = gradient_x(depth) 259 | depth_dy = gradient_y(depth) 260 | image_dx = gradient_x(img) 261 | image_dy = gradient_y(img) 262 | weights_x = torch.exp(-(lambda_wt * torch.mean(torch.abs(image_dx), 3, keepdim=True))) 263 | weights_y = torch.exp(-(lambda_wt * torch.mean(torch.abs(image_dy), 3, keepdim=True))) 264 | # print('depth_dx: {} weights_x: {}'.format(depth_dx.shape, weights_x.shape)) 265 | # print('depth_dy: {} weights_y: {}'.format(depth_dy.shape, weights_y.shape)) 266 | smoothness_x = depth_dx * weights_x 267 | smoothness_y = depth_dy * weights_y 268 | return torch.mean(torch.abs(smoothness_x)) + torch.mean(torch.abs(smoothness_y)) 269 | 270 | 271 | def compute_reconstr_loss(warped, ref, mask, simple=True): 272 | if simple: 273 | return F.smooth_l1_loss(warped*mask, ref*mask, reduction='mean') 274 | else: 275 | alpha = 0.5 276 | ref_dx, ref_dy = gradient(ref * mask) 277 | warped_dx, warped_dy = gradient(warped * mask) 278 | photo_loss = F.smooth_l1_loss(warped*mask, ref*mask, reduction='mean') 279 | grad_loss = F.smooth_l1_loss(warped_dx, ref_dx, reduction='mean') + \ 280 | F.smooth_l1_loss(warped_dy, ref_dy, reduction='mean') 281 | return (1 - alpha) * photo_loss + alpha * grad_loss 282 | 283 | class UnSupLoss(nn.Module): 284 | def __init__(self): 285 | super(UnSupLoss, self).__init__() 286 | self.ssim = SSIM() 287 | 288 | def forward(self, imgs, cams, depth, stage_idx): 289 | # print('imgs: {}'.format(imgs.shape)) 290 | # print('cams: {}'.format(cams.shape)) 291 | # print('depth: {}'.format(depth.shape)) 292 | 293 | imgs = torch.unbind(imgs, 1) 294 | cams = torch.unbind(cams, 1) 295 | assert len(imgs) == len(cams), "Different number of images and projection matrices" 296 | img_height, img_width = imgs[0].shape[2], imgs[0].shape[3] 297 | num_views = len(imgs) 298 | 299 | ref_img = imgs[0] 300 | 301 | if stage_idx == 2: 302 | ref_img = F.interpolate(ref_img, scale_factor=0.25,recompute_scale_factor=True) 303 | elif stage_idx == 1: 304 | ref_img = F.interpolate(ref_img, scale_factor=0.5,recompute_scale_factor=True) 305 | else: 306 | pass 307 | ref_img = ref_img.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C] 308 | ref_cam = cams[0] 309 | # print('ref_cam: {}'.format(ref_cam.shape)) 310 | 311 | # depth reshape 312 | # depth = depth.unsqueeze(dim=1) # [B, 1, H, W] 313 | # depth = F.interpolate(depth, size=[img_height, img_width]) 314 | # depth = depth.squeeze(dim=1) # [B, H, W] 315 | 316 | self.reconstr_loss = 0 317 | self.ssim_loss = 0 318 | self.smooth_loss = 0 319 | 320 | warped_img_list = [] 321 | mask_list = [] 322 | reprojection_losses = [] 323 | for view in range(1, num_views): 324 | view_img = imgs[view] 325 | view_cam = cams[view] 326 | # print('view_cam: {}'.format(view_cam.shape)) 327 | # view_img = F.interpolate(view_img, scale_factor=0.25, mode='bilinear') 328 | if stage_idx == 2: 329 | view_img = F.interpolate(view_img, scale_factor=0.25,recompute_scale_factor=True) 330 | elif stage_idx == 1: 331 | view_img = F.interpolate(view_img, scale_factor=0.5,recompute_scale_factor=True) 332 | else: 333 | pass 334 | view_img = view_img.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C] 335 | # warp view_img to the ref_img using the dmap of the ref_img 336 | warped_img, mask = inverse_warping(view_img, ref_cam, view_cam, depth) 337 | warped_img_list.append(warped_img) 338 | mask_list.append(mask) 339 | 340 | reconstr_loss = compute_reconstr_loss(warped_img, ref_img, mask, simple=False) 341 | valid_mask = 1 - mask # replace all 0 values with INF 342 | reprojection_losses.append(reconstr_loss + 1e4 * valid_mask) 343 | 344 | # SSIM loss## 345 | if view < 3: 346 | self.ssim_loss += torch.mean(self.ssim(ref_img, warped_img, mask)) 347 | 348 | ##smooth loss## 349 | self.smooth_loss += depth_smoothness(depth.unsqueeze(dim=-1), ref_img, 1.0) 350 | 351 | # top-k operates along the last dimension, so swap the axes accordingly 352 | reprojection_volume = torch.stack(reprojection_losses).permute(1, 2, 3, 4, 0) 353 | # print('reprojection_volume: {}'.format(reprojection_volume.shape)) 354 | # by default, it'll return top-k largest entries, hence sorted=False to get smallest entries 355 | # top_vals, top_inds = torch.topk(torch.neg(reprojection_volume), k=3, sorted=False) 356 | top_vals, top_inds = torch.topk(torch.neg(reprojection_volume), k=1, sorted=False) 357 | top_vals = torch.neg(top_vals) 358 | # top_mask = top_vals < (1e4 * torch.ones_like(top_vals, device=device)) 359 | top_mask = top_vals < (1e4 * torch.ones_like(top_vals).cuda()) 360 | top_mask = top_mask.float() 361 | top_vals = torch.mul(top_vals, top_mask) 362 | # print('top_vals: {}'.format(top_vals.shape)) 363 | 364 | self.reconstr_loss = torch.mean(torch.sum(top_vals, dim=-1)) 365 | self.unsup_loss = 12 * self.reconstr_loss + 6 * self.ssim_loss + 0.18 * self.smooth_loss 366 | # 按照un_mvsnet和M3VSNet的设置 367 | # self.unsup_loss = (0.8 * self.reconstr_loss + 0.2 * self.ssim_loss + 0.067 * self.smooth_loss) * 15 368 | return self.unsup_loss 369 | 370 | class UnsupLossMultiStage(nn.Module): 371 | def __init__(self): 372 | super(UnsupLossMultiStage, self).__init__() 373 | self.unsup_loss = UnSupLoss() 374 | 375 | def forward(self, inputs, imgs, cams, depth_loss_weights=[2.0,1.0,0.5]): 376 | ''' 377 | inputs: dict {"level_0": (1, 8, 240, 320), "level_1": (1,, 8, 120, 160), "level_2": (1, 8, 60, 80))} 378 | imgs: (1, 8, 3, 480, 640) 379 | cams: (1, 8, 3, 2, 4, 4) 380 | ''' 381 | # depth_loss_weights = kwargs.get("dlossw", None) 382 | 383 | total_loss = torch.tensor(0.0, dtype=torch.float32, device=imgs.device, requires_grad=False) 384 | 385 | scalar_outputs = {} 386 | for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "level_" in k]: 387 | stage_idx = int(stage_key.replace("level_", "")) # (0,1,2; 0 is biggest; 2 is smallest) 388 | 389 | depth_est = stage_inputs 390 | depth_loss = 0 391 | for i in range(depth_est.shape[1]): 392 | depth_loss += self.unsup_loss(imgs.roll(-i,1), cams.roll(-i,1)[:,:,stage_idx], depth_est[:, i], stage_idx) 393 | 394 | 395 | if depth_loss_weights is not None: 396 | total_loss += depth_loss_weights[stage_idx] * depth_loss 397 | else: 398 | total_loss += 1.0 * depth_loss 399 | 400 | scalar_outputs["depth_loss_stage{}".format(stage_idx + 1)] = depth_loss 401 | scalar_outputs["reconstr_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.reconstr_loss 402 | scalar_outputs["ssim_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.ssim_loss 403 | scalar_outputs["smooth_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.smooth_loss 404 | 405 | return total_loss, scalar_outputs -------------------------------------------------------------------------------- /utils/scannet_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import glob 3 | import os 4 | import torch 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | from PIL import Image 9 | from skimage.io import imread 10 | from natsort import natsorted 11 | 12 | from utils.utils import downsample_gaussian_blur, pose_inverse 13 | 14 | 15 | # From https://github.com/open-mmlab/mmdetection3d/blob/fcb4545ce719ac121348cab59bac9b69dd1b1b59/mmdet3d/datasets/scannet_dataset.py 16 | class PointSegClassMapping(object): 17 | """Map original semantic class to valid category ids. 18 | Map valid classes as 0~len(valid_cat_ids)-1 and 19 | others as len(valid_cat_ids). 20 | Args: 21 | valid_cat_ids (tuple[int]): A tuple of valid category. 22 | max_cat_id (int, optional): The max possible cat_id in input 23 | segmentation mask. Defaults to 40. 24 | """ 25 | 26 | def __init__(self, valid_cat_ids, max_cat_id=40): 27 | assert max_cat_id >= np.max(valid_cat_ids), \ 28 | 'max_cat_id should be greater than maximum id in valid_cat_ids' 29 | 30 | self.valid_cat_ids = valid_cat_ids 31 | self.max_cat_id = int(max_cat_id) 32 | 33 | # build cat_id to class index mapping 34 | neg_cls = len(valid_cat_ids) 35 | self.cat_id2class = np.ones( 36 | self.max_cat_id + 1, dtype=int) * neg_cls 37 | for cls_idx, cat_id in enumerate(valid_cat_ids): 38 | self.cat_id2class[cat_id] = cls_idx 39 | for i in range(self.cat_id2class.shape[0]): 40 | value = self.cat_id2class[i] 41 | if value == 19: 42 | self.cat_id2class[i] = 6 43 | elif value == 20: 44 | self.cat_id2class[i] = 19 45 | 46 | def __call__(self, seg_label): 47 | """Call function to map original semantic class to valid category ids. 48 | Args: 49 | results (dict): Result dict containing point semantic masks. 50 | Returns: 51 | dict: The result dict containing the mapped category ids. 52 | Updated key and value are described below. 53 | - pts_semantic_mask (np.ndarray): Mapped semantic masks. 54 | """ 55 | seg_label = np.clip(seg_label, 0, self.max_cat_id) 56 | return self.cat_id2class[seg_label] 57 | 58 | 59 | class BaseDatabase(abc.ABC): 60 | def __init__(self, database_name): 61 | self.database_name = database_name 62 | 63 | @abc.abstractmethod 64 | def get_image(self, img_id): 65 | pass 66 | 67 | @abc.abstractmethod 68 | def get_K(self, img_id): 69 | pass 70 | 71 | @abc.abstractmethod 72 | def get_pose(self, img_id): 73 | pass 74 | 75 | @abc.abstractmethod 76 | def get_img_ids(self, check_depth_exist=False): 77 | pass 78 | 79 | @abc.abstractmethod 80 | def get_bbox(self, img_id): 81 | pass 82 | 83 | @abc.abstractmethod 84 | def get_depth(self, img_id): 85 | pass 86 | 87 | @abc.abstractmethod 88 | def get_mask(self, img_id): 89 | pass 90 | 91 | @abc.abstractmethod 92 | def get_depth_range(self, img_id): 93 | pass 94 | 95 | 96 | 97 | class ScannetDatabase(BaseDatabase): 98 | def __init__(self, database_name, root_dir='data/scannet'): 99 | super().__init__(database_name) 100 | _, self.scene_name, background_size = database_name.split('/') 101 | background, image_size = background_size.split('_') 102 | image_size = int(image_size) 103 | self.image_size = image_size 104 | self.background = background 105 | self.root_dir = f'{root_dir}/{self.scene_name}' 106 | self.ratio = image_size / 1296 107 | self.h, self.w = int(self.ratio*972), int(image_size) 108 | 109 | rgb_paths = [x for x in glob.glob(os.path.join( 110 | self.root_dir, "color", "*")) if (x.endswith(".jpg") or x.endswith(".png"))] 111 | rgb_paths = sorted(rgb_paths) 112 | 113 | K = np.loadtxt( 114 | f'{self.root_dir}/intrinsic/intrinsic_color.txt').reshape([4, 4])[:3, :3] 115 | # After resize, we need to change the intrinsic matrix 116 | K[:2, :] *= self.ratio 117 | self.K = K 118 | 119 | self.img_ids = [] 120 | for i, rgb_path in enumerate(rgb_paths): 121 | pose = self.get_pose(i) 122 | if np.isinf(pose).any() or np.isnan(pose).any(): 123 | continue 124 | self.img_ids.append(f'{i}') 125 | 126 | self.img_id2imgs = {} 127 | # mapping from scanntet class id to nyu40 class id 128 | # mapping_file = 'data/scannet/scannetv2-labels.combined.tsv' 129 | mapping_file = os.path.join(root_dir, 'scannetv2-labels.combined.tsv') 130 | mapping_file = pd.read_csv(mapping_file, sep='\t', header=0) 131 | scan_ids = mapping_file['id'].values 132 | nyu40_ids = mapping_file['nyu40id'].values 133 | scan2nyu = np.zeros(max(scan_ids) + 1, dtype=np.int32) 134 | for i in range(len(scan_ids)): 135 | scan2nyu[scan_ids[i]] = nyu40_ids[i] 136 | self.scan2nyu = scan2nyu 137 | self.label_mapping = PointSegClassMapping( 138 | valid_cat_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 139 | 11, 12, 14, 16, 24, 28, 33, 34, 36, 39], 140 | max_cat_id=40 141 | ) 142 | 143 | def get_image(self, img_id): 144 | if img_id in self.img_id2imgs: 145 | return self.img_id2imgs[img_id] 146 | img = imread(os.path.join( 147 | self.root_dir, 'color', f'{int(img_id)}.jpg')) 148 | if self.w != 1296: 149 | img = cv2.resize(downsample_gaussian_blur( 150 | img, self.ratio), (self.w, self.h), interpolation=cv2.INTER_LINEAR) 151 | 152 | return img 153 | 154 | def get_K(self, img_id): 155 | return self.K.astype(np.float32) 156 | 157 | def get_pose(self, img_id): 158 | transf = np.diag(np.asarray([1, -1, -1, 1])) 159 | pose = np.loadtxt( 160 | f'{self.root_dir}/pose/{int(img_id)}.txt').reshape([4, 4]) 161 | # pose = transf @ pose 162 | # c2w in files, change to w2c 163 | # pose = pose_inverse(pose) 164 | return pose.copy() 165 | 166 | def get_img_ids(self, check_depth_exist=False): 167 | return self.img_ids 168 | 169 | def get_bbox(self, img_id): 170 | raise NotImplementedError 171 | 172 | def get_depth(self, img_id): 173 | h, w, _ = self.get_image(img_id).shape 174 | img = Image.open(f'{self.root_dir}/depth/{int(img_id)}.png') 175 | depth = np.asarray(img, dtype=np.float32) / 1000.0 # mm -> m 176 | # depth = np.asarray(img, dtype=np.float32) 177 | depth = np.ascontiguousarray(depth, dtype=np.float32) 178 | depth = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST) 179 | return depth 180 | 181 | def get_mask(self, img_id): 182 | h, w, _ = self.get_image(img_id).shape 183 | return np.ones([h, w], bool) 184 | 185 | def get_depth_range(self, img_id): 186 | return np.asarray((0.1, 10.0), np.float32) 187 | 188 | def get_label(self, img_id): 189 | h, w, _ = self.get_image(img_id).shape 190 | img = Image.open(f'{self.root_dir}/label-filt/{int(img_id)}.png') 191 | label = np.asarray(img, dtype=np.int32) 192 | label = np.ascontiguousarray(label) 193 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST) 194 | label = label.astype(np.int32) 195 | label = self.scan2nyu[label] 196 | return self.label_mapping(label) 197 | 198 | def parse_database_name(database_name, root_dir) -> BaseDatabase: 199 | name2database = { 200 | 'scannet': ScannetDatabase, 201 | 'replica': ReplicaDatabase, 202 | } 203 | database_type = database_name.split('/')[0] 204 | if database_type in name2database: 205 | return name2database[database_type](database_name, root_dir) 206 | else: 207 | raise NotImplementedError 208 | 209 | def get_database_split(database: BaseDatabase, split_type='val'): 210 | database_name = database.database_name 211 | if split_type.startswith('val'): 212 | splits = split_type.split('_') 213 | depth_valid = not(len(splits) > 1 and splits[1] == 'all') 214 | if database_name.startswith('scannet'): 215 | img_ids = database.get_img_ids() 216 | train_ids = img_ids[:700:5] 217 | val_ids = img_ids[2:700:20] 218 | if len(val_ids) > 10: 219 | val_ids = val_ids[:10] 220 | elif database_name.startswith('replica'): 221 | img_ids = database.get_img_ids() 222 | train_ids = img_ids[:700:5] 223 | val_ids = img_ids[2:700:20] 224 | if len(val_ids) > 10: 225 | val_ids = val_ids[:10] 226 | else: 227 | raise NotImplementedError 228 | elif split_type.startswith('test'): 229 | splits = split_type.split('_') 230 | depth_valid = not(len(splits) > 1 and splits[1] == 'all') 231 | if database_name.startswith('scannet'): 232 | img_ids = database.get_img_ids() 233 | train_ids = img_ids[:700:5] 234 | val_ids = img_ids[2:700:20] 235 | if len(val_ids) > 10: 236 | val_ids = val_ids[:10] 237 | elif database_name.startswith('replica'): 238 | img_ids = database.get_img_ids() 239 | train_ids = img_ids[:700:5] 240 | val_ids = img_ids[2:700:20] 241 | if len(val_ids) > 10: 242 | val_ids = val_ids[:10] 243 | else: 244 | raise NotImplementedError 245 | elif split_type.startswith('video'): 246 | img_ids = database.get_img_ids() 247 | train_ids = img_ids[::2] 248 | val_ids = img_ids[25:-25:2] 249 | else: 250 | raise NotImplementedError 251 | print('train_ids:\n', train_ids) 252 | print('val_ids:\n', val_ids) 253 | return train_ids, val_ids 254 | 255 | 256 | def get_coords_mask(que_mask, train_ray_num, foreground_ratio): 257 | min_pos_num = int(train_ray_num * foreground_ratio) 258 | y0, x0 = np.nonzero(que_mask) 259 | y1, x1 = np.nonzero(~que_mask) 260 | xy0 = np.stack([x0, y0], 1).astype(np.float32) 261 | xy1 = np.stack([x1, y1], 1).astype(np.float32) 262 | idx = np.arange(xy0.shape[0]) 263 | np.random.shuffle(idx) 264 | xy0 = xy0[idx] 265 | coords0 = xy0[:min_pos_num] 266 | # still remain pixels 267 | if min_pos_num < train_ray_num: 268 | xy1 = np.concatenate([xy1, xy0[min_pos_num:]], 0) 269 | idx = np.arange(xy1.shape[0]) 270 | np.random.shuffle(idx) 271 | coords1 = xy1[idx[:(train_ray_num - min_pos_num)]] 272 | coords = np.concatenate([coords0, coords1], 0) 273 | else: 274 | coords = coords0 275 | return coords 276 | 277 | 278 | def color_map_forward(rgb): 279 | return rgb.astype(np.float32) / 255 280 | 281 | 282 | def pad_img_end(img, th, tw, padding_mode='edge', constant_values=0): 283 | h, w = img.shape[:2] 284 | hp = th - h 285 | wp = tw - w 286 | if hp != 0 or wp != 0: 287 | if padding_mode == 'constant': 288 | img = np.pad(img, ((0, hp), (0, wp), (0, 0)), padding_mode, constant_values=constant_values) 289 | else: 290 | img = np.pad(img, ((0, hp), (0, wp), (0, 0)), padding_mode) 291 | return img 292 | 293 | def random_crop(ref_imgs_info, que_imgs_info, target_size): 294 | imgs = ref_imgs_info['imgs'] 295 | n, _, h, w = imgs.shape 296 | out_h, out_w = target_size[0], target_size[1] 297 | if out_w >= w or out_h >= h: 298 | return ref_imgs_info 299 | 300 | center_h = np.random.randint(low=out_h // 2 + 1, high=h - out_h // 2 - 1) 301 | center_w = np.random.randint(low=out_w // 2 + 1, high=w - out_w // 2 - 1) 302 | 303 | def crop(tensor): 304 | tensor = tensor[:, :, center_h - out_h // 2:center_h + out_h // 2, 305 | center_w - out_w // 2:center_w + out_w // 2] 306 | return tensor 307 | 308 | def crop_imgs_info(imgs_info): 309 | imgs_info['imgs'] = crop(imgs_info['imgs']) 310 | if 'depth' in imgs_info: imgs_info['depth'] = crop(imgs_info['depth']) 311 | if 'true_depth' in imgs_info: imgs_info['true_depth'] = crop(imgs_info['true_depth']) 312 | if 'masks' in imgs_info: imgs_info['masks'] = crop(imgs_info['masks']) 313 | 314 | Ks = imgs_info['Ks'] # n, 3, 3 315 | h_init = center_h - out_h // 2 316 | w_init = center_w - out_w // 2 317 | Ks[:,0,2]-=w_init 318 | Ks[:,1,2]-=h_init 319 | imgs_info['Ks']=Ks 320 | return imgs_info 321 | 322 | return crop_imgs_info(ref_imgs_info), crop_imgs_info(que_imgs_info) 323 | 324 | def random_flip(ref_imgs_info,que_imgs_info): 325 | def flip(tensor): 326 | tensor = np.flip(tensor.transpose([0, 2, 3, 1]), 2) # n,h,w,3 327 | tensor = np.ascontiguousarray(tensor.transpose([0, 3, 1, 2])) 328 | return tensor 329 | 330 | def flip_imgs_info(imgs_info): 331 | imgs_info['imgs'] = flip(imgs_info['imgs']) 332 | if 'depth' in imgs_info: imgs_info['depth'] = flip(imgs_info['depth']) 333 | if 'true_depth' in imgs_info: imgs_info['true_depth'] = flip(imgs_info['true_depth']) 334 | if 'masks' in imgs_info: imgs_info['masks'] = flip(imgs_info['masks']) 335 | 336 | Ks = imgs_info['Ks'] # n, 3, 3 337 | Ks[:, 0, :] *= -1 338 | w = imgs_info['imgs'].shape[-1] 339 | Ks[:, 0, 2] += w - 1 340 | imgs_info['Ks'] = Ks 341 | return imgs_info 342 | 343 | ref_imgs_info = flip_imgs_info(ref_imgs_info) 344 | que_imgs_info = flip_imgs_info(que_imgs_info) 345 | return ref_imgs_info, que_imgs_info 346 | 347 | def pad_imgs_info(ref_imgs_info,pad_interval): 348 | ref_imgs, ref_depths, ref_masks = ref_imgs_info['imgs'], ref_imgs_info['depth'], ref_imgs_info['masks'] 349 | ref_depth_gt = ref_imgs_info['true_depth'] if 'true_depth' in ref_imgs_info else None 350 | rfn, _, h, w = ref_imgs.shape 351 | ph = (pad_interval - (h % pad_interval)) % pad_interval 352 | pw = (pad_interval - (w % pad_interval)) % pad_interval 353 | if ph != 0 or pw != 0: 354 | ref_imgs = np.pad(ref_imgs, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 355 | ref_depths = np.pad(ref_depths, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 356 | ref_masks = np.pad(ref_masks, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 357 | if ref_depth_gt is not None: 358 | ref_depth_gt = np.pad(ref_depth_gt, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 359 | ref_imgs_info['imgs'], ref_imgs_info['depth'], ref_imgs_info['masks'] = ref_imgs, ref_depths, ref_masks 360 | if ref_depth_gt is not None: 361 | ref_imgs_info['true_depth'] = ref_depth_gt 362 | return ref_imgs_info 363 | 364 | def build_imgs_info(database, ref_ids, pad_interval=-1, is_aligned=True, align_depth_range=False, has_depth=True, replace_none_depth=False, add_label=True, num_classes=0): 365 | if not is_aligned: 366 | assert has_depth 367 | rfn = len(ref_ids) 368 | ref_imgs, ref_labels, ref_masks, ref_depths, shapes = [], [], [], [], [] 369 | for ref_id in ref_ids: 370 | img = database.get_image(ref_id) 371 | if add_label: 372 | label = database.get_label(ref_id) 373 | ref_labels.append(label) 374 | shapes.append([img.shape[0], img.shape[1]]) 375 | ref_imgs.append(img) 376 | ref_masks.append(database.get_mask(ref_id)) 377 | ref_depths.append(database.get_depth(ref_id)) 378 | 379 | shapes = np.asarray(shapes) 380 | th, tw = np.max(shapes, 0) 381 | for rfi in range(rfn): 382 | ref_imgs[rfi] = pad_img_end(ref_imgs[rfi], th, tw, 'reflect') 383 | ref_labels[rfi] = pad_img_end(ref_labels[rfi], th, tw, 'reflect') 384 | ref_masks[rfi] = pad_img_end(ref_masks[rfi][:, :, None], th, tw, 'constant', 0)[..., 0] 385 | ref_depths[rfi] = pad_img_end(ref_depths[rfi][:, :, None], th, tw, 'constant', 0)[..., 0] 386 | ref_imgs = color_map_forward(np.stack(ref_imgs, 0)).transpose([0, 3, 1, 2]) 387 | ref_labels = np.stack(ref_labels, 0).transpose([0, 3, 1, 2]) 388 | ref_masks = np.stack(ref_masks, 0)[:, None, :, :] 389 | ref_depths = np.stack(ref_depths, 0)[:, None, :, :] 390 | else: 391 | ref_imgs = color_map_forward(np.asarray([database.get_image(ref_id) for ref_id in ref_ids])).transpose([0, 3, 1, 2]) 392 | ref_labels = np.asarray([database.get_label(ref_id) for ref_id in ref_ids])[:, None, :, :] 393 | ref_masks = np.asarray([database.get_mask(ref_id) for ref_id in ref_ids], dtype=np.float32)[:, None, :, :] 394 | if has_depth: 395 | ref_depths = [database.get_depth(ref_id) for ref_id in ref_ids] 396 | if replace_none_depth: 397 | b, _, h, w = ref_imgs.shape 398 | for i, depth in enumerate(ref_depths): 399 | if depth is None: ref_depths[i] = np.zeros([h, w], dtype=np.float32) 400 | ref_depths = np.asarray(ref_depths, dtype=np.float32)[:, None, :, :] 401 | else: ref_depths = None 402 | 403 | ref_poses = np.asarray([database.get_pose(ref_id) for ref_id in ref_ids], dtype=np.float32) 404 | ref_Ks = np.asarray([database.get_K(ref_id) for ref_id in ref_ids], dtype=np.float32) 405 | ref_depth_range = np.asarray([database.get_depth_range(ref_id) for ref_id in ref_ids], dtype=np.float32) 406 | if align_depth_range: 407 | ref_depth_range[:,0]=np.min(ref_depth_range[:,0]) 408 | ref_depth_range[:,1]=np.max(ref_depth_range[:,1]) 409 | ref_imgs_info = {'imgs': ref_imgs, 'poses': ref_poses, 'Ks': ref_Ks, 'depth_range': ref_depth_range, 'masks': ref_masks, 'labels': ref_labels} 410 | if has_depth: ref_imgs_info['depth'] = ref_depths 411 | if pad_interval!=-1: 412 | ref_imgs_info = pad_imgs_info(ref_imgs_info, pad_interval) 413 | return ref_imgs_info 414 | 415 | def build_render_imgs_info(que_pose,que_K,que_shape,que_depth_range): 416 | h, w = que_shape 417 | h, w = int(h), int(w) 418 | que_coords = np.stack(np.meshgrid(np.arange(w), np.arange(h), indexing='xy'), -1) 419 | que_coords = que_coords.reshape([1, -1, 2]).astype(np.float32) 420 | return {'poses': que_pose.astype(np.float32)[None,:,:], # 1,3,4 421 | 'Ks': que_K.astype(np.float32)[None,:,:], # 1,3,3 422 | 'coords': que_coords, 423 | 'depth_range': np.asarray(que_depth_range, np.float32)[None, :], 424 | 'shape': (h,w)} 425 | 426 | def imgs_info_to_torch(imgs_info): 427 | for k, v in imgs_info.items(): 428 | if isinstance(v,np.ndarray): 429 | imgs_info[k] = torch.from_numpy(v) 430 | return imgs_info 431 | 432 | def imgs_info_slice(imgs_info, indices): 433 | imgs_info_out={} 434 | for k, v in imgs_info.items(): 435 | imgs_info_out[k] = v[indices] 436 | return imgs_info_out 437 | 438 | 439 | class ReplicaDatabase(BaseDatabase): 440 | def __init__(self, database_name, root_dir='data/scannet'): 441 | super().__init__(database_name) 442 | _, self.scene_name, self.seq_id, background_size = database_name.split('/') 443 | background, image_size = background_size.split('_') 444 | self.image_size = int(image_size) 445 | self.background = background 446 | self.root_dir = f'{root_dir}/{self.scene_name}/{self.seq_id}' 447 | # self.root_dir = f'data/replica/{self.scene_name}/{self.seq_id}' 448 | self.ratio = self.image_size / 640 449 | self.h, self.w = int(self.ratio*480), int(self.image_size) 450 | 451 | rgb_paths = [x for x in glob.glob(os.path.join( 452 | self.root_dir, "rgb", "*")) if (x.endswith(".jpg") or x.endswith(".png"))] 453 | self.rgb_paths = natsorted(rgb_paths) 454 | # DO NOT use sorted() here!!! it will sort the name in a wrong way since the name is like rgb_1.jpg 455 | 456 | depth_paths = [x for x in glob.glob(os.path.join( 457 | self.root_dir, "depth", "*")) if (x.endswith(".jpg") or x.endswith(".png"))] 458 | self.depth_paths = natsorted(depth_paths) 459 | 460 | label_paths = [x for x in glob.glob(os.path.join( 461 | self.root_dir, "semantic_class", "*")) if (x.endswith(".jpg") or x.endswith(".png"))] 462 | self.label_paths = natsorted(label_paths) 463 | 464 | # Replica camera intrinsics 465 | # Pinhole Camera Model 466 | fx, fy, cx, cy, s = 320.0, 320.0, 319.5, 229.5, 0.0 467 | if self.ratio != 1.0: 468 | fx, fy, cx, cy = fx * self.ratio, fy * self.ratio, cx * self.ratio, cy * self.ratio 469 | self.K = np.array([[fx, s, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) 470 | c2ws = np.loadtxt(f'{self.root_dir}/traj_w_c.txt', 471 | delimiter=' ').reshape(-1, 4, 4).astype(np.float32) 472 | self.poses = [] 473 | transf = np.diag(np.asarray([1, -1, -1])) 474 | num_poses = c2ws.shape[0] 475 | for i in range(num_poses): 476 | pose = c2ws[i] 477 | # Change the pose to OpenGL coordinate system 478 | # TODO: check if this is correct, our code is using OpenCV coordinate system 479 | # pose = transf @ pose 480 | # pose = pose_inverse(pose) 481 | self.poses.append(pose) 482 | 483 | self.img_ids = [] 484 | for i, rgb_path in enumerate(self.rgb_paths): 485 | self.img_ids.append(i) 486 | 487 | self.label_mapping = PointSegClassMapping( 488 | valid_cat_ids=[12, 17, 20, 22, 31, 37, 40, 44, 47, 56, 489 | 64, 79, 80, 87, 91, 92, 93, 95, 97], 490 | max_cat_id=101 491 | ) 492 | 493 | def get_image(self, img_id): 494 | img = imread(self.rgb_paths[img_id]) 495 | if self.w != 640: 496 | img = cv2.resize(downsample_gaussian_blur( 497 | img, self.ratio), (self.w, self.h), interpolation=cv2.INTER_LINEAR) 498 | return img 499 | 500 | def get_K(self, img_id): 501 | return self.K 502 | 503 | def get_pose(self, img_id): 504 | pose = self.poses[img_id] 505 | return pose.copy() 506 | 507 | def get_img_ids(self, check_depth_exist=False): 508 | return self.img_ids 509 | 510 | def get_bbox(self, img_id): 511 | raise NotImplementedError 512 | 513 | def get_depth(self, img_id): 514 | h, w, _ = self.get_image(img_id).shape 515 | img = Image.open(self.depth_paths[img_id]) 516 | depth = np.asarray(img, dtype=np.float32) / 1000.0 # mm to m 517 | depth = np.ascontiguousarray(depth, dtype=np.float32) 518 | depth = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST) 519 | return depth 520 | 521 | def get_mask(self, img_id): 522 | h, w, _ = self.get_image(img_id).shape 523 | return np.ones([h, w], bool) 524 | 525 | def get_depth_range(self, img_id): 526 | return np.asarray((0.1, 6.0), np.float32) 527 | 528 | def get_label(self, img_id): 529 | h, w, _ = self.get_image(img_id).shape 530 | img = Image.open(self.label_paths[img_id]) 531 | label = np.asarray(img, dtype=np.int32) 532 | label = np.ascontiguousarray(label) 533 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST) 534 | label = label.astype(np.int32) 535 | return self.label_mapping(label) 536 | 537 | -------------------------------------------------------------------------------- /data/scannet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import numpy as np 5 | import random 6 | import time 7 | import os 8 | import cv2 9 | from glob import glob as glob 10 | from PIL import Image 11 | from torchvision import transforms as T 12 | 13 | from utils.utils import read_pfm, get_nearest_pose_ids, get_rays, compute_nearest_camera_indices 14 | from utils.scannet_utils import parse_database_name, get_database_split, get_coords_mask, random_crop, random_flip, build_imgs_info, pad_imgs_info, imgs_info_to_torch, imgs_info_slice 15 | 16 | def set_seed(index,is_train): 17 | if is_train: 18 | np.random.seed((index+int(time.time()))%(2**16)) 19 | random.seed((index+int(time.time()))%(2**16)+1) 20 | torch.random.manual_seed((index+int(time.time()))%(2**16)+1) 21 | else: 22 | np.random.seed(index % (2 ** 16)) 23 | random.seed(index % (2 ** 16) + 1) 24 | torch.random.manual_seed(index % (2 ** 16) + 1) 25 | 26 | def add_depth_offset(depth,mask,region_min,region_max,offset_min,offset_max,noise_ratio,depth_length): 27 | coords = np.stack(np.nonzero(mask), -1)[:, (1, 0)] 28 | length = np.max(coords, 0) - np.min(coords, 0) 29 | center = coords[np.random.randint(0, coords.shape[0])] 30 | lx, ly = np.random.uniform(region_min, region_max, 2) * length 31 | diff = coords - center[None, :] 32 | mask0 = np.abs(diff[:, 0]) < lx 33 | mask1 = np.abs(diff[:, 1]) < ly 34 | masked_coords = coords[mask0 & mask1] 35 | global_offset = np.random.uniform(offset_min, offset_max) * depth_length 36 | if np.random.random() < 0.5: 37 | global_offset = -global_offset 38 | local_offset = np.random.uniform(-noise_ratio, noise_ratio, masked_coords.shape[0]) * depth_length + global_offset 39 | depth[masked_coords[:, 1], masked_coords[:, 0]] += local_offset 40 | 41 | def build_src_imgs_info_select(database, ref_ids, ref_ids_all, cost_volume_nn_num, pad_interval=-1): 42 | # ref_ids - selected ref ids for rendering 43 | ref_idx_exp = compute_nearest_camera_indices(database, ref_ids, ref_ids_all) 44 | ref_idx_exp = ref_idx_exp[:, 1:1 + cost_volume_nn_num] 45 | ref_ids_all = np.asarray(ref_ids_all) 46 | ref_ids_exp = ref_ids_all[ref_idx_exp] # rfn,nn 47 | ref_ids_exp_ = ref_ids_exp.flatten() 48 | ref_ids = np.asarray(ref_ids) 49 | ref_ids_in = np.unique(np.concatenate([ref_ids_exp_, ref_ids])) # rfn' 50 | mask0 = ref_ids_in[None, :] == ref_ids[:, None] # rfn,rfn' 51 | ref_idx_, ref_idx = np.nonzero(mask0) 52 | ref_real_idx = ref_idx[np.argsort(ref_idx_)] # sort 53 | 54 | rfn, nn = ref_ids_exp.shape 55 | mask1 = ref_ids_in[None, :] == ref_ids_exp.flatten()[:, None] # nn*rfn,rfn' 56 | ref_cv_idx_, ref_cv_idx = np.nonzero(mask1) 57 | ref_cv_idx = ref_cv_idx[np.argsort(ref_cv_idx_)] # sort 58 | ref_cv_idx = ref_cv_idx.reshape([rfn, nn]) 59 | is_aligned = not database.database_name.startswith('space') 60 | ref_imgs_info = build_imgs_info(database, ref_ids_in, pad_interval, is_aligned) 61 | return ref_imgs_info, ref_cv_idx, ref_real_idx 62 | 63 | 64 | class RendererDataset(Dataset): 65 | default_cfg={ 66 | 'train_database_types':['scannet'], 67 | 'type2sample_weights': {'scannet': 1}, 68 | 'val_database_name': 'scannet/scene0200_00/black_320', 69 | 'val_database_split_type': 'val', 70 | 71 | 'min_wn': 8, 72 | 'max_wn': 9, 73 | 'ref_pad_interval': 16, 74 | 'train_ray_num': 512, 75 | 'foreground_ratio': 0.5, 76 | 'resolution_type': 'lr', 77 | "use_consistent_depth_range": True, 78 | 'use_depth_loss_for_all': False, 79 | "use_depth": True, 80 | "use_src_imgs": False, 81 | "cost_volume_nn_num": 3, 82 | 83 | "aug_depth_range_prob": 0.05, 84 | 'aug_depth_range_min': 0.95, 85 | 'aug_depth_range_max': 1.05, 86 | "aug_use_depth_offset": True, 87 | "aug_depth_offset_prob": 0.25, 88 | "aug_depth_offset_region_min": 0.05, 89 | "aug_depth_offset_region_max": 0.1, 90 | 'aug_depth_offset_min': 0.5, 91 | 'aug_depth_offset_max': 1.0, 92 | 'aug_depth_offset_local': 0.1, 93 | "aug_use_depth_small_offset": True, 94 | "aug_use_global_noise": True, 95 | "aug_global_noise_prob": 0.5, 96 | "aug_depth_small_offset_prob": 0.5, 97 | "aug_forward_crop_size": (400,600), 98 | "aug_pixel_center_sample": True, 99 | "aug_view_select_type": "easy", 100 | 101 | "use_consistent_min_max": False, 102 | "revise_depth_range": False, 103 | } 104 | def __init__(self, root_dir, is_train, cfg=None): 105 | if cfg is not None: 106 | self.cfg={**self.default_cfg,**cfg} 107 | else: 108 | self.cfg={**self.default_cfg} 109 | self.root_dir = root_dir 110 | self.is_train = is_train 111 | if is_train: 112 | self.num=999999 113 | self.type2scene_names,self.database_types,self.database_weights = {}, [], [] 114 | if self.cfg['resolution_type']=='hr': 115 | type2scene_names={ 116 | 'replica': np.loadtxt('configs/lists/replica_train_split.txt',dtype=str).tolist(), 117 | } 118 | elif self.cfg['resolution_type']=='lr': 119 | type2scene_names={ 120 | 'scannet': np.loadtxt('configs/lists/scannet_train_split.txt',dtype=str).tolist(), 121 | 'scannet_single': [self.cfg['val_database_name']], 122 | } 123 | else: 124 | raise NotImplementedError 125 | 126 | for database_type in self.cfg['train_database_types']: 127 | self.type2scene_names[database_type] = type2scene_names[database_type] 128 | self.database_types.append(database_type) 129 | self.database_weights.append(self.cfg['type2sample_weights'][database_type]) 130 | assert(len(self.database_types)>0) 131 | # normalize weights 132 | self.database_weights=np.asarray(self.database_weights) 133 | self.database_weights=self.database_weights/np.sum(self.database_weights) 134 | else: 135 | self.database = parse_database_name(self.cfg['val_database_name'], root_dir=self.root_dir) 136 | self.ref_ids, self.que_ids = get_database_split(self.database,self.cfg['val_database_split_type']) 137 | self.num=len(self.que_ids) 138 | self.database_statistics = {} 139 | 140 | self.blender2opencv = torch.tensor( 141 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 142 | ).to(torch.float32) 143 | 144 | def get_database_ref_que_ids(self, index): 145 | if self.is_train: 146 | database_type = np.random.choice(self.database_types,1,False,p=self.database_weights)[0] 147 | database_scene_name = np.random.choice(self.type2scene_names[database_type]) 148 | database = parse_database_name(database_scene_name, root_dir=self.root_dir) 149 | # if there is no depth for all views, we repeat random sample until find a scene with depth 150 | while True: 151 | ref_ids = database.get_img_ids(check_depth_exist=True) 152 | if len(ref_ids)==0: 153 | database_type = np.random.choice(self.database_types, 1, False, self.database_weights)[0] 154 | database_scene_name = np.random.choice(self.type2scene_names[database_type]) 155 | database = parse_database_name(database_scene_name, root_dir=self.root_dir) 156 | else: break 157 | que_id = np.random.choice(ref_ids) 158 | # if database.database_name.startswith('real_estate'): 159 | # que_id, ref_ids = select_train_ids_for_real_estate(ref_ids) 160 | else: 161 | database = self.database 162 | que_id, ref_ids = self.que_ids[index], self.ref_ids 163 | return database, que_id, np.asarray(ref_ids) 164 | 165 | def select_working_views_impl(self, database_name, dist_idx, ref_num): 166 | if self.cfg['aug_view_select_type']=='default': 167 | if database_name.startswith('space') or database_name.startswith('real_estate'): 168 | pass 169 | elif database_name.startswith('gso'): 170 | pool_ratio = np.random.randint(1, 5) 171 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 32)] 172 | elif database_name.startswith('real_iconic'): 173 | pool_ratio = np.random.randint(1, 5) 174 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 32)] 175 | elif database_name.startswith('dtu_train'): 176 | pool_ratio = np.random.randint(1, 3) 177 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)] 178 | elif database_name.startswith('scannet'): 179 | pool_ratio = np.random.randint(1, 3) 180 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)] 181 | elif database_name.startswith('replica'): 182 | pool_ratio = np.random.randint(1, 3) 183 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)] 184 | else: 185 | raise NotImplementedError 186 | elif self.cfg['aug_view_select_type']=='easy': 187 | if database_name.startswith('space') or database_name.startswith('real_estate'): 188 | pass 189 | elif database_name.startswith('gso'): 190 | pool_ratio = 3 191 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 24)] 192 | elif database_name.startswith('real_iconic'): 193 | pool_ratio = np.random.randint(1, 4) 194 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 20)] 195 | elif database_name.startswith('dtu_train'): 196 | pool_ratio = np.random.randint(1, 3) 197 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)] 198 | elif database_name.startswith('scannet'): 199 | pool_ratio = np.random.randint(1, 3) 200 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)] 201 | elif database_name.startswith('replica'): 202 | pool_ratio = np.random.randint(1, 3) 203 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)] 204 | else: 205 | raise NotImplementedError 206 | 207 | return dist_idx 208 | 209 | def select_working_views(self, database, que_id, ref_ids): 210 | database_name = database.database_name 211 | dist_idx = compute_nearest_camera_indices(database, [que_id], ref_ids)[0] 212 | if self.is_train: 213 | if np.random.random()>0.02: # 2% chance to include que image 214 | dist_idx = dist_idx[ref_ids[dist_idx]!=que_id] 215 | ref_num = np.random.randint(self.cfg['min_wn'], self.cfg['max_wn']) 216 | dist_idx = self.select_working_views_impl(database_name,dist_idx,ref_num) 217 | if not database_name.startswith('real_estate'): 218 | # we already select working views for real estate dataset 219 | np.random.shuffle(dist_idx) 220 | dist_idx = dist_idx[:ref_num] 221 | ref_ids = ref_ids[dist_idx] 222 | else: 223 | ref_ids = ref_ids[:ref_num] 224 | else: 225 | dist_idx = dist_idx[:self.cfg['min_wn']] 226 | ref_ids = ref_ids[dist_idx] 227 | return ref_ids 228 | 229 | def random_change_depth_range(self, depth_range): 230 | depth_range_new = depth_range.copy() 231 | if np.random.random()0 273 | coords = get_coords_mask(que_mask_cur, self.cfg['train_ray_num'], self.cfg['foreground_ratio']).reshape([1,-1,2]) 274 | return coords 275 | 276 | def consistent_depth_range(self, ref_imgs_info, que_imgs_info): 277 | depth_range_all = np.concatenate([ref_imgs_info['depth_range'], que_imgs_info['depth_range']], 0) 278 | if self.cfg['use_consistent_min_max']: 279 | depth_range_all[:, 0] = np.min(depth_range_all) 280 | depth_range_all[:, 1] = np.max(depth_range_all) 281 | else: 282 | range_len = depth_range_all[:, 1] - depth_range_all[:, 0] 283 | max_len = np.max(range_len) 284 | range_margin = (max_len - range_len) / 2 285 | ref_near = depth_range_all[:, 0] - range_margin 286 | ref_near = np.max(np.stack([ref_near, depth_range_all[:, 0] * 0.5], -1), 1) 287 | depth_range_all[:, 0] = ref_near 288 | depth_range_all[:, 1] = ref_near + max_len 289 | ref_imgs_info['depth_range'] = depth_range_all[:-1] 290 | que_imgs_info['depth_range'] = depth_range_all[-1:] 291 | 292 | 293 | def multi_scale_depth(self, depth_h): 294 | ''' 295 | This is the implementation of Klevr dataset and move here to make dataset format the same 296 | ''' 297 | 298 | depth = {} 299 | for l in range(3): 300 | 301 | depth[f"level_{l}"] = cv2.resize( 302 | depth_h, 303 | None, 304 | fx=1.0 / (2**l), 305 | fy=1.0 / (2**l), 306 | interpolation=cv2.INTER_NEAREST, 307 | ) 308 | # depth[f"level_{l}"][depth[f"level_{l}"] > far_bound * 0.95] = 0.0 309 | 310 | if self.is_train: 311 | cutout = np.ones_like(depth[f"level_2"]) 312 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1)) 313 | h1 = int( 314 | np.random.randint( 315 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1 316 | ) 317 | ) 318 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1)) 319 | w1 = int( 320 | np.random.randint( 321 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1 322 | ) 323 | ) 324 | cutout[h0:h1, w0:w1] = 0 325 | depth_aug = depth[f"level_2"] * cutout 326 | else: 327 | depth_aug = depth[f"level_2"].copy() 328 | 329 | return depth, depth_aug 330 | 331 | 332 | 333 | 334 | def __getitem__(self, index): 335 | set_seed(index, self.is_train) 336 | database, que_id, ref_ids_all = self.get_database_ref_que_ids(index) 337 | ref_ids = self.select_working_views(database, que_id, ref_ids_all) 338 | if self.cfg['use_src_imgs']: 339 | # src_imgs_info used in construction of cost volume 340 | ref_imgs_info, ref_cv_idx, ref_real_idx = build_src_imgs_info_select(database,ref_ids,ref_ids_all,self.cfg['cost_volume_nn_num']) 341 | else: 342 | ref_idx = compute_nearest_camera_indices(database, ref_ids)[:,0:4] # used in cost volume construction 343 | is_aligned = not database.database_name.startswith('space') 344 | ref_imgs_info = build_imgs_info(database, ref_ids, -1, is_aligned) 345 | # semray's implementation query image cannot access to depth, we use depth here but not as input nor supervision 346 | # que_imgs_info = build_imgs_info(database, [que_id], has_depth=self.is_train) 347 | que_imgs_info = build_imgs_info(database, [que_id]) 348 | 349 | if self.is_train: 350 | # data augmentation 351 | depth_range_all = np.concatenate([ref_imgs_info['depth_range'],que_imgs_info['depth_range']],0) 352 | 353 | depth_range_all = self.random_change_depth_range(depth_range_all) 354 | ref_imgs_info['depth_range'] = depth_range_all[:-1] 355 | que_imgs_info['depth_range'] = depth_range_all[-1:] 356 | 357 | 358 | 359 | if database.database_name.startswith('real_estate') \ 360 | or database.database_name.startswith('real_iconic') \ 361 | or database.database_name.startswith('space'): 362 | # crop all datasets 363 | ref_imgs_info, que_imgs_info = random_crop(ref_imgs_info, que_imgs_info, self.cfg['aug_forward_crop_size']) 364 | if np.random.random()<0.5: 365 | ref_imgs_info, que_imgs_info = random_flip(ref_imgs_info, que_imgs_info) 366 | 367 | if self.cfg['use_depth_loss_for_all'] and self.cfg['use_depth']: 368 | if not database.database_name.startswith('gso'): 369 | ref_imgs_info['true_depth'] = ref_imgs_info['depth'] 370 | 371 | if self.cfg['use_consistent_depth_range']: 372 | self.consistent_depth_range(ref_imgs_info, que_imgs_info) 373 | 374 | 375 | ref_imgs_info = pad_imgs_info(ref_imgs_info,self.cfg['ref_pad_interval']) 376 | 377 | # don't feed depth to gpu 378 | if not self.cfg['use_depth']: 379 | if 'depth' in ref_imgs_info: ref_imgs_info.pop('depth') 380 | if 'depth' in que_imgs_info: que_imgs_info.pop('depth') 381 | if 'true_depth' in ref_imgs_info: ref_imgs_info.pop('true_depth') 382 | 383 | if self.cfg['use_src_imgs']: 384 | src_imgs_info = ref_imgs_info.copy() 385 | ref_imgs_info = imgs_info_slice(ref_imgs_info, ref_real_idx) 386 | ref_imgs_info['nn_ids'] = ref_cv_idx 387 | else: 388 | # 'nn_ids' used in constructing cost volume (specify source image ids) 389 | ref_imgs_info['nn_ids'] = ref_idx.astype(np.int64) 390 | 391 | # ref_imgs_info = imgs_info_to_torch(ref_imgs_info) 392 | # que_imgs_info = imgs_info_to_torch(que_imgs_info) 393 | 394 | # outputs = {'ref_imgs_info': ref_imgs_info, 'que_imgs_info': que_imgs_info, 'scene_name': database.database_name} 395 | # if self.cfg['use_src_imgs']: outputs['src_imgs_info'] = imgs_info_to_torch(src_imgs_info) 396 | 397 | same_format_as_klevr = True 398 | if same_format_as_klevr: 399 | sample = {} 400 | sample['images'] = np.concatenate((ref_imgs_info['imgs'], que_imgs_info['imgs']), 0) 401 | sample['semantics'] = np.concatenate((ref_imgs_info['labels'], que_imgs_info['labels']), 0).squeeze(1) 402 | 403 | 404 | # sample['w2cs'] = np.concatenate((ref_imgs_info['poses'], que_imgs_info['poses']), 0) # (1+nb_views, 3, 4) 405 | # sample['w2cs'] = np.concatenate((sample['w2cs'], torch.ones_like(sample['w2cs'])[:,0:1,:]), 1) # (1+nb_views, 4, 4) 406 | 407 | sample['c2ws'] = np.concatenate((ref_imgs_info['poses'], que_imgs_info['poses']), 0) # (1+nb_views, 4, 4) 408 | # sample['c2ws'] = np.concatenate((sample['c2ws'], torch.ones_like(sample['c2ws'])[:,0:1,:]), 1) # (1+nb_views, 4, 4) 409 | # sample['c2ws'] = sample['c2ws'] @ self.blender2opencv 410 | 411 | 412 | sample['intrinsics'] = np.concatenate((ref_imgs_info['Ks'], que_imgs_info['Ks']), 0) 413 | sample['near_fars'] = np.concatenate((ref_imgs_info['depth_range'], que_imgs_info['depth_range']), 0) 414 | sample['depths_h'] = np.concatenate((ref_imgs_info['depth'], que_imgs_info['depth']), 0).squeeze(1) 415 | sample['closest_idxs'] = ref_imgs_info['nn_ids'] # (nb_view, 4) # used in cost volume construction # hard code to [0:4] 416 | 417 | # affine_mats (1+nb_views, 4, 4, 3) 418 | # affine_mats_inv (1+nb_views, 4, 4, 3) 419 | # depths_aug (1+nb_views, 1, H/4, W/4) 420 | # depths {dict} {'level_0': (1+nb_views, 1, H, W), 'level_1': (1+nb_views, 1, H/2, W/2), 'level_2': (1+nb_views, 1, H/4, W/4)} 421 | 422 | sample['w2cs'] = [] 423 | affine_mats, affine_mats_inv, depths_aug = [], [], [] 424 | project_mats = [] 425 | depths = {"level_0": [], "level_1": [], "level_2": []} 426 | 427 | for i in range(sample['c2ws'].shape[0]): 428 | sample['w2cs'].append(np.linalg.inv(sample['c2ws'][i])) 429 | # sample['w2cs'].append(torch.asarray(np.linalg.inv(np.asarray(sample['c2ws'][i])))) 430 | 431 | aff = [] 432 | aff_inv = [] 433 | proj_matrices = [] 434 | 435 | for l in range(3): 436 | proj_mat_l = np.eye(4) 437 | intrinsic_temp = sample['intrinsics'][i].copy() 438 | intrinsic_temp[:2] = intrinsic_temp[:2]/(2**l) 439 | proj_mat_l[:3,:4] = intrinsic_temp @ sample['w2cs'][i][:3,:4] 440 | aff.append(proj_mat_l) 441 | aff_inv.append(np.linalg.inv(proj_mat_l)) 442 | # For unsupervised depth loss 443 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) 444 | proj_mat[0, :4, :4] = sample['w2cs'][i][:4,:4] 445 | proj_mat[1, :3, :3] = intrinsic_temp 446 | proj_matrices.append(proj_mat) 447 | 448 | aff = np.stack(aff, axis=-1) 449 | aff_inv = np.stack(aff_inv, axis=-1) 450 | proj_matrices = np.stack(proj_matrices) 451 | 452 | affine_mats.append(aff) 453 | affine_mats_inv.append(aff_inv) 454 | project_mats.append(proj_matrices) 455 | 456 | depth, depth_aug = self.multi_scale_depth(np.asarray(sample['depths_h'][i])) 457 | depths["level_0"].append(depth["level_0"]) 458 | depths["level_1"].append(depth["level_1"]) 459 | depths["level_2"].append(depth["level_2"]) 460 | depths_aug.append(depth_aug) 461 | 462 | affine_mats = np.stack(affine_mats) 463 | affine_mats_inv = np.stack(affine_mats_inv) 464 | project_mats = np.stack(project_mats) 465 | depths_aug = np.stack(depths_aug) 466 | depths["level_0"] = np.stack(depths["level_0"]) 467 | depths["level_1"] = np.stack(depths["level_1"]) 468 | depths["level_2"] = np.stack(depths["level_2"]) 469 | 470 | 471 | sample['w2cs'] = np.stack(sample['w2cs'], 0) # (1+nb_views, 4, 4) 472 | sample['affine_mats'] = affine_mats 473 | sample['affine_mats_inv'] = affine_mats_inv 474 | sample['depths_aug'] = depths_aug 475 | sample['depths'] = depths 476 | sample['project_mats'] = project_mats 477 | 478 | return sample 479 | 480 | # if same_format_as_klevr == False: 481 | # return outputs 482 | 483 | def __len__(self): 484 | return self.num 485 | 486 | 487 | --------------------------------------------------------------------------------