├── data
├── __init__.py
├── get_datasets.py
├── nerf.py
├── llff.py
├── dtu.py
├── klevr.py
└── scannet.py
├── model
├── __init__.py
├── UNet.py
├── geo_reasoner.py
└── self_attn_renderer.py
├── utils
├── __init__.py
├── optimizer.py
├── loss.py
├── metrics.py
├── depth_map.py
├── options.py
├── klevr_utils.py
├── rendering.py
├── depth_loss.py
└── scannet_utils.py
├── img
└── cvpr_poster_final_5120.png
├── configs
├── lists
│ ├── replica_test_split.txt
│ ├── scannet_test_split.txt
│ ├── replica_train_split.txt
│ └── scannet_train_split.txt
├── replica.txt
└── scannet.txt
├── requirements.txt
└── README.md
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/img/cvpr_poster_final_5120.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimChou-ntu/GSNeRF/HEAD/img/cvpr_poster_final_5120.png
--------------------------------------------------------------------------------
/configs/lists/replica_test_split.txt:
--------------------------------------------------------------------------------
1 | replica/office_4/Sequence_1/black_320
2 | replica/office_4/Sequence_2/black_320
3 | replica/room_2/Sequence_1/black_320
4 | replica/room_2/Sequence_2/black_320
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning
2 | inplace_abn
3 | imageio
4 | pillow
5 | scikit-image
6 | opencv-python
7 | ConfigArgParse
8 | lpips
9 | kornia
10 | ipdb
11 | scikit-learn
12 | pandas
13 | natsort
14 | segmentation-models-pytorch
15 | wandb
--------------------------------------------------------------------------------
/configs/lists/scannet_test_split.txt:
--------------------------------------------------------------------------------
1 | scannet/scene0063_00/black_320
2 | scannet/scene0067_00/black_320
3 | scannet/scene0071_00/black_320
4 | scannet/scene0074_00/black_320
5 | scannet/scene0079_00/black_320
6 | scannet/scene0086_00/black_320
7 | scannet/scene0200_00/black_320
8 | scannet/scene0211_00/black_320
9 | scannet/scene0226_00/black_320
10 | scannet/scene0376_02/black_320
11 |
--------------------------------------------------------------------------------
/configs/lists/replica_train_split.txt:
--------------------------------------------------------------------------------
1 | replica/office_0/Sequence_1/black_320
2 | replica/office_0/Sequence_2/black_320
3 | replica/office_1/Sequence_1/black_320
4 | replica/office_1/Sequence_2/black_320
5 | replica/office_2/Sequence_1/black_320
6 | replica/office_2/Sequence_2/black_320
7 | replica/office_3/Sequence_1/black_320
8 | replica/office_3/Sequence_2/black_320
9 | replica/room_0/Sequence_1/black_320
10 | replica/room_0/Sequence_2/black_320
11 | replica/room_1/Sequence_1/black_320
12 | replica/room_1/Sequence_2/black_320
13 |
--------------------------------------------------------------------------------
/configs/replica.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | expname = replica
3 | logdir = ./logs_replica/
4 | nb_views = 8
5 |
6 | ### number of class + 1
7 | nb_class = 20
8 | ignore_label = 19
9 |
10 | ## model
11 | using_semantic_global_tokens = 1
12 | only_using_semantic_global_tokens = 0
13 | use_depth_refine_net = False
14 |
15 | ## dataset
16 | dataset_name = replica
17 | replica_path = "/mnt/sdb/timothy/Desktop/2023Spring/Semantic-Ray/data/replica"
18 | scene = None
19 | val_set_list = "configs/lists/replica_val_split.txt"
20 |
21 | ### TESTING
22 | chunk = 4096 ### Reduce it to save memory
23 |
24 | ### TRAINING
25 | ### num_steps = 250000
26 | num_steps = 300000
27 | lrate = 0.0005
28 | logger = wandb
29 | batch_size = 1024
30 | two_stage_training_steps = 0
31 | cross_entropy_weight = 0.5
32 | background_weight = 0.8
33 | use_batch_semantic_feature = True
34 | feat_net = smp_UNet
--------------------------------------------------------------------------------
/configs/scannet.txt:
--------------------------------------------------------------------------------
1 | ### INPUT
2 | expname = scannet
3 | logdir = ./logs_scannet/
4 | nb_views = 8
5 |
6 | ### number of class + 1
7 | nb_class = 21
8 | ignore_label = 20
9 |
10 | ## model
11 | using_semantic_global_tokens = 1
12 | only_using_semantic_global_tokens = 0
13 | use_depth_refine_net = False
14 | feat_net = smp_UNet
15 |
16 | ## dataset
17 | dataset_name = scannet
18 | scannet_path = "/mnt/sdb/timothy/Desktop/2023Spring/Semantic-Ray/data/scannet"
19 | scene = None
20 | val_set_list = "configs/lists/scannet_test_split.txt"
21 |
22 | ### TESTING
23 | chunk = 4096 ### Reduce it to save memory
24 |
25 | ### TRAINING
26 | ### num_steps = 250000
27 | num_steps = 300000
28 | lrate = 0.0005
29 | logger = wandb
30 | batch_size = 1024
31 | two_stage_training_steps = 0
32 | cross_entropy_weight = 0.5
33 | background_weight = 1.0
34 | use_batch_semantic_feature = True
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.optim import SGD, Adam
3 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
4 |
5 | def get_optimizer(hparams, models):
6 | eps = 1e-5
7 | parameters = []
8 | for model in models:
9 | parameters += list(model.parameters())
10 | if hparams.optimizer == 'sgd':
11 | optimizer = SGD(parameters, lr=hparams.lrate,
12 | weight_decay=1e-5)
13 | elif hparams.optimizer == 'adam':
14 | optimizer = Adam(parameters, lr=hparams.lrate,
15 | betas=(0.9, 0.999),
16 | weight_decay=1e-5)
17 | else:
18 | raise ValueError('optimizer not recognized!')
19 |
20 | return optimizer
21 |
22 | def get_scheduler(hparams, optimizer):
23 | eps = 1e-5
24 | # if hparams.lr_scheduler == 'steplr':
25 | # scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step,
26 | # gamma=hparams.decay_gamma)
27 | # elif hparams.lr_scheduler == 'cosine':
28 | # scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_steps, eta_min=eps)
29 |
30 | # else:
31 | # raise ValueError('scheduler not recognized!')
32 | num_steps = hparams.num_steps
33 | if hparams.ddp:
34 | num_steps /= 8
35 | return CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=eps)
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | ### from semray
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Loss:
8 | def __init__(self, keys):
9 | """
10 | keys are used in multi-gpu model, DummyLoss in train_tools.py
11 | :param keys: the output keys of the dict
12 | """
13 | self.keys = keys
14 |
15 | def __call__(self, data_pr, data_gt, **kwargs):
16 | pass
17 |
18 | class SemanticLoss(Loss):
19 | def __init__(self, nb_class, ignore_label, weight=None):
20 | super().__init__(['loss_semantic'])
21 | self.nb_class = nb_class
22 | self.ignore_label = ignore_label
23 | self.weight = weight
24 |
25 | def __call__(self, data_pr, data_gt, **kwargs):
26 | def compute_loss(label_pr, label_gt):
27 | label_pr = label_pr.reshape(-1, self.nb_class)
28 | label_gt = label_gt.reshape(-1).long()
29 | valid_mask = (label_gt != self.ignore_label)
30 | label_pr = label_pr[valid_mask]
31 | label_gt = label_gt[valid_mask]
32 | if self.weight != None:
33 | self.weight = self.weight.to(label_pr.device)
34 | return nn.functional.cross_entropy(label_pr, label_gt, reduction='mean', weight=self.weight).unsqueeze(0)
35 | else:
36 | return nn.functional.cross_entropy(label_pr, label_gt, reduction='mean').unsqueeze(0)
37 |
38 | loss = compute_loss(data_pr, data_gt)
39 |
40 | return loss
--------------------------------------------------------------------------------
/configs/lists/scannet_train_split.txt:
--------------------------------------------------------------------------------
1 | scannet/scene0000_00/black_320
2 | scannet/scene0001_00/black_320
3 | scannet/scene0002_00/black_320
4 | scannet/scene0003_00/black_320
5 | scannet/scene0004_00/black_320
6 | scannet/scene0005_00/black_320
7 | scannet/scene0006_00/black_320
8 | scannet/scene0007_00/black_320
9 | scannet/scene0008_00/black_320
10 | scannet/scene0009_00/black_320
11 | scannet/scene0010_00/black_320
12 | scannet/scene0011_00/black_320
13 | scannet/scene0012_00/black_320
14 | scannet/scene0013_00/black_320
15 | scannet/scene0014_00/black_320
16 | scannet/scene0015_00/black_320
17 | scannet/scene0016_00/black_320
18 | scannet/scene0017_00/black_320
19 | scannet/scene0018_00/black_320
20 | scannet/scene0019_00/black_320
21 | scannet/scene0020_00/black_320
22 | scannet/scene0021_00/black_320
23 | scannet/scene0022_00/black_320
24 | scannet/scene0023_00/black_320
25 | scannet/scene0024_00/black_320
26 | scannet/scene0025_00/black_320
27 | scannet/scene0026_00/black_320
28 | scannet/scene0027_00/black_320
29 | scannet/scene0028_00/black_320
30 | scannet/scene0029_00/black_320
31 | scannet/scene0030_00/black_320
32 | scannet/scene0031_00/black_320
33 | scannet/scene0032_00/black_320
34 | scannet/scene0033_00/black_320
35 | scannet/scene0034_00/black_320
36 | scannet/scene0035_00/black_320
37 | scannet/scene0036_00/black_320
38 | scannet/scene0037_00/black_320
39 | scannet/scene0038_00/black_320
40 | scannet/scene0039_00/black_320
41 | scannet/scene0040_00/black_320
42 | scannet/scene0041_00/black_320
43 | scannet/scene0042_00/black_320
44 | scannet/scene0043_00/black_320
45 | scannet/scene0044_00/black_320
46 | scannet/scene0045_00/black_320
47 | scannet/scene0046_00/black_320
48 | scannet/scene0047_00/black_320
49 | scannet/scene0048_00/black_320
50 | scannet/scene0049_00/black_320
51 | scannet/scene0050_00/black_320
52 | scannet/scene0051_00/black_320
53 | scannet/scene0052_00/black_320
54 | scannet/scene0053_00/black_320
55 | scannet/scene0054_00/black_320
56 | scannet/scene0055_00/black_320
57 | scannet/scene0056_00/black_320
58 | scannet/scene0057_00/black_320
59 | scannet/scene0058_00/black_320
60 | scannet/scene0059_00/black_320
61 | scannet/scene0060_00/black_320
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.metrics import confusion_matrix
4 |
5 | def nanmean(data, **args):
6 | # This makes it ignore the first 'background' class
7 | return np.ma.masked_array(data, np.isnan(data)).mean(**args)
8 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean()
9 |
10 |
11 | def calculate_segmentation_metrics(true_labels, predicted_labels, number_classes, ignore_label=-1):
12 | '''
13 | return:
14 | miou: the miou of all classes (without nan classes / not existing classes)
15 | valid_miou: the miou of valid classes (without ignore_class and not existing classes)
16 | class_average_accuracy: the average accuracy of all classes
17 | total_accuracy: the accuracy of all classes
18 | ious: per class iou
19 | '''
20 |
21 | np.seterr(divide='ignore', invalid='ignore')
22 | if (true_labels == ignore_label).all():
23 | return [0]*5
24 |
25 | true_labels = true_labels.flatten().cpu().numpy()
26 | predicted_labels = predicted_labels.flatten().cpu().numpy()
27 | valid_pix_ids = true_labels!=ignore_label
28 | predicted_labels = predicted_labels[valid_pix_ids]
29 | true_labels = true_labels[valid_pix_ids]
30 |
31 | conf_mat = confusion_matrix(true_labels, predicted_labels, labels=list(range(number_classes)))
32 | norm_conf_mat = np.transpose(
33 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1))
34 |
35 | missing_class_mask = np.isnan(norm_conf_mat.sum(1)) # missing class will have NaN at corresponding class
36 | exsiting_class_mask = ~ missing_class_mask
37 |
38 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat))
39 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat))
40 | ious = np.zeros(number_classes)
41 | for class_id in range(number_classes):
42 | ious[class_id] = (conf_mat[class_id, class_id] / (
43 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) -
44 | conf_mat[class_id, class_id]))
45 | miou = nanmean(ious)
46 | miou_valid_class = np.mean(ious[exsiting_class_mask])
47 | return miou, miou_valid_class, total_accuracy, class_average_accuracy, ious
48 |
49 |
50 | # From https://github.com/Harry-Zhi/semantic_nerf/blob/a0113bb08dc6499187c7c48c3f784c2764b8abf1/SSR/training/training_utils.py
51 | class IoU():
52 |
53 | def __init__(self, ignore_label=-1, num_classes=20):
54 | self.ignore_label = ignore_label
55 | self.num_classes = num_classes
56 |
57 | def __call__(self, true_labels, predicted_labels):
58 | np.seterr(divide='ignore', invalid='ignore')
59 | true_labels = true_labels.long().detach().cpu().numpy()
60 | predicted_labels = predicted_labels.long().detach().cpu().numpy()
61 |
62 | if self.ignore_label != -1:
63 | valid_pix_ids = true_labels != self.ignore_label
64 | else:
65 | valid_pix_ids = np.ones_like(true_labels, dtype=bool)
66 |
67 | num_classes = self.num_classes
68 | predicted_labels = predicted_labels[valid_pix_ids]
69 | true_labels = true_labels[valid_pix_ids]
70 |
71 | conf_mat = confusion_matrix(
72 | true_labels, predicted_labels, labels=list(range(num_classes)))
73 | norm_conf_mat = np.transpose(np.transpose(
74 | conf_mat) / conf_mat.astype(float).sum(axis=1))
75 |
76 | # missing class will have NaN at corresponding class
77 | missing_class_mask = np.isnan(norm_conf_mat.sum(1))
78 | exsiting_class_mask = ~ missing_class_mask
79 |
80 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat))
81 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat))
82 | ious = np.zeros(num_classes)
83 | for class_id in range(num_classes):
84 | ious[class_id] = (conf_mat[class_id, class_id] / (
85 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) -
86 | conf_mat[class_id, class_id]))
87 | miou = np.mean(ious[exsiting_class_mask])
88 | if np.isnan(miou):
89 | miou = 0.
90 | total_accuracy = 0.
91 | class_average_accuracy = 0.
92 | output = {
93 | 'miou': torch.tensor([miou], dtype=torch.float32),
94 | 'total_accuracy': torch.tensor([total_accuracy], dtype=torch.float32),
95 | 'class_average_accuracy': torch.tensor([class_average_accuracy], dtype=torch.float32)
96 | }
97 | return output
--------------------------------------------------------------------------------
/utils/depth_map.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | # This depth map seems to be slow and inaccurate
4 | def dense_map(Pts, n, m, grid):
5 | ng = 2 * grid + 1
6 |
7 | mX = 100000*torch.ones((m,n)).to(Pts.dtype).to(Pts.device)
8 | mY = 100000*torch.ones((m,n)).to(Pts.dtype).to(Pts.device)
9 | mD = torch.zeros((m,n)).to(Pts.device)
10 | y_ = Pts[1].to(torch.int32)
11 | x_ = Pts[0].to(torch.int32)
12 | int_x = torch.round(Pts[0]).to(Pts.device)
13 | int_y = torch.round(Pts[1]).to(Pts.device)
14 | mX[y_,x_] = Pts[0] - int_x
15 | mY[y_,x_] = Pts[1] - int_y
16 | mD[y_,x_] = Pts[2]
17 |
18 | KmX = torch.zeros((ng, ng, m - ng, n - ng)).to(Pts.device)
19 | KmY = torch.zeros((ng, ng, m - ng, n - ng)).to(Pts.device)
20 | KmD = torch.zeros((ng, ng, m - ng, n - ng)).to(Pts.device)
21 |
22 | for i in range(ng):
23 | for j in range(ng):
24 | KmX[i,j] = mX[i : (m - ng + i), j : (n - ng + j)] - grid - 1 +i
25 | KmY[i,j] = mY[i : (m - ng + i), j : (n - ng + j)] - grid - 1 +i
26 | KmD[i,j] = mD[i : (m - ng + i), j : (n - ng + j)]
27 | S = torch.zeros_like(KmD[0,0]).to(Pts.device)
28 | Y = torch.zeros_like(KmD[0,0]).to(Pts.device)
29 |
30 | for i in range(ng):
31 | for j in range(ng):
32 | s = 1/torch.sqrt(KmX[i,j] * KmX[i,j] + KmY[i,j] * KmY[i,j])
33 | Y = Y + s * KmD[i,j]
34 | S = S + s
35 | del s
36 |
37 | S[S == 0] = 1
38 | # out = torch.zeros((m,n)).to(Pts.device)
39 | # set to the far range of the dataset, TODO: change this to be more general
40 | out = torch.ones((m,n)).to(Pts.device)*10.0
41 | # incase Y and S goes too big and becomes inf, we set inf to the biggest value of Y/S respectively
42 | Y = torch.nan_to_num(Y, posinf=Y[Y!=torch.inf].max().item())
43 | S = torch.nan_to_num(S, posinf=S[S!=torch.inf].max().item())
44 | out[grid + 1 : -grid, grid + 1 : -grid] = Y/S
45 | out = F.pad(out[grid + 1 : -grid, grid + 1 : -grid].unsqueeze(0), (grid+1, grid, grid+1, grid), mode='replicate').squeeze(0)
46 | if torch.isnan(out).any():
47 | print('Nan in depth map')
48 | del mX, mY, mD, KmX, KmY, KmD, S, Y, y_, x_, int_x, int_y
49 | return out
50 |
51 |
52 | def get_target_view_depth(source_depths, source_intrinsics, source_c2ws, target_intrinsics, target_w2c, img_wh, grid_size):
53 | '''
54 | source_depth: [N, H, W]
55 | source_intrinsics: [N, 3, 3]
56 | source_c2ws: [N, 4, 4]
57 | target_intrinsics: [3, 3]
58 | target_w2c: [4, 4]
59 | img_wh: [2]
60 | grid_size: int
61 | return: depth map [H, W]
62 | '''
63 | W, H = img_wh
64 | N = source_depths.shape[0]
65 | points = []
66 |
67 | ys, xs = torch.meshgrid(
68 | torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W), indexing="ij"
69 | ) # pytorch's meshgrid has indexing='ij'
70 | ys, xs = ys.reshape(-1).to(source_intrinsics.device), xs.reshape(-1).to(source_intrinsics.device)
71 |
72 | for num in range(N):
73 | # Might need to change this to be more general (too small or too big value are not good)
74 | mask = source_depths[num] > 0
75 |
76 | dirs = torch.stack(
77 | [
78 | (xs - source_intrinsics[num][0, 2]) / source_intrinsics[num][0, 0],
79 | (ys - source_intrinsics[num][1, 2]) / source_intrinsics[num][1, 1],
80 | torch.ones_like(xs),
81 | ],
82 | -1,
83 | )
84 | rays_dir = (
85 | dirs @ source_c2ws[num][:3, :3].t()
86 | )
87 | rays_orig = source_c2ws[num][:3, -1].clone().reshape(1, 3).expand(rays_dir.shape[0], -1)
88 | rays_orig = rays_orig.reshape(H,W,-1)[mask]
89 | rays_depth = source_depths[num].reshape(H,W,-1)[mask]
90 | rays_dir = rays_dir.reshape(H,W,-1)[mask]
91 | ray_pts = rays_orig + rays_depth * rays_dir
92 | points.append(ray_pts.reshape(-1,3))
93 |
94 | del rays_orig, rays_depth, rays_dir, ray_pts, dirs, mask
95 |
96 | points = torch.cat(points,0).reshape(-1,3)
97 |
98 | R = target_w2c[:3, :3] # (3, 3)
99 | T = target_w2c[:3, 3:] # (3, 1)
100 | ray_pts_transformed = torch.matmul(points, R.t()) + T.reshape(1, 3)
101 |
102 | ray_pts_ndc = ray_pts_transformed @ target_intrinsics.t()
103 | ndc = ray_pts_ndc[:, :2] / ray_pts_ndc[:, -1:]
104 | # ray_pts_ndc[:, 0] = ray_pts_ndc[:, 0] / ray_pts_ndc[:, 2]
105 | # ray_pts_ndc[:, 1] = ray_pts_ndc[:, 1] / ray_pts_ndc[:, 2]
106 | mask = (ndc[:, 0] >= 0) & (ndc[:, 0] <= W-1) & (ndc[:, 1] >= 0) & (ndc[:, 1] <= H-1)
107 | # when doing scannet dataset this is not necessary, cause the ndc depth more than 2
108 | # mask = mask & (ray_pts_transformed[:, 2] > 2)
109 | points_2d = ndc[mask, 0:2]
110 |
111 | lidarOnImage = torch.cat((points_2d, ray_pts_transformed[mask,2].reshape(-1,1)), 1)
112 | if torch.isnan(lidarOnImage).any():
113 | print("lidarOnImage has nan")
114 | depth_map = dense_map(lidarOnImage.t(), W, H, grid_size)
115 | # del ray_pts_ndc, target_intrinsics, ray_pts_transformed
116 | # del ys, xs, ray_pts_transformed, points, points_2d, ray_pts_ndc, mask, lidarOnImage, R, T
117 | # depth_map = torch.ones_like(source_depths[0]).to("cuda")
118 |
119 | return depth_map
120 |
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 |
3 | def config_parser():
4 | parser = configargparse.ArgumentParser()
5 | parser.add_argument("--config", is_config_file=True, help="Config file path")
6 |
7 | # Task options
8 | parser.add_argument("--segmentation", action="store_true", help="Use segmentation mask for training")
9 | parser.add_argument("--nb_class", type=int, default=21, help="Number of classes for segmentation")
10 | parser.add_argument("--ignore_label", type=int, default=20, help="Ignore label for segmentation")
11 |
12 | # Datasets options
13 | parser.add_argument("--dataset_name", type=str, default="llff", choices=["llff", "nerf", "dtu", "klevr", "scannet", "replica"],)
14 | parser.add_argument("--llff_path", type=str, help="Path to llff dataset")
15 | parser.add_argument("--llff_test_path", type=str, help="Path to llff dataset")
16 | parser.add_argument("--dtu_path", type=str, help="Path to dtu dataset")
17 | parser.add_argument("--dtu_pre_path", type=str, help="Path to preprocessed dtu dataset")
18 | parser.add_argument("--nerf_path", type=str, help="Path to nerf dataset")
19 | parser.add_argument("--ams_path", type=str, help="Path to ams dataset")
20 | parser.add_argument("--ibrnet1_path", type=str, help="Path to ibrnet1 dataset")
21 | parser.add_argument("--ibrnet2_path", type=str, help="Path to ibrnet2 dataset")
22 | parser.add_argument("--klevr_path", type=str, help="Path to klevr dataset")
23 | parser.add_argument("--scannet_path", type=str, help="Path to scannet dataset")
24 | parser.add_argument("--replica_path", type=str, help="Path to replica dataset")
25 |
26 | # for scannet dataset
27 | parser.add_argument("--val_set_list", type=str, help="Path to scannet val dataset list")
28 |
29 | # Training options
30 | parser.add_argument("--batch_size", type=int, default=512)
31 | parser.add_argument("--num_steps", type=int, default=200000)
32 | parser.add_argument("--nb_views", type=int, default=3)
33 | parser.add_argument("--lrate", type=float, default=5e-4, help="Learning rate")
34 | parser.add_argument("--warmup_steps", type=int, default=500, help="Gradually warm-up learning rate in optimizer")
35 | parser.add_argument("--scene", type=str, default="None", help="Scene for fine-tuning")
36 | parser.add_argument("--cross_entropy_weight", type=float, default=0.1, help="Weight for cross entropy loss")
37 | parser.add_argument("--optimizer", type=str, default="adam", help="select optimizer: adam / sgd")
38 | parser.add_argument("--background_weight", type=float, default=1, help="Weight for background class in cross entropy loss")
39 | parser.add_argument("--two_stage_training_steps", type=int, default=60000, help="Use two stage training, indicating how many steps for first stage")
40 | parser.add_argument("--self_supervised_depth_loss", action="store_true", help="Use self supervised depth loss")
41 |
42 | # Rendering options
43 | parser.add_argument("--chunk", type=int, default=4096, help="Number of rays rendered in parallel")
44 | parser.add_argument("--nb_coarse", type=int, default=96, help="Number of coarse samples per ray")
45 | parser.add_argument("--nb_fine", type=int, default=32, help="Number of additional fine samples per ray",)
46 |
47 | # Other options
48 | parser.add_argument("--expname", type=str, help="Experiment name")
49 | parser.add_argument("--logger", type=str, default="wandb", choices=["wandb", "tensorboard", "none"])
50 | parser.add_argument("--logdir", type=str, default="./logs/", help="Where to store ckpts and logs")
51 | parser.add_argument("--eval", action="store_true", help="Render and evaluate the test set")
52 | parser.add_argument("--use_depth", action="store_true", help="Use ground truth low-res depth maps in rendering process")
53 | parser.add_argument("--seed", type=int, default=123, help="Random seed")
54 | parser.add_argument("--val_save_img_type", default=["target"], action="append", help="choices=[target, depth, source], Save target comparison images or depth maps or source images")
55 | parser.add_argument("--target_depth_estimation", action="store_true", help="Use target depth estimation in rendering process")
56 | parser.add_argument("--use_depth_refine_net", action="store_true", help="Use depth refine net before rendering process")
57 | parser.add_argument("--using_semantic_global_tokens", type=int, default=0, help="Use only semantic global tokens in rendering process. 0: not use, 1: use")
58 | parser.add_argument("--only_using_semantic_global_tokens", type=int, default=0, help="Use only semantic global tokens in rendering process. 0: not use, 1: use")
59 | parser.add_argument("--use_batch_semantic_feature", action="store_true", help="Use batch semantic feature in rendering process")
60 | parser.add_argument("--ddp", action="store_true", help="Use distributed data parallel")
61 | parser.add_argument("--feat_net", type=str, default="UNet", choices=["UNet", "smp_UNet"], help="FeatureNet used in depth estimation")
62 | # resume options
63 | parser.add_argument("--ckpt_path", type=str, default=None, help="Path to a checkpoint to resume training")
64 | parser.add_argument("--finetune", action="store_true", help="Finetune the model with a checkpoint")
65 | parser.add_argument("--fintune_scene", type=str, default="None", help="Scene for fine-tuning")
66 | return parser.parse_args()
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GSNeRF: Enhancing 3D Scene Understanding with Generalizable Semantic Neural Radiance Fields
2 | > Zi-Ting Chou, Sheng-Yu Huang, I-Jieh Liu, Yu-Chiang Wang
3 | > [Project Page](https://timchou-ntu.github.io/gsnerf/) | [Paper](https://arxiv.org/abs/2403.03608)
4 |
5 | This repository contains a official PyTorch Lightning implementation of our paper, GSNeRF (CVPR 2024).
6 |
7 |
8 |

9 |
10 |
11 | ## Installation
12 |
13 | #### Tested on NVIDIA GeForce RTX 3090 GPUs with cuda 11.7, PyTorch 2.0.1 and PyTorch Lightning 2.0.4
14 |
15 | To install the dependencies, in addition to PyTorch, run:
16 |
17 | ```
18 | git clone --recursive https://github.com/TimChou-ntu/GSNeRF.git
19 | conda create -n gsnerf python=3.9
20 | conda activate gsnerf
21 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117
22 | cd GSNeRF
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## Evaluation and Training
27 | Following [Semantic Nerf](https://github.com/Harry-Zhi/semantic_nerf) and [Semantic-Ray](https://github.com/liuff19/Semantic-Ray), we conduct experiment on [ScanNet](#scannet-real-world-indoor-scene-dataset) and [Replica](#replica-synthetic-indoor-scene-dataset) respectively.
28 |
29 | Download `scannet` from [here](https://github.com/ScanNet/ScanNet) and set its path as `scannet_path` in the [scannet.txt](./configs/scannet.txt) file.
30 |
31 | Download `Replica` from [here](https://www.dropbox.com/sh/9yu1elddll00sdl/AAC-rSJdLX0C6HhKXGKMOIija?dl=0) and set its path as `replica_path` in the [replica.txt](configs/replica.txt) file. (Thanks [Semantic Nerf](https://github.com/Harry-Zhi/semantic_nerf) for rendering 2D image and semantic map.)
32 |
33 | Organize the data in the following structure:
34 | ```
35 | ├── data
36 | │ ├── scannet
37 | │ │ ├── scene0000_00
38 | │ │ │ ├── color
39 | │ │ │ │ ├── 0.jpg
40 | │ │ │ │ ├── ...
41 | │ │ │ ├── depth
42 | │ │ │ │ ├── 0.png
43 | │ │ │ │ ├── ...
44 | │ │ │ ├── label-filt
45 | │ │ │ │ ├── 0.png
46 | │ │ │ │ ├── ...
47 | │ │ │ ├── pose
48 | │ │ │ │ ├── 0.txt
49 | │ │ │ │ ├── ...
50 | │ │ │ ├── intrinsic
51 | │ │ │ │ ├── extrinsic_color.txt
52 | │ │ │ │ ├── intrinsic_color.txt
53 | │ │ │ │ ├── ...
54 | │ │ │ ├── ...
55 | │ │ ├── ...
56 | │ │ ├── scannetv2-labels.combined.tsv
57 | | |
58 | │ ├── replica
59 | │ │ ├── office_0
60 | │ │ │ ├── Sequence_1
61 | │ │ │ │ ├── depth
62 | | │ │ │ │ ├── depth_0.png
63 | | │ │ │ │ ├── ...
64 | │ │ │ │ ├── rgb
65 | | │ │ │ │ ├── rgb_0.png
66 | | │ │ │ │ ├── ...
67 | │ │ │ │ ├── semantic_class
68 | | │ │ │ │ ├── semantic_class_0.png
69 | | │ │ │ │ ├── ...
70 | │ │ │ │ ├── traj_w_c.txt
71 | │ │ ├── ...
72 | │ │ ├── semantic_info
73 | ```
74 | ## ScanNet (real-world indoor scene) Dataset
75 |
76 | For training a generalizable model, set the number of source views to 8 (nb_views = 8) in the [scannet.txt](./configs/scannet.txt) file and run the following command:
77 |
78 | ```
79 | python train.py --config configs/scannet.txt --segmentation --logger wandb --target_depth_estimation
80 | ```
81 |
82 | For evaluation on a novel scene, run the following command: (replace [ckpt path] with your trained checkpoint path.)
83 |
84 | ```
85 | python train.py --config configs/scannet.txt --segmentation --logger none --target_depth_estimation --ckpt_path [ckpt path] --eval
86 | ```
87 |
88 | ## Replica (Synthetic indoor scene) Dataset
89 |
90 | For training a generalizable model, set the number of source views to 8 (nb_views = 8) in the [replica.txt](./configs/replica.txt) file and run the following command:
91 |
92 | ```
93 | python train.py --config configs/replica.txt --segmentation --logger wandb --target_depth_estimation
94 | ```
95 |
96 | For evaluation on a novel scene, run the following command: (replace [ckpt path] with your trained checkpoint path.)
97 |
98 | ```
99 | python train.py --config configs/replica.txt --segmentation --logger none --target_depth_estimation --ckpt_path [ckpt path] --eval
100 | ```
101 |
102 | ### Self-supervised depth model
103 | Simply add --self_supervised_depth_loss at the end of command.
104 |
105 | ### Contact
106 | You can contact the author through email: A88551212@gmail.com
107 |
108 | ## Citing
109 | If you find our work useful, please consider citing:
110 | ```BibTeX
111 | @inproceedings{Chou2024gsnerf,
112 | author = {Zi‑Ting Chou* and Sheng‑Yu Huang* and I‑Jieh Liu and Yu‑Chiang Frank Wang},
113 | title = {GSNeRF: Generalizable Semantic Neural Radiance Fields with Enhanced 3D Scene Understanding},
114 | booktitle = CVPR,
115 | year = {2024},
116 | arxiv = {2403.03608},
117 | }
118 | ```
119 |
120 | ### Acknowledgement
121 |
122 | Some portions of the code were derived from [GeoNeRF](https://github.com/idiap/GeoNeRF).
123 |
124 | Additionally, the well-structured codebases of [nerf_pl](https://github.com/kwea123/nerf_pl), [nesf](https://nesf3d.github.io/), and [RC-MVSNet](https://github.com/Boese0601/RC-MVSNet) were extremely helpful during the experiment. Shout out to them for their contributions.
--------------------------------------------------------------------------------
/utils/klevr_utils.py:
--------------------------------------------------------------------------------
1 | from scipy.spatial import transform
2 | import numpy as np
3 | import torch.nn.functional as F
4 | import torch
5 | import json
6 | import os
7 |
8 | def blender_quat2rot(quaternion):
9 | """Convert quaternion to rotation matrix.
10 | Equivalent to, but support batched case:
11 | ```python
12 | rot3x3 = mathutils.Quaternion(quaternion).to_matrix()
13 | ```
14 | Args:
15 | quaternion:
16 | Returns:
17 | rotation matrix
18 | """
19 |
20 | # Note: Blender first cast to double values for numerical precision while
21 | # we're using float32.
22 | q = np.sqrt(2) * quaternion
23 |
24 | q0 = q[..., 0]
25 | q1 = q[..., 1]
26 | q2 = q[..., 2]
27 | q3 = q[..., 3]
28 |
29 | qda = q0 * q1
30 | qdb = q0 * q2
31 | qdc = q0 * q3
32 | qaa = q1 * q1
33 | qab = q1 * q2
34 | qac = q1 * q3
35 | qbb = q2 * q2
36 | qbc = q2 * q3
37 | qcc = q3 * q3
38 |
39 | # Note: idx are inverted as blender and numpy convensions do not
40 | # match (x, y) -> (y, x)
41 | rotation = np.empty((*quaternion.shape[:-1], 3, 3), dtype=np.float32)
42 | rotation[..., 0, 0] = 1.0 - qbb - qcc
43 | rotation[..., 1, 0] = qdc + qab
44 | rotation[..., 2, 0] = -qdb + qac
45 |
46 | rotation[..., 0, 1] = -qdc + qab
47 | rotation[..., 1, 1] = 1.0 - qaa - qcc
48 | rotation[..., 2, 1] = qda + qbc
49 |
50 | rotation[..., 0, 2] = qdb + qac
51 | rotation[..., 1, 2] = -qda + qbc
52 | rotation[..., 2, 2] = 1.0 - qaa - qbb
53 | return rotation
54 |
55 | def make_transform_matrix(positions,rotations,):
56 | """Create the 4x4 transformation matrix.
57 | Note: This function uses numpy.
58 | Args:
59 | positions: Translation applied after the rotation.
60 | Last column of the transformation matrix
61 | rotations: Rotation. Top-left 3x3 matrix of the transformation matrix.
62 | Returns:
63 | transformation_matrix:
64 | """
65 | # Create the 4x4 transformation matrix
66 | rot_pos = np.broadcast_to(np.eye(4), (*positions.shape[:-1], 4, 4)).copy()
67 | rot_pos[..., :3, :3] = rotations
68 | # Note: Blender and numpy use different convensions for the translation
69 | rot_pos[..., :3, 3] = positions
70 | return rot_pos
71 |
72 | def from_position_and_quaternion(positions, quaternions, use_unreal_axes):
73 | if use_unreal_axes:
74 | rotations = transform.Rotation.from_quat(quaternions).as_matrix()
75 | else:
76 | # Rotation matrix that rotates from world to object coordinates.
77 | # Warning: Rotations should be given in blender convensions as
78 | # scipy.transform uses different convensions.
79 | rotations = blender_quat2rot(quaternions)
80 | px2world_transform = make_transform_matrix(positions=positions,rotations=rotations)
81 | return px2world_transform
82 |
83 | def scale_rays(all_rays_o, all_rays_d, scene_boundaries, img_wh):
84 | """Rescale scene boundaries.
85 | rays_o: (len(image_paths)*h*w, 3)
86 | rays_d: (len(image_paths)*h*w, 3)
87 | scene_boundaries: np.array(2 ,3), [min, max]
88 | img_wh: (2)
89 | """
90 | # Rescale (x, y, z) from [min, max] -> [-1, 1]
91 | # all_rays_o = all_rays_o.reshape(-1, img_wh[0], img_wh[1], 3) # (len(image_paths), h, w, 3))
92 | # all_rays_d = all_rays_d.reshape(-1, img_wh[0], img_wh[1], 3)
93 | assert all_rays_o.shape[-1] == 3, "all_rays_o should be (chunk, 3)"
94 | assert all_rays_d.shape[-1] == 3, "all_rays_d should be (chunk, 3)"
95 | old_min = torch.from_numpy(scene_boundaries[0]).to(all_rays_o.dtype).to(all_rays_o.device)
96 | old_max = torch.from_numpy(scene_boundaries[1]).to(all_rays_o.dtype).to(all_rays_o.device)
97 | new_min = torch.tensor([-1,-1,-1]).to(all_rays_o.dtype).to(all_rays_o.device)
98 | new_max = torch.tensor([1,1,1]).to(all_rays_o.dtype).to(all_rays_o.device)
99 | # scale = max(scene_boundaries[1] - scene_boundaries[0])/2
100 | # all_rays_o = (all_rays_o - torch.mean(all_rays_o, dim=-1, keepdim=True)) / scale
101 | # This is from jax3d.interp, kind of weird but true
102 | all_rays_o = ((new_min - new_max) / (old_min - old_max))*all_rays_o + (old_min * new_max - new_min * old_max) / (old_min - old_max)
103 |
104 | # We also need to rescale the camera direction by bbox.size.
105 | # The direction can be though of a ray from a point in space (the camera
106 | # origin) to another point in space (say the red light on the lego
107 | # bulldozer). When we scale the scene in a certain way, this direction
108 | # also needs to be scaled in the same way.
109 | all_rays_d = all_rays_d * 2 / (old_max - old_min)
110 | # (re)-normalize the rays
111 | all_rays_d = all_rays_d / torch.linalg.norm(all_rays_d, dim=-1, keepdims=True)
112 | return all_rays_o.reshape(-1, 3), all_rays_d.reshape(-1, 3)
113 |
114 |
115 | def calculate_near_and_far(rays_o, rays_d, bbox_min=[-1.,-1.,-1.], bbox_max=[1.,1.,1.]):
116 | '''
117 | rays_o, (len(self.split_ids)*h*w, 3)
118 | rays_d, (len(self.split_ids)*h*w, 3)
119 | bbox_min=[-1,-1,-1],
120 | bbox_max=[1,1,1]
121 | '''
122 | # map all shape to same (len(self.split_ids)*h*w, 3, 2)
123 | corners = torch.stack((torch.tensor(bbox_min),torch.tensor(bbox_max)), dim=-1).to(rays_o.dtype).to(rays_o.device)
124 | corners = corners.unsqueeze(0).repeat(rays_o.shape[0],1,1) # (len(self.split_ids)*h*w, 3, 2)
125 | corners -= torch.unsqueeze(rays_o, -1).repeat(1,1,2)
126 | intersections = (corners / (torch.unsqueeze(rays_d, -1).repeat(1,1,2)))
127 |
128 | min_intersections = torch.amax(torch.amin(intersections, dim=-1), dim=-1, keepdim=True)
129 | max_intersections = torch.amin(torch.amax(intersections, dim=-1), dim=-1, keepdim=True)
130 | epsilon = 1e-1*torch.ones_like(min_intersections)
131 | near = torch.maximum(epsilon, min_intersections)
132 | # tmp = near
133 | near = torch.where((near > max_intersections), epsilon, near)
134 | far = torch.where(near < max_intersections, max_intersections, near+epsilon)
135 |
136 | return near, far
--------------------------------------------------------------------------------
/model/UNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from inplace_abn import InPlaceABN
5 | import segmentation_models_pytorch as smp
6 |
7 | class DoubleConv(nn.Module):
8 | """(convolution => [BN] => ReLU) * 2"""
9 |
10 | def __init__(self, in_channels, out_channels, mid_channels=None):
11 | super().__init__()
12 | if not mid_channels:
13 | mid_channels = out_channels
14 | self.double_conv = nn.Sequential(
15 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
16 | InPlaceABN(mid_channels),
17 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
18 | InPlaceABN(out_channels),
19 | )
20 |
21 | def forward(self, x):
22 | return self.double_conv(x)
23 |
24 |
25 | class Down(nn.Module):
26 | """Downscaling with maxpool then double conv"""
27 |
28 | def __init__(self, in_channels, out_channels):
29 | super().__init__()
30 | self.maxpool_conv = nn.Sequential(
31 | # nn.MaxPool2d(2),
32 | nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2, bias=False),
33 | InPlaceABN(out_channels),
34 | DoubleConv(out_channels, out_channels)
35 | )
36 |
37 | def forward(self, x):
38 | return self.maxpool_conv(x)
39 |
40 |
41 | class Up(nn.Module):
42 | """Upscaling then double conv"""
43 |
44 | def __init__(self, in_channels, out_channels, bilinear=True):
45 | super().__init__()
46 | self.bilinear = bilinear
47 | # if bilinear, use the normal convolutions to reduce the number of channels
48 | if bilinear:
49 | self.up = F.interpolate
50 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
51 | else:
52 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
53 | self.conv = DoubleConv(in_channels, out_channels)
54 |
55 | def forward(self, x1, x2):
56 | if self.bilinear:
57 | x1 = self.up(x1, scale_factor=2, mode="bilinear", align_corners=True)
58 | else:
59 | x1 = self.up(x1)
60 |
61 | x = torch.cat([x2, x1], dim=1)
62 | return self.conv(x)
63 |
64 |
65 | class OutConv(nn.Module):
66 | def __init__(self, in_channels, out_channels):
67 | super(OutConv, self).__init__()
68 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
69 |
70 | def forward(self, x):
71 | return self.conv(x)
72 |
73 |
74 | class UNet(nn.Module):
75 | def __init__(self, n_channels, n_classes, bilinear=False):
76 | super(UNet, self).__init__()
77 | self.n_channels = n_channels
78 | self.n_classes = n_classes
79 | self.bilinear = bilinear
80 |
81 | self.inc = (DoubleConv(n_channels, 8))
82 | self.down1 = (Down(8, 16))
83 | self.down2 = (Down(16, 32))
84 | self.down3 = (Down(32, 64))
85 | # TODO: down four times might be too much
86 | self.down4 = (Down(64, 128 ))
87 | self.up1 = (Up(128, 64 , bilinear))
88 | self.up2 = (Up(64, 32 , bilinear))
89 | self.up3 = (Up(32, 16 , bilinear))
90 | self.up4 = (Up(16, 8, bilinear))
91 | self.outc1 = (OutConv(8, n_classes))
92 | self.outc2 = (OutConv(n_classes, n_classes))
93 |
94 |
95 |
96 | self.toplayer = nn.Conv2d(32, 32, 1)
97 | self.lat1 = nn.Conv2d(16, 32, 1)
98 | self.lat0 = nn.Conv2d(8, 32, 1)
99 |
100 | # to reduce channel size of the outputs from FPN
101 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
102 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
103 |
104 | def _upsample_add(self, x, y):
105 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y
106 |
107 | def forward(self, x):
108 | x1 = self.inc(x)
109 | x2 = self.down1(x1)
110 | x3 = self.down2(x2)
111 | x4 = self.down3(x3)
112 | x5 = self.down4(x4)
113 | x = self.up1(x5, x4)
114 | x = self.up2(x, x3)
115 | x = self.up3(x, x2)
116 | x = self.up4(x, x1)
117 | feature = self.outc1(x)
118 | logits = self.outc2(feature)
119 |
120 | # original FeatureNet used in depth estimation
121 | feat2 = self.toplayer(x3) # (B, 32, H//4, W//4)
122 | feat1 = self._upsample_add(feat2, self.lat1(x2)) # (B, 32, H//2, W//2)
123 | feat0 = self._upsample_add(feat1, self.lat0(x1)) # (B, 32, H, W)
124 |
125 | # reduce output channels
126 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
127 | feat0 = self.smooth0(feat0) # (B, 8, H, W)
128 |
129 | output = {"level_0": feat0, "level_1": feat1, "level_2": feat2, 'logits': logits, 'feature': feature}
130 |
131 | return output
132 |
133 | class smp_UNet(nn.Module):
134 | def __init__(self, n_channels, n_classes, bilinear=False):
135 | super(smp_UNet, self).__init__()
136 | self.n_channels = n_channels
137 | self.n_classes = n_classes
138 |
139 | self.model = smp.Unet(
140 | encoder_name="timm-mobilenetv3_small_minimal_100", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
141 | in_channels=n_channels, # model input channels (1 for grayscale images, 3 for RGB, etc.)
142 | classes=self.n_classes, # model output channels (number of classes in your dataset)
143 | encoder_depth=4,
144 | decoder_channels=(128, 64, 64, 32),
145 | )
146 | del self.model.encoder.model.blocks[4:]
147 | self.toplayer = nn.Conv2d(16, 32, 1)
148 | self.lat1 = nn.Conv2d(16, 32, 1)
149 | self.lat0 = nn.Conv2d(3, 32, 1)
150 |
151 | # to reduce channel size of the outputs from FPN
152 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
153 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
154 |
155 | self.projection = nn.Sequential(
156 | nn.Conv2d(self.n_classes, self.n_classes, 1),
157 | )
158 |
159 | def _upsample_add(self, x, y):
160 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y
161 |
162 | def forward(self, x):
163 | feat = self.model.encoder(x)
164 |
165 | feature = self.model.decoder(*feat)
166 | feature = self.model.segmentation_head(feature)
167 |
168 | logits = self.projection(feature)
169 |
170 | x3 = feat[2]
171 | x2 = feat[1]
172 | x1 = feat[0]
173 | # original FeatureNet used in depth estimation
174 | feat2 = self.toplayer(x3) # (B, 32, H//4, W//4)
175 | feat1 = self._upsample_add(feat2, self.lat1(x2)) # (B, 32, H//2, W//2)
176 | feat0 = self._upsample_add(feat1, self.lat0(x1)) # (B, 32, H, W)
177 |
178 | # reduce output channels
179 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
180 | feat0 = self.smooth0(feat0) # (B, 8, H, W)
181 |
182 | output = {"level_0": feat0, "level_1": feat1, "level_2": feat2, 'logits': logits, 'feature': feature}
183 |
184 | return output
--------------------------------------------------------------------------------
/utils/rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from utils.utils import normal_vect, interpolate_3D, interpolate_2D
5 |
6 |
7 | class Embedder:
8 | def __init__(self, **kwargs):
9 | self.kwargs = kwargs
10 | self.create_embedding_fn()
11 |
12 | def create_embedding_fn(self):
13 | embed_fns = []
14 |
15 | if self.kwargs["include_input"]:
16 | embed_fns.append(lambda x: x)
17 |
18 | max_freq = self.kwargs["max_freq_log2"]
19 | N_freqs = self.kwargs["num_freqs"]
20 |
21 | if self.kwargs["log_sampling"]:
22 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
23 | else:
24 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
25 | self.freq_bands = freq_bands.reshape(1, -1, 1)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs["periodic_fns"]:
29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
30 |
31 | self.embed_fns = embed_fns
32 |
33 | def embed(self, inputs):
34 | repeat = inputs.dim() - 1
35 | self.freq_bands = self.freq_bands.to(inputs.device)
36 | inputs_scaled = (
37 | inputs.unsqueeze(-2) * self.freq_bands.view(*[1] * repeat, -1, 1)
38 | ).reshape(*inputs.shape[:-1], -1)
39 | inputs_scaled = torch.cat(
40 | (inputs, torch.sin(inputs_scaled), torch.cos(inputs_scaled)), dim=-1
41 | )
42 | return inputs_scaled
43 |
44 |
45 | def get_embedder(multires=4):
46 |
47 | embed_kwargs = {
48 | "include_input": True,
49 | "max_freq_log2": multires - 1,
50 | "num_freqs": multires,
51 | "log_sampling": True,
52 | "periodic_fns": [torch.sin, torch.cos],
53 | }
54 |
55 | embedder_obj = Embedder(**embed_kwargs)
56 | embed = lambda x, eo=embedder_obj: eo.embed(x)
57 | return embed
58 |
59 |
60 | def sigma2weights(sigma):
61 | alpha = 1.0 - torch.exp(-sigma)
62 | T = torch.cumprod(
63 | torch.cat(
64 | [torch.ones(alpha.shape[0], 1).to(alpha.device), 1.0 - alpha + 1e-10], -1
65 | ),
66 | -1,
67 | )[:, :-1]
68 | weights = alpha * T
69 |
70 | return weights
71 |
72 |
73 | def volume_rendering(rgb_sigma, pts_depth):
74 | rgb = rgb_sigma[..., :3]
75 | weights = sigma2weights(rgb_sigma[..., 3])
76 |
77 | rendered_rgb = torch.sum(weights[..., None] * rgb, -2)
78 | rendered_depth = torch.sum(weights * pts_depth, -1)
79 |
80 | return rendered_rgb, rendered_depth
81 |
82 |
83 | def get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit):
84 | nb_rays = rays_pts.shape[0]
85 | ## Unit vectors from source cameras to the points on the ray
86 | dirs = normal_vect(rays_pts.unsqueeze(2) - c2ws[:, :3, 3][None, None])
87 | ## Cosine of the angle between two directions
88 | angle_cos = torch.sum(
89 | dirs * rays_dir_unit.reshape(nb_rays, 1, 1, 3), dim=-1, keepdim=True
90 | )
91 | # Cosine to Sine and approximating it as the angle (angle << 1 => sin(angle) = angle)
92 | angle = (1 - (angle_cos**2)).abs().sqrt()
93 |
94 | return angle
95 |
96 |
97 | def interpolate_pts_feats(imgs, feats_fpn, semantic_feat, feats_vol, rays_pts_ndc, padding_mode='border', use_batch_semantic_features=False):
98 | nb_views = feats_fpn.shape[1]
99 | interpolated_feats = []
100 |
101 | for i in range(nb_views):
102 | ray_feats_0 = interpolate_3D(feats_vol[f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i], padding_mode=padding_mode)
103 | ray_feats_1 = interpolate_3D(
104 | feats_vol[f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i], padding_mode=padding_mode
105 | )
106 | ray_feats_2 = interpolate_3D(
107 | feats_vol[f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i], padding_mode=padding_mode
108 | )
109 | ray_feats_fpn, ray_colors, ray_semantic_feats_fpn, ray_masks = interpolate_2D(
110 | feats_fpn[:, i], imgs[:, i], semantic_feat[:, i], rays_pts_ndc[f"level_0"][:, :, i], padding_mode=padding_mode, use_batch_semantic_features=use_batch_semantic_features
111 | )
112 | # When using only one point per ray, all features except ray masks are 2D (N_rays, C), while ray masks are 3D (N_rays, C, 1), so we need to squeeze it
113 | if ray_colors.dim() == 2 and ray_masks.dim() == 3:
114 | ray_masks = ray_masks.squeeze(-1)
115 |
116 | interpolated_feats.append(
117 | torch.cat(
118 | [
119 | ray_feats_0, # (N_rays, N_samples, 8)
120 | ray_feats_1, # (N_rays, N_samples, 8)
121 | ray_feats_2, # (N_rays, N_samples, 8)
122 | ray_feats_fpn, # (N_rays, N_samples, 8)
123 | ray_colors, # (N_rays, N_samples, 3)
124 | ray_semantic_feats_fpn, # (N_rays, N_samples, nb_classes(21)) / if use_batch_semantic_features, (N_rays, N_samples, 9*nb_classes(21))
125 | ray_masks, # (N_rays, N_samples, 1)
126 | ],
127 | dim=-1,
128 | )
129 | )
130 | interpolated_feats = torch.stack(interpolated_feats, dim=2)
131 | if torch.isnan(interpolated_feats).any():
132 | print("interpolated_feats has nan values")
133 | return interpolated_feats
134 |
135 |
136 | def get_occ_masks(depth_map_norm, rays_pts_ndc, visibility_thr=0.2):
137 | nb_views = depth_map_norm["level_0"].shape[1]
138 | z_diff = []
139 | for i in range(nb_views):
140 | ## Interpolate depth maps corresponding to each sample point
141 | # [1 H W 3] (x,y,z)
142 | grid = rays_pts_ndc[f"level_0"][None, :, :, i, :2] * 2 - 1.0
143 | rays_depths = F.grid_sample(
144 | depth_map_norm["level_0"][:, i : i + 1],
145 | grid,
146 | align_corners=True,
147 | mode="bilinear",
148 | padding_mode="border",
149 | )[0, 0]
150 | z_diff.append(rays_pts_ndc["level_0"][:, :, i, 2] - rays_depths)
151 | z_diff = torch.stack(z_diff, dim=2)
152 |
153 | occ_masks = z_diff.unsqueeze(-1) < visibility_thr
154 |
155 | return occ_masks
156 |
157 |
158 | def render_rays(
159 | c2ws,
160 | rays_pts,
161 | rays_pts_ndc,
162 | pts_depth,
163 | rays_dir,
164 | feats_vol,
165 | feats_fpn,
166 | imgs,
167 | depth_map_norm,
168 | renderer_net,
169 | middle_pts_mask,
170 | semantic_feat=None,
171 | use_batch_semantic_features=False,
172 | ):
173 | ## The angles between the ray and source camera vectors
174 | rays_dir_unit = rays_dir / torch.norm(rays_dir, dim=-1, keepdim=True)
175 | angles = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit)
176 |
177 | ## Positional encoding
178 | embedded_angles = get_embedder()(angles)
179 |
180 | ## Interpolate all features for sample points (N_rays, N_samples, source_view_num, 8+8+8+8+3+21+1)
181 | pts_feat = interpolate_pts_feats(imgs, feats_fpn, semantic_feat, feats_vol, rays_pts_ndc, use_batch_semantic_features=use_batch_semantic_features)
182 |
183 | ## Getting Occlusion Masks based on predicted depths (N_rays, N_samples, source_view_num, 1)
184 | occ_masks = get_occ_masks(depth_map_norm, rays_pts_ndc)
185 |
186 | ## rendering sigma and RGB values
187 | rgb_sigma, rendered_semantic = renderer_net(embedded_angles, pts_feat, occ_masks, middle_pts_mask)
188 |
189 | rendered_rgb, rendered_depth = volume_rendering(rgb_sigma, pts_depth)
190 |
191 | if torch.isnan(rendered_semantic).sum() > 0:
192 | print("NaN in rendered_semantic")
193 |
194 | return rendered_rgb, rendered_semantic, rendered_depth
195 |
--------------------------------------------------------------------------------
/data/get_datasets.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | import torch
23 | from torch.utils.data import ConcatDataset, WeightedRandomSampler
24 | import numpy as np
25 |
26 | from data.llff import LLFF_Dataset
27 | from data.dtu import DTU_Dataset
28 | from data.nerf import NeRF_Dataset
29 | from data.klevr import KlevrDataset
30 | from data.scannet import RendererDataset
31 |
32 | def get_training_dataset(args, downsample=1.0):
33 |
34 | if args.dataset_name == "klevr":
35 | train_dataset = KlevrDataset(
36 | root_dir=args.klevr_path,
37 | split="train",
38 | max_len=-1,
39 | downSample=downsample,
40 | nb_views=args.nb_views,
41 | get_semantic=args.segmentation,
42 | )
43 | train_sampler = None
44 | return train_dataset, train_sampler
45 |
46 | elif args.dataset_name == "scannet":
47 | if args.finetune:
48 | cfg = {
49 | 'val_database_name': args.fintune_scene,
50 | 'min_wn': args.nb_views, 'max_wn': args.nb_views+1,
51 | 'type2sample_weights': {'scannet_single': 1},
52 | 'train_database_types': ['scannet_single'],
53 | }
54 | else:
55 | cfg = {'min_wn': args.nb_views, 'max_wn': args.nb_views+1}
56 | train_dataset = RendererDataset(
57 | root_dir=args.scannet_path,
58 | is_train=True,
59 | cfg=cfg,
60 | )
61 | train_sampler = None
62 | return train_dataset, train_sampler
63 | elif args.dataset_name == "replica":
64 | cfg = {"resolution_type": "hr", "type2sample_weights": {"replica": 1}, "train_database_types": ['replica'], 'min_wn': args.nb_views, 'max_wn': args.nb_views+1}
65 | train_dataset = RendererDataset(
66 | cfg=cfg,
67 | root_dir=args.replica_path,
68 | is_train=True
69 | )
70 | train_sampler = None
71 | return train_dataset, train_sampler
72 |
73 | train_datasets = [
74 | DTU_Dataset(
75 | original_root_dir=args.dtu_path,
76 | preprocessed_root_dir=args.dtu_pre_path,
77 | split="train",
78 | max_len=-1,
79 | downSample=downsample,
80 | nb_views=args.nb_views,
81 | ),
82 | LLFF_Dataset(
83 | root_dir=args.ibrnet1_path,
84 | split="train",
85 | max_len=-1,
86 | downSample=downsample,
87 | nb_views=args.nb_views,
88 | imgs_folder_name="images",
89 | ),
90 | LLFF_Dataset(
91 | root_dir=args.ibrnet2_path,
92 | split="train",
93 | max_len=-1,
94 | downSample=downsample,
95 | nb_views=args.nb_views,
96 | imgs_folder_name="images",
97 | ),
98 | LLFF_Dataset(
99 | root_dir=args.llff_path,
100 | split="train",
101 | max_len=-1,
102 | downSample=downsample,
103 | nb_views=args.nb_views,
104 | imgs_folder_name="images_4",
105 | ),
106 | ]
107 | weights = [0.5, 0.22, 0.12, 0.16]
108 |
109 | train_weights_samples = []
110 | for dataset, weight in zip(train_datasets, weights):
111 | num_samples = len(dataset)
112 | weight_each_sample = weight / num_samples
113 | train_weights_samples.extend([weight_each_sample] * num_samples)
114 |
115 | train_dataset = ConcatDataset(train_datasets)
116 | train_weights = torch.from_numpy(np.array(train_weights_samples))
117 | train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
118 |
119 | return train_dataset, train_sampler
120 |
121 |
122 | def get_finetuning_dataset(args, downsample=1.0):
123 | if args.dataset_name == "dtu":
124 | train_dataset = DTU_Dataset(
125 | original_root_dir=args.dtu_path,
126 | preprocessed_root_dir=args.dtu_pre_path,
127 | split="train",
128 | max_len=-1,
129 | downSample=downsample,
130 | nb_views=args.nb_views,
131 | scene=args.scene,
132 | )
133 | elif args.dataset_name == "llff":
134 | train_dataset = LLFF_Dataset(
135 | root_dir=args.llff_path,
136 | split="train",
137 | max_len=-1,
138 | downSample=downsample,
139 | nb_views=args.nb_views,
140 | scene=args.scene,
141 | imgs_folder_name="images_4",
142 | )
143 | elif args.dataset_name == "nerf":
144 | train_dataset = NeRF_Dataset(
145 | root_dir=args.nerf_path,
146 | split="train",
147 | max_len=-1,
148 | downSample=downsample,
149 | nb_views=args.nb_views,
150 | scene=args.scene,
151 | )
152 |
153 | train_sampler = None
154 |
155 | return train_dataset, train_sampler
156 |
157 |
158 | def get_validation_dataset(args, downsample=1.0):
159 | if args.scene == "None":
160 | max_len = 10
161 | else:
162 | max_len = -1
163 |
164 | if args.dataset_name == "dtu":
165 | val_dataset = DTU_Dataset(
166 | original_root_dir=args.dtu_path,
167 | preprocessed_root_dir=args.dtu_pre_path,
168 | split="val",
169 | max_len=max_len,
170 | downSample=downsample,
171 | nb_views=args.nb_views,
172 | scene=args.scene,
173 | )
174 | elif args.dataset_name == "llff":
175 | val_dataset = LLFF_Dataset(
176 | root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path,
177 | split="val",
178 | max_len=max_len,
179 | downSample=downsample,
180 | nb_views=args.nb_views,
181 | scene=args.scene,
182 | imgs_folder_name="images_4",
183 | )
184 | elif args.dataset_name == "nerf":
185 | val_dataset = NeRF_Dataset(
186 | root_dir=args.nerf_path,
187 | split="val",
188 | max_len=max_len,
189 | downSample=downsample,
190 | nb_views=args.nb_views,
191 | scene=args.scene,
192 | )
193 | elif args.dataset_name == "klevr":
194 | val_dataset = KlevrDataset(
195 | root_dir=args.klevr_path,
196 | split="val",
197 | max_len=max_len,
198 | downSample=downsample,
199 | nb_views=args.nb_views,
200 | get_semantic=args.segmentation,
201 | )
202 | elif args.dataset_name == "scannet":
203 | val_set_list, val_set_names = [], []
204 | if args.finetune:
205 | val_cfg = {'val_database_name': args.fintune_scene,'min_wn': args.nb_views, 'max_wn': args.nb_views+1}
206 | val_set = RendererDataset(cfg=val_cfg, is_train=False, root_dir=args.scannet_path)
207 | return val_set
208 | if isinstance(args.val_set_list, str):
209 | val_scenes = np.loadtxt(args.val_set_list, dtype=str).tolist()
210 | for name in val_scenes:
211 | val_cfg = {'val_database_name': name,'min_wn': args.nb_views, 'max_wn': args.nb_views+1}
212 | val_set = RendererDataset(cfg=val_cfg, is_train=False, root_dir=args.scannet_path)
213 | val_set_list.append(val_set)
214 | val_set_names.append(name)
215 | print(f'{name} val set len {len(val_set)}')
216 | # print("only one scene")
217 | # break
218 | val_dataset = ConcatDataset(val_set_list)
219 | elif args.dataset_name == "replica":
220 | val_set_list, val_set_names = [], []
221 | if isinstance(args.val_set_list, str):
222 | val_scenes = np.loadtxt(args.val_set_list, dtype=str).tolist()
223 | for name in val_scenes:
224 | val_cfg = {'val_database_name': name,'min_wn': args.nb_views, 'max_wn': args.nb_views+1}
225 | val_set = RendererDataset(cfg=val_cfg, is_train=False, root_dir=args.replica_path)
226 | val_set_list.append(val_set)
227 | val_set_names.append(name)
228 | print(f'{name} val set len {len(val_set)}')
229 | val_dataset = ConcatDataset(val_set_list)
230 | else:
231 | raise NotImplementedError
232 |
233 | return val_dataset
234 |
--------------------------------------------------------------------------------
/data/nerf.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | from torch.utils.data import Dataset
48 | from torchvision import transforms as T
49 |
50 | import os
51 | import json
52 | import numpy as np
53 | from PIL import Image
54 |
55 | from utils.utils import get_nearest_pose_ids
56 |
57 | class NeRF_Dataset(Dataset):
58 | def __init__(
59 | self,
60 | root_dir,
61 | split,
62 | nb_views,
63 | downSample=1.0,
64 | max_len=-1,
65 | scene="None",
66 | ):
67 | self.root_dir = root_dir
68 | self.split = split
69 | self.nb_views = nb_views
70 | self.scene = scene
71 |
72 | self.downsample = downSample
73 | self.max_len = max_len
74 |
75 | self.img_wh = (int(800 * self.downsample), int(800 * self.downsample))
76 |
77 | self.define_transforms()
78 | self.blender2opencv = np.array(
79 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
80 | )
81 |
82 | self.build_metas()
83 |
84 | def define_transforms(self):
85 | self.transform = T.ToTensor()
86 |
87 | self.src_transform = T.Compose(
88 | [
89 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90 | ]
91 | )
92 |
93 | def build_metas(self):
94 | self.meta = {}
95 | with open(
96 | os.path.join(self.root_dir, self.scene, "transforms_train.json"), "r"
97 | ) as f:
98 | self.meta["train"] = json.load(f)
99 |
100 | with open(
101 | os.path.join(self.root_dir, self.scene, "transforms_test.json"), "r"
102 | ) as f:
103 | self.meta["val"] = json.load(f)
104 |
105 | w, h = self.img_wh
106 |
107 | # original focal length
108 | focal = 0.5 * 800 / np.tan(0.5 * self.meta["train"]["camera_angle_x"])
109 |
110 | # modify focal length to match size self.img_wh
111 | focal *= self.img_wh[0] / 800
112 |
113 | self.near_far = np.array([2.0, 6.0])
114 |
115 | self.image_paths = {"train": [], "val": []}
116 | self.c2ws = {"train": [], "val": []}
117 | self.w2cs = {"train": [], "val": []}
118 | self.intrinsics = {"train": [], "val": []}
119 |
120 | for frame in self.meta["train"]["frames"]:
121 | self.image_paths["train"].append(
122 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png")
123 | )
124 |
125 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv
126 | w2c = np.linalg.inv(c2w)
127 | self.c2ws["train"].append(c2w)
128 | self.w2cs["train"].append(w2c)
129 |
130 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]])
131 | self.intrinsics["train"].append(intrinsic.copy())
132 |
133 | self.c2ws["train"] = np.stack(self.c2ws["train"], axis=0)
134 | self.w2cs["train"] = np.stack(self.w2cs["train"], axis=0)
135 | self.intrinsics["train"] = np.stack(self.intrinsics["train"], axis=0)
136 |
137 | for frame in self.meta["val"]["frames"]:
138 | self.image_paths["val"].append(
139 | os.path.join(self.root_dir, self.scene, f"{frame['file_path']}.png")
140 | )
141 |
142 | c2w = np.array(frame["transform_matrix"]) @ self.blender2opencv
143 | w2c = np.linalg.inv(c2w)
144 | self.c2ws["val"].append(c2w)
145 | self.w2cs["val"].append(w2c)
146 |
147 | intrinsic = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]])
148 | self.intrinsics["val"].append(intrinsic.copy())
149 |
150 | self.c2ws["val"] = np.stack(self.c2ws["val"], axis=0)
151 | self.w2cs["val"] = np.stack(self.w2cs["val"], axis=0)
152 | self.intrinsics["val"] = np.stack(self.intrinsics["val"], axis=0)
153 |
154 | def __len__(self):
155 | return len(self.image_paths[self.split]) if self.max_len <= 0 else self.max_len
156 |
157 | def __getitem__(self, idx):
158 | target_frame = self.meta[self.split]["frames"][idx]
159 | c2w = np.array(target_frame["transform_matrix"]) @ self.blender2opencv
160 | w2c = np.linalg.inv(c2w)
161 |
162 | if self.split == "train":
163 | src_views = get_nearest_pose_ids(
164 | c2w,
165 | ref_poses=self.c2ws["train"],
166 | num_select=self.nb_views + 1,
167 | angular_dist_method="dist",
168 | )[1:]
169 | else:
170 | src_views = get_nearest_pose_ids(
171 | c2w,
172 | ref_poses=self.c2ws["train"],
173 | num_select=self.nb_views,
174 | angular_dist_method="dist",
175 | )
176 |
177 | imgs, depths, depths_h, depths_aug = [], [], [], []
178 | intrinsics, w2cs, c2ws, near_fars = [], [], [], []
179 | affine_mats, affine_mats_inv = [], []
180 |
181 | w, h = self.img_wh
182 |
183 | for vid in src_views:
184 | img_filename = self.image_paths["train"][vid]
185 | img = Image.open(img_filename)
186 | if img.size != (w, h):
187 | img = img.resize((w, h), Image.BICUBIC)
188 |
189 | img = self.transform(img)
190 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
191 | imgs.append(self.src_transform(img))
192 |
193 | intrinsic = self.intrinsics["train"][vid]
194 | intrinsics.append(intrinsic)
195 |
196 | w2c = self.w2cs["train"][vid]
197 | w2cs.append(w2c)
198 | c2ws.append(self.c2ws["train"][vid])
199 |
200 | aff = []
201 | aff_inv = []
202 | for l in range(3):
203 | proj_mat_l = np.eye(4)
204 | intrinsic_temp = intrinsic.copy()
205 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
206 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
207 | aff.append(proj_mat_l.copy())
208 | aff_inv.append(np.linalg.inv(proj_mat_l))
209 | aff = np.stack(aff, axis=-1)
210 | aff_inv = np.stack(aff_inv, axis=-1)
211 |
212 | affine_mats.append(aff)
213 | affine_mats_inv.append(aff_inv)
214 |
215 | near_fars.append(self.near_far)
216 |
217 | depths_h.append(np.zeros([h, w]))
218 | depths.append(np.zeros([h // 4, w // 4]))
219 | depths_aug.append(np.zeros([h // 4, w // 4]))
220 |
221 | ## Adding target data
222 | img_filename = self.image_paths[self.split][idx]
223 | img = Image.open(img_filename)
224 | if img.size != (w, h):
225 | img = img.resize((w, h), Image.BICUBIC)
226 |
227 | img = self.transform(img) # (4, h, w)
228 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
229 | imgs.append(self.src_transform(img))
230 |
231 | intrinsic = self.intrinsics[self.split][idx]
232 | intrinsics.append(intrinsic)
233 |
234 | w2c = self.w2cs[self.split][idx]
235 | w2cs.append(w2c)
236 | c2ws.append(self.c2ws[self.split][idx])
237 |
238 | near_fars.append(self.near_far)
239 |
240 | depths_h.append(np.zeros([h, w]))
241 | depths.append(np.zeros([h // 4, w // 4]))
242 | depths_aug.append(np.zeros([h // 4, w // 4]))
243 |
244 | ## Stacking
245 | imgs = np.stack(imgs)
246 | depths = np.stack(depths)
247 | depths_h = np.stack(depths_h)
248 | depths_aug = np.stack(depths_aug)
249 | affine_mats = np.stack(affine_mats)
250 | affine_mats_inv = np.stack(affine_mats_inv)
251 | intrinsics = np.stack(intrinsics)
252 | w2cs = np.stack(w2cs)
253 | c2ws = np.stack(c2ws)
254 | near_fars = np.stack(near_fars)
255 |
256 | closest_idxs = []
257 | for pose in c2ws[:-1]:
258 | closest_idxs.append(
259 | get_nearest_pose_ids(
260 | pose, ref_poses=c2ws[:-1], num_select=5, angular_dist_method="dist"
261 | )
262 | )
263 | closest_idxs = np.stack(closest_idxs, axis=0)
264 |
265 | sample = {}
266 | sample["images"] = imgs
267 | sample["depths"] = depths
268 | sample["depths_h"] = depths_h
269 | sample["depths_aug"] = depths_aug
270 | sample["w2cs"] = w2cs.astype("float32")
271 | sample["c2ws"] = c2ws.astype("float32")
272 | sample["near_fars"] = near_fars
273 | sample["affine_mats"] = affine_mats
274 | sample["affine_mats_inv"] = affine_mats_inv
275 | sample["intrinsics"] = intrinsics.astype("float32")
276 | sample["closest_idxs"] = closest_idxs
277 |
278 | return sample
279 |
--------------------------------------------------------------------------------
/data/llff.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | from torch.utils.data import Dataset
48 | from torchvision import transforms as T
49 |
50 | import os
51 | import glob
52 | import numpy as np
53 | from PIL import Image
54 |
55 | from utils.utils import get_nearest_pose_ids
56 |
57 | def normalize(v):
58 | return v / np.linalg.norm(v)
59 |
60 |
61 | def average_poses(poses):
62 | # 1. Compute the center
63 | center = poses[..., 3].mean(0) # (3)
64 |
65 | # 2. Compute the z axis
66 | z = normalize(poses[..., 2].mean(0)) # (3)
67 |
68 | # 3. Compute axis y' (no need to normalize as it's not the final output)
69 | y_ = poses[..., 1].mean(0) # (3)
70 |
71 | # 4. Compute the x axis
72 | x = normalize(np.cross(y_, z)) # (3)
73 |
74 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
75 | y = np.cross(z, x) # (3)
76 |
77 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
78 |
79 | return pose_avg
80 |
81 |
82 | def center_poses(poses, blender2opencv):
83 | pose_avg = average_poses(poses) # (3, 4)
84 | pose_avg_homo = np.eye(4)
85 |
86 | # convert to homogeneous coordinate for faster computation
87 | # by simply adding 0, 0, 0, 1 as the last row
88 | pose_avg_homo[:3] = pose_avg
89 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
90 |
91 | # (N_images, 4, 4) homogeneous coordinate
92 | poses_homo = np.concatenate([poses, last_row], 1)
93 |
94 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
95 | poses_centered = poses_centered @ blender2opencv
96 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
97 |
98 | return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv
99 |
100 |
101 | class LLFF_Dataset(Dataset):
102 | def __init__(
103 | self,
104 | root_dir,
105 | split,
106 | nb_views,
107 | downSample=1.0,
108 | max_len=-1,
109 | scene="None",
110 | imgs_folder_name="images",
111 | ):
112 | self.root_dir = root_dir
113 | self.split = split
114 | self.nb_views = nb_views
115 | self.scene = scene
116 | self.imgs_folder_name = imgs_folder_name
117 |
118 | self.downsample = downSample
119 | self.max_len = max_len
120 | self.img_wh = (int(960 * self.downsample), int(720 * self.downsample))
121 |
122 | self.define_transforms()
123 | self.blender2opencv = np.array(
124 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
125 | )
126 |
127 | self.build_metas()
128 |
129 | def define_transforms(self):
130 | self.transform = T.Compose(
131 | [
132 | T.ToTensor(),
133 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
134 | ]
135 | )
136 |
137 | def build_metas(self):
138 | if self.scene != "None":
139 | self.scans = [
140 | os.path.basename(scan_dir)
141 | for scan_dir in sorted(
142 | glob.glob(os.path.join(self.root_dir, self.scene))
143 | )
144 | ]
145 | else:
146 | self.scans = [
147 | os.path.basename(scan_dir)
148 | for scan_dir in sorted(glob.glob(os.path.join(self.root_dir, "*")))
149 | ]
150 |
151 | self.meta = []
152 | self.image_paths = {}
153 | self.near_far = {}
154 | self.id_list = {}
155 | self.closest_idxs = {}
156 | self.c2ws = {}
157 | self.w2cs = {}
158 | self.intrinsics = {}
159 | self.affine_mats = {}
160 | self.affine_mats_inv = {}
161 | for scan in self.scans:
162 | self.image_paths[scan] = sorted(
163 | glob.glob(os.path.join(self.root_dir, scan, self.imgs_folder_name, "*"))
164 | )
165 | poses_bounds = np.load(
166 | os.path.join(self.root_dir, scan, "poses_bounds.npy")
167 | ) # (N_images, 17)
168 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
169 | bounds = poses_bounds[:, -2:] # (N_images, 2)
170 |
171 | # Step 1: rescale focal length according to training resolution
172 | H, W, focal = poses[0, :, -1] # original intrinsics, same for all images
173 |
174 | focal = [focal * self.img_wh[0] / W, focal * self.img_wh[1] / H]
175 |
176 | # Step 2: correct poses
177 | poses = np.concatenate(
178 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1
179 | )
180 | poses, _ = center_poses(poses, self.blender2opencv)
181 | # poses = poses @ self.blender2opencv
182 |
183 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
184 | near_original = bounds.min()
185 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
186 | bounds /= scale_factor
187 | poses[..., 3] /= scale_factor
188 |
189 | self.near_far[scan] = bounds.astype('float32')
190 |
191 | num_viewpoint = len(self.image_paths[scan])
192 | val_ids = [idx for idx in range(0, num_viewpoint, 8)]
193 | w, h = self.img_wh
194 |
195 | self.id_list[scan] = []
196 | self.closest_idxs[scan] = []
197 | self.c2ws[scan] = []
198 | self.w2cs[scan] = []
199 | self.intrinsics[scan] = []
200 | self.affine_mats[scan] = []
201 | self.affine_mats_inv[scan] = []
202 | for idx in range(num_viewpoint):
203 | if (
204 | (self.split == "val" and idx in val_ids)
205 | or (
206 | self.split == "train"
207 | and self.scene != "None"
208 | and idx not in val_ids
209 | )
210 | or (self.split == "train" and self.scene == "None")
211 | ):
212 | self.meta.append({"scan": scan, "target_idx": idx})
213 |
214 | view_ids = get_nearest_pose_ids(
215 | poses[idx, :, :],
216 | ref_poses=poses[..., :],
217 | num_select=self.nb_views + 1,
218 | angular_dist_method="dist",
219 | )
220 |
221 | self.id_list[scan].append(view_ids)
222 |
223 | closest_idxs = []
224 | source_views = view_ids[1:]
225 | for vid in source_views:
226 | closest_idxs.append(
227 | get_nearest_pose_ids(
228 | poses[vid, :, :],
229 | ref_poses=poses[source_views],
230 | num_select=5,
231 | angular_dist_method="dist",
232 | )
233 | )
234 | self.closest_idxs[scan].append(np.stack(closest_idxs, axis=0))
235 |
236 | c2w = np.eye(4).astype('float32')
237 | c2w[:3] = poses[idx]
238 | w2c = np.linalg.inv(c2w)
239 | self.c2ws[scan].append(c2w)
240 | self.w2cs[scan].append(w2c)
241 |
242 | intrinsic = np.array([[focal[0], 0, w / 2], [0, focal[1], h / 2], [0, 0, 1]]).astype('float32')
243 | self.intrinsics[scan].append(intrinsic)
244 |
245 | def __len__(self):
246 | return len(self.meta) if self.max_len <= 0 else self.max_len
247 |
248 | def __getitem__(self, idx):
249 | if self.split == "train" and self.scene == "None":
250 | noisy_factor = float(np.random.choice([1.0, 0.75, 0.5], 1))
251 | close_views = int(np.random.choice([3, 4, 5], 1))
252 | else:
253 | noisy_factor = 1.0
254 | close_views = 5
255 |
256 | scan = self.meta[idx]["scan"]
257 | target_idx = self.meta[idx]["target_idx"]
258 |
259 | view_ids = self.id_list[scan][target_idx]
260 | target_view = view_ids[0]
261 | src_views = view_ids[1:]
262 | view_ids = [vid for vid in src_views] + [target_view]
263 |
264 | closest_idxs = self.closest_idxs[scan][target_idx][:, :close_views]
265 |
266 | imgs, depths, depths_h, depths_aug = [], [], [], []
267 | intrinsics, w2cs, c2ws, near_fars = [], [], [], []
268 | affine_mats, affine_mats_inv = [], []
269 |
270 | w, h = self.img_wh
271 | w, h = int(w * noisy_factor), int(h * noisy_factor)
272 |
273 | for vid in view_ids:
274 | img_filename = self.image_paths[scan][vid]
275 | img = Image.open(img_filename).convert("RGB")
276 | if img.size != (w, h):
277 | img = img.resize((w, h), Image.BICUBIC)
278 | img = self.transform(img)
279 | imgs.append(img)
280 |
281 | intrinsic = self.intrinsics[scan][vid].copy()
282 | intrinsic[:2] = intrinsic[:2] * noisy_factor
283 | intrinsics.append(intrinsic)
284 |
285 | w2c = self.w2cs[scan][vid]
286 | w2cs.append(w2c)
287 | c2ws.append(self.c2ws[scan][vid])
288 |
289 | aff = []
290 | aff_inv = []
291 | for l in range(3):
292 | proj_mat_l = np.eye(4)
293 | intrinsic_temp = intrinsic.copy()
294 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
295 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
296 | aff.append(proj_mat_l.copy())
297 | aff_inv.append(np.linalg.inv(proj_mat_l))
298 | aff = np.stack(aff, axis=-1)
299 | aff_inv = np.stack(aff_inv, axis=-1)
300 |
301 | affine_mats.append(aff)
302 | affine_mats_inv.append(aff_inv)
303 |
304 | near_fars.append(self.near_far[scan][vid])
305 |
306 | depths_h.append(np.zeros([h, w]))
307 | depths.append(np.zeros([h // 4, w // 4]))
308 | depths_aug.append(np.zeros([h // 4, w // 4]))
309 |
310 | imgs = np.stack(imgs)
311 | depths = np.stack(depths)
312 | depths_h = np.stack(depths_h)
313 | depths_aug = np.stack(depths_aug)
314 | affine_mats = np.stack(affine_mats)
315 | affine_mats_inv = np.stack(affine_mats_inv)
316 | intrinsics = np.stack(intrinsics)
317 | w2cs = np.stack(w2cs)
318 | c2ws = np.stack(c2ws)
319 | near_fars = np.stack(near_fars)
320 |
321 | sample = {}
322 | sample["images"] = imgs
323 | sample["depths"] = depths
324 | sample["depths_h"] = depths_h
325 | sample["depths_aug"] = depths_aug
326 | sample["w2cs"] = w2cs
327 | sample["c2ws"] = c2ws
328 | sample["near_fars"] = near_fars
329 | sample["affine_mats"] = affine_mats
330 | sample["affine_mats_inv"] = affine_mats_inv
331 | sample["intrinsics"] = intrinsics
332 | sample["closest_idxs"] = closest_idxs
333 |
334 | return sample
335 |
--------------------------------------------------------------------------------
/data/dtu.py:
--------------------------------------------------------------------------------
1 | # GeoNeRF is a generalizable NeRF model that renders novel views
2 | # without requiring per-scene optimization. This software is the
3 | # implementation of the paper "GeoNeRF: Generalizing NeRF with
4 | # Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
5 | # and Francois Fleuret.
6 |
7 | # Copyright (c) 2022 ams International AG
8 |
9 | # This file is part of GeoNeRF.
10 | # GeoNeRF is free software: you can redistribute it and/or modify
11 | # it under the terms of the GNU General Public License version 3 as
12 | # published by the Free Software Foundation.
13 |
14 | # GeoNeRF is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 |
19 | # You should have received a copy of the GNU General Public License
20 | # along with GeoNeRF. If not, see .
21 |
22 | # This file incorporates work covered by the following copyright and
23 | # permission notice:
24 |
25 | # MIT License
26 |
27 | # Copyright (c) 2021 apchenstu
28 |
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 |
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 |
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | from torch.utils.data import Dataset
48 | from torchvision import transforms as T
49 |
50 | import os
51 | import cv2
52 | import numpy as np
53 | from PIL import Image
54 |
55 | from utils.utils import read_pfm, get_nearest_pose_ids
56 |
57 | class DTU_Dataset(Dataset):
58 | def __init__(
59 | self,
60 | original_root_dir,
61 | preprocessed_root_dir,
62 | split,
63 | nb_views,
64 | downSample=1.0,
65 | max_len=-1,
66 | scene="None",
67 | ):
68 | self.original_root_dir = original_root_dir
69 | self.preprocessed_root_dir = preprocessed_root_dir
70 | self.split = split
71 | self.scene = scene
72 |
73 | self.downSample = downSample
74 | self.scale_factor = 1.0 / 200
75 | self.interval_scale = 1.06
76 | self.max_len = max_len
77 | self.nb_views = nb_views
78 |
79 | self.build_metas()
80 | self.build_proj_mats()
81 | self.define_transforms()
82 |
83 | def define_transforms(self):
84 | self.transform = T.Compose(
85 | [
86 | T.ToTensor(),
87 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
88 | ]
89 | )
90 |
91 | def build_metas(self):
92 | self.metas = []
93 | with open(f"configs/lists/dtu_{self.split}_all.txt") as f:
94 | self.scans = [line.rstrip() for line in f.readlines()]
95 | if self.scene != "None":
96 | self.scans = [self.scene]
97 |
98 | # light conditions 2-5 for training
99 | # light condition 3 for testing (the brightest?)
100 | light_idxs = (
101 | [3] if "train" != self.split or self.scene != "None" else range(2, 5)
102 | )
103 |
104 | self.id_list = []
105 |
106 | if self.split == "train":
107 | if self.scene == "None":
108 | pair_file = f"configs/lists/dtu_pairs.txt"
109 | else:
110 | pair_file = f"configs/lists/dtu_pairs_ft.txt"
111 | else:
112 | pair_file = f"configs/lists/dtu_pairs_val.txt"
113 |
114 | for scan in self.scans:
115 | with open(pair_file) as f:
116 | num_viewpoint = int(f.readline())
117 | for _ in range(num_viewpoint):
118 | ref_view = int(f.readline().rstrip())
119 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
120 | for light_idx in light_idxs:
121 | self.metas += [(scan, light_idx, ref_view, src_views)]
122 | self.id_list.append([ref_view] + src_views)
123 |
124 | self.id_list = np.unique(self.id_list)
125 | self.build_remap()
126 |
127 | def build_proj_mats(self):
128 | near_fars, intrinsics, world2cams, cam2worlds = [], [], [], []
129 | for vid in self.id_list:
130 | proj_mat_filename = os.path.join(
131 | self.preprocessed_root_dir, f"Cameras/train/{vid:08d}_cam.txt"
132 | )
133 | intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
134 | intrinsic[:2] *= 4
135 | extrinsic[:3, 3] *= self.scale_factor
136 |
137 | intrinsic[:2] = intrinsic[:2] * self.downSample
138 | intrinsics += [intrinsic.copy()]
139 |
140 | near_fars += [near_far]
141 | world2cams += [extrinsic]
142 | cam2worlds += [np.linalg.inv(extrinsic)]
143 |
144 | self.near_fars, self.intrinsics = np.stack(near_fars), np.stack(intrinsics)
145 | self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds)
146 |
147 | def read_cam_file(self, filename):
148 | with open(filename) as f:
149 | lines = [line.rstrip() for line in f.readlines()]
150 | # extrinsics: line [1,5), 4x4 matrix
151 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ")
152 | extrinsics = extrinsics.reshape((4, 4))
153 | # intrinsics: line [7-10), 3x3 matrix
154 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ")
155 | intrinsics = intrinsics.reshape((3, 3))
156 | # depth_min & depth_interval: line 11
157 | depth_min, depth_interval = lines[11].split()
158 | depth_min = float(depth_min) * self.scale_factor
159 | depth_max = depth_min + float(depth_interval) * 192 * self.interval_scale * self.scale_factor
160 |
161 | intrinsics[0, 2] = intrinsics[0, 2] + 80.0 / 4.0
162 | intrinsics[1, 2] = intrinsics[1, 2] + 44.0 / 4.0
163 | intrinsics[:2] = intrinsics[:2]
164 |
165 | return intrinsics, extrinsics, [depth_min, depth_max]
166 |
167 | def read_depth(self, filename, far_bound, noisy_factor=1.0):
168 | depth_h = self.scale_factor * np.array(
169 | read_pfm(filename)[0], dtype=np.float32
170 | )
171 | depth_h = cv2.resize(
172 | depth_h, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST
173 | )
174 |
175 | depth_h = cv2.resize(
176 | depth_h,
177 | None,
178 | fx=self.downSample * noisy_factor,
179 | fy=self.downSample * noisy_factor,
180 | interpolation=cv2.INTER_NEAREST,
181 | )
182 |
183 | ## Exclude points beyond the bounds
184 | depth_h[depth_h > far_bound * 0.95] = 0.0
185 |
186 | depth = {}
187 | for l in range(3):
188 | depth[f"level_{l}"] = cv2.resize(
189 | depth_h,
190 | None,
191 | fx=1.0 / (2**l),
192 | fy=1.0 / (2**l),
193 | interpolation=cv2.INTER_NEAREST,
194 | )
195 |
196 | if self.split == "train":
197 | cutout = np.ones_like(depth[f"level_2"])
198 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1))
199 | h1 = int(
200 | np.random.randint(
201 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1
202 | )
203 | )
204 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1))
205 | w1 = int(
206 | np.random.randint(
207 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1
208 | )
209 | )
210 | cutout[h0:h1, w0:w1] = 0
211 | depth_aug = depth[f"level_2"] * cutout
212 | else:
213 | depth_aug = depth[f"level_2"].copy()
214 |
215 | return depth, depth_h, depth_aug
216 |
217 | def build_remap(self):
218 | self.remap = np.zeros(np.max(self.id_list) + 1).astype("int")
219 | for i, item in enumerate(self.id_list):
220 | self.remap[item] = i
221 |
222 | def __len__(self):
223 | return len(self.metas) if self.max_len <= 0 else self.max_len
224 |
225 | def __getitem__(self, idx):
226 | if self.split == "train" and self.scene == "None":
227 | noisy_factor = float(np.random.choice([1.0, 0.5], 1))
228 | close_views = int(np.random.choice([3, 4, 5], 1))
229 | else:
230 | noisy_factor = 1.0
231 | close_views = 5
232 |
233 | scan, light_idx, target_view, src_views = self.metas[idx]
234 | view_ids = src_views[:self.nb_views] + [target_view]
235 |
236 | affine_mats, affine_mats_inv = [], []
237 | imgs, depths_h, depths_aug = [], [], []
238 | depths = {"level_0": [], "level_1": [], "level_2": []}
239 | intrinsics, w2cs, c2ws, near_fars = [], [], [], []
240 |
241 | for vid in view_ids:
242 | # Note that the id in image file names is from 1 to 49 (not 0~48)
243 | img_filename = os.path.join(
244 | self.original_root_dir,
245 | f"Rectified/{scan}/rect_{vid + 1:03d}_{light_idx}_r5000.png",
246 | )
247 | depth_filename = os.path.join(
248 | self.preprocessed_root_dir, f"Depths/{scan}/depth_map_{vid:04d}.pfm"
249 | )
250 | img = Image.open(img_filename)
251 | img_wh = np.round(
252 | np.array(img.size) / 2.0 * self.downSample * noisy_factor
253 | ).astype("int")
254 | img = img.resize(img_wh, Image.BICUBIC)
255 | img = self.transform(img)
256 | imgs += [img]
257 |
258 | index_mat = self.remap[vid]
259 |
260 | intrinsic = self.intrinsics[index_mat].copy()
261 | intrinsic[:2] = intrinsic[:2] * noisy_factor
262 | intrinsics.append(intrinsic)
263 |
264 | w2c = self.world2cams[index_mat]
265 | w2cs.append(w2c)
266 | c2ws.append(self.cam2worlds[index_mat])
267 |
268 | aff = []
269 | aff_inv = []
270 | for l in range(3):
271 | proj_mat_l = np.eye(4)
272 | intrinsic_temp = intrinsic.copy()
273 | intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
274 | proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
275 | aff.append(proj_mat_l.copy())
276 | aff_inv.append(np.linalg.inv(proj_mat_l))
277 | aff = np.stack(aff, axis=-1)
278 | aff_inv = np.stack(aff_inv, axis=-1)
279 |
280 | affine_mats.append(aff)
281 | affine_mats_inv.append(aff_inv)
282 |
283 | near_far = self.near_fars[index_mat]
284 |
285 | depth, depth_h, depth_aug = self.read_depth(
286 | depth_filename, near_far[1], noisy_factor
287 | )
288 |
289 | depths["level_0"].append(depth["level_0"])
290 | depths["level_1"].append(depth["level_1"])
291 | depths["level_2"].append(depth["level_2"])
292 | depths_h.append(depth_h)
293 | depths_aug.append(depth_aug)
294 |
295 | near_fars.append(near_far)
296 |
297 | imgs = np.stack(imgs)
298 | depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug)
299 | depths["level_0"] = np.stack(depths["level_0"])
300 | depths["level_1"] = np.stack(depths["level_1"])
301 | depths["level_2"] = np.stack(depths["level_2"])
302 | affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv)
303 | intrinsics = np.stack(intrinsics)
304 | w2cs = np.stack(w2cs)
305 | c2ws = np.stack(c2ws)
306 | near_fars = np.stack(near_fars)
307 |
308 | closest_idxs = []
309 | for pose in c2ws[:-1]:
310 | closest_idxs.append(
311 | get_nearest_pose_ids(
312 | pose,
313 | ref_poses=c2ws[:-1],
314 | num_select=close_views,
315 | angular_dist_method="dist",
316 | )
317 | )
318 | closest_idxs = np.stack(closest_idxs, axis=0)
319 |
320 | sample = {}
321 | sample["images"] = imgs
322 | sample["depths"] = depths
323 | sample["depths_h"] = depths_h
324 | sample["depths_aug"] = depths_aug
325 | sample["w2cs"] = w2cs
326 | sample["c2ws"] = c2ws
327 | sample["near_fars"] = near_fars
328 | sample["intrinsics"] = intrinsics
329 | sample["affine_mats"] = affine_mats
330 | sample["affine_mats_inv"] = affine_mats_inv
331 | sample["closest_idxs"] = closest_idxs
332 |
333 | return sample
334 |
--------------------------------------------------------------------------------
/data/klevr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import json
4 | import numpy as np
5 | import os
6 | from glob import glob as glob
7 | from PIL import Image
8 | from torchvision import transforms as T
9 |
10 | from utils.klevr_utils import from_position_and_quaternion, scale_rays, calculate_near_and_far
11 | from utils.utils import read_pfm, get_nearest_pose_ids, get_rays
12 |
13 | # Nesf Klevr
14 | class KlevrDataset(Dataset):
15 | def __init__(
16 | self,
17 | root_dir,
18 | nb_views,
19 | split='train',
20 | get_rgb=True,
21 | get_semantic=False,
22 | max_len=-1,
23 | scene=None,
24 | downSample=1.0
25 | ):
26 |
27 | # super().__init__()
28 | self.root_dir = root_dir
29 | self.get_rgb = get_rgb
30 | self.get_semantic = get_semantic
31 | self.nb_views = nb_views
32 | self.max_len = max_len
33 | self.scene = scene
34 | self.downSample = downSample
35 | self.blender2opencv = np.array(
36 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
37 | )
38 |
39 | # This is hard coded for Klevr as all scans are in the same range (according to the metadata)
40 | self.scene_boundaries = np.array([[-3.1,-3.1,-0.1],[3.1,3.1,3.1]])
41 | if split == 'train':
42 | self.split = split
43 | elif split =='val':
44 | self.split = 'test'
45 | else:
46 | raise KeyError("only train/val split works")
47 |
48 | self.define_transforms()
49 | self.read_meta()
50 | self.white_back = True
51 | self.buid_proj_mats()
52 |
53 | def define_transforms(self):
54 | # this normalize is for imagenet pretrained resnet for CasMVSNet pretrained weights
55 | self.transform = T.Compose(
56 | [
57 | T.ToTensor(),
58 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59 | ]
60 | )
61 |
62 | def read_meta(self):
63 | self.metas = []
64 | # for now, we only use the first 10 scans to train, 11~20 scans to test
65 | if self.split == 'train':
66 | self.scans = sorted(glob(os.path.join(self.root_dir, '*')))[:100]
67 | else:
68 | self.scans = sorted(glob(os.path.join(self.root_dir, '*')))[100:120]
69 |
70 | # remap the scan_idx to the scan name
71 | self.scan_idx_to_name = {}
72 |
73 | # record ids that being used of each scan
74 | self.id_list = []
75 |
76 | # read the pair file, here use the same pair list as dtu dataset;
77 | # could be 6x more pairs since each klevr scene have 300 views, dtu have 50 views
78 | if self.split == "train":
79 | if self.scene == "None":
80 | pair_file = f"configs/lists/dtu_pairs.txt"
81 | else:
82 | pair_file = f"configs/lists/dtu_pairs_ft.txt"
83 | else:
84 | pair_file = f"configs/lists/dtu_pairs_val.txt"
85 |
86 | for scan_idx, meta_filename in enumerate(self.scans):
87 | with open(pair_file) as f:
88 | scan = meta_filename.split('/')[-1]
89 | self.scan_idx_to_name[scan_idx] = scan
90 | num_viewpoint = int(f.readline())
91 | for _ in range(num_viewpoint):
92 | ref_view = int(f.readline().rstrip())
93 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
94 | self.metas += [(scan_idx, ref_view, src_views[:self.nb_views])]
95 | self.id_list.append([ref_view] + src_views)
96 |
97 | self.id_list = np.unique(self.id_list)
98 | self.build_remap()
99 |
100 | def build_remap(self):
101 | self.remap = np.zeros(np.max(self.id_list) + 1).astype("int")
102 | for i, item in enumerate(self.id_list):
103 | self.remap[item] = i
104 |
105 | def buid_proj_mats(self):
106 | # maybe not calculate near_far now. Do it when creating the rays
107 | self.near_fars, self.intrinsics, self.world2cams, self.cam2worlds = None, {}, {}, {}
108 | for scan_idx, meta_fileprefix in enumerate(self.scans):
109 | meta_filename = os.path.join(meta_fileprefix, 'metadata.json')
110 | intrinsic, world2cam, cam2world = self.read_cam_file(meta_filename, scan_idx)
111 | self.intrinsics[scan_idx], self.world2cams[scan_idx], self.cam2worlds[scan_idx] = np.array(intrinsic), np.array(world2cam), np.array(cam2world)
112 |
113 | def read_cam_file(self, filename, scan_idx):
114 | '''
115 | read the metadata file and return the near/far, intrinsic, world2cam, cam2world
116 | filename(str): the metadata file
117 | scan_idx(int): the index of the scan
118 |
119 | return:
120 | intrinsic: the intrinsic of the scan [N,3,3]
121 | world2cam: the world2cam of the scan [N,4,4]
122 | cam2world: the cam2world of the scan [N,4,4]
123 | '''
124 | intrinsic, world2cam, cam2world = [], [], []
125 | with open(filename, 'r') as f:
126 | meta = json.load(f)
127 | w, h = meta['metadata']['width'], meta['metadata']['height']
128 | focal = meta['camera']['focal_length']*w/meta['camera']['sensor_width']
129 |
130 | camera_positions = np.array(meta['camera']['positions'])
131 | camera_quaternions = np.array(meta['camera']['quaternions'])
132 | # calculate camera pose of each frame that will be used (in this scan idx)
133 | for frame_id in self.id_list:
134 | c2w = from_position_and_quaternion(camera_positions[frame_id], camera_quaternions[frame_id], False).tolist()
135 | # not sure
136 | c2w = np.array(c2w) @ self.blender2opencv
137 | cam2world += [c2w.tolist()]
138 | world2cam += [np.linalg.inv(c2w).tolist()]
139 | intrinsic += [[[focal, 0, w/2], [0, focal, h/2], [0, 0, 1]]]
140 |
141 | return intrinsic, world2cam, cam2world
142 |
143 | def read_depth(self, filename, far_bound, noisy_factor=1.0):
144 | # read depth image, currently not using it
145 | depth_h = Image.open(filename)
146 | depth_wh = np.round(np.array(depth_h.size) * self.downSample * noisy_factor).astype(np.int32)
147 |
148 | # originally NeRF use Image.Resampling.LANCZOS, not sure if BICUBIC is better
149 | depth_h = depth_h.resize(depth_wh, Image.BILINEAR)
150 |
151 | ## Exclude points beyond the bounds
152 | depth_h_filter = np.array(depth_h)
153 | depth_h_filter[depth_h_filter > far_bound * 0.95] = 0.0
154 |
155 |
156 | depth = {}
157 | for l in range(3):
158 | depth_wh = np.round(np.array(depth_h.size) / (2**l)).astype(np.int32)
159 | depth[f"level_{l}"] = np.array(depth_h.resize(depth_wh, Image.BILINEAR))
160 | depth[f"level_{l}"][depth[f"level_{l}"] > far_bound * 0.95] = 0.0
161 |
162 | if self.split == "train":
163 | cutout = np.ones_like(depth[f"level_2"])
164 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1))
165 | h1 = int(
166 | np.random.randint(
167 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1
168 | )
169 | )
170 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1))
171 | w1 = int(
172 | np.random.randint(
173 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1
174 | )
175 | )
176 | cutout[h0:h1, w0:w1] = 0
177 | depth_aug = depth[f"level_2"] * cutout
178 | else:
179 | depth_aug = depth[f"level_2"].copy()
180 |
181 | return depth, depth_h_filter, depth_aug
182 |
183 |
184 | def __len__(self):
185 | return len(self.metas) if self.max_len <= 0 else self.max_len
186 |
187 | def __getitem__(self, idx):
188 | # haven't used the depth image and noisy factor yet
189 | if self.split == "train" and self.scene == "None":
190 | noisy_factor = float(np.random.choice([1.0, 0.5], 1))
191 | close_views = int(np.random.choice([3, 4, 5], 1))
192 | else:
193 | noisy_factor = 1.0
194 | close_views = 5
195 |
196 | scan_idx, ref_id, src_ids = self.metas[idx]
197 |
198 | # notice that the ref_id is in the last position
199 | view_ids = src_ids + [ref_id]
200 |
201 | affine_mats, affine_mats_inv = [], []
202 | imgs, depths_h, depths_aug = [], [], []
203 | depths = {"level_0": [], "level_1": [], "level_2": []}
204 | semantics = []
205 | intrinsics, w2cs, c2ws = [], [], []
206 |
207 | # intrinsic now every frame has its own intrinsic, but actually it is the same for all frames in a scan
208 | # # every scan has only one intrinsic, here actually is focal
209 | # intrinsic = self.intrinsics[scan_idx]
210 |
211 | for vid in view_ids:
212 | img_filename = os.path.join(self.root_dir, self.scan_idx_to_name[scan_idx],f'rgba_{vid:05d}.png')
213 | depth_filename = os.path.join(self.root_dir, self.scan_idx_to_name[scan_idx],f'depth_{vid:05d}.tiff')
214 |
215 | img = Image.open(img_filename)
216 | img_wh = np.round(np.array(img.size)*self.downSample).astype(np.int32)
217 |
218 | # originally NeRF use Image.Resampling.LANCZOS, not sure if BICUBIC is better
219 | img = img.resize(img_wh, Image.LANCZOS)
220 | # discard the alpha channel, only use rgb. Maybe need "valid_mask = img[-1]>0"
221 | img = self.transform(np.array(img)[:,:,:3])
222 | imgs += [img]
223 |
224 | # semantic part
225 | if self.get_semantic:
226 | semantic_filename = os.path.join(self.root_dir, self.scan_idx_to_name[scan_idx],f'segmentation_{vid:05d}.png')
227 | semantic = Image.open(semantic_filename)
228 | semantic = semantic.resize(img_wh, Image.LANCZOS)
229 | semantic = torch.from_numpy(np.array(semantic)).long() #(h,w)
230 | semantics += [semantic]
231 |
232 | index = self.remap[vid]
233 | # # debug
234 | # print("vid: ", vid, "index: ", index, "scan_idx: ", scan_idx)
235 | # print("self.remap[scan_idx]: ", self.remap[scan_idx])
236 | # print("self.cam2worlds[scan_idx]: ", np.array(self.cam2worlds[scan_idx]).shape)
237 | # raise Exception("debug")
238 | c2ws.append(self.cam2worlds[scan_idx][index])
239 |
240 | w2c = self.world2cams[scan_idx][index]
241 | w2cs.append(w2c)
242 |
243 | intrinsic = self.intrinsics[scan_idx][index]
244 | intrinsics.append(intrinsic)
245 |
246 | aff = []
247 | aff_inv = []
248 | # if using the depth image, there should be for l in range(3)
249 | for l in range(3):
250 | proj_mat_l = np.eye(4)
251 | intrinsic_temp = intrinsic.copy()
252 | intrinsic_temp[:2] = intrinsic_temp[:2]/(2**l)
253 | proj_mat_l[:3,:4] = intrinsic_temp @ w2c[:3,:4]
254 | aff.append(proj_mat_l)
255 | aff_inv.append(np.linalg.inv(proj_mat_l))
256 |
257 | aff = np.stack(aff, axis=-1)
258 | aff_inv = np.stack(aff_inv, axis=-1)
259 |
260 | affine_mats.append(aff)
261 | affine_mats_inv.append(aff_inv)
262 | # currently hardcode far bound to be 17
263 | depth, depth_h, depth_aug = self.read_depth(depth_filename, far_bound=17, noisy_factor=1)
264 |
265 | depths["level_0"].append(depth["level_0"])
266 | depths["level_1"].append(depth["level_1"])
267 | depths["level_2"].append(depth["level_2"])
268 | depths_h.append(depth_h)
269 | depths_aug.append(depth_aug)
270 |
271 |
272 | imgs = np.stack(imgs)
273 | semantics = np.stack(semantics)
274 | affine_mats = np.stack(affine_mats)
275 | affine_mats_inv = np.stack(affine_mats_inv)
276 | intrinsics = np.stack(intrinsics)
277 | w2cs = np.stack(w2cs)
278 | c2ws = np.stack(c2ws)
279 | depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug)
280 | depths["level_0"] = np.stack(depths["level_0"])
281 | depths["level_1"] = np.stack(depths["level_1"])
282 | depths["level_2"] = np.stack(depths["level_2"])
283 |
284 |
285 | close_idxs = []
286 | for pose in c2ws[:-1]:
287 | close_idxs.append(
288 | get_nearest_pose_ids(
289 | pose,
290 | c2ws[:-1],
291 | close_views,
292 | angular_dist_method="dist"
293 | )
294 | )
295 | close_idxs = np.stack(close_idxs, axis=0)
296 | self.near_fars = []
297 | for i in range(imgs.shape[0]):
298 | rays_orig, rays_dir, _ = get_rays(
299 | H=imgs.shape[2],
300 | W=imgs.shape[3],
301 | # hard code to cuda
302 | intrinsics_target=torch.tensor(intrinsics[i].astype(imgs.dtype)),
303 | c2w_target=torch.tensor(c2ws[i].astype(imgs.dtype)),
304 | # train=False with chunk=-1, will return all rays of the image
305 | train=False,
306 | )
307 | near, far = calculate_near_and_far(rays_orig, rays_dir, bbox_min=[-3.1, -3.1, -0.1], bbox_max=[3.1, 3.1, 3.1])
308 | near = near.min().item()
309 | far = far.max().item()
310 | self.near_fars.append([near, far])
311 |
312 | sample = {}
313 | sample['images'] = imgs
314 | if self.get_semantic:
315 | sample['semantics'] = semantics
316 | sample['w2cs'] = w2cs.astype(imgs.dtype)
317 | sample['c2ws'] = c2ws.astype(imgs.dtype)
318 | sample['intrinsics'] = intrinsics.astype(imgs.dtype)
319 | sample['affine_mats'] = affine_mats.astype(imgs.dtype)
320 | sample['affine_mats_inv'] = affine_mats_inv.astype(imgs.dtype)
321 | sample['closest_idxs'] = close_idxs
322 | # depth aug seems to be a must to give, but if set to None doesn't matter (the use_depth is False)
323 | sample['depths_aug'] = depths_aug
324 | sample['depths_h'] = depths_h
325 | # depth should be a dict, but if set to None doesn't matter (the use_depth is False) (original code still use depth loss)
326 | sample['depths'] = depths
327 | # near_fars is now just using constant value
328 | sample['near_fars'] = np.array(self.near_fars).astype(imgs.dtype)
329 | return sample
--------------------------------------------------------------------------------
/model/geo_reasoner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.utils.checkpoint import checkpoint
5 |
6 | from model.UNet import UNet, smp_UNet
7 | from utils.utils import homo_warp
8 | from inplace_abn import InPlaceABN
9 |
10 |
11 | def get_depth_values(current_depth, n_depths, depth_interval):
12 | depth_min = torch.clamp_min(current_depth - n_depths / 2 * depth_interval, 1e-3)
13 | depth_values = (
14 | depth_min
15 | + depth_interval
16 | * torch.arange(
17 | 0, n_depths, device=current_depth.device, dtype=current_depth.dtype
18 | )[None, :, None, None]
19 | )
20 | return depth_values
21 |
22 |
23 | class ConvBnReLU(nn.Module):
24 | def __init__(
25 | self,
26 | in_channels,
27 | out_channels,
28 | kernel_size=3,
29 | stride=1,
30 | pad=1,
31 | norm_act=InPlaceABN,
32 | ):
33 | super(ConvBnReLU, self).__init__()
34 | self.conv = nn.Conv2d(
35 | in_channels,
36 | out_channels,
37 | kernel_size,
38 | stride=stride,
39 | padding=pad,
40 | bias=False,
41 | )
42 | self.bn = norm_act(out_channels)
43 |
44 | def forward(self, x):
45 | return self.bn(self.conv(x))
46 |
47 |
48 | class ConvBnReLU3D(nn.Module):
49 | def __init__(
50 | self,
51 | in_channels,
52 | out_channels,
53 | kernel_size=3,
54 | stride=1,
55 | pad=1,
56 | norm_act=InPlaceABN,
57 | ):
58 | super(ConvBnReLU3D, self).__init__()
59 | self.conv = nn.Conv3d(
60 | in_channels,
61 | out_channels,
62 | kernel_size,
63 | stride=stride,
64 | padding=pad,
65 | bias=False,
66 | )
67 | self.bn = norm_act(out_channels)
68 |
69 | def forward(self, x):
70 | return self.bn(self.conv(x))
71 |
72 |
73 | class FeatureNet(nn.Module):
74 | def __init__(self, norm_act=InPlaceABN):
75 | super(FeatureNet, self).__init__()
76 |
77 | self.conv0 = nn.Sequential(
78 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
79 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act),
80 | )
81 |
82 | self.conv1 = nn.Sequential(
83 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
84 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
85 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
86 | )
87 |
88 | self.conv2 = nn.Sequential(
89 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
90 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
91 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
92 | )
93 |
94 | self.toplayer = nn.Conv2d(32, 32, 1)
95 | self.lat1 = nn.Conv2d(16, 32, 1)
96 | self.lat0 = nn.Conv2d(8, 32, 1)
97 |
98 | # to reduce channel size of the outputs from FPN
99 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
100 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
101 |
102 | def _upsample_add(self, x, y):
103 | return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y
104 |
105 | def forward(self, x, dummy=None):
106 | # x: (B, 3, H, W)
107 | conv0 = self.conv0(x) # (B, 8, H, W)
108 | conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
109 | conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
110 | feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
111 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
112 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
113 |
114 | # reduce output channels
115 | feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
116 | feat0 = self.smooth0(feat0) # (B, 8, H, W)
117 |
118 | feats = {"level_0": feat0, "level_1": feat1, "level_2": feat2}
119 |
120 | return feats
121 |
122 |
123 | class CostRegNet(nn.Module):
124 | def __init__(self, in_channels, norm_act=InPlaceABN):
125 | super(CostRegNet, self).__init__()
126 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
127 |
128 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
129 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
130 |
131 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
132 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
133 |
134 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
135 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
136 |
137 | self.conv7 = nn.Sequential(
138 | nn.ConvTranspose3d(
139 | 64, 32, 3, padding=1, output_padding=1, stride=2, bias=False
140 | ),
141 | norm_act(32),
142 | )
143 |
144 | self.conv9 = nn.Sequential(
145 | nn.ConvTranspose3d(
146 | 32, 16, 3, padding=1, output_padding=1, stride=2, bias=False
147 | ),
148 | norm_act(16),
149 | )
150 |
151 | self.conv11 = nn.Sequential(
152 | nn.ConvTranspose3d(
153 | 16, 8, 3, padding=1, output_padding=1, stride=2, bias=False
154 | ),
155 | norm_act(8),
156 | )
157 |
158 | self.br1 = ConvBnReLU3D(8, 8, norm_act=norm_act)
159 | self.br2 = ConvBnReLU3D(8, 8, norm_act=norm_act)
160 |
161 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1)
162 |
163 | def forward(self, x):
164 | if x.shape[-2] % 8 != 0 or x.shape[-1] % 8 != 0:
165 | pad_h = 8 * (x.shape[-2] // 8 + 1) - x.shape[-2]
166 | pad_w = 8 * (x.shape[-1] // 8 + 1) - x.shape[-1]
167 | x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0)
168 | else:
169 | pad_h = 0
170 | pad_w = 0
171 |
172 | conv0 = self.conv0(x)
173 | conv2 = self.conv2(self.conv1(conv0))
174 | conv4 = self.conv4(self.conv3(conv2))
175 |
176 | x = self.conv6(self.conv5(conv4))
177 | x = conv4 + self.conv7(x)
178 | del conv4
179 | x = conv2 + self.conv9(x)
180 | del conv2
181 | x = conv0 + self.conv11(x)
182 | del conv0
183 | ####################
184 | # x1 = self.br1(x)
185 | # with torch.enable_grad():
186 | # x2 = self.br2(x)
187 | x1 = self.br1(x)
188 | x2 = self.br2(x)
189 | ####################
190 | p = self.prob(x1)
191 |
192 | if pad_h > 0 or pad_w > 0:
193 | x2 = x2[..., :-pad_h, :-pad_w]
194 | p = p[..., :-pad_h, :-pad_w]
195 |
196 | return x2, p
197 |
198 |
199 | class CasMVSNet(nn.Module):
200 | def __init__(self, num_groups=8, norm_act=InPlaceABN, levels=3, use_depth=False, nb_class=1, feat_net=None):
201 | super(CasMVSNet, self).__init__()
202 | self.levels = levels # 3 depth levels
203 | self.n_depths = [8, 32, 48]
204 | self.interval_ratios = [1, 2, 4]
205 | self.use_depth = use_depth
206 |
207 | self.G = num_groups # number of groups in groupwise correlation
208 | # self.feature = FeatureNet()
209 | # change to UNet
210 | if feat_net is None:
211 | raise ValueError("feat_net must be specified")
212 | elif feat_net == 'UNet':
213 | self.feature = UNet(3,nb_class)
214 | elif feat_net == 'smp_UNet':
215 | self.feature = smp_UNet(3,nb_class)
216 |
217 | for l in range(self.levels):
218 | if l == self.levels - 1 and self.use_depth:
219 | cost_reg_l = CostRegNet(self.G + 1, norm_act)
220 | else:
221 | cost_reg_l = CostRegNet(self.G, norm_act)
222 |
223 | setattr(self, f"cost_reg_{l}", cost_reg_l)
224 |
225 | def build_cost_volumes(self, feats, affine_mats, affine_mats_inv, depth_values, idx, spikes):
226 | B, V, C, H, W = feats.shape
227 | D = depth_values.shape[1]
228 |
229 | ref_feats, src_feats = feats[:, idx[0]], feats[:, idx[1:]]
230 | src_feats = src_feats.permute(1, 0, 2, 3, 4) # (V-1, B, C, h, w)
231 |
232 | affine_mats_inv = affine_mats_inv[:, idx[0]]
233 | affine_mats = affine_mats[:, idx[1:]]
234 | affine_mats = affine_mats.permute(1, 0, 2, 3) # (V-1, B, 3, 4)
235 |
236 | ref_volume = ref_feats.unsqueeze(2).repeat(1, 1, D, 1, 1) # (B, C, D, h, w)
237 |
238 | ref_volume = ref_volume.view(B, self.G, C // self.G, *ref_volume.shape[-3:])
239 | volume_sum = 0
240 |
241 | for i in range(len(idx) - 1):
242 | proj_mat = (affine_mats[i].double() @ affine_mats_inv.double()).float()[
243 | :, :3
244 | ] # shape (1,3,4)
245 | warped_volume, grid = homo_warp(src_feats[i], proj_mat, depth_values)
246 |
247 | warped_volume = warped_volume.view_as(ref_volume)
248 | volume_sum = volume_sum + warped_volume # (B, G, C//G, D, h, w)
249 | if torch.isnan(volume_sum).sum()>0:
250 | print("nan in volume_sum")
251 |
252 | volume = (volume_sum * ref_volume).mean(dim=2) / (V - 1)
253 |
254 | if spikes is None:
255 | output = volume
256 | else:
257 | output = torch.cat([volume, spikes], dim=1)
258 |
259 | return output
260 |
261 | def create_neural_volume(
262 | self,
263 | feats,
264 | affine_mats,
265 | affine_mats_inv,
266 | idx,
267 | init_depth_min,
268 | depth_interval,
269 | gt_depths,
270 | ):
271 | if feats["level_0"].shape[-1] >= 800:
272 | hres_input = True
273 | else:
274 | hres_input = False
275 |
276 | B, V = affine_mats.shape[:2]
277 |
278 | v_feat = {}
279 | depth_maps = {}
280 | depth_values = {}
281 | for l in reversed(range(self.levels)): # (2, 1, 0)
282 | feats_l = feats[f"level_{l}"] # (B*V, C, h, w)
283 | feats_l = feats_l.view(B, V, *feats_l.shape[1:]) # (B, V, C, h, w)
284 | h, w = feats_l.shape[-2:]
285 | depth_interval_l = depth_interval * self.interval_ratios[l]
286 | D = self.n_depths[l]
287 | if l == self.levels - 1: # coarsest level
288 | depth_values_l = init_depth_min + depth_interval_l * torch.arange(
289 | 0, D, device=feats_l.device, dtype=feats_l.dtype
290 | ) # (D)
291 | depth_values_l = depth_values_l[None, :, None, None].expand(
292 | -1, -1, h, w
293 | )
294 |
295 | if self.use_depth:
296 | gt_mask = gt_depths > 0
297 | sp_idx_float = (
298 | gt_mask * (gt_depths - init_depth_min) / (depth_interval_l)
299 | )[:, :, None]
300 | spikes = (
301 | torch.arange(D).view(1, 1, -1, 1, 1).to(gt_mask.device)
302 | == sp_idx_float.floor().long()
303 | ) * (1 - sp_idx_float.frac())
304 | spikes = spikes + (
305 | torch.arange(D).view(1, 1, -1, 1, 1).to(gt_mask.device)
306 | == sp_idx_float.ceil().long()
307 | ) * (sp_idx_float.frac())
308 | spikes = (spikes * gt_mask[:, :, None]).float()
309 | else:
310 | depth_lm1 = depth_l.detach() # the depth of previous level
311 | depth_lm1 = F.interpolate(
312 | depth_lm1, scale_factor=2, mode="bilinear", align_corners=True
313 | ) # (B, 1, h, w)
314 | depth_values_l = get_depth_values(depth_lm1, D, depth_interval_l)
315 |
316 | affine_mats_l = affine_mats[..., l]
317 | affine_mats_inv_l = affine_mats_inv[..., l]
318 |
319 | if l == self.levels - 1 and self.use_depth:
320 | spikes_ = spikes
321 | else:
322 | spikes_ = None
323 |
324 | if hres_input:
325 | v_feat_l = checkpoint(
326 | self.build_cost_volumes,
327 | feats_l,
328 | affine_mats_l,
329 | affine_mats_inv_l,
330 | depth_values_l,
331 | idx,
332 | spikes_,
333 | preserve_rng_state=False,
334 | )
335 | else:
336 | v_feat_l = self.build_cost_volumes(
337 | feats_l,
338 | affine_mats_l,
339 | affine_mats_inv_l,
340 | depth_values_l,
341 | idx,
342 | spikes_,
343 | )
344 |
345 | cost_reg_l = getattr(self, f"cost_reg_{l}")
346 | v_feat_l_, depth_prob = cost_reg_l(v_feat_l) # (B, 1, D, h, w)
347 |
348 | depth_l = (F.softmax(depth_prob, dim=2) * depth_values_l[:, None]).sum(
349 | dim=2
350 | )
351 | # v_feat_l have 8 nan values, go debug build_cost_volumes
352 | if torch.isnan(v_feat_l_).sum()>0:
353 | print("nan in v_feat_l_")
354 | v_feat[f"level_{l}"] = v_feat_l_
355 | depth_maps[f"level_{l}"] = depth_l
356 | depth_values[f"level_{l}"] = depth_values_l
357 |
358 | return v_feat, depth_maps, depth_values
359 |
360 | def forward(
361 | self, imgs, affine_mats, affine_mats_inv, near_far, closest_idxs, gt_depths=None
362 | ):
363 | B, V, _, H, W = imgs.shape
364 |
365 | ## Feature Pyramid
366 | feats = self.feature(
367 | imgs.reshape(B * V, 3, H, W)
368 | ) # (B*V, 8, H, W), (B*V, 16, H//2, W//2), (B*V, 32, H//4, W//4)
369 | feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:])
370 | semantic_logits = feats[f'logits'].reshape(B, V, *feats[f'logits'].shape[1:])
371 | semantic_feature = feats[f'feature'].reshape(B, V, *feats[f'feature'].shape[1:])
372 |
373 | feats_vol = {"level_0": [], "level_1": [], "level_2": []}
374 | depth_map = {"level_0": [], "level_1": [], "level_2": []}
375 | depth_values = {"level_0": [], "level_1": [], "level_2": []}
376 | ## Create cost volumes for each view
377 | for i in range(0, V):
378 | permuted_idx = closest_idxs[0, i].clone().detach().to(feats['level_0'].device)
379 | # if near_far.sum() == 0:
380 | # init_depth_min = 1
381 | # depth_interval = 0.1
382 | # else:
383 | init_depth_min = near_far[0, i, 0]
384 | depth_interval = (
385 | (near_far[0, i, 1] - near_far[0, i, 0])
386 | / self.n_depths[-1]
387 | / self.interval_ratios[-1]
388 | )
389 |
390 | v_feat, d_map, d_values = self.create_neural_volume(
391 | feats,
392 | affine_mats,
393 | affine_mats_inv,
394 | idx=permuted_idx,
395 | init_depth_min=init_depth_min,
396 | depth_interval=depth_interval,
397 | gt_depths=gt_depths[:, i : i + 1],
398 | )
399 |
400 | for l in range(3):
401 | feats_vol[f"level_{l}"].append(v_feat[f"level_{l}"])
402 | depth_map[f"level_{l}"].append(d_map[f"level_{l}"])
403 | depth_values[f"level_{l}"].append(d_values[f"level_{l}"])
404 |
405 | for l in range(3):
406 | feats_vol[f"level_{l}"] = torch.stack(feats_vol[f"level_{l}"], dim=1)
407 | depth_map[f"level_{l}"] = torch.cat(depth_map[f"level_{l}"], dim=1)
408 | depth_values[f"level_{l}"] = torch.stack(depth_values[f"level_{l}"], dim=1)
409 |
410 | return feats_vol, feats_fpn, depth_map, depth_values, semantic_logits, semantic_feature
411 |
--------------------------------------------------------------------------------
/model/self_attn_renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import math
6 |
7 | def weights_init(m):
8 | if isinstance(m, nn.Linear):
9 | stdv = 1.0 / math.sqrt(m.weight.size(1))
10 | m.weight.data.uniform_(-stdv, stdv)
11 | if m.bias is not None:
12 | m.bias.data.uniform_(stdv, stdv)
13 |
14 |
15 | def masked_softmax(x, mask, **kwargs):
16 | x_masked = x.masked_fill(mask == 0, -float("inf"))
17 |
18 | return torch.softmax(x_masked, **kwargs)
19 |
20 |
21 | ## Auto-encoder network
22 | class ConvAutoEncoder(nn.Module):
23 | def __init__(self, num_ch, S):
24 | super(ConvAutoEncoder, self).__init__()
25 |
26 | # Encoder
27 | self.conv1 = nn.Sequential(
28 | nn.Conv1d(num_ch, num_ch * 2, 3, stride=1, padding=1),
29 | # nn.LayerNorm(S, elementwise_affine=False),
30 | nn.ELU(alpha=1.0, inplace=True),
31 | nn.MaxPool1d(2),
32 | )
33 | self.conv2 = nn.Sequential(
34 | nn.Conv1d(num_ch * 2, num_ch * 4, 3, stride=1, padding=1),
35 | # nn.LayerNorm(S // 2, elementwise_affine=False),
36 | nn.ELU(alpha=1.0, inplace=True),
37 | nn.MaxPool1d(2),
38 | )
39 | self.conv3 = nn.Sequential(
40 | nn.Conv1d(num_ch * 4, num_ch * 4, 3, stride=1, padding=1),
41 | # nn.LayerNorm(S // 4, elementwise_affine=False),
42 | nn.ELU(alpha=1.0, inplace=True),
43 | nn.MaxPool1d(2),
44 | )
45 |
46 | # Decoder
47 | self.t_conv1 = nn.Sequential(
48 | nn.ConvTranspose1d(num_ch * 4, num_ch * 4, 4, stride=2, padding=1),
49 | # nn.LayerNorm(S // 4, elementwise_affine=False),
50 | nn.ELU(alpha=1.0, inplace=True),
51 | )
52 | self.t_conv2 = nn.Sequential(
53 | nn.ConvTranspose1d(num_ch * 8, num_ch * 2, 4, stride=2, padding=1),
54 | # nn.LayerNorm(S // 2, elementwise_affine=False),
55 | nn.ELU(alpha=1.0, inplace=True),
56 | )
57 | self.t_conv3 = nn.Sequential(
58 | nn.ConvTranspose1d(num_ch * 4, num_ch, 4, stride=2, padding=1),
59 | # nn.LayerNorm(S, elementwise_affine=False),
60 | nn.ELU(alpha=1.0, inplace=True),
61 | )
62 | # Output
63 | self.conv_out = nn.Sequential(
64 | nn.Conv1d(num_ch * 2, num_ch, 3, stride=1, padding=1),
65 | # nn.LayerNorm(S, elementwise_affine=False),
66 | nn.ELU(alpha=1.0, inplace=True),
67 | )
68 |
69 | def forward(self, x):
70 | input = x
71 | x = self.conv1(x)
72 | conv1_out = x
73 | x = self.conv2(x)
74 | conv2_out = x
75 | x = self.conv3(x)
76 |
77 | x = self.t_conv1(x)
78 | x = self.t_conv2(torch.cat([x, conv2_out], dim=1))
79 | x = self.t_conv3(torch.cat([x, conv1_out], dim=1))
80 |
81 | x = self.conv_out(torch.cat([x, input], dim=1))
82 |
83 | return x
84 |
85 |
86 | class ScaledDotProductAttention(nn.Module):
87 | def __init__(self, temperature, attn_dropout=0.1):
88 | super().__init__()
89 | self.temperature = temperature
90 |
91 | def forward(self, q, k, v, mask=None):
92 |
93 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
94 |
95 | if mask is not None:
96 | attn = masked_softmax(attn, mask, dim=-1)
97 | else:
98 | attn = F.softmax(attn, dim=-1)
99 |
100 | output = torch.matmul(attn, v)
101 |
102 | return output, attn
103 |
104 |
105 | class PositionwiseFeedForward(nn.Module):
106 | def __init__(self, d_in, d_hid, dropout=0.1):
107 | super().__init__()
108 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise
109 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise
110 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
111 |
112 | def forward(self, x):
113 |
114 | residual = x
115 |
116 | x = self.w_2(F.relu(self.w_1(x)))
117 | x += residual
118 |
119 | x = self.layer_norm(x)
120 |
121 | return x
122 |
123 |
124 | class MultiHeadAttention(nn.Module):
125 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
126 | super().__init__()
127 |
128 | self.n_head = n_head
129 | self.d_k = d_k
130 | self.d_v = d_v
131 |
132 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
133 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
134 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
135 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
136 |
137 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
138 |
139 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
140 |
141 | def forward(self, q, k, v, mask=None):
142 |
143 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
144 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
145 |
146 | residual = q
147 |
148 | # Pass through the pre-attention projection: b x lq x (n*dv)
149 | # Separate different heads: b x lq x n x dv
150 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
151 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
152 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
153 |
154 | # Transpose for attention dot product: b x n x lq x dv
155 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
156 |
157 | if mask is not None:
158 | mask = mask.transpose(1, 2).unsqueeze(1) # For head axis broadcasting.
159 |
160 | q, attn = self.attention(q, k, v, mask=mask)
161 |
162 | # Transpose to move the head dimension back: b x lq x n x dv
163 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
164 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
165 | q = self.fc(q)
166 | q += residual
167 |
168 | q = self.layer_norm(q)
169 |
170 | return q, attn
171 |
172 |
173 | class EncoderLayer(nn.Module):
174 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0):
175 | super(EncoderLayer, self).__init__()
176 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
177 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
178 |
179 | def forward(self, enc_input, slf_attn_mask=None):
180 | enc_output, enc_slf_attn = self.slf_attn(
181 | enc_input, enc_input, enc_input, mask=slf_attn_mask
182 | )
183 | enc_output = self.pos_ffn(enc_output)
184 | return enc_output, enc_slf_attn
185 |
186 |
187 | class Renderer(nn.Module):
188 | def __init__(self, nb_samples_per_ray, nb_view=8, nb_class=21,
189 | using_semantic_global_tokens=True, only_using_semantic_global_tokens=True, use_batch_semantic_feature=False):
190 | super(Renderer, self).__init__()
191 |
192 | self.nb_view = nb_view
193 | self.nb_class = nb_class
194 | self.using_semantic_global_tokens = using_semantic_global_tokens
195 | self.only_using_semantic_global_tokens = only_using_semantic_global_tokens
196 | self.dim = 32
197 |
198 | if use_batch_semantic_feature:
199 | self.nb_class = self.nb_class * 9
200 | else:
201 | self.nb_class = self.nb_class
202 |
203 | self.semantic_token_gen = nn.Linear(1 + self.nb_class, self.dim)
204 |
205 | self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)
206 |
207 | ## Self-Attention Settings
208 | d_inner = self.dim
209 | n_head = 4
210 | d_k = self.dim // n_head
211 | d_v = self.dim // n_head
212 | num_layers = 4
213 | self.attn_layers = nn.ModuleList(
214 | [
215 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
216 | for i in range(num_layers)
217 | ]
218 | )
219 |
220 | self.semantic_attn_layers = nn.ModuleList(
221 | [
222 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
223 | for i in range(num_layers)
224 | ]
225 | )
226 |
227 | # +1 because we add the mean and variance of input features as global features
228 | if using_semantic_global_tokens and only_using_semantic_global_tokens:
229 | self.semantic_dim = self.dim
230 | ## Processing the mean and variance of semantic features
231 | self.semantic_var_mean_fc1 = nn.Linear(self.nb_class*2, self.dim)
232 | self.semantic_var_mean_fc2 = nn.Linear(self.dim, self.dim)
233 | elif using_semantic_global_tokens:
234 | self.semantic_dim = self.dim * (nb_view + 1)
235 | ## Processing the mean and variance of semantic features
236 | self.semantic_var_mean_fc1 = nn.Linear(self.nb_class*2, self.dim)
237 | self.semantic_var_mean_fc2 = nn.Linear(self.dim, self.dim)
238 | else:
239 | self.semantic_dim = self.dim * nb_view
240 |
241 | self.semantic_fc1 = nn.Linear(self.semantic_dim, self.semantic_dim)
242 | self.semantic_fc2 = nn.Linear(self.semantic_dim, self.semantic_dim // 2)
243 | self.semantic_fc3 = nn.Linear(self.semantic_dim // 2, nb_class)
244 |
245 | ## Processing the mean and variance of input features
246 | self.var_mean_fc1 = nn.Linear(16, self.dim)
247 | self.var_mean_fc2 = nn.Linear(self.dim, self.dim)
248 |
249 |
250 | ## Setting mask of var_mean always enabled
251 | self.var_mean_mask = torch.tensor([1])
252 | self.var_mean_mask.requires_grad = False
253 |
254 | ## For aggregating data along ray samples
255 | self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)
256 |
257 | self.sigma_fc1 = nn.Linear(self.dim, self.dim)
258 | self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
259 | self.sigma_fc3 = nn.Linear(self.dim // 2, 1)
260 |
261 |
262 | self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
263 | self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
264 | self.rgb_fc3 = nn.Linear(self.dim // 2, 1)
265 |
266 | ## Initialization
267 | self.sigma_fc1.apply(weights_init)
268 | self.sigma_fc2.apply(weights_init)
269 | self.sigma_fc3.apply(weights_init)
270 | self.rgb_fc1.apply(weights_init)
271 | self.rgb_fc2.apply(weights_init)
272 | self.rgb_fc3.apply(weights_init)
273 |
274 | def forward(self, viewdirs, feat, occ_masks, middle_pts_mask):
275 | ## Viewing samples regardless of batch or ray
276 | N, S, V = feat.shape[:3]
277 | feat = feat.view(-1, *feat.shape[2:])
278 | v_feat = feat[..., :24]
279 | s_feat = feat[..., 24 : 24 + 8]
280 | colors = feat[..., 24 + 8 : 24 + 8 + 3]
281 | semantic_feat = feat[..., 24 + 8 + 3 : -1]
282 | vis_mask = feat[..., -1:].detach()
283 |
284 | occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
285 | viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])
286 |
287 | ## Mean and variance of 2D features provide view-independent tokens
288 | var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
289 | var_mean = torch.cat(var_mean, dim=-1)
290 | var_mean = F.elu(self.var_mean_fc1(var_mean))
291 | var_mean = F.elu(self.var_mean_fc2(var_mean))
292 |
293 | ## Converting the input features to tokens (view-dependent) before self-attention
294 | tokens = F.elu(
295 | self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
296 | )
297 | tokens = torch.cat([tokens, var_mean], dim=1)
298 |
299 | # by adding middle_pts_mask, we can only take the predicted depth's points into account
300 |
301 | if self.using_semantic_global_tokens:
302 | semantic_var_mean = torch.var_mean(semantic_feat[middle_pts_mask.view(-1)], dim=1, unbiased=False, keepdim=True)
303 | semantic_var_mean = torch.cat(semantic_var_mean, dim=-1)
304 | semantic_var_mean = F.elu(self.semantic_var_mean_fc1(semantic_var_mean))
305 | semantic_var_mean = F.elu(self.semantic_var_mean_fc2(semantic_var_mean))
306 | # (N_rays, 1, views, feat_dim)
307 | semantic_tokens = F.elu(
308 | self.semantic_token_gen(torch.cat([semantic_feat[middle_pts_mask.view(-1)], vis_mask[middle_pts_mask.view(-1)]], dim=-1))
309 | )
310 |
311 | if self.using_semantic_global_tokens:
312 | semantic_tokens = torch.cat([semantic_tokens, semantic_var_mean], dim=1)
313 |
314 | ## Adding a new channel to mask for var_mean
315 | vis_mask = torch.cat(
316 | [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1).to(vis_mask.device)], dim=1
317 | )
318 | ## If a point is not visible by any source view, force its masks to enabled
319 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
320 |
321 | ## Taking occ_masks into account, but remembering if there were any visibility before that
322 | mask_cloned = vis_mask.clone()
323 | vis_mask[:, :-1] *= occ_masks
324 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
325 | masks = vis_mask * mask_cloned
326 |
327 | ## Performing self-attention
328 | for layer in self.attn_layers:
329 | tokens, _ = layer(tokens, masks)
330 |
331 | for semantic_layer in self.semantic_attn_layers:
332 | if self.using_semantic_global_tokens:
333 | # mask has shape (N_rays*N_points, nb_views+1, 1), because of the var_mean_mask, semantic not using that
334 | semantic_tokens, _ = semantic_layer(semantic_tokens, masks[middle_pts_mask.view(-1)])
335 | else:
336 | semantic_tokens, _ = semantic_layer(semantic_tokens, masks[middle_pts_mask.view(-1)][:, :-1])
337 |
338 | ## Predicting sigma with an Auto-Encoder and MLP
339 | sigma_tokens = tokens[:, -1:]
340 | sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
341 | sigma_tokens = self.auto_enc(sigma_tokens)
342 | sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)
343 |
344 | sigma_tokens_ = F.elu(self.sigma_fc1(sigma_tokens))
345 | sigma_tokens_ = F.elu(self.sigma_fc2(sigma_tokens_))
346 | # sigma shape (N_rays*N_points, 1)
347 | sigma = torch.relu(self.sigma_fc3(sigma_tokens_[:, 0]))
348 |
349 | if self.using_semantic_global_tokens and self.only_using_semantic_global_tokens:
350 | semantic_global_tokens = semantic_tokens[:, -1:]
351 | elif self.using_semantic_global_tokens:
352 | semantic_global_tokens = semantic_tokens.reshape(-1, self.semantic_dim)
353 | else:
354 | semantic_global_tokens = semantic_tokens.reshape(-1, self.semantic_dim)
355 | semantic_tokens_ = F.elu(self.semantic_fc1(semantic_global_tokens))
356 | semantic_tokens_ = F.elu(self.semantic_fc2(semantic_tokens_))
357 | semantic_tokens_ = torch.relu(self.semantic_fc3(semantic_tokens_))
358 |
359 | semantic = semantic_tokens_.reshape(N, -1).unsqueeze(1)
360 |
361 | ## Concatenating positional encodings and predicting RGB weights
362 | rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
363 | rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
364 | rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
365 | rgb_w = self.rgb_fc3(rgb_tokens)
366 | rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)
367 |
368 | rgb = (colors * rgb_w).sum(1)
369 |
370 | outputs = torch.cat([rgb, sigma], -1)
371 | outputs = outputs.reshape(N, S, -1)
372 |
373 | return outputs, semantic
374 |
375 |
376 | class Semantic_predictor(nn.Module):
377 | def __init__(self, nb_view=6, nb_class=0):
378 | super(Semantic_predictor, self).__init__()
379 | self.nb_class = nb_class
380 | self.dim = 32
381 | # self.attn_token_gen = nn.Linear(24 + 1 + self.nb_class, self.dim)
382 | self.attn_token_gen = nn.Linear(1 + self.nb_class, self.dim)
383 | self.semantic_dim = self.dim * nb_view
384 |
385 | # Self-Attention Settings, This attention is cross-view attention for a point, which represent a pixel in target view
386 | d_inner = self.dim
387 | n_head = 4
388 | d_k = self.dim // n_head
389 | d_v = self.dim // n_head
390 | num_layers = 4
391 | self.attn_layers = nn.ModuleList(
392 | [
393 | EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
394 | for i in range(num_layers)
395 | ]
396 | )
397 | self.semantic_fc1 = nn.Linear(self.semantic_dim, self.semantic_dim)
398 | self.semantic_fc2 = nn.Linear(self.semantic_dim, self.semantic_dim // 2)
399 | self.semantic_fc3 = nn.Linear(self.semantic_dim // 2, nb_class)
400 |
401 | def forward(self, feat, occ_masks):
402 | if feat.dim() == 3:
403 | feat = feat.unsqueeze(1)
404 | if occ_masks.dim() == 3:
405 | occ_masks = occ_masks.unsqueeze(1)
406 | N, S, V, C = feat.shape # (num_rays, num_samples, num_views, feat_dim), S should be 1 here
407 |
408 | feat = feat.view(-1, *feat.shape[2:]) # (num_rays * num_samples, num_views, feat_dim)
409 | v_feat = feat[..., :24]
410 | s_feat = feat[..., 24 : 24 + self.nb_class]
411 | colors = feat[..., 24 + self.nb_class : -1]
412 | vis_mask = feat[..., -1:].detach()
413 |
414 | occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
415 |
416 | tokens = F.elu(
417 | # self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
418 | self.attn_token_gen(torch.cat([vis_mask, s_feat], dim=-1))
419 | )
420 |
421 | ## If a point is not visible by any source view, force its masks to enabled
422 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
423 |
424 | ## Taking occ_masks into account, but remembering if there were any visibility before that
425 | mask_cloned = vis_mask.clone()
426 | vis_mask *= occ_masks
427 | vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
428 | masks = vis_mask * mask_cloned
429 |
430 | ## Performing self-attention on source view features,
431 | for layer in self.attn_layers:
432 | tokens, _ = layer(tokens, masks)
433 | # tokens, _ = layer(tokens, vis_mask)
434 |
435 | ## Predicting semantic with MLP
436 | ## tokens shape: (N*S, V, dim), S = 1
437 | tokens = tokens.reshape(N, V*self.dim)
438 | semantic_tokens_ = F.elu(self.semantic_fc1(tokens))
439 | semantic_tokens_ = F.elu(self.semantic_fc2(semantic_tokens_))
440 | semantic_tokens_ = torch.relu(self.semantic_fc3(semantic_tokens_))
441 |
442 | semantic = semantic_tokens_.reshape(N, S, -1)
443 |
444 | return semantic
--------------------------------------------------------------------------------
/utils/depth_loss.py:
--------------------------------------------------------------------------------
1 | # reference: RC-MVSNet
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def inverse_warping(img, left_cam, right_cam, depth):
9 | # img: [batch_size, height, width, channels]
10 |
11 | # cameras (K, R, t)
12 | # print('left_cam: {}'.format(left_cam.shape))
13 | R_left = left_cam[:, 0:1, 0:3, 0:3] # [B, 1, 3, 3]
14 | R_right = right_cam[:, 0:1, 0:3, 0:3] # [B, 1, 3, 3]
15 | t_left = left_cam[:, 0:1, 0:3, 3:4] # [B, 1, 3, 1]
16 | t_right = right_cam[:, 0:1, 0:3, 3:4] # [B, 1, 3, 1]
17 | K_left = left_cam[:, 1:2, 0:3, 0:3] # [B, 1, 3, 3]
18 | K_right = right_cam[:, 1:2, 0:3, 0:3] # [B, 1, 3, 3]
19 |
20 | K_left = K_left.squeeze(1) # [B, 3, 3]
21 | K_left_inv = torch.inverse(K_left) # [B, 3, 3]
22 | R_left_trans = R_left.squeeze(1).permute(0, 2, 1) # [B, 3, 3]
23 | R_right_trans = R_right.squeeze(1).permute(0, 2, 1) # [B, 3, 3]
24 |
25 | R_left = R_left.squeeze(1)
26 | t_left = t_left.squeeze(1)
27 | R_right = R_right.squeeze(1)
28 | t_right = t_right.squeeze(1)
29 |
30 | ## estimate egomotion by inverse composing R1,R2 and t1,t2
31 | R_rel = torch.matmul(R_right, R_left_trans) # [B, 3, 3]
32 | t_rel = t_right - torch.matmul(R_rel, t_left) # [B, 3, 1]
33 | ## now convert R and t to transform mat, as in SFMlearner
34 | batch_size = R_left.shape[0]
35 | # filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).to(device).reshape(1, 1, 4) # [1, 1, 4]
36 | filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).cuda().reshape(1, 1, 4) # [1, 1, 4]
37 | filler = filler.repeat(batch_size, 1, 1) # [B, 1, 4]
38 | transform_mat = torch.cat([R_rel, t_rel], dim=2) # [B, 3, 4]
39 | transform_mat = torch.cat([transform_mat.float(), filler.float()], dim=1) # [B, 4, 4]
40 | # print(img.shape)
41 | batch_size, img_height, img_width, _ = img.shape
42 | # print(depth.shape)
43 | # print('depth: {}'.format(depth.shape))
44 | depth = depth.reshape(batch_size, 1, img_height * img_width) # [batch_size, 1, height * width]
45 |
46 | grid = _meshgrid_abs(img_height, img_width) # [3, height * width]
47 | grid = grid.unsqueeze(0).repeat(batch_size, 1, 1) # [batch_size, 3, height * width]
48 | cam_coords = _pixel2cam(depth, grid, K_left_inv) # [batch_size, 3, height * width]
49 | # ones = torch.ones([batch_size, 1, img_height * img_width], device=device) # [batch_size, 1, height * width]
50 | ones = torch.ones([batch_size, 1, img_height * img_width]).cuda() # [batch_size, 1, height * width]
51 | cam_coords_hom = torch.cat([cam_coords, ones], dim=1) # [batch_size, 4, height * width]
52 |
53 | # Get projection matrix for target camera frame to source pixel frame
54 | # hom_filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).to(device).reshape(1, 1, 4) # [1, 1, 4]
55 | hom_filler = torch.Tensor([0.0, 0.0, 0.0, 1.0]).cuda().reshape(1, 1, 4) # [1, 1, 4]
56 | hom_filler = hom_filler.repeat(batch_size, 1, 1) # [B, 1, 4]
57 | intrinsic_mat_hom = torch.cat([K_left.float(), torch.zeros([batch_size, 3, 1]).cuda()], dim=2) # [B, 3, 4]
58 | intrinsic_mat_hom = torch.cat([intrinsic_mat_hom, hom_filler], dim=1) # [B, 4, 4]
59 | proj_target_cam_to_source_pixel = torch.matmul(intrinsic_mat_hom, transform_mat) # [B, 4, 4]
60 | source_pixel_coords = _cam2pixel(cam_coords_hom, proj_target_cam_to_source_pixel) # [batch_size, 2, height * width]
61 | source_pixel_coords = source_pixel_coords.reshape(batch_size, 2, img_height, img_width) # [batch_size, 2, height, width]
62 | source_pixel_coords = source_pixel_coords.permute(0, 2, 3, 1) # [batch_size, height, width, 2]
63 | warped_right, mask = _spatial_transformer(img, source_pixel_coords)
64 | return warped_right, mask
65 |
66 |
67 | def _meshgrid_abs(height, width):
68 | """Meshgrid in the absolute coordinates."""
69 | x_t = torch.matmul(
70 | torch.ones([height, 1]),
71 | torch.linspace(-1.0, 1.0, width).unsqueeze(1).permute(1, 0)
72 | ) # [height, width]
73 | y_t = torch.matmul(
74 | torch.linspace(-1.0, 1.0, height).unsqueeze(1),
75 | torch.ones([1, width])
76 | )
77 | x_t = (x_t + 1.0) * 0.5 * (width - 1)
78 | y_t = (y_t + 1.0) * 0.5 * (height - 1)
79 | x_t_flat = x_t.reshape(1, -1)
80 | y_t_flat = y_t.reshape(1, -1)
81 | ones = torch.ones_like(x_t_flat)
82 | grid = torch.cat([x_t_flat, y_t_flat, ones], dim=0) # [3, height * width]
83 | # return grid.to(device)
84 | return grid.cuda()
85 |
86 |
87 | def _pixel2cam(depth, pixel_coords, intrinsic_mat_inv):
88 | """Transform coordinates in the pixel frame to the camera frame."""
89 | cam_coords = torch.matmul(intrinsic_mat_inv.float(), pixel_coords.float()) * depth.float()
90 | return cam_coords
91 |
92 |
93 | def _cam2pixel(cam_coords, proj_c2p):
94 | """Transform coordinates in the camera frame to the pixel frame."""
95 | pcoords = torch.matmul(proj_c2p, cam_coords) # [batch_size, 4, height * width]
96 | x = pcoords[:, 0:1, :] # [batch_size, 1, height * width]
97 | y = pcoords[:, 1:2, :] # [batch_size, 1, height * width]
98 | z = pcoords[:, 2:3, :] # [batch_size, 1, height * width]
99 | x_norm = x / (z + 1e-10)
100 | y_norm = y / (z + 1e-10)
101 | pixel_coords = torch.cat([x_norm, y_norm], dim=1)
102 | return pixel_coords # [batch_size, 2, height * width]
103 |
104 |
105 | def _spatial_transformer(img, coords):
106 | """A wrapper over binlinear_sampler(), taking absolute coords as input."""
107 | # img: [B, H, W, C]
108 | img_height = img.shape[1]
109 | img_width = img.shape[2]
110 | px = coords[:, :, :, :1] # [batch_size, height, width, 1]
111 | py = coords[:, :, :, 1:] # [batch_size, height, width, 1]
112 | # Normalize coordinates to [-1, 1] to send to _bilinear_sampler.
113 | px = px / (img_width - 1) * 2.0 - 1.0 # [batch_size, height, width, 1]
114 | py = py / (img_height - 1) * 2.0 - 1.0 # [batch_size, height, width, 1]
115 | output_img, mask = _bilinear_sample(img, px, py)
116 | return output_img, mask
117 |
118 |
119 | def _bilinear_sample(im, x, y, name='bilinear_sampler'):
120 | """Perform bilinear sampling on im given list of x, y coordinates.
121 | Implements the differentiable sampling mechanism with bilinear kernel
122 | in https://arxiv.org/abs/1506.02025.
123 | x,y are tensors specifying normalized coordinates [-1, 1] to be sampled on im.
124 | For example, (-1, -1) in (x, y) corresponds to pixel location (0, 0) in im,
125 | and (1, 1) in (x, y) corresponds to the bottom right pixel in im.
126 | Args:
127 | im: Batch of images with shape [B, h, w, channels].
128 | x: Tensor of normalized x coordinates in [-1, 1], with shape [B, h, w, 1].
129 | y: Tensor of normalized y coordinates in [-1, 1], with shape [B, h, w, 1].
130 | name: Name scope for ops.
131 | Returns:
132 | Sampled image with shape [B, h, w, channels].
133 | Principled mask with shape [B, h, w, 1], dtype:float32. A value of 1.0
134 | in the mask indicates that the corresponding coordinate in the sampled
135 | image is valid.
136 | """
137 | x = x.reshape(-1) # [batch_size * height * width]
138 | y = y.reshape(-1) # [batch_size * height * width]
139 |
140 | # Constants.
141 | batch_size, height, width, channels = im.shape
142 |
143 | x, y = x.float(), y.float()
144 | max_y = int(height - 1)
145 | max_x = int(width - 1)
146 |
147 | # Scale indices from [-1, 1] to [0, width - 1] or [0, height - 1].
148 | x = (x + 1.0) * (width - 1.0) / 2.0
149 | y = (y + 1.0) * (height - 1.0) / 2.0
150 |
151 | # Compute the coordinates of the 4 pixels to sample from.
152 | x0 = torch.floor(x).int()
153 | x1 = x0 + 1
154 | y0 = torch.floor(y).int()
155 | y1 = y0 + 1
156 |
157 | mask = (x0 >= 0) & (x1 <= max_x) & (y0 >= 0) & (y0 <= max_y)
158 | mask = mask.float()
159 |
160 | x0 = torch.clamp(x0, 0, max_x)
161 | x1 = torch.clamp(x1, 0, max_x)
162 | y0 = torch.clamp(y0, 0, max_y)
163 | y1 = torch.clamp(y1, 0, max_y)
164 | dim2 = width
165 | dim1 = width * height
166 |
167 | # Create base index.
168 | base = torch.arange(batch_size) * dim1
169 | base = base.reshape(-1, 1)
170 | base = base.repeat(1, height * width)
171 | base = base.reshape(-1) # [batch_size * height * width]
172 | # base = base.long().to(device)
173 | base = base.long().cuda()
174 |
175 | base_y0 = base + y0.long() * dim2
176 | base_y1 = base + y1.long() * dim2
177 | idx_a = base_y0 + x0.long()
178 | idx_b = base_y1 + x0.long()
179 | idx_c = base_y0 + x1.long()
180 | idx_d = base_y1 + x1.long()
181 |
182 | # Use indices to lookup pixels in the flat image and restore channels dim.
183 | im_flat = im.reshape(-1, channels).float() # [batch_size * height * width, channels]
184 | # pixel_a = tf.gather(im_flat, idx_a)
185 | # pixel_b = tf.gather(im_flat, idx_b)
186 | # pixel_c = tf.gather(im_flat, idx_c)
187 | # pixel_d = tf.gather(im_flat, idx_d)
188 | pixel_a = im_flat[idx_a]
189 | pixel_b = im_flat[idx_b]
190 | pixel_c = im_flat[idx_c]
191 | pixel_d = im_flat[idx_d]
192 |
193 | wa = (x1.float() - x) * (y1.float() - y)
194 | wb = (x1.float() - x) * (1.0 - (y1.float() - y))
195 | wc = (1.0 - (x1.float() - x)) * (y1.float() - y)
196 | wd = (1.0 - (x1.float() - x)) * (1.0 - (y1.float() - y))
197 | wa, wb, wc, wd = wa.unsqueeze(1), wb.unsqueeze(1), wc.unsqueeze(1), wd.unsqueeze(1)
198 |
199 | output = wa * pixel_a + wb * pixel_b + wc * pixel_c + wd * pixel_d
200 | output = output.reshape(batch_size, height, width, channels)
201 | mask = mask.reshape(batch_size, height, width, 1)
202 | return output, mask
203 |
204 |
205 | class SSIM(nn.Module):
206 | """Layer to compute the SSIM loss between a pair of images
207 | """
208 | def __init__(self):
209 | super(SSIM, self).__init__()
210 | self.mu_x_pool = nn.AvgPool2d(3, 1)
211 | self.mu_y_pool = nn.AvgPool2d(3, 1)
212 | self.sig_x_pool = nn.AvgPool2d(3, 1)
213 | self.sig_y_pool = nn.AvgPool2d(3, 1)
214 | self.sig_xy_pool = nn.AvgPool2d(3, 1)
215 | self.mask_pool = nn.AvgPool2d(3, 1)
216 | # self.refl = nn.ReflectionPad2d(1)
217 |
218 | self.C1 = 0.01 ** 2
219 | self.C2 = 0.03 ** 2
220 |
221 | def forward(self, x, y, mask):
222 | # print('mask: {}'.format(mask.shape))
223 | # print('x: {}'.format(x.shape))
224 | # print('y: {}'.format(y.shape))
225 | x = x.permute(0, 3, 1, 2) # [B, H, W, C] --> [B, C, H, W]
226 | y = y.permute(0, 3, 1, 2)
227 | mask = mask.permute(0, 3, 1, 2)
228 |
229 | # x = self.refl(x)
230 | # y = self.refl(y)
231 | mu_x = self.mu_x_pool(x)
232 | mu_y = self.mu_y_pool(y)
233 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
234 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
235 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
236 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
237 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
238 | SSIM_mask = self.mask_pool(mask)
239 | output = SSIM_mask * torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
240 | return output.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C]
241 |
242 |
243 | def gradient_x(img):
244 | return img[:, :, :-1, :] - img[:, :, 1:, :]
245 |
246 | def gradient_y(img):
247 | return img[:, :-1, :, :] - img[:, 1:, :, :]
248 |
249 | def gradient(pred):
250 | D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :]
251 | D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :]
252 | return D_dx, D_dy
253 |
254 |
255 | def depth_smoothness(depth, img,lambda_wt=1):
256 | """Computes image-aware depth smoothness loss."""
257 | # print('depth: {} img: {}'.format(depth.shape, img.shape))
258 | depth_dx = gradient_x(depth)
259 | depth_dy = gradient_y(depth)
260 | image_dx = gradient_x(img)
261 | image_dy = gradient_y(img)
262 | weights_x = torch.exp(-(lambda_wt * torch.mean(torch.abs(image_dx), 3, keepdim=True)))
263 | weights_y = torch.exp(-(lambda_wt * torch.mean(torch.abs(image_dy), 3, keepdim=True)))
264 | # print('depth_dx: {} weights_x: {}'.format(depth_dx.shape, weights_x.shape))
265 | # print('depth_dy: {} weights_y: {}'.format(depth_dy.shape, weights_y.shape))
266 | smoothness_x = depth_dx * weights_x
267 | smoothness_y = depth_dy * weights_y
268 | return torch.mean(torch.abs(smoothness_x)) + torch.mean(torch.abs(smoothness_y))
269 |
270 |
271 | def compute_reconstr_loss(warped, ref, mask, simple=True):
272 | if simple:
273 | return F.smooth_l1_loss(warped*mask, ref*mask, reduction='mean')
274 | else:
275 | alpha = 0.5
276 | ref_dx, ref_dy = gradient(ref * mask)
277 | warped_dx, warped_dy = gradient(warped * mask)
278 | photo_loss = F.smooth_l1_loss(warped*mask, ref*mask, reduction='mean')
279 | grad_loss = F.smooth_l1_loss(warped_dx, ref_dx, reduction='mean') + \
280 | F.smooth_l1_loss(warped_dy, ref_dy, reduction='mean')
281 | return (1 - alpha) * photo_loss + alpha * grad_loss
282 |
283 | class UnSupLoss(nn.Module):
284 | def __init__(self):
285 | super(UnSupLoss, self).__init__()
286 | self.ssim = SSIM()
287 |
288 | def forward(self, imgs, cams, depth, stage_idx):
289 | # print('imgs: {}'.format(imgs.shape))
290 | # print('cams: {}'.format(cams.shape))
291 | # print('depth: {}'.format(depth.shape))
292 |
293 | imgs = torch.unbind(imgs, 1)
294 | cams = torch.unbind(cams, 1)
295 | assert len(imgs) == len(cams), "Different number of images and projection matrices"
296 | img_height, img_width = imgs[0].shape[2], imgs[0].shape[3]
297 | num_views = len(imgs)
298 |
299 | ref_img = imgs[0]
300 |
301 | if stage_idx == 2:
302 | ref_img = F.interpolate(ref_img, scale_factor=0.25,recompute_scale_factor=True)
303 | elif stage_idx == 1:
304 | ref_img = F.interpolate(ref_img, scale_factor=0.5,recompute_scale_factor=True)
305 | else:
306 | pass
307 | ref_img = ref_img.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C]
308 | ref_cam = cams[0]
309 | # print('ref_cam: {}'.format(ref_cam.shape))
310 |
311 | # depth reshape
312 | # depth = depth.unsqueeze(dim=1) # [B, 1, H, W]
313 | # depth = F.interpolate(depth, size=[img_height, img_width])
314 | # depth = depth.squeeze(dim=1) # [B, H, W]
315 |
316 | self.reconstr_loss = 0
317 | self.ssim_loss = 0
318 | self.smooth_loss = 0
319 |
320 | warped_img_list = []
321 | mask_list = []
322 | reprojection_losses = []
323 | for view in range(1, num_views):
324 | view_img = imgs[view]
325 | view_cam = cams[view]
326 | # print('view_cam: {}'.format(view_cam.shape))
327 | # view_img = F.interpolate(view_img, scale_factor=0.25, mode='bilinear')
328 | if stage_idx == 2:
329 | view_img = F.interpolate(view_img, scale_factor=0.25,recompute_scale_factor=True)
330 | elif stage_idx == 1:
331 | view_img = F.interpolate(view_img, scale_factor=0.5,recompute_scale_factor=True)
332 | else:
333 | pass
334 | view_img = view_img.permute(0, 2, 3, 1) # [B, C, H, W] --> [B, H, W, C]
335 | # warp view_img to the ref_img using the dmap of the ref_img
336 | warped_img, mask = inverse_warping(view_img, ref_cam, view_cam, depth)
337 | warped_img_list.append(warped_img)
338 | mask_list.append(mask)
339 |
340 | reconstr_loss = compute_reconstr_loss(warped_img, ref_img, mask, simple=False)
341 | valid_mask = 1 - mask # replace all 0 values with INF
342 | reprojection_losses.append(reconstr_loss + 1e4 * valid_mask)
343 |
344 | # SSIM loss##
345 | if view < 3:
346 | self.ssim_loss += torch.mean(self.ssim(ref_img, warped_img, mask))
347 |
348 | ##smooth loss##
349 | self.smooth_loss += depth_smoothness(depth.unsqueeze(dim=-1), ref_img, 1.0)
350 |
351 | # top-k operates along the last dimension, so swap the axes accordingly
352 | reprojection_volume = torch.stack(reprojection_losses).permute(1, 2, 3, 4, 0)
353 | # print('reprojection_volume: {}'.format(reprojection_volume.shape))
354 | # by default, it'll return top-k largest entries, hence sorted=False to get smallest entries
355 | # top_vals, top_inds = torch.topk(torch.neg(reprojection_volume), k=3, sorted=False)
356 | top_vals, top_inds = torch.topk(torch.neg(reprojection_volume), k=1, sorted=False)
357 | top_vals = torch.neg(top_vals)
358 | # top_mask = top_vals < (1e4 * torch.ones_like(top_vals, device=device))
359 | top_mask = top_vals < (1e4 * torch.ones_like(top_vals).cuda())
360 | top_mask = top_mask.float()
361 | top_vals = torch.mul(top_vals, top_mask)
362 | # print('top_vals: {}'.format(top_vals.shape))
363 |
364 | self.reconstr_loss = torch.mean(torch.sum(top_vals, dim=-1))
365 | self.unsup_loss = 12 * self.reconstr_loss + 6 * self.ssim_loss + 0.18 * self.smooth_loss
366 | # 按照un_mvsnet和M3VSNet的设置
367 | # self.unsup_loss = (0.8 * self.reconstr_loss + 0.2 * self.ssim_loss + 0.067 * self.smooth_loss) * 15
368 | return self.unsup_loss
369 |
370 | class UnsupLossMultiStage(nn.Module):
371 | def __init__(self):
372 | super(UnsupLossMultiStage, self).__init__()
373 | self.unsup_loss = UnSupLoss()
374 |
375 | def forward(self, inputs, imgs, cams, depth_loss_weights=[2.0,1.0,0.5]):
376 | '''
377 | inputs: dict {"level_0": (1, 8, 240, 320), "level_1": (1,, 8, 120, 160), "level_2": (1, 8, 60, 80))}
378 | imgs: (1, 8, 3, 480, 640)
379 | cams: (1, 8, 3, 2, 4, 4)
380 | '''
381 | # depth_loss_weights = kwargs.get("dlossw", None)
382 |
383 | total_loss = torch.tensor(0.0, dtype=torch.float32, device=imgs.device, requires_grad=False)
384 |
385 | scalar_outputs = {}
386 | for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "level_" in k]:
387 | stage_idx = int(stage_key.replace("level_", "")) # (0,1,2; 0 is biggest; 2 is smallest)
388 |
389 | depth_est = stage_inputs
390 | depth_loss = 0
391 | for i in range(depth_est.shape[1]):
392 | depth_loss += self.unsup_loss(imgs.roll(-i,1), cams.roll(-i,1)[:,:,stage_idx], depth_est[:, i], stage_idx)
393 |
394 |
395 | if depth_loss_weights is not None:
396 | total_loss += depth_loss_weights[stage_idx] * depth_loss
397 | else:
398 | total_loss += 1.0 * depth_loss
399 |
400 | scalar_outputs["depth_loss_stage{}".format(stage_idx + 1)] = depth_loss
401 | scalar_outputs["reconstr_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.reconstr_loss
402 | scalar_outputs["ssim_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.ssim_loss
403 | scalar_outputs["smooth_loss_stage{}".format(stage_idx + 1)] = self.unsup_loss.smooth_loss
404 |
405 | return total_loss, scalar_outputs
--------------------------------------------------------------------------------
/utils/scannet_utils.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import glob
3 | import os
4 | import torch
5 | import cv2
6 | import numpy as np
7 | import pandas as pd
8 | from PIL import Image
9 | from skimage.io import imread
10 | from natsort import natsorted
11 |
12 | from utils.utils import downsample_gaussian_blur, pose_inverse
13 |
14 |
15 | # From https://github.com/open-mmlab/mmdetection3d/blob/fcb4545ce719ac121348cab59bac9b69dd1b1b59/mmdet3d/datasets/scannet_dataset.py
16 | class PointSegClassMapping(object):
17 | """Map original semantic class to valid category ids.
18 | Map valid classes as 0~len(valid_cat_ids)-1 and
19 | others as len(valid_cat_ids).
20 | Args:
21 | valid_cat_ids (tuple[int]): A tuple of valid category.
22 | max_cat_id (int, optional): The max possible cat_id in input
23 | segmentation mask. Defaults to 40.
24 | """
25 |
26 | def __init__(self, valid_cat_ids, max_cat_id=40):
27 | assert max_cat_id >= np.max(valid_cat_ids), \
28 | 'max_cat_id should be greater than maximum id in valid_cat_ids'
29 |
30 | self.valid_cat_ids = valid_cat_ids
31 | self.max_cat_id = int(max_cat_id)
32 |
33 | # build cat_id to class index mapping
34 | neg_cls = len(valid_cat_ids)
35 | self.cat_id2class = np.ones(
36 | self.max_cat_id + 1, dtype=int) * neg_cls
37 | for cls_idx, cat_id in enumerate(valid_cat_ids):
38 | self.cat_id2class[cat_id] = cls_idx
39 | for i in range(self.cat_id2class.shape[0]):
40 | value = self.cat_id2class[i]
41 | if value == 19:
42 | self.cat_id2class[i] = 6
43 | elif value == 20:
44 | self.cat_id2class[i] = 19
45 |
46 | def __call__(self, seg_label):
47 | """Call function to map original semantic class to valid category ids.
48 | Args:
49 | results (dict): Result dict containing point semantic masks.
50 | Returns:
51 | dict: The result dict containing the mapped category ids.
52 | Updated key and value are described below.
53 | - pts_semantic_mask (np.ndarray): Mapped semantic masks.
54 | """
55 | seg_label = np.clip(seg_label, 0, self.max_cat_id)
56 | return self.cat_id2class[seg_label]
57 |
58 |
59 | class BaseDatabase(abc.ABC):
60 | def __init__(self, database_name):
61 | self.database_name = database_name
62 |
63 | @abc.abstractmethod
64 | def get_image(self, img_id):
65 | pass
66 |
67 | @abc.abstractmethod
68 | def get_K(self, img_id):
69 | pass
70 |
71 | @abc.abstractmethod
72 | def get_pose(self, img_id):
73 | pass
74 |
75 | @abc.abstractmethod
76 | def get_img_ids(self, check_depth_exist=False):
77 | pass
78 |
79 | @abc.abstractmethod
80 | def get_bbox(self, img_id):
81 | pass
82 |
83 | @abc.abstractmethod
84 | def get_depth(self, img_id):
85 | pass
86 |
87 | @abc.abstractmethod
88 | def get_mask(self, img_id):
89 | pass
90 |
91 | @abc.abstractmethod
92 | def get_depth_range(self, img_id):
93 | pass
94 |
95 |
96 |
97 | class ScannetDatabase(BaseDatabase):
98 | def __init__(self, database_name, root_dir='data/scannet'):
99 | super().__init__(database_name)
100 | _, self.scene_name, background_size = database_name.split('/')
101 | background, image_size = background_size.split('_')
102 | image_size = int(image_size)
103 | self.image_size = image_size
104 | self.background = background
105 | self.root_dir = f'{root_dir}/{self.scene_name}'
106 | self.ratio = image_size / 1296
107 | self.h, self.w = int(self.ratio*972), int(image_size)
108 |
109 | rgb_paths = [x for x in glob.glob(os.path.join(
110 | self.root_dir, "color", "*")) if (x.endswith(".jpg") or x.endswith(".png"))]
111 | rgb_paths = sorted(rgb_paths)
112 |
113 | K = np.loadtxt(
114 | f'{self.root_dir}/intrinsic/intrinsic_color.txt').reshape([4, 4])[:3, :3]
115 | # After resize, we need to change the intrinsic matrix
116 | K[:2, :] *= self.ratio
117 | self.K = K
118 |
119 | self.img_ids = []
120 | for i, rgb_path in enumerate(rgb_paths):
121 | pose = self.get_pose(i)
122 | if np.isinf(pose).any() or np.isnan(pose).any():
123 | continue
124 | self.img_ids.append(f'{i}')
125 |
126 | self.img_id2imgs = {}
127 | # mapping from scanntet class id to nyu40 class id
128 | # mapping_file = 'data/scannet/scannetv2-labels.combined.tsv'
129 | mapping_file = os.path.join(root_dir, 'scannetv2-labels.combined.tsv')
130 | mapping_file = pd.read_csv(mapping_file, sep='\t', header=0)
131 | scan_ids = mapping_file['id'].values
132 | nyu40_ids = mapping_file['nyu40id'].values
133 | scan2nyu = np.zeros(max(scan_ids) + 1, dtype=np.int32)
134 | for i in range(len(scan_ids)):
135 | scan2nyu[scan_ids[i]] = nyu40_ids[i]
136 | self.scan2nyu = scan2nyu
137 | self.label_mapping = PointSegClassMapping(
138 | valid_cat_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
139 | 11, 12, 14, 16, 24, 28, 33, 34, 36, 39],
140 | max_cat_id=40
141 | )
142 |
143 | def get_image(self, img_id):
144 | if img_id in self.img_id2imgs:
145 | return self.img_id2imgs[img_id]
146 | img = imread(os.path.join(
147 | self.root_dir, 'color', f'{int(img_id)}.jpg'))
148 | if self.w != 1296:
149 | img = cv2.resize(downsample_gaussian_blur(
150 | img, self.ratio), (self.w, self.h), interpolation=cv2.INTER_LINEAR)
151 |
152 | return img
153 |
154 | def get_K(self, img_id):
155 | return self.K.astype(np.float32)
156 |
157 | def get_pose(self, img_id):
158 | transf = np.diag(np.asarray([1, -1, -1, 1]))
159 | pose = np.loadtxt(
160 | f'{self.root_dir}/pose/{int(img_id)}.txt').reshape([4, 4])
161 | # pose = transf @ pose
162 | # c2w in files, change to w2c
163 | # pose = pose_inverse(pose)
164 | return pose.copy()
165 |
166 | def get_img_ids(self, check_depth_exist=False):
167 | return self.img_ids
168 |
169 | def get_bbox(self, img_id):
170 | raise NotImplementedError
171 |
172 | def get_depth(self, img_id):
173 | h, w, _ = self.get_image(img_id).shape
174 | img = Image.open(f'{self.root_dir}/depth/{int(img_id)}.png')
175 | depth = np.asarray(img, dtype=np.float32) / 1000.0 # mm -> m
176 | # depth = np.asarray(img, dtype=np.float32)
177 | depth = np.ascontiguousarray(depth, dtype=np.float32)
178 | depth = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST)
179 | return depth
180 |
181 | def get_mask(self, img_id):
182 | h, w, _ = self.get_image(img_id).shape
183 | return np.ones([h, w], bool)
184 |
185 | def get_depth_range(self, img_id):
186 | return np.asarray((0.1, 10.0), np.float32)
187 |
188 | def get_label(self, img_id):
189 | h, w, _ = self.get_image(img_id).shape
190 | img = Image.open(f'{self.root_dir}/label-filt/{int(img_id)}.png')
191 | label = np.asarray(img, dtype=np.int32)
192 | label = np.ascontiguousarray(label)
193 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
194 | label = label.astype(np.int32)
195 | label = self.scan2nyu[label]
196 | return self.label_mapping(label)
197 |
198 | def parse_database_name(database_name, root_dir) -> BaseDatabase:
199 | name2database = {
200 | 'scannet': ScannetDatabase,
201 | 'replica': ReplicaDatabase,
202 | }
203 | database_type = database_name.split('/')[0]
204 | if database_type in name2database:
205 | return name2database[database_type](database_name, root_dir)
206 | else:
207 | raise NotImplementedError
208 |
209 | def get_database_split(database: BaseDatabase, split_type='val'):
210 | database_name = database.database_name
211 | if split_type.startswith('val'):
212 | splits = split_type.split('_')
213 | depth_valid = not(len(splits) > 1 and splits[1] == 'all')
214 | if database_name.startswith('scannet'):
215 | img_ids = database.get_img_ids()
216 | train_ids = img_ids[:700:5]
217 | val_ids = img_ids[2:700:20]
218 | if len(val_ids) > 10:
219 | val_ids = val_ids[:10]
220 | elif database_name.startswith('replica'):
221 | img_ids = database.get_img_ids()
222 | train_ids = img_ids[:700:5]
223 | val_ids = img_ids[2:700:20]
224 | if len(val_ids) > 10:
225 | val_ids = val_ids[:10]
226 | else:
227 | raise NotImplementedError
228 | elif split_type.startswith('test'):
229 | splits = split_type.split('_')
230 | depth_valid = not(len(splits) > 1 and splits[1] == 'all')
231 | if database_name.startswith('scannet'):
232 | img_ids = database.get_img_ids()
233 | train_ids = img_ids[:700:5]
234 | val_ids = img_ids[2:700:20]
235 | if len(val_ids) > 10:
236 | val_ids = val_ids[:10]
237 | elif database_name.startswith('replica'):
238 | img_ids = database.get_img_ids()
239 | train_ids = img_ids[:700:5]
240 | val_ids = img_ids[2:700:20]
241 | if len(val_ids) > 10:
242 | val_ids = val_ids[:10]
243 | else:
244 | raise NotImplementedError
245 | elif split_type.startswith('video'):
246 | img_ids = database.get_img_ids()
247 | train_ids = img_ids[::2]
248 | val_ids = img_ids[25:-25:2]
249 | else:
250 | raise NotImplementedError
251 | print('train_ids:\n', train_ids)
252 | print('val_ids:\n', val_ids)
253 | return train_ids, val_ids
254 |
255 |
256 | def get_coords_mask(que_mask, train_ray_num, foreground_ratio):
257 | min_pos_num = int(train_ray_num * foreground_ratio)
258 | y0, x0 = np.nonzero(que_mask)
259 | y1, x1 = np.nonzero(~que_mask)
260 | xy0 = np.stack([x0, y0], 1).astype(np.float32)
261 | xy1 = np.stack([x1, y1], 1).astype(np.float32)
262 | idx = np.arange(xy0.shape[0])
263 | np.random.shuffle(idx)
264 | xy0 = xy0[idx]
265 | coords0 = xy0[:min_pos_num]
266 | # still remain pixels
267 | if min_pos_num < train_ray_num:
268 | xy1 = np.concatenate([xy1, xy0[min_pos_num:]], 0)
269 | idx = np.arange(xy1.shape[0])
270 | np.random.shuffle(idx)
271 | coords1 = xy1[idx[:(train_ray_num - min_pos_num)]]
272 | coords = np.concatenate([coords0, coords1], 0)
273 | else:
274 | coords = coords0
275 | return coords
276 |
277 |
278 | def color_map_forward(rgb):
279 | return rgb.astype(np.float32) / 255
280 |
281 |
282 | def pad_img_end(img, th, tw, padding_mode='edge', constant_values=0):
283 | h, w = img.shape[:2]
284 | hp = th - h
285 | wp = tw - w
286 | if hp != 0 or wp != 0:
287 | if padding_mode == 'constant':
288 | img = np.pad(img, ((0, hp), (0, wp), (0, 0)), padding_mode, constant_values=constant_values)
289 | else:
290 | img = np.pad(img, ((0, hp), (0, wp), (0, 0)), padding_mode)
291 | return img
292 |
293 | def random_crop(ref_imgs_info, que_imgs_info, target_size):
294 | imgs = ref_imgs_info['imgs']
295 | n, _, h, w = imgs.shape
296 | out_h, out_w = target_size[0], target_size[1]
297 | if out_w >= w or out_h >= h:
298 | return ref_imgs_info
299 |
300 | center_h = np.random.randint(low=out_h // 2 + 1, high=h - out_h // 2 - 1)
301 | center_w = np.random.randint(low=out_w // 2 + 1, high=w - out_w // 2 - 1)
302 |
303 | def crop(tensor):
304 | tensor = tensor[:, :, center_h - out_h // 2:center_h + out_h // 2,
305 | center_w - out_w // 2:center_w + out_w // 2]
306 | return tensor
307 |
308 | def crop_imgs_info(imgs_info):
309 | imgs_info['imgs'] = crop(imgs_info['imgs'])
310 | if 'depth' in imgs_info: imgs_info['depth'] = crop(imgs_info['depth'])
311 | if 'true_depth' in imgs_info: imgs_info['true_depth'] = crop(imgs_info['true_depth'])
312 | if 'masks' in imgs_info: imgs_info['masks'] = crop(imgs_info['masks'])
313 |
314 | Ks = imgs_info['Ks'] # n, 3, 3
315 | h_init = center_h - out_h // 2
316 | w_init = center_w - out_w // 2
317 | Ks[:,0,2]-=w_init
318 | Ks[:,1,2]-=h_init
319 | imgs_info['Ks']=Ks
320 | return imgs_info
321 |
322 | return crop_imgs_info(ref_imgs_info), crop_imgs_info(que_imgs_info)
323 |
324 | def random_flip(ref_imgs_info,que_imgs_info):
325 | def flip(tensor):
326 | tensor = np.flip(tensor.transpose([0, 2, 3, 1]), 2) # n,h,w,3
327 | tensor = np.ascontiguousarray(tensor.transpose([0, 3, 1, 2]))
328 | return tensor
329 |
330 | def flip_imgs_info(imgs_info):
331 | imgs_info['imgs'] = flip(imgs_info['imgs'])
332 | if 'depth' in imgs_info: imgs_info['depth'] = flip(imgs_info['depth'])
333 | if 'true_depth' in imgs_info: imgs_info['true_depth'] = flip(imgs_info['true_depth'])
334 | if 'masks' in imgs_info: imgs_info['masks'] = flip(imgs_info['masks'])
335 |
336 | Ks = imgs_info['Ks'] # n, 3, 3
337 | Ks[:, 0, :] *= -1
338 | w = imgs_info['imgs'].shape[-1]
339 | Ks[:, 0, 2] += w - 1
340 | imgs_info['Ks'] = Ks
341 | return imgs_info
342 |
343 | ref_imgs_info = flip_imgs_info(ref_imgs_info)
344 | que_imgs_info = flip_imgs_info(que_imgs_info)
345 | return ref_imgs_info, que_imgs_info
346 |
347 | def pad_imgs_info(ref_imgs_info,pad_interval):
348 | ref_imgs, ref_depths, ref_masks = ref_imgs_info['imgs'], ref_imgs_info['depth'], ref_imgs_info['masks']
349 | ref_depth_gt = ref_imgs_info['true_depth'] if 'true_depth' in ref_imgs_info else None
350 | rfn, _, h, w = ref_imgs.shape
351 | ph = (pad_interval - (h % pad_interval)) % pad_interval
352 | pw = (pad_interval - (w % pad_interval)) % pad_interval
353 | if ph != 0 or pw != 0:
354 | ref_imgs = np.pad(ref_imgs, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect')
355 | ref_depths = np.pad(ref_depths, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect')
356 | ref_masks = np.pad(ref_masks, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect')
357 | if ref_depth_gt is not None:
358 | ref_depth_gt = np.pad(ref_depth_gt, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect')
359 | ref_imgs_info['imgs'], ref_imgs_info['depth'], ref_imgs_info['masks'] = ref_imgs, ref_depths, ref_masks
360 | if ref_depth_gt is not None:
361 | ref_imgs_info['true_depth'] = ref_depth_gt
362 | return ref_imgs_info
363 |
364 | def build_imgs_info(database, ref_ids, pad_interval=-1, is_aligned=True, align_depth_range=False, has_depth=True, replace_none_depth=False, add_label=True, num_classes=0):
365 | if not is_aligned:
366 | assert has_depth
367 | rfn = len(ref_ids)
368 | ref_imgs, ref_labels, ref_masks, ref_depths, shapes = [], [], [], [], []
369 | for ref_id in ref_ids:
370 | img = database.get_image(ref_id)
371 | if add_label:
372 | label = database.get_label(ref_id)
373 | ref_labels.append(label)
374 | shapes.append([img.shape[0], img.shape[1]])
375 | ref_imgs.append(img)
376 | ref_masks.append(database.get_mask(ref_id))
377 | ref_depths.append(database.get_depth(ref_id))
378 |
379 | shapes = np.asarray(shapes)
380 | th, tw = np.max(shapes, 0)
381 | for rfi in range(rfn):
382 | ref_imgs[rfi] = pad_img_end(ref_imgs[rfi], th, tw, 'reflect')
383 | ref_labels[rfi] = pad_img_end(ref_labels[rfi], th, tw, 'reflect')
384 | ref_masks[rfi] = pad_img_end(ref_masks[rfi][:, :, None], th, tw, 'constant', 0)[..., 0]
385 | ref_depths[rfi] = pad_img_end(ref_depths[rfi][:, :, None], th, tw, 'constant', 0)[..., 0]
386 | ref_imgs = color_map_forward(np.stack(ref_imgs, 0)).transpose([0, 3, 1, 2])
387 | ref_labels = np.stack(ref_labels, 0).transpose([0, 3, 1, 2])
388 | ref_masks = np.stack(ref_masks, 0)[:, None, :, :]
389 | ref_depths = np.stack(ref_depths, 0)[:, None, :, :]
390 | else:
391 | ref_imgs = color_map_forward(np.asarray([database.get_image(ref_id) for ref_id in ref_ids])).transpose([0, 3, 1, 2])
392 | ref_labels = np.asarray([database.get_label(ref_id) for ref_id in ref_ids])[:, None, :, :]
393 | ref_masks = np.asarray([database.get_mask(ref_id) for ref_id in ref_ids], dtype=np.float32)[:, None, :, :]
394 | if has_depth:
395 | ref_depths = [database.get_depth(ref_id) for ref_id in ref_ids]
396 | if replace_none_depth:
397 | b, _, h, w = ref_imgs.shape
398 | for i, depth in enumerate(ref_depths):
399 | if depth is None: ref_depths[i] = np.zeros([h, w], dtype=np.float32)
400 | ref_depths = np.asarray(ref_depths, dtype=np.float32)[:, None, :, :]
401 | else: ref_depths = None
402 |
403 | ref_poses = np.asarray([database.get_pose(ref_id) for ref_id in ref_ids], dtype=np.float32)
404 | ref_Ks = np.asarray([database.get_K(ref_id) for ref_id in ref_ids], dtype=np.float32)
405 | ref_depth_range = np.asarray([database.get_depth_range(ref_id) for ref_id in ref_ids], dtype=np.float32)
406 | if align_depth_range:
407 | ref_depth_range[:,0]=np.min(ref_depth_range[:,0])
408 | ref_depth_range[:,1]=np.max(ref_depth_range[:,1])
409 | ref_imgs_info = {'imgs': ref_imgs, 'poses': ref_poses, 'Ks': ref_Ks, 'depth_range': ref_depth_range, 'masks': ref_masks, 'labels': ref_labels}
410 | if has_depth: ref_imgs_info['depth'] = ref_depths
411 | if pad_interval!=-1:
412 | ref_imgs_info = pad_imgs_info(ref_imgs_info, pad_interval)
413 | return ref_imgs_info
414 |
415 | def build_render_imgs_info(que_pose,que_K,que_shape,que_depth_range):
416 | h, w = que_shape
417 | h, w = int(h), int(w)
418 | que_coords = np.stack(np.meshgrid(np.arange(w), np.arange(h), indexing='xy'), -1)
419 | que_coords = que_coords.reshape([1, -1, 2]).astype(np.float32)
420 | return {'poses': que_pose.astype(np.float32)[None,:,:], # 1,3,4
421 | 'Ks': que_K.astype(np.float32)[None,:,:], # 1,3,3
422 | 'coords': que_coords,
423 | 'depth_range': np.asarray(que_depth_range, np.float32)[None, :],
424 | 'shape': (h,w)}
425 |
426 | def imgs_info_to_torch(imgs_info):
427 | for k, v in imgs_info.items():
428 | if isinstance(v,np.ndarray):
429 | imgs_info[k] = torch.from_numpy(v)
430 | return imgs_info
431 |
432 | def imgs_info_slice(imgs_info, indices):
433 | imgs_info_out={}
434 | for k, v in imgs_info.items():
435 | imgs_info_out[k] = v[indices]
436 | return imgs_info_out
437 |
438 |
439 | class ReplicaDatabase(BaseDatabase):
440 | def __init__(self, database_name, root_dir='data/scannet'):
441 | super().__init__(database_name)
442 | _, self.scene_name, self.seq_id, background_size = database_name.split('/')
443 | background, image_size = background_size.split('_')
444 | self.image_size = int(image_size)
445 | self.background = background
446 | self.root_dir = f'{root_dir}/{self.scene_name}/{self.seq_id}'
447 | # self.root_dir = f'data/replica/{self.scene_name}/{self.seq_id}'
448 | self.ratio = self.image_size / 640
449 | self.h, self.w = int(self.ratio*480), int(self.image_size)
450 |
451 | rgb_paths = [x for x in glob.glob(os.path.join(
452 | self.root_dir, "rgb", "*")) if (x.endswith(".jpg") or x.endswith(".png"))]
453 | self.rgb_paths = natsorted(rgb_paths)
454 | # DO NOT use sorted() here!!! it will sort the name in a wrong way since the name is like rgb_1.jpg
455 |
456 | depth_paths = [x for x in glob.glob(os.path.join(
457 | self.root_dir, "depth", "*")) if (x.endswith(".jpg") or x.endswith(".png"))]
458 | self.depth_paths = natsorted(depth_paths)
459 |
460 | label_paths = [x for x in glob.glob(os.path.join(
461 | self.root_dir, "semantic_class", "*")) if (x.endswith(".jpg") or x.endswith(".png"))]
462 | self.label_paths = natsorted(label_paths)
463 |
464 | # Replica camera intrinsics
465 | # Pinhole Camera Model
466 | fx, fy, cx, cy, s = 320.0, 320.0, 319.5, 229.5, 0.0
467 | if self.ratio != 1.0:
468 | fx, fy, cx, cy = fx * self.ratio, fy * self.ratio, cx * self.ratio, cy * self.ratio
469 | self.K = np.array([[fx, s, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
470 | c2ws = np.loadtxt(f'{self.root_dir}/traj_w_c.txt',
471 | delimiter=' ').reshape(-1, 4, 4).astype(np.float32)
472 | self.poses = []
473 | transf = np.diag(np.asarray([1, -1, -1]))
474 | num_poses = c2ws.shape[0]
475 | for i in range(num_poses):
476 | pose = c2ws[i]
477 | # Change the pose to OpenGL coordinate system
478 | # TODO: check if this is correct, our code is using OpenCV coordinate system
479 | # pose = transf @ pose
480 | # pose = pose_inverse(pose)
481 | self.poses.append(pose)
482 |
483 | self.img_ids = []
484 | for i, rgb_path in enumerate(self.rgb_paths):
485 | self.img_ids.append(i)
486 |
487 | self.label_mapping = PointSegClassMapping(
488 | valid_cat_ids=[12, 17, 20, 22, 31, 37, 40, 44, 47, 56,
489 | 64, 79, 80, 87, 91, 92, 93, 95, 97],
490 | max_cat_id=101
491 | )
492 |
493 | def get_image(self, img_id):
494 | img = imread(self.rgb_paths[img_id])
495 | if self.w != 640:
496 | img = cv2.resize(downsample_gaussian_blur(
497 | img, self.ratio), (self.w, self.h), interpolation=cv2.INTER_LINEAR)
498 | return img
499 |
500 | def get_K(self, img_id):
501 | return self.K
502 |
503 | def get_pose(self, img_id):
504 | pose = self.poses[img_id]
505 | return pose.copy()
506 |
507 | def get_img_ids(self, check_depth_exist=False):
508 | return self.img_ids
509 |
510 | def get_bbox(self, img_id):
511 | raise NotImplementedError
512 |
513 | def get_depth(self, img_id):
514 | h, w, _ = self.get_image(img_id).shape
515 | img = Image.open(self.depth_paths[img_id])
516 | depth = np.asarray(img, dtype=np.float32) / 1000.0 # mm to m
517 | depth = np.ascontiguousarray(depth, dtype=np.float32)
518 | depth = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST)
519 | return depth
520 |
521 | def get_mask(self, img_id):
522 | h, w, _ = self.get_image(img_id).shape
523 | return np.ones([h, w], bool)
524 |
525 | def get_depth_range(self, img_id):
526 | return np.asarray((0.1, 6.0), np.float32)
527 |
528 | def get_label(self, img_id):
529 | h, w, _ = self.get_image(img_id).shape
530 | img = Image.open(self.label_paths[img_id])
531 | label = np.asarray(img, dtype=np.int32)
532 | label = np.ascontiguousarray(label)
533 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
534 | label = label.astype(np.int32)
535 | return self.label_mapping(label)
536 |
537 |
--------------------------------------------------------------------------------
/data/scannet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import json
4 | import numpy as np
5 | import random
6 | import time
7 | import os
8 | import cv2
9 | from glob import glob as glob
10 | from PIL import Image
11 | from torchvision import transforms as T
12 |
13 | from utils.utils import read_pfm, get_nearest_pose_ids, get_rays, compute_nearest_camera_indices
14 | from utils.scannet_utils import parse_database_name, get_database_split, get_coords_mask, random_crop, random_flip, build_imgs_info, pad_imgs_info, imgs_info_to_torch, imgs_info_slice
15 |
16 | def set_seed(index,is_train):
17 | if is_train:
18 | np.random.seed((index+int(time.time()))%(2**16))
19 | random.seed((index+int(time.time()))%(2**16)+1)
20 | torch.random.manual_seed((index+int(time.time()))%(2**16)+1)
21 | else:
22 | np.random.seed(index % (2 ** 16))
23 | random.seed(index % (2 ** 16) + 1)
24 | torch.random.manual_seed(index % (2 ** 16) + 1)
25 |
26 | def add_depth_offset(depth,mask,region_min,region_max,offset_min,offset_max,noise_ratio,depth_length):
27 | coords = np.stack(np.nonzero(mask), -1)[:, (1, 0)]
28 | length = np.max(coords, 0) - np.min(coords, 0)
29 | center = coords[np.random.randint(0, coords.shape[0])]
30 | lx, ly = np.random.uniform(region_min, region_max, 2) * length
31 | diff = coords - center[None, :]
32 | mask0 = np.abs(diff[:, 0]) < lx
33 | mask1 = np.abs(diff[:, 1]) < ly
34 | masked_coords = coords[mask0 & mask1]
35 | global_offset = np.random.uniform(offset_min, offset_max) * depth_length
36 | if np.random.random() < 0.5:
37 | global_offset = -global_offset
38 | local_offset = np.random.uniform(-noise_ratio, noise_ratio, masked_coords.shape[0]) * depth_length + global_offset
39 | depth[masked_coords[:, 1], masked_coords[:, 0]] += local_offset
40 |
41 | def build_src_imgs_info_select(database, ref_ids, ref_ids_all, cost_volume_nn_num, pad_interval=-1):
42 | # ref_ids - selected ref ids for rendering
43 | ref_idx_exp = compute_nearest_camera_indices(database, ref_ids, ref_ids_all)
44 | ref_idx_exp = ref_idx_exp[:, 1:1 + cost_volume_nn_num]
45 | ref_ids_all = np.asarray(ref_ids_all)
46 | ref_ids_exp = ref_ids_all[ref_idx_exp] # rfn,nn
47 | ref_ids_exp_ = ref_ids_exp.flatten()
48 | ref_ids = np.asarray(ref_ids)
49 | ref_ids_in = np.unique(np.concatenate([ref_ids_exp_, ref_ids])) # rfn'
50 | mask0 = ref_ids_in[None, :] == ref_ids[:, None] # rfn,rfn'
51 | ref_idx_, ref_idx = np.nonzero(mask0)
52 | ref_real_idx = ref_idx[np.argsort(ref_idx_)] # sort
53 |
54 | rfn, nn = ref_ids_exp.shape
55 | mask1 = ref_ids_in[None, :] == ref_ids_exp.flatten()[:, None] # nn*rfn,rfn'
56 | ref_cv_idx_, ref_cv_idx = np.nonzero(mask1)
57 | ref_cv_idx = ref_cv_idx[np.argsort(ref_cv_idx_)] # sort
58 | ref_cv_idx = ref_cv_idx.reshape([rfn, nn])
59 | is_aligned = not database.database_name.startswith('space')
60 | ref_imgs_info = build_imgs_info(database, ref_ids_in, pad_interval, is_aligned)
61 | return ref_imgs_info, ref_cv_idx, ref_real_idx
62 |
63 |
64 | class RendererDataset(Dataset):
65 | default_cfg={
66 | 'train_database_types':['scannet'],
67 | 'type2sample_weights': {'scannet': 1},
68 | 'val_database_name': 'scannet/scene0200_00/black_320',
69 | 'val_database_split_type': 'val',
70 |
71 | 'min_wn': 8,
72 | 'max_wn': 9,
73 | 'ref_pad_interval': 16,
74 | 'train_ray_num': 512,
75 | 'foreground_ratio': 0.5,
76 | 'resolution_type': 'lr',
77 | "use_consistent_depth_range": True,
78 | 'use_depth_loss_for_all': False,
79 | "use_depth": True,
80 | "use_src_imgs": False,
81 | "cost_volume_nn_num": 3,
82 |
83 | "aug_depth_range_prob": 0.05,
84 | 'aug_depth_range_min': 0.95,
85 | 'aug_depth_range_max': 1.05,
86 | "aug_use_depth_offset": True,
87 | "aug_depth_offset_prob": 0.25,
88 | "aug_depth_offset_region_min": 0.05,
89 | "aug_depth_offset_region_max": 0.1,
90 | 'aug_depth_offset_min': 0.5,
91 | 'aug_depth_offset_max': 1.0,
92 | 'aug_depth_offset_local': 0.1,
93 | "aug_use_depth_small_offset": True,
94 | "aug_use_global_noise": True,
95 | "aug_global_noise_prob": 0.5,
96 | "aug_depth_small_offset_prob": 0.5,
97 | "aug_forward_crop_size": (400,600),
98 | "aug_pixel_center_sample": True,
99 | "aug_view_select_type": "easy",
100 |
101 | "use_consistent_min_max": False,
102 | "revise_depth_range": False,
103 | }
104 | def __init__(self, root_dir, is_train, cfg=None):
105 | if cfg is not None:
106 | self.cfg={**self.default_cfg,**cfg}
107 | else:
108 | self.cfg={**self.default_cfg}
109 | self.root_dir = root_dir
110 | self.is_train = is_train
111 | if is_train:
112 | self.num=999999
113 | self.type2scene_names,self.database_types,self.database_weights = {}, [], []
114 | if self.cfg['resolution_type']=='hr':
115 | type2scene_names={
116 | 'replica': np.loadtxt('configs/lists/replica_train_split.txt',dtype=str).tolist(),
117 | }
118 | elif self.cfg['resolution_type']=='lr':
119 | type2scene_names={
120 | 'scannet': np.loadtxt('configs/lists/scannet_train_split.txt',dtype=str).tolist(),
121 | 'scannet_single': [self.cfg['val_database_name']],
122 | }
123 | else:
124 | raise NotImplementedError
125 |
126 | for database_type in self.cfg['train_database_types']:
127 | self.type2scene_names[database_type] = type2scene_names[database_type]
128 | self.database_types.append(database_type)
129 | self.database_weights.append(self.cfg['type2sample_weights'][database_type])
130 | assert(len(self.database_types)>0)
131 | # normalize weights
132 | self.database_weights=np.asarray(self.database_weights)
133 | self.database_weights=self.database_weights/np.sum(self.database_weights)
134 | else:
135 | self.database = parse_database_name(self.cfg['val_database_name'], root_dir=self.root_dir)
136 | self.ref_ids, self.que_ids = get_database_split(self.database,self.cfg['val_database_split_type'])
137 | self.num=len(self.que_ids)
138 | self.database_statistics = {}
139 |
140 | self.blender2opencv = torch.tensor(
141 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
142 | ).to(torch.float32)
143 |
144 | def get_database_ref_que_ids(self, index):
145 | if self.is_train:
146 | database_type = np.random.choice(self.database_types,1,False,p=self.database_weights)[0]
147 | database_scene_name = np.random.choice(self.type2scene_names[database_type])
148 | database = parse_database_name(database_scene_name, root_dir=self.root_dir)
149 | # if there is no depth for all views, we repeat random sample until find a scene with depth
150 | while True:
151 | ref_ids = database.get_img_ids(check_depth_exist=True)
152 | if len(ref_ids)==0:
153 | database_type = np.random.choice(self.database_types, 1, False, self.database_weights)[0]
154 | database_scene_name = np.random.choice(self.type2scene_names[database_type])
155 | database = parse_database_name(database_scene_name, root_dir=self.root_dir)
156 | else: break
157 | que_id = np.random.choice(ref_ids)
158 | # if database.database_name.startswith('real_estate'):
159 | # que_id, ref_ids = select_train_ids_for_real_estate(ref_ids)
160 | else:
161 | database = self.database
162 | que_id, ref_ids = self.que_ids[index], self.ref_ids
163 | return database, que_id, np.asarray(ref_ids)
164 |
165 | def select_working_views_impl(self, database_name, dist_idx, ref_num):
166 | if self.cfg['aug_view_select_type']=='default':
167 | if database_name.startswith('space') or database_name.startswith('real_estate'):
168 | pass
169 | elif database_name.startswith('gso'):
170 | pool_ratio = np.random.randint(1, 5)
171 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 32)]
172 | elif database_name.startswith('real_iconic'):
173 | pool_ratio = np.random.randint(1, 5)
174 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 32)]
175 | elif database_name.startswith('dtu_train'):
176 | pool_ratio = np.random.randint(1, 3)
177 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)]
178 | elif database_name.startswith('scannet'):
179 | pool_ratio = np.random.randint(1, 3)
180 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)]
181 | elif database_name.startswith('replica'):
182 | pool_ratio = np.random.randint(1, 3)
183 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)]
184 | else:
185 | raise NotImplementedError
186 | elif self.cfg['aug_view_select_type']=='easy':
187 | if database_name.startswith('space') or database_name.startswith('real_estate'):
188 | pass
189 | elif database_name.startswith('gso'):
190 | pool_ratio = 3
191 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 24)]
192 | elif database_name.startswith('real_iconic'):
193 | pool_ratio = np.random.randint(1, 4)
194 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 20)]
195 | elif database_name.startswith('dtu_train'):
196 | pool_ratio = np.random.randint(1, 3)
197 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)]
198 | elif database_name.startswith('scannet'):
199 | pool_ratio = np.random.randint(1, 3)
200 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)]
201 | elif database_name.startswith('replica'):
202 | pool_ratio = np.random.randint(1, 3)
203 | dist_idx = dist_idx[:min(ref_num * pool_ratio, 12)]
204 | else:
205 | raise NotImplementedError
206 |
207 | return dist_idx
208 |
209 | def select_working_views(self, database, que_id, ref_ids):
210 | database_name = database.database_name
211 | dist_idx = compute_nearest_camera_indices(database, [que_id], ref_ids)[0]
212 | if self.is_train:
213 | if np.random.random()>0.02: # 2% chance to include que image
214 | dist_idx = dist_idx[ref_ids[dist_idx]!=que_id]
215 | ref_num = np.random.randint(self.cfg['min_wn'], self.cfg['max_wn'])
216 | dist_idx = self.select_working_views_impl(database_name,dist_idx,ref_num)
217 | if not database_name.startswith('real_estate'):
218 | # we already select working views for real estate dataset
219 | np.random.shuffle(dist_idx)
220 | dist_idx = dist_idx[:ref_num]
221 | ref_ids = ref_ids[dist_idx]
222 | else:
223 | ref_ids = ref_ids[:ref_num]
224 | else:
225 | dist_idx = dist_idx[:self.cfg['min_wn']]
226 | ref_ids = ref_ids[dist_idx]
227 | return ref_ids
228 |
229 | def random_change_depth_range(self, depth_range):
230 | depth_range_new = depth_range.copy()
231 | if np.random.random()0
273 | coords = get_coords_mask(que_mask_cur, self.cfg['train_ray_num'], self.cfg['foreground_ratio']).reshape([1,-1,2])
274 | return coords
275 |
276 | def consistent_depth_range(self, ref_imgs_info, que_imgs_info):
277 | depth_range_all = np.concatenate([ref_imgs_info['depth_range'], que_imgs_info['depth_range']], 0)
278 | if self.cfg['use_consistent_min_max']:
279 | depth_range_all[:, 0] = np.min(depth_range_all)
280 | depth_range_all[:, 1] = np.max(depth_range_all)
281 | else:
282 | range_len = depth_range_all[:, 1] - depth_range_all[:, 0]
283 | max_len = np.max(range_len)
284 | range_margin = (max_len - range_len) / 2
285 | ref_near = depth_range_all[:, 0] - range_margin
286 | ref_near = np.max(np.stack([ref_near, depth_range_all[:, 0] * 0.5], -1), 1)
287 | depth_range_all[:, 0] = ref_near
288 | depth_range_all[:, 1] = ref_near + max_len
289 | ref_imgs_info['depth_range'] = depth_range_all[:-1]
290 | que_imgs_info['depth_range'] = depth_range_all[-1:]
291 |
292 |
293 | def multi_scale_depth(self, depth_h):
294 | '''
295 | This is the implementation of Klevr dataset and move here to make dataset format the same
296 | '''
297 |
298 | depth = {}
299 | for l in range(3):
300 |
301 | depth[f"level_{l}"] = cv2.resize(
302 | depth_h,
303 | None,
304 | fx=1.0 / (2**l),
305 | fy=1.0 / (2**l),
306 | interpolation=cv2.INTER_NEAREST,
307 | )
308 | # depth[f"level_{l}"][depth[f"level_{l}"] > far_bound * 0.95] = 0.0
309 |
310 | if self.is_train:
311 | cutout = np.ones_like(depth[f"level_2"])
312 | h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1))
313 | h1 = int(
314 | np.random.randint(
315 | 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1
316 | )
317 | )
318 | w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1))
319 | w1 = int(
320 | np.random.randint(
321 | 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1
322 | )
323 | )
324 | cutout[h0:h1, w0:w1] = 0
325 | depth_aug = depth[f"level_2"] * cutout
326 | else:
327 | depth_aug = depth[f"level_2"].copy()
328 |
329 | return depth, depth_aug
330 |
331 |
332 |
333 |
334 | def __getitem__(self, index):
335 | set_seed(index, self.is_train)
336 | database, que_id, ref_ids_all = self.get_database_ref_que_ids(index)
337 | ref_ids = self.select_working_views(database, que_id, ref_ids_all)
338 | if self.cfg['use_src_imgs']:
339 | # src_imgs_info used in construction of cost volume
340 | ref_imgs_info, ref_cv_idx, ref_real_idx = build_src_imgs_info_select(database,ref_ids,ref_ids_all,self.cfg['cost_volume_nn_num'])
341 | else:
342 | ref_idx = compute_nearest_camera_indices(database, ref_ids)[:,0:4] # used in cost volume construction
343 | is_aligned = not database.database_name.startswith('space')
344 | ref_imgs_info = build_imgs_info(database, ref_ids, -1, is_aligned)
345 | # semray's implementation query image cannot access to depth, we use depth here but not as input nor supervision
346 | # que_imgs_info = build_imgs_info(database, [que_id], has_depth=self.is_train)
347 | que_imgs_info = build_imgs_info(database, [que_id])
348 |
349 | if self.is_train:
350 | # data augmentation
351 | depth_range_all = np.concatenate([ref_imgs_info['depth_range'],que_imgs_info['depth_range']],0)
352 |
353 | depth_range_all = self.random_change_depth_range(depth_range_all)
354 | ref_imgs_info['depth_range'] = depth_range_all[:-1]
355 | que_imgs_info['depth_range'] = depth_range_all[-1:]
356 |
357 |
358 |
359 | if database.database_name.startswith('real_estate') \
360 | or database.database_name.startswith('real_iconic') \
361 | or database.database_name.startswith('space'):
362 | # crop all datasets
363 | ref_imgs_info, que_imgs_info = random_crop(ref_imgs_info, que_imgs_info, self.cfg['aug_forward_crop_size'])
364 | if np.random.random()<0.5:
365 | ref_imgs_info, que_imgs_info = random_flip(ref_imgs_info, que_imgs_info)
366 |
367 | if self.cfg['use_depth_loss_for_all'] and self.cfg['use_depth']:
368 | if not database.database_name.startswith('gso'):
369 | ref_imgs_info['true_depth'] = ref_imgs_info['depth']
370 |
371 | if self.cfg['use_consistent_depth_range']:
372 | self.consistent_depth_range(ref_imgs_info, que_imgs_info)
373 |
374 |
375 | ref_imgs_info = pad_imgs_info(ref_imgs_info,self.cfg['ref_pad_interval'])
376 |
377 | # don't feed depth to gpu
378 | if not self.cfg['use_depth']:
379 | if 'depth' in ref_imgs_info: ref_imgs_info.pop('depth')
380 | if 'depth' in que_imgs_info: que_imgs_info.pop('depth')
381 | if 'true_depth' in ref_imgs_info: ref_imgs_info.pop('true_depth')
382 |
383 | if self.cfg['use_src_imgs']:
384 | src_imgs_info = ref_imgs_info.copy()
385 | ref_imgs_info = imgs_info_slice(ref_imgs_info, ref_real_idx)
386 | ref_imgs_info['nn_ids'] = ref_cv_idx
387 | else:
388 | # 'nn_ids' used in constructing cost volume (specify source image ids)
389 | ref_imgs_info['nn_ids'] = ref_idx.astype(np.int64)
390 |
391 | # ref_imgs_info = imgs_info_to_torch(ref_imgs_info)
392 | # que_imgs_info = imgs_info_to_torch(que_imgs_info)
393 |
394 | # outputs = {'ref_imgs_info': ref_imgs_info, 'que_imgs_info': que_imgs_info, 'scene_name': database.database_name}
395 | # if self.cfg['use_src_imgs']: outputs['src_imgs_info'] = imgs_info_to_torch(src_imgs_info)
396 |
397 | same_format_as_klevr = True
398 | if same_format_as_klevr:
399 | sample = {}
400 | sample['images'] = np.concatenate((ref_imgs_info['imgs'], que_imgs_info['imgs']), 0)
401 | sample['semantics'] = np.concatenate((ref_imgs_info['labels'], que_imgs_info['labels']), 0).squeeze(1)
402 |
403 |
404 | # sample['w2cs'] = np.concatenate((ref_imgs_info['poses'], que_imgs_info['poses']), 0) # (1+nb_views, 3, 4)
405 | # sample['w2cs'] = np.concatenate((sample['w2cs'], torch.ones_like(sample['w2cs'])[:,0:1,:]), 1) # (1+nb_views, 4, 4)
406 |
407 | sample['c2ws'] = np.concatenate((ref_imgs_info['poses'], que_imgs_info['poses']), 0) # (1+nb_views, 4, 4)
408 | # sample['c2ws'] = np.concatenate((sample['c2ws'], torch.ones_like(sample['c2ws'])[:,0:1,:]), 1) # (1+nb_views, 4, 4)
409 | # sample['c2ws'] = sample['c2ws'] @ self.blender2opencv
410 |
411 |
412 | sample['intrinsics'] = np.concatenate((ref_imgs_info['Ks'], que_imgs_info['Ks']), 0)
413 | sample['near_fars'] = np.concatenate((ref_imgs_info['depth_range'], que_imgs_info['depth_range']), 0)
414 | sample['depths_h'] = np.concatenate((ref_imgs_info['depth'], que_imgs_info['depth']), 0).squeeze(1)
415 | sample['closest_idxs'] = ref_imgs_info['nn_ids'] # (nb_view, 4) # used in cost volume construction # hard code to [0:4]
416 |
417 | # affine_mats (1+nb_views, 4, 4, 3)
418 | # affine_mats_inv (1+nb_views, 4, 4, 3)
419 | # depths_aug (1+nb_views, 1, H/4, W/4)
420 | # depths {dict} {'level_0': (1+nb_views, 1, H, W), 'level_1': (1+nb_views, 1, H/2, W/2), 'level_2': (1+nb_views, 1, H/4, W/4)}
421 |
422 | sample['w2cs'] = []
423 | affine_mats, affine_mats_inv, depths_aug = [], [], []
424 | project_mats = []
425 | depths = {"level_0": [], "level_1": [], "level_2": []}
426 |
427 | for i in range(sample['c2ws'].shape[0]):
428 | sample['w2cs'].append(np.linalg.inv(sample['c2ws'][i]))
429 | # sample['w2cs'].append(torch.asarray(np.linalg.inv(np.asarray(sample['c2ws'][i]))))
430 |
431 | aff = []
432 | aff_inv = []
433 | proj_matrices = []
434 |
435 | for l in range(3):
436 | proj_mat_l = np.eye(4)
437 | intrinsic_temp = sample['intrinsics'][i].copy()
438 | intrinsic_temp[:2] = intrinsic_temp[:2]/(2**l)
439 | proj_mat_l[:3,:4] = intrinsic_temp @ sample['w2cs'][i][:3,:4]
440 | aff.append(proj_mat_l)
441 | aff_inv.append(np.linalg.inv(proj_mat_l))
442 | # For unsupervised depth loss
443 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32)
444 | proj_mat[0, :4, :4] = sample['w2cs'][i][:4,:4]
445 | proj_mat[1, :3, :3] = intrinsic_temp
446 | proj_matrices.append(proj_mat)
447 |
448 | aff = np.stack(aff, axis=-1)
449 | aff_inv = np.stack(aff_inv, axis=-1)
450 | proj_matrices = np.stack(proj_matrices)
451 |
452 | affine_mats.append(aff)
453 | affine_mats_inv.append(aff_inv)
454 | project_mats.append(proj_matrices)
455 |
456 | depth, depth_aug = self.multi_scale_depth(np.asarray(sample['depths_h'][i]))
457 | depths["level_0"].append(depth["level_0"])
458 | depths["level_1"].append(depth["level_1"])
459 | depths["level_2"].append(depth["level_2"])
460 | depths_aug.append(depth_aug)
461 |
462 | affine_mats = np.stack(affine_mats)
463 | affine_mats_inv = np.stack(affine_mats_inv)
464 | project_mats = np.stack(project_mats)
465 | depths_aug = np.stack(depths_aug)
466 | depths["level_0"] = np.stack(depths["level_0"])
467 | depths["level_1"] = np.stack(depths["level_1"])
468 | depths["level_2"] = np.stack(depths["level_2"])
469 |
470 |
471 | sample['w2cs'] = np.stack(sample['w2cs'], 0) # (1+nb_views, 4, 4)
472 | sample['affine_mats'] = affine_mats
473 | sample['affine_mats_inv'] = affine_mats_inv
474 | sample['depths_aug'] = depths_aug
475 | sample['depths'] = depths
476 | sample['project_mats'] = project_mats
477 |
478 | return sample
479 |
480 | # if same_format_as_klevr == False:
481 | # return outputs
482 |
483 | def __len__(self):
484 | return self.num
485 |
486 |
487 |
--------------------------------------------------------------------------------