├── README.md ├── config.yml ├── cvrecon ├── cnn2d.py ├── cnn3d.py ├── collate.py ├── cost_volume.py ├── cvrecon.py ├── data.py ├── lightningmodel.py ├── mv_fusion.py ├── transformer.py ├── tsdf_fusion.py ├── utils.py └── view_direction_encoder.py ├── scripts ├── inference.py └── train.py └── tools ├── generate_gt.py ├── preprocess_scannet.py └── simple_loader.py /README.md: -------------------------------------------------------------------------------- 1 | # CVRecon: Rethinking 3D Geometric Feature Learning for Neural Reconstruction 2 | 3 | This paper has been accepted by [ICCV 2023](https://iccv2023.thecvf.com) 4 | 5 | By [Ziyue Feng](https://ziyue.cool), [Liang Yang](https://ericlyang.github.io/), [Pengsheng Guo](https://psguo.github.io), and [Bing Li](https://www.clemson.edu/cecas/departments/automotive-engineering/people/li.html). 6 | 7 | Project Page: [cvrecon.ziyue.cool](https://cvrecon.ziyue.cool) 8 | 9 | ## Video 10 | 11 | [![image](https://i.ibb.co/KjpVBN5/Screenshot-2023-09-17-at-11-31-15-PM.png)](https://www.youtube.com/watch?v=AVbbx4TBFf8) 12 | 13 | Dear readers: 14 | 15 | Apologize for late release of the code, I have been too busy recently so still have not got time to clean up the code. 16 | 17 | I hope this initial release could give you some idea about how the CVRecon works. The implementation is based on the Cost Volume of "SimpleRecon" and the framework of "VoRTX". 18 | 19 | I will clean up the code as soon as I got time. 20 | 21 | ### Dependencies 22 | 23 | ``` 24 | conda create -n cvrecon python=3.9 -y 25 | conda activate cvrecon 26 | 27 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch 28 | 29 | pip install \ 30 | pytorch-lightning==1.5 \ 31 | scikit-image==0.18 \ 32 | numba \ 33 | pillow \ 34 | wandb \ 35 | tqdm \ 36 | open3d \ 37 | pyrender \ 38 | ray \ 39 | trimesh \ 40 | pyyaml \ 41 | matplotlib \ 42 | black \ 43 | pycuda \ 44 | opencv-python \ 45 | imageio 46 | 47 | sudo apt install libsparsehash-dev 48 | pip install torchsparse-v1.4.0 49 | 50 | pip install -e . 51 | ``` 52 | 53 | 54 | ### Data 55 | 56 | The ScanNet data should be downloaded and extracted by the script provided by the authors. 57 | 58 | 59 | To format ScanNet for cvrecon: 60 | ``` 61 | python tools/preprocess_scannet.py --src path/to/scannet_src --dst path/to/new/scannet_dst 62 | ``` 63 | In `config.yml`, set `scannet_dir` to the value of `--dst`. 64 | 65 | To generate ground truth tsdf: 66 | ``` 67 | python tools/generate_gt.py --data_path path/to/scannet_src --save_name TSDF_OUTPUT_DIR 68 | # For the test split 69 | python tools/generate_gt.py --test --data_path path/to/scannet_src --save_name TSDF_OUTPUT_DIR 70 | ``` 71 | In `config.yml`, set `tsdf_dir` to the value of `TSDF_OUTPUT_DIR`. 72 | 73 | ## Training 74 | 75 | ``` 76 | python scripts/train.py --config config.yml 77 | ``` 78 | Parameters can be adjusted in `config.yml`. 79 | Set `attn_heads=0` to use direct averaging instead of transformers. 80 | 81 | ## Inference 82 | 83 | ``` 84 | python scripts/inference.py \ 85 | --ckpt path/to/checkpoint.ckpt \ 86 | --split [train / val / test] \ 87 | --outputdir path/to/desired_output_directory \ 88 | --n-imgs 60 \ 89 | --config config.yml \ 90 | --cropsize 96 91 | ``` 92 | 93 | ## Evaluation 94 | 95 | Refer to the evaluation protocal by Atlas and TransformerFusion 96 | 97 | # Citation 98 | ``` 99 | @misc{feng2023cvrecon, 100 | title={CVRecon: Rethinking 3D Geometric Feature Learning For Neural Reconstruction}, 101 | author={Ziyue Feng and Leon Yang and Pengsheng Guo and Bing Li}, 102 | year={2023}, 103 | eprint={2304.14633}, 104 | archivePrefix={arXiv}, 105 | primaryClass={cs.CV} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | n_imgs_train: 20 2 | n_imgs_val: 20 3 | n_imgs_test: 20 4 | crop_size_train: [96, 96, 48] 5 | crop_size_val: [96, 96, 48] 6 | crop_size_test: [96, 96, 48] 7 | attn_heads: 2 8 | attn_layers: 2 9 | use_proj_occ: False 10 | 11 | ckpt: null 12 | wandb_runid: null 13 | 14 | seed: 0 15 | use_amp: True 16 | nworkers: 8 17 | 18 | initial_epochs: 350 19 | initial_lr: .001 20 | initial_batch_size: 2 21 | 22 | finetune_epochs: 150 23 | finetune_lr: .0001 24 | finetune_batch_size: 1 25 | 26 | wandb_project_name: "cvrecon" 27 | 28 | scannet_dir: "ScanNet2/cvrecon" 29 | tsdf_dir: "ScanNet2/cvrecon_gt" 30 | 31 | SRfeat: False 32 | SR_vi_ebd: False 33 | SRCV: False 34 | 35 | cost_volume: True 36 | cv_dim: 15 37 | cv_overall: True 38 | 39 | depth_head: False 40 | 41 | accu_grad: 1 -------------------------------------------------------------------------------- /cvrecon/cnn2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision.models import mnasnet1_0, MNASNet1_0_Weights 4 | 5 | 6 | class MnasMulti(torch.nn.Module): 7 | def __init__(self, output_depths, pretrained=True): 8 | super().__init__() 9 | MNASNet = mnasnet1_0(weights=MNASNet1_0_Weights.DEFAULT) 10 | self.conv0 = torch.nn.Sequential( 11 | MNASNet.layers._modules["0"], 12 | MNASNet.layers._modules["1"], 13 | MNASNet.layers._modules["2"], 14 | MNASNet.layers._modules["3"], 15 | MNASNet.layers._modules["4"], 16 | MNASNet.layers._modules["5"], 17 | MNASNet.layers._modules["6"], 18 | MNASNet.layers._modules["7"], 19 | MNASNet.layers._modules["8"], 20 | ) 21 | self.conv1 = MNASNet.layers._modules["9"] 22 | self.conv2 = MNASNet.layers._modules["10"] 23 | 24 | final_chs = 80 25 | 26 | self.inner1 = torch.nn.Sequential( 27 | torch.nn.BatchNorm2d(output_depths[1]), 28 | torch.nn.ReLU(True), 29 | torch.nn.Conv2d(output_depths[1], final_chs, 1, bias=False), 30 | ) 31 | self.inner2 = torch.nn.Sequential( 32 | torch.nn.BatchNorm2d(output_depths[2]), 33 | torch.nn.ReLU(True), 34 | torch.nn.Conv2d(output_depths[2], final_chs, 1, bias=False), 35 | ) 36 | 37 | self.out1 = torch.nn.Sequential( 38 | torch.nn.BatchNorm2d(final_chs), 39 | torch.nn.ReLU(True), 40 | torch.nn.Conv2d(final_chs, output_depths[0], 1, bias=False), 41 | ) 42 | self.out2 = torch.nn.Sequential( 43 | torch.nn.BatchNorm2d(final_chs), 44 | torch.nn.ReLU(True), 45 | torch.nn.Conv2d(final_chs, output_depths[1], 3, bias=False, padding=1), 46 | ) 47 | self.out3 = torch.nn.Sequential( 48 | torch.nn.BatchNorm2d(final_chs), 49 | torch.nn.ReLU(True), 50 | torch.nn.Conv2d(final_chs, output_depths[2], 3, bias=False, padding=1), 51 | ) 52 | 53 | torch.nn.init.kaiming_normal_(self.inner1[2].weight) 54 | torch.nn.init.kaiming_normal_(self.inner2[2].weight) 55 | torch.nn.init.kaiming_normal_(self.out1[2].weight) 56 | torch.nn.init.kaiming_normal_(self.out2[2].weight) 57 | torch.nn.init.kaiming_normal_(self.out3[2].weight) 58 | 59 | def forward(self, x): 60 | conv0 = self.conv0(x) 61 | conv1 = self.conv1(conv0) 62 | conv2 = self.conv2(conv1) 63 | 64 | intra_feat = conv2 65 | outputs = {} 66 | out = self.out1(intra_feat) 67 | outputs["coarse"] = out 68 | 69 | intra_feat = torch.nn.functional.interpolate( 70 | intra_feat, scale_factor=2, mode="bilinear", align_corners=False 71 | ) + self.inner1(conv1) 72 | out = self.out2(intra_feat) 73 | outputs["medium"] = out 74 | 75 | intra_feat = torch.nn.functional.interpolate( 76 | intra_feat, scale_factor=2, mode="bilinear", align_corners=False 77 | ) + self.inner2(conv0) 78 | out = self.out3(intra_feat) 79 | outputs["fine"] = out 80 | 81 | return outputs 82 | -------------------------------------------------------------------------------- /cvrecon/cnn3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchsparse 3 | import torchsparse.nn as spnn 4 | 5 | import numpy as np 6 | 7 | 8 | class BasicConvolutionBlock(nn.Module): 9 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 10 | super().__init__() 11 | self.net = nn.Sequential( 12 | spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), 13 | spnn.BatchNorm(outc), 14 | spnn.ReLU(True), 15 | ) 16 | 17 | def forward(self, x): 18 | out = self.net(x) 19 | return out 20 | 21 | 22 | class BasicDeconvolutionBlock(nn.Module): 23 | def __init__(self, inc, outc, ks=3, stride=1): 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | spnn.Conv3d(inc, outc, kernel_size=ks, stride=stride, transposed=True), 27 | spnn.BatchNorm(outc), 28 | spnn.ReLU(True), 29 | ) 30 | 31 | def forward(self, x): 32 | return self.net(x) 33 | 34 | 35 | class ResidualBlock(nn.Module): 36 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 37 | super().__init__() 38 | self.net = nn.Sequential( 39 | spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), 40 | spnn.BatchNorm(outc), 41 | spnn.ReLU(True), 42 | spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation, stride=1), 43 | spnn.BatchNorm(outc), 44 | ) 45 | 46 | self.downsample = ( 47 | nn.Sequential() 48 | if (inc == outc and stride == 1) 49 | else nn.Sequential( 50 | spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), 51 | spnn.BatchNorm(outc), 52 | ) 53 | ) 54 | 55 | self.relu = spnn.ReLU(True) 56 | 57 | def forward(self, x): 58 | out = self.relu(self.net(x) + self.downsample(x)) 59 | return out 60 | 61 | 62 | class SPVCNN(nn.Module): 63 | def __init__(self, **kwargs): 64 | super().__init__() 65 | 66 | self.dropout = kwargs["dropout"] 67 | 68 | base_depth = kwargs["base_depth"] 69 | cs = np.array([1, 2, 4, 3, 3]) * base_depth 70 | self.output_depth = cs[-1] 71 | 72 | self.stem = nn.Sequential( 73 | spnn.Conv3d(kwargs["in_channels"], cs[0], kernel_size=3, stride=1), 74 | spnn.BatchNorm(cs[0]), 75 | spnn.ReLU(True), 76 | ) 77 | 78 | self.stage1 = nn.Sequential( 79 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), 80 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), 81 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), 82 | ) 83 | 84 | self.stage2 = nn.Sequential( 85 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), 86 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), 87 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), 88 | ) 89 | 90 | self.up1 = nn.ModuleList( 91 | [ 92 | BasicDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2), 93 | nn.Sequential( 94 | ResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1, dilation=1), 95 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), 96 | ), 97 | ] 98 | ) 99 | 100 | self.up2 = nn.ModuleList( 101 | [ 102 | BasicDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2), 103 | nn.Sequential( 104 | ResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1, dilation=1), 105 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), 106 | ), 107 | ] 108 | ) 109 | 110 | self.weight_initialization() 111 | 112 | if self.dropout: 113 | self.dropout = nn.Dropout(0.3, True) 114 | 115 | def weight_initialization(self): 116 | for m in self.modules(): 117 | if isinstance(m, nn.BatchNorm1d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | def forward(self, x0): 122 | x0 = self.stem(x0) 123 | 124 | x1 = self.stage1(x0) 125 | x2 = self.stage2(x1) 126 | 127 | y3 = self.up1[0](x2) 128 | y3 = torchsparse.cat([y3, x1]) 129 | y3 = self.up1[1](y3) 130 | 131 | y4 = self.up2[0](y3) 132 | y4 = torchsparse.cat([y4, x0]) 133 | y4 = self.up2[1](y4) 134 | 135 | return y4 136 | -------------------------------------------------------------------------------- /cvrecon/collate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchsparse 4 | import torchsparse.utils 5 | 6 | 7 | def sparse_collate_tensors(tensors): 8 | lens = [len(t.C) for t in tensors] 9 | coords = torch.empty((sum(lens), 4), dtype=torch.int32, device=tensors[0].C.device) 10 | prev = 0 11 | for i, n in enumerate(lens): 12 | coords[prev : prev + n, 3] = i 13 | coords[prev : prev + n, :3] = tensors[i].C 14 | prev += n 15 | 16 | feats = torch.cat([t.F for t in tensors], dim=0) 17 | if feats.dtype is not torch.float32: 18 | raise Exception("features should be float32") 19 | return torchsparse.SparseTensor(feats, coords) 20 | 21 | 22 | def sparse_collate_fn(batch): 23 | if isinstance(batch[0], dict): 24 | batch_size = batch.__len__() 25 | ans_dict = {} 26 | for key in batch[0].keys(): 27 | if isinstance(batch[0][key], torchsparse.SparseTensor): 28 | ans_dict[key] = sparse_collate_tensors( 29 | [sample[key] for sample in batch] 30 | ) 31 | elif isinstance(batch[0][key], np.ndarray): 32 | ans_dict[key] = torch.stack( 33 | [torch.from_numpy(sample[key]) for sample in batch], axis=0 34 | ) 35 | elif isinstance(batch[0][key], torch.Tensor): 36 | ans_dict[key] = torch.stack([sample[key] for sample in batch], axis=0) 37 | elif isinstance(batch[0][key], dict): 38 | ans_dict[key] = sparse_collate_fn([sample[key] for sample in batch]) 39 | else: 40 | ans_dict[key] = [sample[key] for sample in batch] 41 | return ans_dict 42 | else: 43 | batch_size = batch.__len__() 44 | ans_dict = tuple() 45 | for i in range(len(batch[0])): 46 | key = batch[0][i] 47 | if isinstance(key, torchsparse.SparseTensor): 48 | ans_dict += (sparse_collate_tensors([sample[i] for sample in batch]),) 49 | elif isinstance(key, np.ndarray): 50 | ans_dict += ( 51 | torch.stack( 52 | [torch.from_numpy(sample[i]) for sample in batch], axis=0 53 | ), 54 | ) 55 | elif isinstance(key, torch.Tensor): 56 | ans_dict += (torch.stack([sample[i] for sample in batch], axis=0),) 57 | elif isinstance(key, dict): 58 | ans_dict += (sparse_collate_fn([sample[i] for sample in batch]),) 59 | else: 60 | ans_dict += ([sample[i] for sample in batch],) 61 | return ans_dict 62 | 63 | 64 | if __name__ == "__main__": 65 | batch = [] 66 | for i in range(5): 67 | n = np.random.randint(100) 68 | feats = torch.from_numpy(np.random.randn(n).astype(np.float32)) 69 | coords = torch.from_numpy(np.random.randint(100, size=(n, 3))) 70 | batch.append({"t": torchsparse.SparseTensor(feats, coords)}) 71 | 72 | tensors = [b["t"] for b in batch] 73 | 74 | a = torchsparse.utils.sparse_collate_fn(batch) 75 | b = sparse_collate_fn(batch) 76 | 77 | assert torch.all(a["t"].C == b["t"].C) 78 | assert torch.all(a["t"].F == b["t"].F) 79 | -------------------------------------------------------------------------------- /cvrecon/cvrecon.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchvision 7 | 8 | import torchsparse 9 | import torchsparse.nn.functional as spf 10 | import torch.nn.functional as F 11 | 12 | from cvrecon import cnn2d, cnn3d, mv_fusion, utils, view_direction_encoder, SR_encoder 13 | from cvrecon.cost_volume import ResnetMatchingEncoder, FastFeatureVolumeManager, tensor_B_to_bM, tensor_bM_to_B, TensorFormatter 14 | 15 | 16 | class cvrecon(torch.nn.Module): 17 | def __init__(self, attn_heads, attn_layers, use_proj_occ, SRfeat, SR_vi_ebd, SRCV, use_cost_volume, cv_dim, cv_overall, depth_head): 18 | super().__init__() 19 | self.use_proj_occ = use_proj_occ 20 | self.n_attn_heads = attn_heads 21 | self.resolutions = collections.OrderedDict( 22 | [ 23 | ["coarse", 0.16], 24 | ["medium", 0.08], 25 | ["fine", 0.04], 26 | ] 27 | ) 28 | self.SRfeat = SRfeat 29 | self.SR_vi_ebd = SR_vi_ebd 30 | self.SRCV = SRCV 31 | self.use_cost_volume = use_cost_volume 32 | self.cv_overall = cv_overall 33 | self.cv_dim = cv_dim 34 | SRcha = [256, 128, 64] 35 | self.max_depth = 5.0 36 | self.min_depth = 0.25 37 | 38 | cnn2d_output_depths = [80, 40, 24] 39 | cnn3d_base_depths = [32, 16, 8] 40 | 41 | self.cnn2d = cnn2d.MnasMulti(cnn2d_output_depths, pretrained=True) 42 | self.upsampler = Upsampler() 43 | 44 | self.output_layers = torch.nn.ModuleDict() 45 | self.cnns3d = torch.nn.ModuleDict() 46 | self.view_embedders = torch.nn.ModuleDict() 47 | self.sr_encoder = torch.nn.ModuleDict() 48 | self.layer_norms = torch.nn.ModuleDict() 49 | self.mv_fusion = torch.nn.ModuleDict() 50 | 51 | if self.use_cost_volume: 52 | self.matching_encoder = ResnetMatchingEncoder(18, 16) # ResNet18, CV feature dim = 16 53 | self.cost_volume = FastFeatureVolumeManager( 54 | matching_height=480 // 8, 55 | matching_width=640 // 8, 56 | num_depth_bins=64, 57 | mlp_channels=[202,128,128,cv_dim - 8 if cv_overall else cv_dim], 58 | matching_dim_size=16, 59 | num_source_views=8 - 1 60 | ) 61 | self.cost_volume.load_state_dict(torch.load('cv02.pth'), strict=False) 62 | self.matching_encoder.load_state_dict(torch.load('me.pth')) 63 | 64 | self.tensor_formatter = TensorFormatter() 65 | 66 | self.cv_global_encoder = torch.nn.ModuleDict() 67 | self.unshared_conv = torch.nn.ModuleDict() 68 | for resname, cha in zip(['coarse', 'medium', 'fine'], [80, 40, 24]): 69 | self.cv_global_encoder[resname] = torch.nn.Sequential( 70 | torch.nn.Conv2d(64+cha, cha+64, 3, padding=1, bias=False), 71 | torch.nn.BatchNorm2d(cha+64), 72 | torch.nn.LeakyReLU(0.2, True), 73 | ) 74 | self.unshared_conv[resname] = torch.nn.Conv2d((cha+7)*64, 7*64, 3, padding=1, groups=64) 75 | 76 | 77 | 78 | if depth_head: 79 | self.depth_head = torch.nn.Conv2d(48, 1, 1) 80 | self.depth_loss = torch.nn.L1Loss() 81 | else: self.depth_head = False 82 | 83 | prev_output_depth = 0 84 | for i, (resname, res) in enumerate(self.resolutions.items()): 85 | if self.SRfeat: 86 | self.sr_encoder[resname] = SR_encoder.SR_encoder(SRcha[i], cnn2d_output_depths[i]) # 1 by 1 conv to adapt SimpleRecon feat channel to required channel. 87 | self.view_embedders[resname] = view_direction_encoder.ViewDirectionEncoder( # to encode camera viewing ray into 2dCNN feature 88 | cnn2d_output_depths[i], L=4 89 | ) 90 | self.layer_norms[resname] = torch.nn.LayerNorm(cnn2d_output_depths[i]) 91 | 92 | if self.n_attn_heads > 0: 93 | self.mv_fusion[resname] = mv_fusion.MVFusionTransformer( 94 | cnn2d_output_depths[i], attn_layers, self.n_attn_heads, cv_cha=self.cv_dim, 95 | ) 96 | else: 97 | self.mv_fusion[resname] = mv_fusion.MVFusionMean() 98 | 99 | input_depth = prev_output_depth + cnn2d_output_depths[i] 100 | if i > 0: 101 | # additional channel for the previous level's occupancy prediction 102 | input_depth += 1 103 | conv = cnn3d.SPVCNN( 104 | in_channels=input_depth, 105 | base_depth=cnn3d_base_depths[i], 106 | dropout=False, 107 | ) 108 | output_depth = conv.output_depth 109 | self.cnns3d[resname] = conv 110 | self.output_layers[resname] = torchsparse.nn.Conv3d( 111 | output_depth, 1, kernel_size=1, stride=1 112 | ) 113 | prev_output_depth = conv.output_depth 114 | 115 | def get_img_feats(self, rgb_imgs, proj_mats, cam_positions): 116 | batchsize, n_imgs, _, imheight, imwidth = rgb_imgs.shape 117 | feats = self.cnn2d(rgb_imgs.reshape((batchsize * n_imgs, *rgb_imgs.shape[2:]))) 118 | for resname in self.resolutions: 119 | f = feats[resname] 120 | f = self.view_embedders[resname](f, proj_mats[resname], cam_positions) 121 | f = f.reshape((batchsize, n_imgs, *f.shape[1:])) 122 | feats[resname] = f 123 | return feats 124 | 125 | def get_SR_feats(self, batch_SRfeats0, batch_SRfeats1, batch_SRfeats2, proj_mats, cam_positions): 126 | ''' 127 | batch_SRfeats0: [4, 20, 64, 96, 128] 128 | 129 | return: 130 | SR_feats: [4, 20, 80, 30, 40], [4, 20, 40, 60, 80], [4, 20, 24, 120, 160] 131 | ''' 132 | batchsize, n_imgs = batch_SRfeats0.shape[:2] 133 | batch_SRfeats0 = batch_SRfeats0.reshape((batchsize * n_imgs, *batch_SRfeats0.shape[2:])) # [bs*n_imgs, c, h, w] 134 | batch_SRfeats1 = batch_SRfeats1.reshape((batchsize * n_imgs, *batch_SRfeats1.shape[2:])) 135 | batch_SRfeats2 = batch_SRfeats2.reshape((batchsize * n_imgs, *batch_SRfeats2.shape[2:])) 136 | 137 | batch_SRfeats0 = F.interpolate(batch_SRfeats0, [120, 160], mode='bilinear') 138 | batch_SRfeats1 = F.interpolate(batch_SRfeats1, [60, 80], mode='bilinear') 139 | batch_SRfeats2 = F.interpolate(batch_SRfeats2, [30, 40], mode='bilinear') 140 | 141 | feats = {} 142 | feats['fine'] = self.sr_encoder['fine'](batch_SRfeats0) 143 | feats['medium'] = self.sr_encoder['medium'](batch_SRfeats1) 144 | feats['coarse'] = self.sr_encoder['coarse'](batch_SRfeats2) 145 | 146 | if self.SR_vi_ebd: 147 | feats['fine'] = self.view_embedders['fine'](feats['fine'], proj_mats['fine'], cam_positions) 148 | feats['medium'] = self.view_embedders['medium'](feats['medium'], proj_mats['medium'], cam_positions) 149 | feats['coarse'] = self.view_embedders['coarse'](feats['coarse'], proj_mats['coarse'], cam_positions) 150 | 151 | feats['fine'] = feats['fine'].reshape((batchsize, n_imgs, 24, 120, 160)) 152 | feats['medium'] = feats['medium'].reshape((batchsize, n_imgs, 40, 60, 80)) 153 | feats['coarse'] = feats['coarse'].reshape((batchsize, n_imgs, 80, 30, 40)) 154 | 155 | return feats 156 | 157 | def compute_matching_feats( 158 | self, 159 | all_frames_bm3hw 160 | ): 161 | """ 162 | Computes matching features for the current image (reference) and 163 | source images. 164 | 165 | Unfortunately on this PyTorch branch we've noticed that the output 166 | of our ResNet matching encoder is not numerically consistent when 167 | batching. While this doesn't affect training (the changes are too 168 | small), it does change and will affect test scores. To combat this 169 | we disable batching through this module when testing and instead 170 | loop through images to compute their feautures. This is stable and 171 | produces exact repeatable results. 172 | 173 | Args: 174 | cur_image: image tensor of shape B3HW for the reference image. 175 | src_image: images tensor of shape BM3HW for the source images. 176 | unbatched_matching_encoder_forward: disable batching and loops 177 | through iamges to compute feaures. 178 | Returns: 179 | matching_cur_feats: tensor of matching features of size bchw for 180 | the reference current image. 181 | matching_src_feats: tensor of matching features of size BMcHW 182 | for the source images. 183 | """ 184 | if True: 185 | batch_size, num_views = all_frames_bm3hw.shape[:2] 186 | all_frames_B3hw = tensor_bM_to_B(all_frames_bm3hw) 187 | matching_feats = [self.matching_encoder(f) 188 | for f in all_frames_B3hw.split(40, dim=0)] 189 | 190 | matching_feats = torch.cat(matching_feats, dim=0) 191 | matching_feats = tensor_B_to_bM( 192 | matching_feats, 193 | batch_size=batch_size, 194 | num_views=num_views, 195 | ) 196 | 197 | else: 198 | # Compute matching features and batch them to reduce variance from 199 | # batchnorm when training. 200 | matching_feats = self.tensor_formatter(all_frames_bm3hw, 201 | apply_func=self.matching_encoder, 202 | ) 203 | 204 | return matching_feats 205 | 206 | 207 | def construct_cv(self, batch, n_imgs): 208 | cvs = [] 209 | cv_masks = [] 210 | cur_invK = batch["cv_invK"] 211 | src_K = batch["cv_k"].unsqueeze(1).repeat(1, 7, 1, 1) 212 | min_depth = torch.tensor(self.min_depth).type_as(src_K).view(1, 1, 1, 1) 213 | max_depth = torch.tensor(self.max_depth).type_as(src_K).view(1, 1, 1, 1) 214 | matching_feats = self.compute_matching_feats(batch["rgb_imgs"]) 215 | matching_src_feats = matching_feats[:, n_imgs:].view([-1, n_imgs, 7] + list(matching_feats.shape[2:])) 216 | inv_poses = batch['inv_pose'][:, n_imgs:].view([-1, n_imgs, 7, 4, 4]) 217 | poses = batch['pose'][:, n_imgs:].view([-1, n_imgs, 7, 4, 4]) 218 | 219 | for i in range(n_imgs): 220 | matching_cur_feats = matching_feats[:, i] 221 | matching_src_feat = matching_src_feats[:, i] 222 | 223 | src_cam_T_world = inv_poses[:, i] 224 | src_world_T_cam = poses[:, i] 225 | cur_cam_T_world = batch["inv_pose"][:, i, ...] 226 | cur_world_T_cam = batch["pose"][:, i, ...] 227 | with torch.cuda.amp.autocast(False): 228 | # Compute src_cam_T_cur_cam, a transformation for going from 3D 229 | # coords in current view coordinate frame to source view coords 230 | # coordinate frames. 231 | src_cam_T_cur_cam = src_cam_T_world @ cur_world_T_cam.unsqueeze(1) 232 | 233 | # Compute cur_cam_T_src_cam the opposite of src_cam_T_cur_cam. From 234 | # source view to current view. 235 | cur_cam_T_src_cam = cur_cam_T_world.unsqueeze(1) @ src_world_T_cam 236 | 237 | cost_volume, lowest_cost, _, overall_mask_bhw = self.cost_volume( 238 | cur_feats=matching_cur_feats, 239 | src_feats=matching_src_feat, 240 | src_extrinsics=src_cam_T_cur_cam, 241 | src_poses=cur_cam_T_src_cam, 242 | src_Ks=src_K, 243 | cur_invK=cur_invK, 244 | min_depth=min_depth, 245 | max_depth=max_depth, 246 | return_mask=True, 247 | ) 248 | cvs.append(cost_volume.unsqueeze(1)) 249 | cv_masks.append(overall_mask_bhw.unsqueeze(1)) 250 | cvs = torch.cat(cvs, dim=1) # [b, n, c, d, h, w] 251 | cv_masks = torch.cat(cv_masks, dim=1) 252 | if self.cv_overall: 253 | # ############################### skiped overall feat #################################################### 254 | # overallfeat = cvs[:, :, -1:, ::8, ...].permute(0, 1, 3, 2, 4, 5).expand([-1, -1, -1, cvs.shape[3], -1, -1]) 255 | 256 | # ############################### conv overall feat #################################################### 257 | # overallfeat = cvs[:, :, -1, ...].view([-1] + list(cvs.shape[3:])) 258 | # overallfeat = self.cv_global_encoder(overallfeat).view(list(cvs.shape[:2]) + [8, 1, cvs.shape[-2], cvs.shape[-1]]) 259 | # overallfeat = overallfeat.expand([-1, -1, -1, cvs.shape[3], -1, -1]) 260 | 261 | ############################### complete overall feat #################################################### 262 | # overallfeat = cvs[:, :, -1:, :, ...].permute(0, 1, 3, 2, 4, 5).expand([-1, -1, -1, cvs.shape[3], -1, -1]) 263 | 264 | # # cvs = cvs[:, :, :-1, ...] 265 | # cvs = torch.cat([overallfeat, cvs], dim=2) 266 | pass 267 | return cvs, cv_masks 268 | 269 | 270 | def forward(self, batch, voxel_inds_16): 271 | bs, n_imgs = batch['depth_imgs'].shape[:2] 272 | if self.use_cost_volume: 273 | cost_volume, cv_masks = self.construct_cv(batch, n_imgs) 274 | batch['rgb_imgs'] = batch['rgb_imgs'][:, :n_imgs] 275 | for b in range(bs): 276 | cost_volume[b][batch['cv_invalid_mask'][b].bool()] = 0 277 | 278 | if self.SRfeat: 279 | feats_2d = self.get_SR_feats(batch["SRfeat0"], batch["SRfeat1"], batch["SRfeat2"] 280 | , batch["proj_mats"], batch["cam_positions"]) 281 | else: 282 | feats_2d = self.get_img_feats( 283 | batch["rgb_imgs"], batch["proj_mats"], batch["cam_positions"] 284 | ) 285 | 286 | if not self.depth_head: depth_out = None 287 | 288 | device = voxel_inds_16.device 289 | proj_occ_logits = {} 290 | voxel_outputs = {} 291 | bp_data = {} 292 | n_subsample = { 293 | "medium": 2 ** 14, 294 | "fine": 2 ** 16, 295 | } 296 | 297 | voxel_inds = voxel_inds_16 298 | voxel_features = torch.empty( 299 | (len(voxel_inds), 0), dtype=feats_2d["coarse"].dtype, device=device 300 | ) 301 | voxel_logits = torch.empty( 302 | (len(voxel_inds), 0), dtype=feats_2d["coarse"].dtype, device=device 303 | ) 304 | for resname, res in self.resolutions.items(): 305 | if self.training and resname in n_subsample: # subsample voxels. 306 | # this saves memory and possibly acts as a data augmentation 307 | subsample_inds = get_subsample_inds(voxel_inds, n_subsample[resname]) 308 | voxel_inds = voxel_inds[subsample_inds] 309 | voxel_features = voxel_features[subsample_inds] 310 | voxel_logits = voxel_logits[subsample_inds] 311 | 312 | voxel_batch_inds = voxel_inds[:, 3].long() 313 | voxel_coords = voxel_inds[:, :3] * res + batch["origin"][voxel_batch_inds] # convert to unit of meters 314 | 315 | featheight, featwidth = feats_2d[resname].shape[-2:] 316 | 317 | feat_cha = {'coarse': 80, 'medium': 40, 'fine':24} 318 | cv_dim = self.cv_dim - 8 319 | if resname != 'medium': 320 | cur_cost_volume = F.interpolate(cost_volume.view([bs*n_imgs*cv_dim, 64, 60, 80]), [featheight, featwidth]).view([bs, n_imgs, cv_dim, 64, featheight, featwidth]) 321 | else: cur_cost_volume = cost_volume.clone() 322 | feats_2d[resname] = self.cv_global_encoder[resname](torch.cat([cur_cost_volume[:,:,-1], feats_2d[resname]], dim=2).view([-1, feat_cha[resname]+64, featheight, featwidth])) 323 | feats_2d[resname] = feats_2d[resname].view([bs, n_imgs, feat_cha[resname]+64, featheight, featwidth]) 324 | overallfeat = feats_2d[resname][:, :, :64].unsqueeze(3).expand([-1, -1, -1, 64, -1, -1]) 325 | feats_2d[resname] = feats_2d[resname][:,:,64:] 326 | # for d in range(64): 327 | # cur_cost_volume[:,:,:,d] = self.unshared_conv[resname][d]( 328 | # torch.cat([cur_cost_volume[:,:,:,d], feats_2d[resname]], dim=2).view([-1, feat_cha[resname]+7, featheight, featwidth]) 329 | # ).view([bs, n_imgs, 7, featheight, featwidth]) 330 | cur_cost_volume = self.unshared_conv[resname]( 331 | torch.cat([feats_2d[resname].unsqueeze(3).expand([-1,-1,-1,64,-1,-1]), cur_cost_volume], dim=2).transpose(2,3).reshape(bs*n_imgs,-1,featheight, featwidth)) 332 | cur_cost_volume = cur_cost_volume.view([bs, n_imgs, 64, 7, featheight, featwidth]).transpose(2,3) 333 | 334 | cur_cost_volume = torch.cat([overallfeat, cur_cost_volume], dim=2) 335 | 336 | bp_uv, bp_depth, bp_mask = self.project_voxels( # project voxels to each image plane 337 | voxel_coords, 338 | voxel_batch_inds, 339 | batch["proj_mats"][resname].transpose(0, 1), 340 | featheight, 341 | featwidth, 342 | ) 343 | bp_data[resname] = { 344 | "voxel_coords": voxel_coords, 345 | "voxel_batch_inds": voxel_batch_inds, 346 | "bp_uv": bp_uv, 347 | "bp_depth": bp_depth, 348 | "bp_mask": bp_mask, 349 | } 350 | bp_feats, cur_proj_occ_logits = self.back_project_features( # put 2dCNN features into voxels. 351 | bp_data[resname], 352 | feats_2d[resname].transpose(0, 1), 353 | self.mv_fusion[resname], 354 | cur_cost_volume if (self.SRCV or self.use_cost_volume) else None, 355 | cv_masks if self.use_cost_volume else None, 356 | ) 357 | proj_occ_logits[resname] = cur_proj_occ_logits 358 | 359 | bp_feats = self.layer_norms[resname](bp_feats) 360 | 361 | voxel_features = torch.cat((voxel_features, bp_feats, voxel_logits), dim=-1) # not understood !!!! 362 | voxel_features = torchsparse.SparseTensor(voxel_features, voxel_inds) 363 | try: 364 | voxel_features = self.cnns3d[resname](voxel_features) 365 | except Exception as e: 366 | print(e) 367 | return voxel_outputs, proj_occ_logits, bp_data, depth_out 368 | 369 | voxel_logits = self.output_layers[resname](voxel_features) 370 | voxel_outputs[resname] = voxel_logits 371 | 372 | if resname in ["coarse", "medium"]: 373 | # sparsify & upsample 374 | occupancy = voxel_logits.F.squeeze(1) > 0 375 | if not torch.any(occupancy): 376 | return voxel_outputs, proj_occ_logits, bp_data, depth_out 377 | voxel_features = self.upsampler.upsample_feats( 378 | voxel_features.F[occupancy] 379 | ) 380 | voxel_inds = self.upsampler.upsample_inds(voxel_logits.C[occupancy]) 381 | voxel_logits = self.upsampler.upsample_feats(voxel_logits.F[occupancy]) 382 | 383 | return voxel_outputs, proj_occ_logits, bp_data, depth_out 384 | 385 | def losses(self, voxel_logits, voxel_gt, proj_occ_logits, bp_data, depth_imgs, depth_out): 386 | voxel_losses = {} 387 | proj_occ_losses = {} 388 | for resname in voxel_logits: 389 | logits = voxel_logits[resname] 390 | gt = voxel_gt[resname] 391 | cur_loss = torch.zeros(1, device=logits.F.device, dtype=torch.float32) 392 | if len(logits.C) > 0: 393 | pred_hash = spf.sphash(logits.C) 394 | gt_hash = spf.sphash(gt.C) 395 | idx_query = spf.sphashquery(pred_hash, gt_hash) 396 | good_query = idx_query != -1 397 | gt = gt.F[idx_query[good_query]] 398 | logits = logits.F.squeeze(1)[good_query] 399 | if len(logits) > 0: 400 | if resname == "fine": 401 | cur_loss = torch.nn.functional.l1_loss( 402 | utils.log_transform(1.05 * torch.tanh(logits)), 403 | utils.log_transform(gt), 404 | ) 405 | else: 406 | cur_loss = torch.nn.functional.binary_cross_entropy_with_logits( 407 | logits, gt 408 | ) 409 | voxel_losses[resname] = cur_loss 410 | 411 | proj_occ_losses[resname] = compute_proj_occ_loss( 412 | proj_occ_logits[resname], 413 | depth_imgs, 414 | bp_data[resname], 415 | truncation_distance=3 * self.resolutions[resname], 416 | ) 417 | 418 | loss = sum(voxel_losses.values()) + sum(proj_occ_losses.values()) 419 | logs = { 420 | **{ 421 | f"voxel_loss_{resname}": voxel_losses[resname].detach() 422 | for resname in voxel_losses 423 | }, 424 | **{ 425 | f"proj_occ_loss_{resname}": proj_occ_losses[resname].detach() 426 | for resname in proj_occ_losses 427 | }, 428 | } 429 | 430 | if depth_out is not None: 431 | bs, n_imgs = depth_imgs.shape[:2] 432 | depth_out = F.interpolate(depth_out, [480, 640], mode="bilinear", align_corners=False,).view([bs, n_imgs, 480, 640]).float() 433 | mask = ((depth_imgs > 0.001) & (depth_imgs < 10)) 434 | depth_loss = self.depth_loss(depth_out[mask], torch.log(depth_imgs)[mask]) 435 | loss += depth_loss 436 | logs.update({'depth_loss': depth_loss.detach()}) 437 | 438 | return loss, logs 439 | 440 | def project_voxels( 441 | self, voxel_coords, voxel_batch_inds, projmat, imheight, imwidth 442 | ): 443 | device = voxel_coords.device 444 | n_voxels = len(voxel_coords) 445 | n_imgs = len(projmat) 446 | bp_uv = torch.zeros((n_imgs, n_voxels, 2), device=device, dtype=torch.float32) 447 | bp_depth = torch.zeros((n_imgs, n_voxels), device=device, dtype=torch.float32) 448 | bp_mask = torch.zeros((n_imgs, n_voxels), device=device, dtype=torch.bool) 449 | batch_inds = torch.unique(voxel_batch_inds) 450 | for batch_ind in batch_inds: 451 | batch_mask = voxel_batch_inds == batch_ind 452 | if torch.sum(batch_mask) == 0: 453 | continue 454 | cur_voxel_coords = voxel_coords[batch_mask] 455 | 456 | ones = torch.ones( 457 | (len(cur_voxel_coords), 1), device=device, dtype=torch.float32 458 | ) 459 | voxel_coords_h = torch.cat((cur_voxel_coords, ones), dim=-1) 460 | 461 | im_p = projmat[:, batch_ind] @ voxel_coords_h.t() 462 | im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] 463 | im_x = im_x / im_z 464 | im_y = im_y / im_z 465 | im_grid = torch.stack( 466 | [2 * im_x / (imwidth - 1) - 1, 2 * im_y / (imheight - 1) - 1], 467 | dim=-1, 468 | ) 469 | im_grid[torch.isinf(im_grid)] = -2 470 | mask = im_grid.abs() <= 1 471 | mask = (mask.sum(dim=-1) == 2) & (im_z > 0) 472 | 473 | bp_uv[:, batch_mask] = im_grid.to(bp_uv.dtype) 474 | bp_depth[:, batch_mask] = im_z.to(bp_uv.dtype) 475 | bp_mask[:, batch_mask] = mask 476 | 477 | return bp_uv, bp_depth, bp_mask 478 | 479 | def back_project_features(self, bp_data, feats, mv_fuser, SRCV=None, cv_masks=None): 480 | n_imgs, batch_size, in_channels, featheight, featwidth = feats.shape 481 | device = feats.device 482 | n_voxels = len(bp_data["voxel_batch_inds"]) 483 | feature_volume_all = torch.zeros( 484 | n_voxels, in_channels, device=device, dtype=torch.float32 485 | ) 486 | # the default proj occ prediction is true everywhere -> logits high 487 | proj_occ_logits = torch.full( 488 | (n_imgs, n_voxels), 100, device=device, dtype=feats.dtype 489 | ) 490 | batch_inds = torch.unique(bp_data["voxel_batch_inds"]) 491 | for batch_ind in batch_inds: 492 | batch_mask = bp_data["voxel_batch_inds"] == batch_ind 493 | if torch.sum(batch_mask) == 0: 494 | continue 495 | 496 | cur_bp_uv = bp_data["bp_uv"][:, batch_mask] # [n_imgs, n_voxels, 2] 497 | cur_bp_depth = bp_data["bp_depth"][:, batch_mask] # [n_imgs, n_voxels] 498 | cur_bp_mask = bp_data["bp_mask"][:, batch_mask] 499 | cur_feats = feats[:, batch_ind].view( 500 | n_imgs, in_channels, featheight, featwidth 501 | ) 502 | 503 | ################# for normal ########################### 504 | # cur_bp_uv = cur_bp_uv.view(n_imgs, 1, -1, 2) # [n_imgs, 1, n_voxels, 2] 505 | # features = torch.nn.functional.grid_sample( 506 | # cur_feats, 507 | # cur_bp_uv.to(cur_feats.dtype), 508 | # padding_mode="reflection", 509 | # align_corners=True, 510 | # ) 511 | # # features [n_imgs, in_channels, 1, n_voxels] 512 | # features = features.view(n_imgs, in_channels, -1) # [n_imgs, in_channels, n_voxels] 513 | ####################################################### 514 | if SRCV is not None: 515 | cur_bp_d = ((torch.log(cur_bp_depth) - torch.log(torch.tensor(self.min_depth))) / torch.log(torch.tensor(self.max_depth/self.min_depth))) * 2.0 - 1.0 516 | cur_bp_d.nan_to_num_(nan = -1) # negative depth will cause nan in log 517 | cur_bp_uv3d = bp_data["bp_uv"][:, batch_mask] 518 | cur_bp_uvd = torch.cat([cur_bp_uv3d, cur_bp_d[...,None]], dim=-1).unsqueeze(1).unsqueeze(1) # [n_imgs, 1, 1, n_voxels, 3] 519 | 520 | ''' cv mask before grid sample, depreciated 521 | cur_srcv = torch.zeros_like(SRCV[batch_ind]) # [n_imgs, ch(128), d(64), h(60), w(80)] 522 | cv_mask = cv_masks[batch_ind][:,None,None].expand(cur_srcv.shape) 523 | # cur_srcv[~cv_mask] = 100 524 | cur_srcv[cv_mask] = SRCV[batch_ind][cv_mask] 525 | ''' 526 | 527 | 528 | # cv_mask = ~cv_masks[batch_ind][:,None].detach() # invalid costs in the cost_volume 529 | # cv_mask = torch.nn.functional.grid_sample( 530 | # cv_mask.to(cur_feats.dtype), 531 | # cur_bp_uv.to(cur_feats.dtype), 532 | # padding_mode="zeros", 533 | # align_corners=True, 534 | # ) # all voxels that are polluted by the invalid costs 535 | 536 | cur_srcv = SRCV[batch_ind] 537 | features_cv = torch.nn.functional.grid_sample( 538 | cur_srcv, 539 | cur_bp_uvd.to(cur_srcv.dtype), 540 | padding_mode="zeros", 541 | align_corners=True, 542 | ) 543 | # features_cv [20, C, 1, 1, 6912] 544 | c = features_cv.shape[1] 545 | features_cv = features_cv.view(n_imgs, c, -1) # [n_imgs, C, n_voxels] 546 | 547 | # ################################# concat mask ################################ remember to change depth mlp channel 548 | # cv_mask = cv_mask.view(n_imgs, 1, -1).detach() 549 | # features_cv = torch.cat([cv_mask, features_cv], dim=1) 550 | # ################################################################################### 551 | 552 | # ################################# grid sample mask ################################ 553 | # cv_mask = (cv_mask > 0).view(n_imgs, 1, -1).expand(features_cv.shape).detach() 554 | # features_cv[cv_mask] = 0 555 | # ################################################################################### 556 | 557 | ######################################## atten mask ################################ 558 | #cv_mask = cv_mask.squeeze() > 0 559 | #cur_bp_mask[cv_mask] = False 560 | #################################################################################### 561 | 562 | # features = torch.cat([features, features_cv], dim=1) 563 | features = features_cv 564 | # features[:, -c:, :] = features_cv 565 | 566 | if isinstance(mv_fuser, mv_fusion.MVFusionTransformer): 567 | pooled_features, cur_proj_occ_logits = mv_fuser( 568 | features, 569 | cur_bp_depth, 570 | cur_bp_mask, 571 | self.use_proj_occ, 572 | ) 573 | feature_volume_all[batch_mask] = pooled_features 574 | proj_occ_logits[:, batch_mask] = cur_proj_occ_logits 575 | else: 576 | pooled_features = mv_fuser(features.transpose(1, 2), cur_bp_mask) 577 | feature_volume_all[batch_mask] = pooled_features 578 | 579 | return (feature_volume_all, proj_occ_logits) 580 | 581 | 582 | class Upsampler(torch.nn.Module): 583 | # nearest neighbor 2x upsampling for sparse 3D array 584 | 585 | def __init__(self): 586 | super().__init__() 587 | self.upsample_offsets = torch.nn.Parameter( 588 | torch.Tensor( 589 | [ 590 | [ 591 | [0, 0, 0, 0], 592 | [1, 0, 0, 0], 593 | [0, 1, 0, 0], 594 | [0, 0, 1, 0], 595 | [1, 1, 0, 0], 596 | [0, 1, 1, 0], 597 | [1, 0, 1, 0], 598 | [1, 1, 1, 0], 599 | ] 600 | ] 601 | ).to(torch.int32), 602 | requires_grad=False, 603 | ) 604 | self.upsample_mul = torch.nn.Parameter( 605 | torch.Tensor([[[2, 2, 2, 1]]]).to(torch.int32), requires_grad=False 606 | ) 607 | 608 | def upsample_inds(self, voxel_inds): 609 | return ( 610 | voxel_inds[:, None] * self.upsample_mul + self.upsample_offsets 611 | ).reshape(-1, 4) 612 | 613 | def upsample_feats(self, feats): 614 | return ( 615 | feats[:, None] 616 | .repeat(1, 8, 1) 617 | .reshape(-1, feats.shape[-1]) 618 | .to(torch.float32) 619 | ) 620 | 621 | 622 | def get_subsample_inds(coords, max_per_example): 623 | keep_inds = [] 624 | batch_inds = coords[:, 3].unique() 625 | for batch_ind in batch_inds: 626 | batch_mask = coords[:, -1] == batch_ind 627 | n = torch.sum(batch_mask) 628 | if n > max_per_example: 629 | keep_inds.append(batch_mask.float().multinomial(max_per_example)) 630 | else: 631 | keep_inds.append(torch.where(batch_mask)[0]) 632 | subsample_inds = torch.cat(keep_inds).long() 633 | return subsample_inds 634 | 635 | 636 | def compute_proj_occ_loss(proj_occ_logits, depth_imgs, bp_data, truncation_distance): 637 | batch_inds = torch.unique(bp_data["voxel_batch_inds"]) 638 | for batch_ind in batch_inds: 639 | batch_mask = bp_data["voxel_batch_inds"] == batch_ind 640 | cur_bp_uv = bp_data["bp_uv"][:, batch_mask] 641 | cur_bp_depth = bp_data["bp_depth"][:, batch_mask] 642 | cur_bp_mask = bp_data["bp_mask"][:, batch_mask] 643 | cur_proj_occ_logits = proj_occ_logits[:, batch_mask] 644 | 645 | depth = torch.nn.functional.grid_sample( 646 | depth_imgs[batch_ind, :, None], 647 | cur_bp_uv[:, None].to(depth_imgs.dtype), 648 | padding_mode="zeros", 649 | mode="nearest", 650 | align_corners=False, 651 | )[:, 0, 0] 652 | 653 | proj_occ_mask = cur_bp_mask & (depth > 0) 654 | if torch.sum(proj_occ_mask) > 0: 655 | proj_occ_gt = torch.abs(cur_bp_depth - depth) < truncation_distance 656 | return torch.nn.functional.binary_cross_entropy_with_logits( 657 | cur_proj_occ_logits[proj_occ_mask], 658 | proj_occ_gt[proj_occ_mask].to(cur_proj_occ_logits.dtype), 659 | ) 660 | else: 661 | return torch.zeros((), dtype=torch.float32, device=depth_imgs.device) 662 | -------------------------------------------------------------------------------- /cvrecon/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | import imageio 6 | import numpy as np 7 | import PIL.Image 8 | import skimage.morphology 9 | import torch 10 | import torchsparse 11 | import torchvision 12 | from collections import defaultdict 13 | 14 | from cvrecon import utils 15 | 16 | 17 | img_mean_rgb = np.array([127.71, 114.66, 99.32], dtype=np.float32) 18 | img_std_rgb = np.array([75.31, 73.53, 71.66], dtype=np.float32) 19 | 20 | 21 | def load_tsdf(tsdf_dir, scene_name): 22 | tsdf_fname = os.path.join(tsdf_dir, scene_name, "full_tsdf_layer0.npz") 23 | with np.load(tsdf_fname) as tsdf_04_npz: 24 | tsdf = tsdf_04_npz["arr_0"] 25 | 26 | pkl_fname = os.path.join(tsdf_dir, scene_name, "tsdf_info.pkl") 27 | with open(pkl_fname, "rb") as tsdf_pkl: 28 | tsdf_info = pickle.load(tsdf_pkl) 29 | origin = tsdf_info['vol_origin'] 30 | voxel_size = tsdf_info['voxel_size'] 31 | 32 | return tsdf, origin, voxel_size 33 | 34 | 35 | def reflect_pose(pose, plane_pt=None, plane_normal=None): 36 | pts = pose @ np.array( 37 | [ 38 | [1, 0, 0, 0], 39 | [0, 1, 0, 0], 40 | [0, 0, 1, 0], 41 | [1, 1, 1, 1], 42 | ], 43 | dtype=np.float32, 44 | ) 45 | plane_pt = np.array([*plane_pt, 1], dtype=np.float32) 46 | 47 | pts = pts - plane_pt[None, :, None] 48 | 49 | plane_normal = plane_normal / np.linalg.norm(plane_normal) 50 | m = np.zeros((4, 4), dtype=np.float32) 51 | m[:3, :3] = np.eye(3) - 2 * plane_normal[None].T @ plane_normal[None] 52 | 53 | pts = m @ pts + plane_pt[None, :, None] 54 | 55 | result = np.eye(4, dtype=np.float32)[None].repeat(len(pose), axis=0) 56 | result[:, :, :3] = pts[:, :, :3] - pts[:, :, 3:] 57 | result[:, :, 3] = pts[:, :, 3] 58 | return result 59 | 60 | 61 | def get_proj_mats(intr, pose, factors): 62 | k = np.eye(4, dtype=np.float32) 63 | k[:3, :3] = intr 64 | k[0] = k[0] * factors[0] 65 | k[1] = k[1] * factors[0] 66 | proj_lowres = k @ pose 67 | 68 | k = np.eye(4, dtype=np.float32) 69 | k[:3, :3] = intr 70 | k[0] = k[0] * factors[1] 71 | k[1] = k[1] * factors[1] 72 | proj_midres = k @ pose 73 | 74 | k = np.eye(4, dtype=np.float32) 75 | k[:3, :3] = intr 76 | k[0] = k[0] * factors[2] 77 | k[1] = k[1] * factors[2] 78 | proj_highres = k @ pose 79 | 80 | k = np.eye(4, dtype=np.float32) 81 | k[:3, :3] = intr 82 | proj_depth = k @ pose 83 | 84 | return { 85 | "coarse": proj_lowres, 86 | "medium": proj_midres, 87 | "fine": proj_highres, 88 | "fullres": proj_depth, 89 | } 90 | 91 | 92 | def load_rgb_imgs(rgb_imgfiles, imheight, imwidth, augment=False): 93 | if augment: 94 | transforms = [ 95 | ( 96 | torchvision.transforms.functional.adjust_brightness, 97 | np.random.uniform(0.5, 1.5), 98 | ), 99 | ( 100 | torchvision.transforms.functional.adjust_contrast, 101 | np.random.uniform(0.5, 1.5), 102 | ), 103 | ( 104 | torchvision.transforms.functional.adjust_hue, 105 | np.random.uniform(-0.05, 0.05), 106 | ), 107 | ( 108 | torchvision.transforms.functional.adjust_saturation, 109 | np.random.uniform(0.5, 1.5), 110 | ), 111 | ( 112 | torchvision.transforms.functional.gaussian_blur, 113 | 7, 114 | np.random.randint(1, 4), 115 | ), 116 | ] 117 | transforms = [ 118 | transforms[i] 119 | for i in np.random.choice(len(transforms), size=2, replace=False) 120 | ] 121 | rgb_imgs = np.empty((len(rgb_imgfiles), imheight, imwidth, 3), dtype=np.float32) 122 | for i, f in enumerate(rgb_imgfiles): 123 | img = PIL.Image.open(f) 124 | if augment: 125 | for t, *params in transforms: 126 | img = t(img, *params) 127 | rgb_imgs[i] = img 128 | 129 | rgb_imgs -= img_mean_rgb 130 | rgb_imgs /= img_std_rgb 131 | rgb_imgs = np.transpose(rgb_imgs, (0, 3, 1, 2)) 132 | return rgb_imgs 133 | 134 | 135 | def load_SRfeats(scene, frame_inds): 136 | ''' 137 | [64, 96, 128], [128, 48, 64], [256, 24, 32] 138 | ''' 139 | # if scene == 'scene0230_00': 140 | # import pdb; pdb.set_trace() 141 | scale0 = np.empty((len(frame_inds), 64, 96, 128), dtype=np.float32) 142 | scale1 = np.empty((len(frame_inds), 128, 48, 64), dtype=np.float32) 143 | scale2 = np.empty((len(frame_inds), 256, 24, 32), dtype=np.float32) 144 | 145 | for i, frame_ind in enumerate(frame_inds): 146 | fname = '/ScanNet2/SRfeats/' + scene + '/' + frame_ind 147 | if os.path.exists(fname + '_s0.npy'): 148 | scale0[i] = np.load(fname + '_s0.npy') 149 | scale1[i] = np.load(fname + '_s1.npy') 150 | scale2[i] = np.load(fname + '_s2.npy') 151 | else: 152 | scale0[i] = np.zeros((64, 96, 128), dtype=np.float32) 153 | scale1[i] = np.zeros((128, 48, 64), dtype=np.float32) 154 | scale2[i] = np.zeros((256, 24, 32), dtype=np.float32) 155 | return [scale0, scale1, scale2] 156 | 157 | 158 | def load_SRCV(scene, frame_inds): 159 | ''' 160 | [64, 96, 128] 161 | ''' 162 | CV = np.empty((len(frame_inds), 64, 96, 128), dtype=np.float32) 163 | 164 | for i, frame_ind in enumerate(frame_inds): 165 | fname = '/ScanNet2/SRCV/' + scene + '/' + frame_ind 166 | if os.path.exists(fname + '_cv.npy'): 167 | CV[i] = np.load(fname + '_cv.npy') 168 | else: 169 | CV[i] = np.zeros((64, 96, 128), dtype=np.float32) 170 | return CV 171 | 172 | 173 | def pose_distance(pose_b44): 174 | """ 175 | DVMVS frame pose distance. 176 | """ 177 | 178 | R = pose_b44[:, :3, :3] 179 | t = pose_b44[:, :3, 3] 180 | R_trace = R.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) 181 | R_measure = torch.sqrt(2 * 182 | (1 - torch.minimum(torch.ones_like(R_trace)*3.0, R_trace) / 3)) 183 | t_measure = torch.norm(t, dim=1) 184 | combined_measure = torch.sqrt(t_measure ** 2 + R_measure ** 2) 185 | 186 | return combined_measure, R_measure, t_measure 187 | 188 | 189 | class Dataset(torch.utils.data.Dataset): 190 | def __init__( 191 | self, 192 | info_files, 193 | tsdf_dir, 194 | n_imgs, 195 | cropsize, 196 | augment=True, 197 | load_extra=False, 198 | split=None, 199 | SRfeat=False, 200 | SRCV=False, 201 | cost_volume=False, 202 | ): 203 | self.info_files = info_files 204 | self.n_imgs = n_imgs 205 | self.cropsize = np.array(cropsize) 206 | self.augment = augment 207 | self.load_extra = load_extra 208 | self.tsdf_dir = tsdf_dir 209 | 210 | self.tmin = 0.1 211 | self.rmin_deg = 15 212 | 213 | self.SRfeat = SRfeat 214 | self.SRCV = SRCV 215 | self.cost_volume = cost_volume 216 | if self.SRfeat or self.SRCV or self.cost_volume: 217 | self.CVDicts = defaultdict(dict) 218 | fname = 'data_splits/ScanNetv2/standard_split/{}_for_cvrecon.txt'.format(split) 219 | if split == 'test': 220 | fname = 'data_splits/ScanNetv2/standard_split/test_eight_view_deepvmvs_dense.txt' 221 | with open(fname, 'r') as f: 222 | lines = f.read().splitlines() 223 | for line in lines: 224 | scan_id, *frame_id = line.split(" ") 225 | self.CVDicts[scan_id][frame_id[0]] = frame_id[1:] 226 | 227 | 228 | 229 | def __len__(self): 230 | return len(self.info_files) 231 | 232 | def getitem(self, ind, **kwargs): 233 | return self.__getitem__(ind, **kwargs) 234 | 235 | def __getitem__(self, ind): # ind is the index of the scene 236 | with open(self.info_files[ind], "r") as f: 237 | info = json.load(f) 238 | 239 | scene_name = info["scene"] 240 | tsdf_04, origin, _ = load_tsdf(self.tsdf_dir, scene_name) 241 | 242 | rgb_imgfiles = [frame["filename_color"] for frame in info["frames"]] 243 | depth_imgfiles = [frame["filename_depth"] for frame in info["frames"]] 244 | pose = np.empty((len(info["frames"]), 4, 4), dtype=np.float32) 245 | for i, frame in enumerate(info["frames"]): 246 | pose[i] = frame["pose"] 247 | intr = np.array(info["intrinsics"], dtype=np.float32) 248 | 249 | test_img = imageio.imread(rgb_imgfiles[0]) 250 | imheight, imwidth, _ = test_img.shape 251 | 252 | assert not np.any(np.isinf(pose) | np.isnan(pose)) 253 | 254 | seen_coords = np.argwhere(np.abs(tsdf_04) < 0.999) * 0.04 + origin 255 | i = np.random.randint(len(seen_coords)) 256 | anchor_pt = seen_coords[i] # anchor of the current fragment 257 | offset = np.array( 258 | [ 259 | np.random.uniform(0.04, self.cropsize[0] * 0.04 - 0.04), 260 | np.random.uniform(0.04, self.cropsize[1] * 0.04 - 0.04), 261 | np.random.uniform(0.04, self.cropsize[2] * 0.04 - 0.04), 262 | ] 263 | ) 264 | minbound = anchor_pt - offset 265 | maxbound = minbound + self.cropsize.astype(np.float32) * 0.04 266 | 267 | # the GT TSDF will be sampled at these points 268 | x = np.arange(minbound[0], maxbound[0], .04, dtype=np.float32) 269 | y = np.arange(minbound[1], maxbound[1], .04, dtype=np.float32) 270 | z = np.arange(minbound[2], maxbound[2], .04, dtype=np.float32) 271 | x = x[: self.cropsize[0]] 272 | y = y[: self.cropsize[0]] 273 | z = z[: self.cropsize[0]] 274 | yy, xx, zz = np.meshgrid(y, x, z) 275 | sample_pts = np.stack([xx, yy, zz], axis=-1) # meter as unit, global coordinate. 276 | 277 | flip = False 278 | if self.augment: 279 | center = np.zeros((4, 4), dtype=np.float32) 280 | center[:3, 3] = anchor_pt 281 | 282 | # rotate 283 | t = np.random.uniform(0, 2 * np.pi) 284 | R = np.array( 285 | [ 286 | [np.cos(t), -np.sin(t), 0, 0], 287 | [np.sin(t), np.cos(t), 0, 0], 288 | [0, 0, 1, 0], 289 | [0, 0, 0, 1], 290 | ], 291 | dtype=np.float32, 292 | ) 293 | 294 | shape = sample_pts.shape 295 | sample_pts = ( 296 | R[:3, :3] @ (sample_pts.reshape(-1, 3) - center[:3, 3]).T 297 | ).T + center[:3, 3] 298 | sample_pts = sample_pts.reshape(shape) 299 | 300 | # flip 301 | if np.random.uniform() > 0.5: 302 | flip = True 303 | sample_pts[..., 0] = -(sample_pts[..., 0] - center[0, 3]) + center[0, 3] 304 | 305 | selected_frame_inds = np.array( 306 | utils.remove_redundant(pose, self.rmin_deg, self.tmin) # remove redundant (too small pose changes) using the NeuralRecon's strategy 307 | ) 308 | 309 | ############ remove frames that are not in SRlist 310 | if self.SRfeat or self.SRCV or self.cost_volume: 311 | SRlist_inds = [] 312 | for frame_ind in selected_frame_inds: 313 | if '0' + info['frames'][frame_ind]['filename_color'][-9:-4] in self.CVDicts[scene_name]: 314 | SRlist_inds.append(frame_ind) 315 | selected_frame_inds = SRlist_inds 316 | ################################################### 317 | 318 | if self.n_imgs is not None: 319 | if len(selected_frame_inds) < self.n_imgs: 320 | # after redundant frame removal we can end up with too few frames-- 321 | # add some back in 322 | # print('!!!!!!!!!!!!!!!!!', len(selected_frame_inds), scene_name) 323 | avail_inds = list(set(np.arange(len(pose))) - set(selected_frame_inds)) 324 | n_needed = self.n_imgs - len(selected_frame_inds) 325 | extra_inds = np.random.choice(avail_inds, size=n_needed, replace=False) 326 | selected_frame_inds = np.concatenate((selected_frame_inds, extra_inds)) 327 | elif len(selected_frame_inds) == self.n_imgs: 328 | ... 329 | else: 330 | # after redundant frame removal we still have more than the target # images-- 331 | # remove even more. 332 | pose = pose[selected_frame_inds] 333 | rgb_imgfiles = [rgb_imgfiles[i] for i in selected_frame_inds] 334 | depth_imgfiles = [depth_imgfiles[i] for i in selected_frame_inds] 335 | 336 | selected_frame_inds, score = utils.frame_selection( # First remove frames that has no intersection with the current fragment, then random select n_imgs (20 for train and val) 337 | pose, 338 | intr, 339 | imwidth, 340 | imheight, 341 | sample_pts.reshape(-1, 3)[::100], # every 100th pt for efficiency 342 | self.tmin, 343 | self.rmin_deg, 344 | self.n_imgs, 345 | ) 346 | pose = pose[selected_frame_inds] 347 | rgb_imgfiles = [rgb_imgfiles[i] for i in selected_frame_inds] 348 | depth_imgfiles = [depth_imgfiles[i] for i in selected_frame_inds] 349 | 350 | if self.cost_volume: 351 | cv_invalid_mask = np.zeros(len(rgb_imgfiles), dtype=np.int) 352 | frame2id = {'0'+frame["filename_color"][-9:-4]:i for i, frame in enumerate(info["frames"])} 353 | for i, fname in enumerate(rgb_imgfiles.copy()): 354 | if '0' + fname[-9: -4] in self.CVDicts[scene_name]: 355 | for frameid in self.CVDicts[scene_name]['0' + fname[-9: -4]]: 356 | pose = np.concatenate((pose, np.array(info['frames'][frame2id[frameid]]['pose'], dtype=np.float32)[None,...])) 357 | rgb_imgfiles.append(info['frames'][frame2id[frameid]]['filename_color']) 358 | else: 359 | cv_invalid_mask[i] = 1 360 | pose = np.concatenate((pose, np.array(info['frames'][frame2id['0'+fname[-9: -4]]]['pose'], dtype=np.float32)[None,...].repeat(7,0))) 361 | rgb_imgfiles += [fname] * 7 362 | 363 | if self.augment: 364 | pose = np.linalg.inv(R) @ (pose - center) + center 365 | if flip: 366 | pose = reflect_pose( 367 | pose, 368 | plane_pt=center[:3, 3], 369 | plane_normal=-np.array( 370 | [np.cos(-t), np.sin(-t), 0], dtype=np.float32 371 | ), 372 | ) 373 | pose[:, :3, 3] -= minbound 374 | 375 | # scale the coordinate within current fragment to [-1, 1] 376 | grid = (sample_pts - origin) / ( 377 | (np.array(tsdf_04.shape, dtype=np.float32) - 1) * 0.04 378 | ) * 2 - 1 379 | grid = grid[..., [2, 1, 0]] 380 | 381 | # GT TSDF of the current fragment. 382 | tsdf_04_n = torch.nn.functional.grid_sample( 383 | torch.from_numpy(tsdf_04)[None, None], 384 | torch.from_numpy(grid[None]), 385 | align_corners=False, 386 | mode="nearest", 387 | )[0, 0].numpy() 388 | 389 | tsdf_04_b = torch.nn.functional.grid_sample( 390 | torch.from_numpy(tsdf_04)[None, None], 391 | torch.from_numpy(grid[None]), 392 | align_corners=False, 393 | mode="bilinear", 394 | )[0, 0].numpy() 395 | 396 | # occupied area use bilinear sample, empty area set to 1 397 | tsdf_04 = tsdf_04_b 398 | inds = np.abs(tsdf_04_n) > 0.999 399 | tsdf_04[inds] = tsdf_04_n[inds] 400 | oob_inds = np.any(np.abs(grid) >= 1, axis=-1) 401 | tsdf_04[oob_inds] = 1 402 | 403 | occ_04 = np.abs(tsdf_04) < 0.999 404 | seen_04 = tsdf_04 < 0.999 405 | 406 | # seems like a bug -- dilation should happen before cropping 407 | occ_08 = skimage.morphology.dilation(occ_04, selem=np.ones((3, 3, 3))) # voxel of size 0.08 meter 408 | not_occ_08 = seen_04 & ~occ_08 409 | occ_08 = occ_08[::2, ::2, ::2] 410 | not_occ_08 = not_occ_08[::2, ::2, ::2] 411 | seen_08 = occ_08 | not_occ_08 412 | 413 | occ_16 = skimage.morphology.dilation(occ_08, selem=np.ones((3, 3, 3))) 414 | not_occ_16 = seen_08 & ~occ_16 415 | occ_16 = occ_16[::2, ::2, ::2] 416 | not_occ_16 = not_occ_16[::2, ::2, ::2] 417 | seen_16 = occ_16 | not_occ_16 418 | 419 | rgb_imgs = load_rgb_imgs(rgb_imgfiles, imheight, imwidth, augment=self.augment) 420 | 421 | depth_imgs = np.empty((len(depth_imgfiles), imheight, imwidth), dtype=np.uint16) 422 | for i, f in enumerate(depth_imgfiles): 423 | depth_imgs[i] = imageio.imread(f) 424 | depth_imgs = depth_imgs / np.float32(1000) 425 | 426 | if self.SRfeat: 427 | SRfeat0, SRfeat1, SRfeat2 = load_SRfeats(scene_name, ['0'+x[-9:-4] for x in rgb_imgfiles]) 428 | if self.augment and flip: 429 | SRfeat0 = np.ascontiguousarray(np.flip(SRfeat0, axis=-1)) 430 | SRfeat1 = np.ascontiguousarray(np.flip(SRfeat1, axis=-1)) 431 | SRfeat2 = np.ascontiguousarray(np.flip(SRfeat2, axis=-1)) 432 | if self.SRCV: 433 | SRCV = load_SRCV(scene_name, ['0'+x[-9:-4] for x in rgb_imgfiles]) 434 | if self.augment and flip: 435 | SRCV = np.ascontiguousarray(np.flip(SRCV, axis=-1)) 436 | 437 | if self.augment and flip: 438 | # flip images 439 | depth_imgs = np.ascontiguousarray(np.flip(depth_imgs, axis=-1)) 440 | rgb_imgs = np.ascontiguousarray(np.flip(rgb_imgs, axis=-1)) 441 | intr[0, 0] *= -1 442 | 443 | inds_04 = np.argwhere( 444 | (tsdf_04 < 0.999) | np.all(tsdf_04 > 0.999, axis=-1, keepdims=True) # from Atlas, penalize the areas outside the room (entire column is unseen(1)) 445 | ) 446 | inds_08 = np.argwhere(seen_08 | np.all(~seen_08, axis=-1, keepdims=True)) 447 | inds_16 = np.argwhere(seen_16 | np.all(~seen_16, axis=-1, keepdims=True)) 448 | 449 | tsdf_04 = tsdf_04[inds_04[:, 0], inds_04[:, 1], inds_04[:, 2]] 450 | occ_08 = occ_08[inds_08[:, 0], inds_08[:, 1], inds_08[:, 2]].astype(np.float32) 451 | occ_16 = occ_16[inds_16[:, 0], inds_16[:, 1], inds_16[:, 2]].astype(np.float32) 452 | 453 | tsdf_04 = torchsparse.SparseTensor( 454 | torch.from_numpy(tsdf_04), torch.from_numpy(inds_04) 455 | ) 456 | occ_08 = torchsparse.SparseTensor( 457 | torch.from_numpy(occ_08), torch.from_numpy(inds_08) 458 | ) 459 | occ_16 = torchsparse.SparseTensor( 460 | torch.from_numpy(occ_16), torch.from_numpy(inds_16) 461 | ) 462 | 463 | cam_positions = pose[:self.n_imgs, :3, 3] 464 | 465 | # world to camera 466 | pose_w2c = np.linalg.inv(pose) 467 | 468 | # refers to the downsampling ratios at various levels of the CNN feature maps 469 | factors = np.array([1 / 16, 1 / 8, 1 / 4]) 470 | proj_mats = get_proj_mats(intr, pose_w2c[:self.n_imgs], factors) 471 | 472 | # generate dense initial grid 473 | x = torch.arange(seen_16.shape[0], dtype=torch.int32) 474 | y = torch.arange(seen_16.shape[1], dtype=torch.int32) 475 | z = torch.arange(seen_16.shape[2], dtype=torch.int32) 476 | xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij') 477 | input_voxels_16 = torch.stack( 478 | (xx.flatten(), yy.flatten(), zz.flatten()), dim=-1 479 | ) 480 | input_voxels_16 = torchsparse.SparseTensor(torch.zeros(0), input_voxels_16) 481 | 482 | # the scene has been adjusted to origin 0 483 | origin = np.zeros(3, dtype=np.float32) 484 | 485 | scene = { 486 | "input_voxels_16": input_voxels_16, # dense, all zero voxel grid 487 | "rgb_imgs": rgb_imgs, # random selected n_imgs (20) that have enough pose difference and at least 1 intersect with current fragment 488 | "cam_positions": cam_positions, 489 | "proj_mats": proj_mats, 490 | "voxel_gt_fine": tsdf_04, # near surface area, inside objects area, and outside room area (all 1 column) 491 | "voxel_gt_medium": occ_08, 492 | "voxel_gt_coarse": occ_16, 493 | "scene_name": scene_name, 494 | "index": ind, # index of the scene 495 | "depth_imgs": depth_imgs, 496 | "origin": origin, 497 | } 498 | 499 | if self.load_extra: 500 | scene.update( 501 | { 502 | "intr_fullres": intr, 503 | "pose": pose_w2c, 504 | } 505 | ) 506 | 507 | if self.SRfeat: 508 | scene.update( 509 | { 510 | "SRfeat0": SRfeat0, 511 | "SRfeat1": SRfeat1, 512 | "SRfeat2": SRfeat2, 513 | } 514 | ) 515 | if self.SRCV: 516 | scene.update( 517 | { 518 | "SRCV": SRCV, 519 | } 520 | ) 521 | if self.cost_volume: 522 | k = np.eye(4, dtype=np.float32) 523 | k[:3, :3] = intr 524 | k[0] = k[0] * 0.125 525 | k[1] = k[1] * 0.125 526 | invK = np.linalg.inv(k) 527 | scene.update( 528 | { 529 | "cv_k": k, 530 | "cv_invK": invK, 531 | "pose": pose, 532 | "inv_pose": pose_w2c, 533 | "cv_invalid_mask": cv_invalid_mask, 534 | } 535 | ) 536 | return scene 537 | 538 | 539 | if __name__ == "__main__": 540 | 541 | import glob 542 | import yaml 543 | 544 | import matplotlib.pyplot as plt 545 | import open3d as o3d 546 | import skimage.measure 547 | 548 | import collate 549 | 550 | with open("config.yml", "r") as f: 551 | config = yaml.safe_load(f) 552 | 553 | with open(os.path.join(config["scannet_dir"], "scannetv2_train.txt"), "r") as f: 554 | train_split = f.read().split() 555 | 556 | with open(os.path.join(config["scannet_dir"], "scannetv2_val.txt"), "r") as f: 557 | val_split = f.read().split() 558 | 559 | info_files = sorted( 560 | glob.glob(os.path.join(config["scannet_dir"], "scans/*/info.json")) 561 | ) 562 | train_info_files = [ 563 | f for f in info_files if os.path.basename(os.path.dirname(f)) in train_split 564 | ] 565 | val_info_files = [ 566 | f for f in info_files if os.path.basename(os.path.dirname(f)) in val_split 567 | ] 568 | 569 | dset = Dataset( 570 | train_info_files, 571 | config["tsdf_dir"], 572 | 35, 573 | (48, 48, 32), 574 | augment=True, 575 | load_extra=True, 576 | ) 577 | 578 | loader = torch.utils.data.DataLoader( 579 | dset, batch_size=2, collate_fn=collate.sparse_collate_fn 580 | ) 581 | batch = next(iter(loader)) 582 | resolutions = { 583 | "coarse": 0.16, 584 | "medium": 0.08, 585 | "fine": 0.04, 586 | } 587 | 588 | # get voxel gt pcds 589 | batch_ind = 0 590 | voxel_gt_pcds = [] 591 | for resname, res in resolutions.items(): 592 | voxel_gt = batch[f"voxel_gt_{resname}"] 593 | batch_mask = voxel_gt.C[:, 3] == batch_ind 594 | coords = voxel_gt.C[batch_mask, :3] * res + batch["origin"][0] 595 | pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(coords.numpy())) 596 | 597 | vals = voxel_gt.F[batch_mask].float() 598 | vals = vals - vals.min() 599 | vals = vals / vals.max() 600 | pcd.colors = o3d.utility.Vector3dVector(plt.cm.jet(vals)[:, :3]) 601 | voxel_gt_pcds.append(pcd) 602 | 603 | # get depth pcd 604 | depth_imgs = batch["depth_imgs"][batch_ind] 605 | imheight = depth_imgs.shape[1] 606 | imwidth = depth_imgs.shape[2] 607 | u = np.arange(imwidth) 608 | v = np.arange(imheight) 609 | uu, vv = np.meshgrid(u, v) 610 | uv = np.c_[uu.flatten(), vv.flatten(), np.ones_like(uu.flatten())] 611 | k = batch["intr_fullres"][batch_ind] 612 | pix_vecs = uv @ np.linalg.inv(k.T) 613 | depth_pts = [] 614 | for i in range(len(depth_imgs)): 615 | pose = batch["pose"][batch_ind][i].numpy() 616 | depth = depth_imgs[i].flatten().numpy() 617 | valid = depth > 0 618 | xyz_cam = pix_vecs[valid] * depth[valid, None] 619 | xyz_world = ( 620 | np.c_[xyz_cam, np.ones((len(xyz_cam), 1))] @ np.linalg.inv(pose).T 621 | )[:, :3] 622 | depth_pts.append(xyz_world) 623 | depth_pts = np.concatenate(depth_pts, axis=0) 624 | depth_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(depth_pts)) 625 | 626 | # get gt high res surface 627 | batch_mask = batch["voxel_gt_fine"].C[:, 3] == batch_ind 628 | coords = batch["voxel_gt_fine"].C[batch_mask, :3].numpy() 629 | tsdf = batch["voxel_gt_fine"].F[batch_mask].numpy() 630 | tsdf = utils.to_vol(coords, tsdf) 631 | mesh = utils.to_mesh( 632 | -tsdf, 633 | level=0, 634 | mask=~np.isnan(tsdf), 635 | origin=batch["origin"][batch_ind].numpy(), 636 | voxel_size=0.04, 637 | ) 638 | mesh.compute_vertex_normals() 639 | 640 | axes = o3d.geometry.TriangleMesh.create_coordinate_frame() 641 | utils.visualize([*voxel_gt_pcds, depth_pcd, mesh, axes]) 642 | -------------------------------------------------------------------------------- /cvrecon/lightningmodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | import pytorch_lightning as pl 6 | import torch 7 | 8 | from cvrecon import collate, data, utils, cvrecon 9 | 10 | 11 | class FineTuning(pl.callbacks.BaseFinetuning): 12 | def __init__(self, initial_epochs, use_cost_volume=False): 13 | super().__init__() 14 | self.initial_epochs = initial_epochs 15 | self.use_cost_volume = use_cost_volume 16 | 17 | def freeze_before_training(self, pl_module): 18 | modules = [ 19 | pl_module.cvrecon.cnn2d.conv0, 20 | pl_module.cvrecon.cnn2d.conv1, 21 | pl_module.cvrecon.cnn2d.conv2, 22 | pl_module.cvrecon.upsampler, 23 | ] + ([ 24 | pl_module.cvrecon.matching_encoder, 25 | pl_module.cvrecon.cost_volume.mlp.net[:4], 26 | ]if self.use_cost_volume else []) 27 | for mod in modules: 28 | self.freeze(mod, train_bn=False) 29 | 30 | def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): 31 | if current_epoch >= self.initial_epochs: 32 | self.unfreeze_and_add_param_group( 33 | modules=[ 34 | pl_module.cvrecon.cnn2d.conv0, 35 | pl_module.cvrecon.cnn2d.conv1, 36 | pl_module.cvrecon.cnn2d.conv2, 37 | ] + ([pl_module.cvrecon.matching_encoder, 38 | pl_module.cvrecon.cost_volume.mlp.net[:4], 39 | ]if self.use_cost_volume else []), 40 | optimizer=optimizer, 41 | train_bn=False, 42 | lr=pl_module.config["finetune_lr"], 43 | ) 44 | pl_module.cvrecon.use_proj_occ = True 45 | for group in pl_module.optimizers().param_groups: 46 | group["lr"] = pl_module.config["finetune_lr"] 47 | 48 | 49 | class LightningModel(pl.LightningModule): 50 | def __init__(self, config): 51 | super().__init__() 52 | self.cvrecon = cvrecon.cvrecon( 53 | config["attn_heads"], config["attn_layers"], config["use_proj_occ"], config["SRfeat"], 54 | config["SR_vi_ebd"], config["SRCV"], config["cost_volume"], config["cv_dim"], config["cv_overall"], config["depth_head"], 55 | ) 56 | self.config = config 57 | 58 | def configure_optimizers(self): 59 | return torch.optim.Adam( 60 | [param for param in self.parameters() if param.requires_grad], 61 | lr=self.config["initial_lr"], 62 | ) 63 | 64 | # def on_train_epoch_start(self): 65 | # self.epoch_train_logs = [] 66 | 67 | def step(self, batch, batch_idx): 68 | voxel_coords_16 = batch["input_voxels_16"].C 69 | voxel_outputs, proj_occ_logits, bp_data, depth_out = self.cvrecon(batch, voxel_coords_16) 70 | voxel_gt = { 71 | "coarse": batch["voxel_gt_coarse"], 72 | "medium": batch["voxel_gt_medium"], 73 | "fine": batch["voxel_gt_fine"], 74 | } 75 | loss, logs = self.cvrecon.losses( 76 | voxel_outputs, voxel_gt, proj_occ_logits, bp_data, batch["depth_imgs"], depth_out 77 | ) 78 | logs["loss"] = loss.detach() 79 | return loss, logs, voxel_outputs 80 | 81 | def training_step(self, batch, batch_idx): 82 | n_warmup_steps = 2_000 83 | if self.global_step < n_warmup_steps: 84 | target_lr = self.config["initial_lr"] 85 | lr = 1e-10 + self.global_step / n_warmup_steps * target_lr 86 | for group in self.optimizers().param_groups: 87 | group["lr"] = lr 88 | 89 | loss, logs, _ = self.step(batch, batch_idx) 90 | # self.epoch_train_logs.append(logs) 91 | for lossname, lossval in logs.items(): 92 | self.log('train/'+lossname, lossval, on_step=True, on_epoch=True, sync_dist=True, reduce_fx='mean', rank_zero_only=True) 93 | return loss 94 | 95 | # def on_validation_epoch_start(self): 96 | # self.epoch_val_logs = [] 97 | 98 | def validation_step(self, batch, batch_idx): 99 | loss, logs, voxel_outputs = self.step(batch, batch_idx) 100 | # self.epoch_val_logs.append(logs) 101 | for lossname, lossval in logs.items(): 102 | self.log('val/'+lossname, lossval, on_step=False, on_epoch=True, sync_dist=True, reduce_fx='mean', rank_zero_only=True) 103 | 104 | def train_dataloader(self): 105 | return self.dataloader("train", augment=True) 106 | 107 | def val_dataloader(self): 108 | return self.dataloader("test") 109 | 110 | def dataloader(self, split, augment=False): 111 | nworkers = self.config["nworkers"] 112 | if split in ["val", "test"]: 113 | batch_size = 1 114 | nworkers //= 2 115 | elif self.current_epoch < self.config["initial_epochs"]: 116 | batch_size = self.config["initial_batch_size"] 117 | else: 118 | batch_size = self.config["finetune_batch_size"] 119 | 120 | info_files = utils.load_info_files(self.config["scannet_dir"], split) 121 | dset = data.Dataset( 122 | info_files, 123 | self.config["tsdf_dir"], 124 | self.config[f"n_imgs_{split}"], 125 | self.config[f"crop_size_{split}"], 126 | augment=augment, 127 | split=split, 128 | SRfeat=self.config["SRfeat"], 129 | SRCV=self.config["SRCV"], 130 | cost_volume=self.config["cost_volume"], 131 | ) 132 | return torch.utils.data.DataLoader( 133 | dset, 134 | batch_size=batch_size, 135 | num_workers=nworkers, 136 | collate_fn=collate.sparse_collate_fn, 137 | drop_last=True, 138 | #persistent_workers=True, 139 | ) 140 | 141 | 142 | def write_mesh(outfile, logits_04): 143 | batch_mask = logits_04.C[:, 3] == 0 144 | inds = logits_04.C[batch_mask, :3].cpu().numpy() 145 | tsdf_logits = logits_04.F[batch_mask, 0].cpu().numpy() 146 | tsdf = 1.05 * np.tanh(tsdf_logits) 147 | tsdf_vol = utils.to_vol(inds, tsdf) 148 | 149 | mesh = utils.to_mesh(tsdf_vol, voxel_size=0.04, level=0, mask=~np.isnan(tsdf_vol)) 150 | o3d.io.write_triangle_mesh(outfile, mesh) 151 | -------------------------------------------------------------------------------- /cvrecon/mv_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from cvrecon import transformer 4 | 5 | 6 | class MVFusionMean(torch.nn.Module): 7 | def forward(self, features, valid_mask): 8 | return mv_fusion_mean(features, valid_mask) 9 | 10 | 11 | class MVFusionTransformer(torch.nn.Module): 12 | def __init__(self, input_depth, n_layers, n_attn_heads, cv_cha=0): 13 | super().__init__() 14 | self.transformer = transformer.Transformer( 15 | input_depth, 16 | input_depth * 2, 17 | num_layers=n_layers, 18 | num_heads=n_attn_heads, 19 | ) 20 | self.depth_mlp = torch.nn.Linear( 1 + cv_cha + 56, input_depth, bias=True) 21 | self.proj_tsdf_mlp = torch.nn.Linear(input_depth, 1, bias=True) 22 | 23 | for mlp in [self.depth_mlp, self.proj_tsdf_mlp]: 24 | torch.nn.init.kaiming_normal_(mlp.weight) 25 | torch.nn.init.zeros_(mlp.bias) 26 | 27 | def forward(self, features, bp_depth, bp_mask, use_proj_occ): 28 | ''' 29 | features: [n_imgs, in_channels, n_voxels] 30 | bp_depth: [n_imgs, n_voxels] 31 | bp_mask: [n_imgs, n_voxels] 32 | ''' 33 | device = features.device 34 | 35 | # attn_mask is False where attention is allowed. 36 | # set diagonal elements False to avoid nan 37 | attn_mask = bp_mask.transpose(0, 1) # attn_mask [n_voxels, n_imgs] 38 | attn_mask = ~attn_mask[:, None].repeat(1, attn_mask.shape[1], 1).contiguous() # [n_voxels, n_imgs, n_imgs] 39 | torch.diagonal(attn_mask, dim1=1, dim2=2)[:] = False 40 | 41 | im_z_norm = (bp_depth - 1.85) / 0.85 42 | features = torch.cat((features, im_z_norm[:, None]), dim=1) # [n_imgs, in_channels+1, n_voxels] 43 | features = self.depth_mlp(features.transpose(1, 2)) # [n_imgs, n_voxels, in_channels] 44 | 45 | features = self.transformer(features, attn_mask) # after self-attention still [n_imgs, n_voxels, in_channels] 46 | 47 | batchsize, nvoxels, _ = features.shape 48 | proj_occ_logits = self.proj_tsdf_mlp( 49 | features.reshape(batchsize * nvoxels, -1) 50 | ).reshape(batchsize, nvoxels) # [n_imgs, n_voxels] 51 | 52 | if use_proj_occ: 53 | weights = proj_occ_logits.masked_fill(~bp_mask, -9e3) 54 | weights = torch.cat( 55 | ( 56 | weights, 57 | torch.zeros( 58 | (1, weights.shape[1]), 59 | device=device, 60 | dtype=weights.dtype, 61 | ), 62 | ), 63 | dim=0, 64 | ) 65 | features = torch.cat( 66 | ( 67 | features, 68 | torch.zeros( 69 | (1, features.shape[1], features.shape[2]), 70 | device=device, 71 | dtype=features.dtype, 72 | ), 73 | ), 74 | dim=0, 75 | ) 76 | weights = torch.softmax(weights, dim=0) 77 | pooled_features = torch.sum(features * weights[..., None], dim=0) 78 | else: 79 | pooled_features = mv_fusion_mean(features, bp_mask) 80 | 81 | return pooled_features, proj_occ_logits 82 | 83 | 84 | def mv_fusion_mean(features, valid_mask): 85 | ''' 86 | features: [n_imgs, n_voxels, n_channels] 87 | valid_mask: [n_imgs, n_voxels] 88 | 89 | return: 90 | pooled_features: each voxel's feature is the average of all seen pixels' feature 91 | ''' 92 | weights = torch.sum(valid_mask, dim=0) 93 | weights[weights == 0] = 1 94 | pooled_features = ( 95 | torch.sum(features * valid_mask[..., None], dim=0) / weights[:, None] 96 | ) 97 | return pooled_features 98 | -------------------------------------------------------------------------------- /cvrecon/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MlpBlock(torch.nn.Module): 5 | """Transformer Feed-Forward Block""" 6 | 7 | def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.0): 8 | super().__init__() 9 | 10 | # init layers 11 | self.fc1 = torch.nn.Linear(in_dim, mlp_dim) 12 | self.fc2 = torch.nn.Linear(mlp_dim, out_dim) 13 | self.act = torch.nn.ReLU(True) 14 | self.dropout1 = torch.nn.Dropout(dropout_rate) 15 | self.dropout2 = torch.nn.Dropout(dropout_rate) 16 | 17 | torch.nn.init.kaiming_normal_(self.fc1.weight) 18 | torch.nn.init.kaiming_normal_(self.fc2.weight) 19 | torch.nn.init.zeros_(self.fc1.bias) 20 | torch.nn.init.zeros_(self.fc2.bias) 21 | 22 | def forward(self, x): 23 | 24 | out = self.fc1(x) 25 | out = self.act(out) 26 | if self.dropout1: 27 | out = self.dropout1(out) 28 | 29 | out = self.fc2(out) 30 | out = self.dropout2(out) 31 | return out 32 | 33 | 34 | class EncoderBlock(torch.nn.Module): 35 | def __init__( 36 | self, in_dim, num_heads, mlp_dim, dropout_rate=0.0, attn_dropout_rate=0.0 37 | ): 38 | super().__init__() 39 | 40 | self.norm1 = torch.nn.LayerNorm(in_dim) 41 | self.attn = torch.nn.MultiheadAttention(in_dim, num_heads) 42 | if dropout_rate > 0: 43 | self.dropout = torch.nn.Dropout(dropout_rate) 44 | else: 45 | self.dropout = None 46 | self.norm2 = torch.nn.LayerNorm(in_dim) 47 | self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) 48 | 49 | def forward(self, x, mask=None): 50 | residual = x 51 | x = self.norm1(x) 52 | x, attn_weights = self.attn(x, x, x, attn_mask=mask, need_weights=False) 53 | if self.dropout is not None: 54 | x = self.dropout(x) 55 | x += residual 56 | residual = x 57 | 58 | x = self.norm2(x) 59 | x = self.mlp(x) 60 | x += residual 61 | return x 62 | 63 | 64 | class Transformer(torch.nn.Module): 65 | def __init__( 66 | self, 67 | emb_dim, 68 | mlp_dim, 69 | num_layers=1, 70 | num_heads=1, 71 | dropout_rate=0.0, 72 | attn_dropout_rate=0.0, 73 | ): 74 | super().__init__() 75 | 76 | in_dim = emb_dim 77 | self.encoder_layers = torch.nn.ModuleList() 78 | for i in range(num_layers): 79 | layer = EncoderBlock( 80 | in_dim, num_heads, mlp_dim, dropout_rate, attn_dropout_rate 81 | ) 82 | self.encoder_layers.append(layer) 83 | 84 | self.num_heads = num_heads 85 | 86 | def forward(self, x, mask=None): 87 | ''' 88 | x: [n_imgs, n_voxels, in_channels] 89 | torch.nn.MultiheadAttention: [seq_len, batchsize, dim] -> [seq_len, batchsize, dim] 90 | ''' 91 | if self.num_heads > 1: 92 | b, s, t = mask.shape 93 | mask_n = mask.repeat(1, self.num_heads, 1).reshape(b * self.num_heads, s, t) 94 | else: 95 | mask_n = mask 96 | 97 | for layer in self.encoder_layers: 98 | x = layer(x, mask=mask_n) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /cvrecon/tsdf_fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from numba import njit, prange 4 | from skimage import measure 5 | import torch 6 | 7 | 8 | class TSDFVolume: 9 | """Volumetric TSDF Fusion of RGB-D Images.""" 10 | 11 | def __init__(self, vol_bnds, voxel_size, use_gpu=True, margin=5): 12 | """Constructor. 13 | Args: 14 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the 15 | xyz bounds (min/max) in meters. 16 | voxel_size (float): The volume discretization in meters. 17 | """ 18 | # try: 19 | import pycuda.driver as cuda 20 | import pycuda.autoinit 21 | from pycuda.compiler import SourceModule 22 | 23 | FUSION_GPU_MODE = 1 24 | self.cuda = cuda 25 | # except Exception as err: 26 | # print('Warning: {}'.format(err)) 27 | # print('Failed to import PyCUDA. Running fusion in CPU mode.') 28 | # FUSION_GPU_MODE = 0 29 | 30 | vol_bnds = np.asarray(vol_bnds) 31 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)." 32 | 33 | # Define voxel volume parameters 34 | self._vol_bnds = vol_bnds 35 | self._voxel_size = float(voxel_size) 36 | self._trunc_margin = margin * self._voxel_size # truncation on SDF 37 | self._color_const = 256 * 256 38 | 39 | # Adjust volume bounds and ensure C-order contiguous 40 | self._vol_dim = ( 41 | np.round((self._vol_bnds[:, 1] - self._vol_bnds[:, 0]) / self._voxel_size) 42 | .copy(order="C") 43 | .astype(int) 44 | ) 45 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + self._vol_dim * self._voxel_size 46 | self._vol_origin = self._vol_bnds[:, 0].copy(order="C").astype(np.float32) 47 | 48 | # Initialize pointers to voxel volume in CPU memory 49 | self._tsdf_vol_cpu = np.ones(self._vol_dim).astype(np.float32) 50 | # for computing the cumulative moving average of observations per voxel 51 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 52 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 53 | 54 | self.gpu_mode = use_gpu and FUSION_GPU_MODE 55 | 56 | # Copy voxel volumes to GPU 57 | if self.gpu_mode: 58 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes) 59 | self.cuda.memcpy_htod(self._tsdf_vol_gpu, self._tsdf_vol_cpu) 60 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes) 61 | self.cuda.memcpy_htod(self._weight_vol_gpu, self._weight_vol_cpu) 62 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes) 63 | self.cuda.memcpy_htod(self._color_vol_gpu, self._color_vol_cpu) 64 | 65 | # Cuda kernel function (C++) 66 | self._cuda_src_mod = SourceModule( 67 | """ 68 | __global__ void integrate(float * tsdf_vol, 69 | float * weight_vol, 70 | float * color_vol, 71 | float * vol_dim, 72 | float * vol_origin, 73 | float * cam_intr, 74 | float * cam_pose, 75 | float * other_params, 76 | float * color_im, 77 | float * depth_im) { 78 | // Get voxel index 79 | int gpu_loop_idx = (int) other_params[0]; 80 | int max_threads_per_block = blockDim.x; 81 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x; 82 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x; 83 | int vol_dim_x = (int) vol_dim[0]; 84 | int vol_dim_y = (int) vol_dim[1]; 85 | int vol_dim_z = (int) vol_dim[2]; 86 | if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z) 87 | return; 88 | // Get voxel grid coordinates (note: be careful when casting) 89 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z))); 90 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z)); 91 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z); 92 | // Voxel grid coordinates to world coordinates 93 | float voxel_size = other_params[1]; 94 | float pt_x = vol_origin[0]+voxel_x*voxel_size; 95 | float pt_y = vol_origin[1]+voxel_y*voxel_size; 96 | float pt_z = vol_origin[2]+voxel_z*voxel_size; 97 | // World coordinates to camera coordinates 98 | float tmp_pt_x = pt_x-cam_pose[0*4+3]; 99 | float tmp_pt_y = pt_y-cam_pose[1*4+3]; 100 | float tmp_pt_z = pt_z-cam_pose[2*4+3]; 101 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z; 102 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z; 103 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z; 104 | // Camera coordinates to image pixels 105 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]); 106 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]); 107 | // Skip if outside view frustum 108 | int im_h = (int) other_params[2]; 109 | int im_w = (int) other_params[3]; 110 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0) 111 | return; 112 | // Skip invalid depth 113 | float depth_value = depth_im[pixel_y*im_w+pixel_x]; 114 | if (depth_value == 0) 115 | return; 116 | // Integrate TSDF 117 | float trunc_margin = other_params[4]; 118 | float depth_diff = depth_value-cam_pt_z; 119 | if (depth_diff < -trunc_margin) 120 | return; 121 | float dist = fmin(1.0f,depth_diff/trunc_margin); 122 | float w_old = weight_vol[voxel_idx]; 123 | float obs_weight = other_params[5]; 124 | float w_new = w_old + obs_weight; 125 | weight_vol[voxel_idx] = w_new; 126 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new; 127 | 128 | // Integrate color 129 | return; 130 | float old_color = color_vol[voxel_idx]; 131 | float old_b = floorf(old_color/(256*256)); 132 | float old_g = floorf((old_color-old_b*256*256)/256); 133 | float old_r = old_color-old_b*256*256-old_g*256; 134 | float new_color = color_im[pixel_y*im_w+pixel_x]; 135 | float new_b = floorf(new_color/(256*256)); 136 | float new_g = floorf((new_color-new_b*256*256)/256); 137 | float new_r = new_color-new_b*256*256-new_g*256; 138 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f); 139 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f); 140 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f); 141 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r; 142 | }""" 143 | ) 144 | 145 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate") 146 | 147 | # Determine block/grid size on GPU 148 | gpu_dev = cuda.Device(0) 149 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK 150 | n_blocks = int( 151 | np.ceil( 152 | float(np.prod(self._vol_dim)) 153 | / float(self._max_gpu_threads_per_block) 154 | ) 155 | ) 156 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X, int(np.floor(np.cbrt(n_blocks)))) 157 | grid_dim_y = min( 158 | gpu_dev.MAX_GRID_DIM_Y, int(np.floor(np.sqrt(n_blocks / grid_dim_x))) 159 | ) 160 | grid_dim_z = min( 161 | gpu_dev.MAX_GRID_DIM_Z, 162 | int(np.ceil(float(n_blocks) / float(grid_dim_x * grid_dim_y))), 163 | ) 164 | self._max_gpu_grid_dim = np.array( 165 | [grid_dim_x, grid_dim_y, grid_dim_z] 166 | ).astype(int) 167 | self._n_gpu_loops = int( 168 | np.ceil( 169 | float(np.prod(self._vol_dim)) 170 | / float( 171 | np.prod(self._max_gpu_grid_dim) 172 | * self._max_gpu_threads_per_block 173 | ) 174 | ) 175 | ) 176 | 177 | else: 178 | # Get voxel grid coordinates 179 | xv, yv, zv = np.meshgrid( 180 | range(self._vol_dim[0]), 181 | range(self._vol_dim[1]), 182 | range(self._vol_dim[2]), 183 | ) 184 | self.vox_coords = ( 185 | np.concatenate( 186 | [xv.reshape(1, -1), yv.reshape(1, -1), zv.reshape(1, -1)], axis=0 187 | ) 188 | .astype(int) 189 | .T 190 | ) 191 | 192 | @staticmethod 193 | @njit(parallel=True) 194 | def vox2world(vol_origin, vox_coords, vox_size): 195 | """Convert voxel grid coordinates to world coordinates.""" 196 | vol_origin = vol_origin.astype(np.float32) 197 | vox_coords = vox_coords.astype(np.float32) 198 | cam_pts = np.empty_like(vox_coords, dtype=np.float32) 199 | for i in prange(vox_coords.shape[0]): 200 | for j in range(3): 201 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j]) 202 | return cam_pts 203 | 204 | @staticmethod 205 | @njit(parallel=True) 206 | def cam2pix(cam_pts, intr): 207 | """Convert camera coordinates to pixel coordinates.""" 208 | intr = intr.astype(np.float32) 209 | fx, fy = intr[0, 0], intr[1, 1] 210 | cx, cy = intr[0, 2], intr[1, 2] 211 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64) 212 | for i in prange(cam_pts.shape[0]): 213 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx)) 214 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy)) 215 | return pix 216 | 217 | @staticmethod 218 | @njit(parallel=True) 219 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight): 220 | """Integrate the TSDF volume.""" 221 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32) 222 | w_new = np.empty_like(w_old, dtype=np.float32) 223 | for i in prange(len(tsdf_vol)): 224 | w_new[i] = w_old[i] + obs_weight 225 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i] 226 | return tsdf_vol_int, w_new 227 | 228 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.0): 229 | """Integrate an RGB-D frame into the TSDF volume. 230 | Args: 231 | color_im (ndarray): An RGB image of shape (H, W, 3). 232 | depth_im (ndarray): A depth image of shape (H, W). 233 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3). 234 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4). 235 | obs_weight (float): The weight to assign for the current observation. A higher 236 | value 237 | """ 238 | im_h, im_w = depth_im.shape 239 | 240 | if color_im is not None: 241 | # Fold RGB color image into a single channel image 242 | color_im = color_im.astype(np.float32) 243 | color_im = np.floor( 244 | color_im[..., 2] * self._color_const 245 | + color_im[..., 1] * 256 246 | + color_im[..., 0] 247 | ) 248 | color_im = color_im.reshape(-1).astype(np.float32) 249 | else: 250 | color_im = np.array(0) 251 | 252 | if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel) 253 | for gpu_loop_idx in range(self._n_gpu_loops): 254 | self._cuda_integrate( 255 | self._tsdf_vol_gpu, 256 | self._weight_vol_gpu, 257 | self._color_vol_gpu, 258 | self.cuda.InOut(self._vol_dim.astype(np.float32)), 259 | self.cuda.InOut(self._vol_origin.astype(np.float32)), 260 | self.cuda.InOut(cam_intr.reshape(-1).astype(np.float32)), 261 | self.cuda.InOut(cam_pose.reshape(-1).astype(np.float32)), 262 | self.cuda.InOut( 263 | np.asarray( 264 | [ 265 | gpu_loop_idx, 266 | self._voxel_size, 267 | im_h, 268 | im_w, 269 | self._trunc_margin, 270 | obs_weight, 271 | ], 272 | np.float32, 273 | ) 274 | ), 275 | self.cuda.InOut(color_im), 276 | self.cuda.InOut(depth_im.reshape(-1).astype(np.float32)), 277 | block=(self._max_gpu_threads_per_block, 1, 1), 278 | grid=( 279 | int(self._max_gpu_grid_dim[0]), 280 | int(self._max_gpu_grid_dim[1]), 281 | int(self._max_gpu_grid_dim[2]), 282 | ), 283 | ) 284 | else: # CPU mode: integrate voxel volume (vectorized implementation) 285 | # Convert voxel grid coordinates to pixel coordinates 286 | cam_pts = self.vox2world( 287 | self._vol_origin, self.vox_coords, self._voxel_size 288 | ) 289 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose)) 290 | pix_z = cam_pts[:, 2] 291 | pix = self.cam2pix(cam_pts, cam_intr) 292 | pix_x, pix_y = pix[:, 0], pix[:, 1] 293 | 294 | # Eliminate pixels outside view frustum 295 | valid_pix = np.logical_and( 296 | pix_x >= 0, 297 | np.logical_and( 298 | pix_x < im_w, 299 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)), 300 | ), 301 | ) 302 | depth_val = np.zeros(pix_x.shape) 303 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]] 304 | 305 | # Integrate TSDF 306 | depth_diff = depth_val - pix_z 307 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin) 308 | dist = np.minimum(1, depth_diff / self._trunc_margin) 309 | valid_vox_x = self.vox_coords[valid_pts, 0] 310 | valid_vox_y = self.vox_coords[valid_pts, 1] 311 | valid_vox_z = self.vox_coords[valid_pts, 2] 312 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 313 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 314 | valid_dist = dist[valid_pts] 315 | tsdf_vol_new, w_new = self.integrate_tsdf( 316 | tsdf_vals, valid_dist, w_old, obs_weight 317 | ) 318 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new 319 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new 320 | 321 | # Integrate color 322 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 323 | old_b = np.floor(old_color / self._color_const) 324 | old_g = np.floor((old_color - old_b * self._color_const) / 256) 325 | old_r = old_color - old_b * self._color_const - old_g * 256 326 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]] 327 | new_b = np.floor(new_color / self._color_const) 328 | new_g = np.floor((new_color - new_b * self._color_const) / 256) 329 | new_r = new_color - new_b * self._color_const - new_g * 256 330 | new_b = np.minimum( 331 | 255.0, np.round((w_old * old_b + obs_weight * new_b) / w_new) 332 | ) 333 | new_g = np.minimum( 334 | 255.0, np.round((w_old * old_g + obs_weight * new_g) / w_new) 335 | ) 336 | new_r = np.minimum( 337 | 255.0, np.round((w_old * old_r + obs_weight * new_r) / w_new) 338 | ) 339 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = ( 340 | new_b * self._color_const + new_g * 256 + new_r 341 | ) 342 | 343 | def get_volume(self): 344 | if self.gpu_mode: 345 | self.cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu) 346 | self.cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu) 347 | self.cuda.memcpy_dtoh(self._weight_vol_cpu, self._weight_vol_gpu) 348 | return self._tsdf_vol_cpu, self._color_vol_cpu, self._weight_vol_cpu 349 | 350 | def get_point_cloud(self): 351 | """Extract a point cloud from the voxel volume.""" 352 | tsdf_vol, color_vol, weight_vol = self.get_volume() 353 | 354 | # Marching cubes 355 | verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0] 356 | verts_ind = np.round(verts).astype(int) 357 | verts = verts * self._voxel_size + self._vol_origin 358 | 359 | # Get vertex colors 360 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 361 | colors_b = np.floor(rgb_vals / self._color_const) 362 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256) 363 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256 364 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 365 | colors = colors.astype(np.uint8) 366 | 367 | pc = np.hstack([verts, colors]) 368 | return pc 369 | 370 | def get_mesh(self): 371 | """Compute a mesh from the voxel volume using marching cubes.""" 372 | tsdf_vol, color_vol, weight_vol = self.get_volume() 373 | 374 | verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=0) 375 | verts_ind = np.round(verts).astype(int) 376 | verts = ( 377 | verts * self._voxel_size + self._vol_origin 378 | ) # voxel grid coordinates to world coordinates 379 | 380 | # Get vertex colors 381 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 382 | colors_b = np.floor(rgb_vals / self._color_const) 383 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256) 384 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256 385 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 386 | colors = colors.astype(np.uint8) 387 | return verts, faces, norms, colors 388 | 389 | 390 | def rigid_transform(xyz, transform): 391 | """Applies a rigid transform to an (N, 3) pointcloud.""" 392 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)]) 393 | xyz_t_h = np.dot(transform, xyz_h.T).T 394 | return xyz_t_h[:, :3] 395 | 396 | 397 | def get_view_frustum(depth_im, cam_intr, cam_pose): 398 | """Get corners of 3D camera view frustum of depth image""" 399 | im_h = depth_im.shape[0] 400 | im_w = depth_im.shape[1] 401 | max_depth = np.max(depth_im) 402 | view_frust_pts = np.array( 403 | [ 404 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2]) 405 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 406 | / cam_intr[0, 0], 407 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2]) 408 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 409 | / cam_intr[1, 1], 410 | np.array([0, max_depth, max_depth, max_depth, max_depth]), 411 | ] 412 | ) 413 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T 414 | return view_frust_pts 415 | 416 | 417 | def meshwrite(filename, verts, faces, norms, colors): 418 | """Save a 3D mesh to a polygon .ply file.""" 419 | # Write header 420 | ply_file = open(filename, "w") 421 | ply_file.write("ply\n") 422 | ply_file.write("format ascii 1.0\n") 423 | ply_file.write("element vertex %d\n" % (verts.shape[0])) 424 | ply_file.write("property float x\n") 425 | ply_file.write("property float y\n") 426 | ply_file.write("property float z\n") 427 | ply_file.write("property float nx\n") 428 | ply_file.write("property float ny\n") 429 | ply_file.write("property float nz\n") 430 | ply_file.write("property uchar red\n") 431 | ply_file.write("property uchar green\n") 432 | ply_file.write("property uchar blue\n") 433 | ply_file.write("element face %d\n" % (faces.shape[0])) 434 | ply_file.write("property list uchar int vertex_index\n") 435 | ply_file.write("end_header\n") 436 | 437 | # Write vertex list 438 | for i in range(verts.shape[0]): 439 | ply_file.write( 440 | "%f %f %f %f %f %f %d %d %d\n" 441 | % ( 442 | verts[i, 0], 443 | verts[i, 1], 444 | verts[i, 2], 445 | norms[i, 0], 446 | norms[i, 1], 447 | norms[i, 2], 448 | colors[i, 0], 449 | colors[i, 1], 450 | colors[i, 2], 451 | ) 452 | ) 453 | 454 | # Write face list 455 | for i in range(faces.shape[0]): 456 | ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2])) 457 | 458 | ply_file.close() 459 | 460 | 461 | def pcwrite(filename, xyzrgb): 462 | """Save a point cloud to a polygon .ply file.""" 463 | xyz = xyzrgb[:, :3] 464 | rgb = xyzrgb[:, 3:].astype(np.uint8) 465 | 466 | # Write header 467 | ply_file = open(filename, "w") 468 | ply_file.write("ply\n") 469 | ply_file.write("format ascii 1.0\n") 470 | ply_file.write("element vertex %d\n" % (xyz.shape[0])) 471 | ply_file.write("property float x\n") 472 | ply_file.write("property float y\n") 473 | ply_file.write("property float z\n") 474 | ply_file.write("property uchar red\n") 475 | ply_file.write("property uchar green\n") 476 | ply_file.write("property uchar blue\n") 477 | ply_file.write("end_header\n") 478 | 479 | # Write vertex list 480 | for i in range(xyz.shape[0]): 481 | ply_file.write( 482 | "%f %f %f %d %d %d\n" 483 | % ( 484 | xyz[i, 0], 485 | xyz[i, 1], 486 | xyz[i, 2], 487 | rgb[i, 0], 488 | rgb[i, 1], 489 | rgb[i, 2], 490 | ) 491 | ) 492 | 493 | 494 | def integrate( 495 | depth_im, 496 | cam_intr, 497 | cam_pose, 498 | obs_weight, 499 | world_c, 500 | vox_coords, 501 | weight_vol, 502 | tsdf_vol, 503 | sdf_trunc, 504 | im_h, 505 | im_w, 506 | ): 507 | # Convert world coordinates to camera coordinates 508 | world2cam = torch.inverse(cam_pose) 509 | cam_c = torch.matmul(world2cam, world_c.transpose(1, 0)).transpose(1, 0).float() 510 | 511 | # Convert camera coordinates to pixel coordinates 512 | fx, fy = cam_intr[0, 0], cam_intr[1, 1] 513 | cx, cy = cam_intr[0, 2], cam_intr[1, 2] 514 | pix_z = cam_c[:, 2] 515 | pix_x = torch.round((cam_c[:, 0] * fx / cam_c[:, 2]) + cx).long() 516 | pix_y = torch.round((cam_c[:, 1] * fy / cam_c[:, 2]) + cy).long() 517 | 518 | # Eliminate pixels outside view frustum 519 | valid_pix = ( 520 | (pix_x >= 0) & (pix_x < im_w) & (pix_y >= 0) & (pix_y < im_h) & (pix_z > 0) 521 | ) 522 | valid_vox_x = vox_coords[valid_pix, 0] 523 | valid_vox_y = vox_coords[valid_pix, 1] 524 | valid_vox_z = vox_coords[valid_pix, 2] 525 | depth_val = depth_im[pix_y[valid_pix], pix_x[valid_pix]] 526 | 527 | # Integrate tsdf 528 | depth_diff = depth_val - pix_z[valid_pix] 529 | dist = torch.clamp(depth_diff / sdf_trunc, max=1) 530 | valid_pts = (depth_val > 0) & (depth_diff >= -sdf_trunc) 531 | valid_vox_x = valid_vox_x[valid_pts] 532 | valid_vox_y = valid_vox_y[valid_pts] 533 | valid_vox_z = valid_vox_z[valid_pts] 534 | valid_dist = dist[valid_pts] 535 | w_old = weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] 536 | tsdf_vals = tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] 537 | w_new = w_old + obs_weight 538 | tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] = ( 539 | w_old * tsdf_vals + obs_weight * valid_dist 540 | ) / w_new 541 | weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] = w_new 542 | 543 | return weight_vol, tsdf_vol 544 | 545 | 546 | class TSDFVolumeTorch: 547 | """Volumetric TSDF Fusion of RGB-D Images.""" 548 | 549 | def __init__(self, voxel_dim, origin, voxel_size, margin=3): 550 | """Constructor. 551 | Args: 552 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the 553 | xyz bounds (min/max) in meters. 554 | voxel_size (float): The volume discretization in meters. 555 | """ 556 | # if torch.cuda.is_available(): 557 | # self.device = torch.device("cuda") 558 | # else: 559 | # print("[!] No GPU detected. Defaulting to CPU.") 560 | self.device = torch.device("cpu") 561 | 562 | # Define voxel volume parameters 563 | self._voxel_size = float(voxel_size) 564 | self._sdf_trunc = margin * self._voxel_size 565 | self._const = 256 * 256 566 | self._integrate_func = integrate 567 | 568 | # Adjust volume bounds 569 | self._vol_dim = voxel_dim.long() 570 | self._vol_origin = origin 571 | self._num_voxels = torch.prod(self._vol_dim).item() 572 | 573 | # Get voxel grid coordinates 574 | xv, yv, zv = torch.meshgrid( 575 | torch.arange(0, self._vol_dim[0]), 576 | torch.arange(0, self._vol_dim[1]), 577 | torch.arange(0, self._vol_dim[2]), 578 | indexing='ij' 579 | ) 580 | self._vox_coords = ( 581 | torch.stack([xv.flatten(), yv.flatten(), zv.flatten()], dim=1) 582 | .long() 583 | .to(self.device) 584 | ) 585 | 586 | # Convert voxel coordinates to world coordinates 587 | self._world_c = self._vol_origin + (self._voxel_size * self._vox_coords) 588 | self._world_c = torch.cat( 589 | [self._world_c, torch.ones(len(self._world_c), 1, device=self.device)], 590 | dim=1, 591 | ) 592 | 593 | self.reset() 594 | 595 | # print("[*] voxel volume: {} x {} x {}".format(*self._vol_dim)) 596 | # print("[*] num voxels: {:,}".format(self._num_voxels)) 597 | 598 | def reset(self): 599 | self._tsdf_vol = torch.ones(*self._vol_dim).to(self.device) 600 | self._weight_vol = torch.zeros(*self._vol_dim).to(self.device) 601 | self._color_vol = torch.zeros(*self._vol_dim).to(self.device) 602 | 603 | def integrate(self, depth_im, cam_intr, cam_pose, obs_weight): 604 | """Integrate an RGB-D frame into the TSDF volume. 605 | Args: 606 | color_im (ndarray): An RGB image of shape (H, W, 3). 607 | depth_im (ndarray): A depth image of shape (H, W). 608 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3). 609 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4). 610 | obs_weight (float): The weight to assign to the current observation. 611 | """ 612 | cam_pose = cam_pose.float().to(self.device) 613 | cam_intr = cam_intr.float().to(self.device) 614 | depth_im = depth_im.float().to(self.device) 615 | im_h, im_w = depth_im.shape 616 | weight_vol, tsdf_vol = self._integrate_func( 617 | depth_im, 618 | cam_intr, 619 | cam_pose, 620 | obs_weight, 621 | self._world_c, 622 | self._vox_coords, 623 | self._weight_vol, 624 | self._tsdf_vol, 625 | self._sdf_trunc, 626 | im_h, 627 | im_w, 628 | ) 629 | self._weight_vol = weight_vol 630 | self._tsdf_vol = tsdf_vol 631 | 632 | def get_volume(self): 633 | return self._tsdf_vol, self._weight_vol 634 | 635 | @property 636 | def sdf_trunc(self): 637 | return self._sdf_trunc 638 | 639 | @property 640 | def voxel_size(self): 641 | return self._voxel_size 642 | -------------------------------------------------------------------------------- /cvrecon/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import glob 3 | import itertools 4 | import os 5 | 6 | import numba 7 | import numpy as np 8 | import open3d as o3d 9 | import skimage.measure 10 | import wandb 11 | 12 | 13 | def log_transform(x, shift=1): 14 | """rescales TSDF values to weight voxels near the surface more than close 15 | to the truncation distance""" 16 | return x.sign() * (1 + x.abs() / shift).log() 17 | 18 | 19 | def to_vol(inds, vals): 20 | dims = np.max(inds, axis=0) + 1 21 | vol = np.ones(dims) * np.nan 22 | vol[inds[:, 0], inds[:, 1], inds[:, 2]] = vals 23 | return vol 24 | 25 | 26 | def to_mesh(vol, voxel_size=1, origin=np.zeros(3), level=0, mask=None): 27 | verts, faces, _, _ = skimage.measure.marching_cubes(vol, level=level, mask=mask) 28 | verts *= voxel_size 29 | verts += origin 30 | 31 | bad_face_inds = np.any(np.isnan(verts[faces]), axis=(1, 2)) 32 | faces = faces[~bad_face_inds] 33 | 34 | bad_vert_inds = np.any(np.isnan(verts), axis=-1) 35 | reindex = np.cumsum(~bad_vert_inds) - 1 36 | faces = reindex[faces] 37 | verts = verts[~bad_vert_inds] 38 | 39 | mesh = o3d.geometry.TriangleMesh( 40 | o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(faces) 41 | ) 42 | mesh.compute_vertex_normals() 43 | return mesh 44 | 45 | 46 | @numba.jit(nopython=True) 47 | def remove_redundant(poses, rmin_deg, tmin): 48 | cos_t_max = np.cos(rmin_deg * np.pi / 180) 49 | frame_inds = np.arange(len(poses)) 50 | selected_frame_inds = [frame_inds[0]] 51 | for frame_ind in frame_inds[1:]: 52 | prev_pose = poses[selected_frame_inds[-1]] 53 | candidate_pose = poses[frame_ind] 54 | cos_t = np.sum(prev_pose[:3, 2] * candidate_pose[:3, 2]) 55 | tdist = np.linalg.norm(prev_pose[:3, 3] - candidate_pose[:3, 3]) 56 | if tdist > tmin or cos_t < cos_t_max: 57 | selected_frame_inds.append(frame_ind) 58 | return selected_frame_inds 59 | 60 | 61 | def frame_selection( 62 | poses, 63 | intr, 64 | imwidth, 65 | imheight, 66 | sample_pts, 67 | tmin, 68 | rmin_deg, 69 | n_imgs, 70 | ): 71 | # select randomly among views that see at least one sample point 72 | 73 | intr4x4 = np.eye(4, dtype=np.float32) 74 | intr4x4[:3, :3] = intr 75 | 76 | xyz = np.concatenate( 77 | (sample_pts, np.ones((len(sample_pts), 1), dtype=sample_pts.dtype)), axis=-1 78 | ) 79 | uv = intr4x4 @ np.linalg.inv(poses) @ xyz.T 80 | z = uv[:, 2] 81 | z_valid = z > 1e-10 82 | z[~z_valid] = 1 83 | uv = uv[:, :2] / z[:, None] 84 | valid = ( 85 | (uv[:, 0] > 0) 86 | & (uv[:, 0] < imwidth) 87 | & (uv[:, 1] > 0) 88 | & (uv[:, 1] < imheight) 89 | & z_valid 90 | ) 91 | intersections = np.sum(valid, axis=-1) 92 | intersect_inds = np.argwhere(intersections > 0).flatten() 93 | 94 | frame_inds = np.arange(len(poses), dtype=np.int32) 95 | 96 | if n_imgs is None: 97 | score = intersections[intersect_inds] 98 | selected_frame_inds = frame_inds[intersect_inds] 99 | elif len(intersect_inds) >= n_imgs: 100 | selected_frame_inds = np.random.choice( 101 | intersect_inds, size=n_imgs, replace=False 102 | ) 103 | score = intersections[selected_frame_inds] 104 | else: 105 | not_intersect_inds = np.argwhere(intersections == 0).flatten() 106 | n_needed = n_imgs - len(intersect_inds) 107 | extra_inds = np.random.choice(not_intersect_inds, size=n_needed, replace=False) 108 | selected_frame_inds = np.concatenate((intersect_inds, extra_inds)) 109 | score = np.concatenate((intersections[intersect_inds], np.zeros(n_needed))) 110 | 111 | return (selected_frame_inds, score) 112 | 113 | 114 | def load_info_files(scannet_dir, split): 115 | with open(os.path.join(scannet_dir, f"scannetv2_{split}.txt"), "r") as f: 116 | scene_names = f.read().split() 117 | scan_dir = "scans" if split in ["train", "val"] else "scans_test" 118 | info_files = sorted(glob.glob(os.path.join(scannet_dir, scan_dir, "*/info.json"))) 119 | info_files = [ 120 | f for f in info_files if os.path.basename(os.path.dirname(f)) in scene_names 121 | ] 122 | return info_files 123 | 124 | 125 | def visualize(o3d_geoms): 126 | visibility = [True] * len(o3d_geoms) 127 | 128 | def toggle_geom(vis, geom_ind): 129 | if visibility[geom_ind]: 130 | vis.remove_geometry(o3d_geoms[geom_ind], reset_bounding_box=False) 131 | visibility[geom_ind] = False 132 | else: 133 | vis.add_geometry(o3d_geoms[geom_ind], reset_bounding_box=False) 134 | visibility[geom_ind] = True 135 | 136 | callbacks = {} 137 | for i in range(len(o3d_geoms)): 138 | callbacks[ord(str(i + 1))] = functools.partial(toggle_geom, geom_ind=i) 139 | o3d.visualization.draw_geometries_with_key_callbacks(o3d_geoms, callbacks) 140 | -------------------------------------------------------------------------------- /cvrecon/view_direction_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ViewDirectionEncoder(torch.nn.Module): 6 | def __init__(self, feat_depth, L): 7 | super().__init__() 8 | self.L = L 9 | self.view_embedding_dim = 3 + self.L * 6 10 | self.conv = torch.nn.Sequential( 11 | torch.nn.Conv2d( 12 | feat_depth + self.view_embedding_dim, feat_depth, 1, bias=False 13 | ), 14 | ) 15 | torch.nn.init.xavier_normal_(self.conv[0].weight) 16 | 17 | def forward(self, feats, proj, cam_positions): 18 | device = feats.device 19 | dtype = feats.dtype 20 | featheight, featwidth = feats.shape[2:] 21 | u = torch.arange(featwidth, device=device, dtype=dtype) 22 | v = torch.arange(featheight, device=device, dtype=dtype) 23 | vv, uu = torch.meshgrid(v, u, indexing='ij') 24 | ones = torch.ones_like(uu) 25 | uv = torch.stack((uu, vv, ones, ones), dim=0).to(dtype) 26 | 27 | inv_proj = torch.linalg.inv(proj) 28 | xyz = inv_proj @ uv.reshape(4, -1) 29 | view_vecs = xyz[:, :, :3] - cam_positions[..., None] 30 | view_vecs /= torch.linalg.norm(view_vecs, dim=2, keepdim=True) 31 | view_vecs = view_vecs.to(dtype) 32 | 33 | view_encoding = [view_vecs] 34 | for i in range(self.L): 35 | view_encoding.append(torch.sin(view_vecs * np.pi * 2 ** i)) 36 | view_encoding.append(torch.cos(view_vecs * np.pi * 2 ** i)) 37 | view_encoding = torch.cat(view_encoding, dim=2) 38 | 39 | view_encoding = view_encoding.reshape( 40 | view_encoding.shape[0] * view_encoding.shape[1], 41 | view_encoding.shape[2], 42 | featheight, 43 | featwidth, 44 | ) 45 | 46 | feats = torch.cat((feats, view_encoding), dim=1) 47 | feats = self.conv(feats) 48 | return feats 49 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import yaml 5 | 6 | import imageio 7 | import numpy as np 8 | import open3d as o3d 9 | import pytorch_lightning as pl 10 | import skimage.measure 11 | import torch 12 | import torchsparse 13 | import tqdm 14 | import torch.nn.functional as F 15 | 16 | from cvrecon import data, lightningmodel, utils 17 | 18 | 19 | import matplotlib.pyplot as plt 20 | from PIL import Image 21 | def colormap_image( 22 | image_1hw, 23 | mask_1hw=None, 24 | invalid_color=(0.0, 0, 0.0), 25 | flip=True, 26 | vmin=None, 27 | vmax=None, 28 | return_vminvmax=False, 29 | colormap="turbo", 30 | ): 31 | """ 32 | Colormaps a one channel tensor using a matplotlib colormap. 33 | 34 | Args: 35 | image_1hw: the tensor to colomap. 36 | mask_1hw: an optional float mask where 1.0 donates valid pixels. 37 | colormap: the colormap to use. Default is turbo. 38 | invalid_color: the color to use for invalid pixels. 39 | flip: should we flip the colormap? True by default. 40 | vmin: if provided uses this as the minimum when normalizing the tensor. 41 | vmax: if provided uses this as the maximum when normalizing the tensor. 42 | When either of vmin or vmax are None, they are computed from the 43 | tensor. 44 | return_vminvmax: when true, returns vmin and vmax. 45 | 46 | Returns: 47 | image_cm_3hw: image of the colormapped tensor. 48 | vmin, vmax: returned when return_vminvmax is true. 49 | 50 | 51 | """ 52 | valid_vals = image_1hw if mask_1hw is None else image_1hw[mask_1hw.bool()] 53 | if vmin is None: 54 | vmin = valid_vals.min() 55 | if vmax is None: 56 | vmax = valid_vals.max() 57 | 58 | cmap = torch.Tensor( 59 | plt.cm.get_cmap(colormap)( 60 | torch.linspace(0, 1, 256) 61 | )[:, :3] 62 | ).to(image_1hw.device) 63 | if flip: 64 | cmap = torch.flip(cmap, (0,)) 65 | 66 | h, w = image_1hw.shape[1:] 67 | 68 | image_norm_1hw = (image_1hw - vmin) / (vmax - vmin) 69 | image_int_1hw = (torch.clamp(image_norm_1hw * 255, 0, 255)).byte().long() 70 | 71 | image_cm_3hw = cmap[image_int_1hw.flatten(start_dim=1) 72 | ].permute([0, 2, 1]).view([-1, h, w]) 73 | 74 | if mask_1hw is not None: 75 | invalid_color = torch.Tensor(invalid_color).view(3, 1, 1).to(image_1hw.device) 76 | image_cm_3hw = image_cm_3hw * mask_1hw + invalid_color * (1 - mask_1hw) 77 | 78 | if return_vminvmax: 79 | return image_cm_3hw, vmin, vmax 80 | else: 81 | return image_cm_3hw 82 | 83 | 84 | def save_gif(rgb_imgfiles, output_path): 85 | gif = [] 86 | for fname in rgb_imgfiles: 87 | gif.append(Image.open(fname)) 88 | gif[0].save(os.path.join(output_path, 'rgb.gif'), save_all=True,optimize=False, append_images=gif[1:], loop=0) 89 | 90 | 91 | def load_model(ckpt_file, use_proj_occ, config): 92 | model = lightningmodel.LightningModel.load_from_checkpoint( 93 | ckpt_file, 94 | config=config, 95 | ) 96 | model.cvrecon.use_proj_occ = use_proj_occ 97 | model = model.cuda() 98 | model = model.eval() 99 | model.requires_grad_(False) 100 | return model 101 | 102 | 103 | def load_scene(info_file): 104 | with open(info_file, "r") as f: 105 | info = json.load(f) 106 | 107 | rgb_imgfiles = [frame["filename_color"] for frame in info["frames"]] 108 | depth_imgfiles = [frame["filename_depth"] for frame in info["frames"]] 109 | pose = np.empty((len(info["frames"]), 4, 4), dtype=np.float32) 110 | for i, frame in enumerate(info["frames"]): 111 | pose[i] = frame["pose"] 112 | intr = np.array(info["intrinsics"], dtype=np.float32) 113 | return rgb_imgfiles, depth_imgfiles, pose, intr 114 | 115 | 116 | def get_scene_bounds(pose, intr, imheight, imwidth, frustum_depth): 117 | frust_pts_img = np.array( 118 | [ 119 | [0, 0], 120 | [imwidth, 0], 121 | [imwidth, imheight], 122 | [0, imheight], 123 | ] 124 | ) 125 | frust_pts_cam = ( 126 | np.linalg.inv(intr) @ np.c_[frust_pts_img, np.ones(len(frust_pts_img))].T 127 | ).T * frustum_depth 128 | frust_pts_world = ( 129 | pose @ np.c_[frust_pts_cam, np.ones(len(frust_pts_cam))].T 130 | ).transpose(0, 2, 1)[..., :3] 131 | 132 | minbound = np.min(frust_pts_world, axis=(0, 1)) 133 | maxbound = np.max(frust_pts_world, axis=(0, 1)) 134 | return minbound, maxbound 135 | 136 | 137 | def get_tiles(minbound, maxbound, cropsize_voxels_fine, voxel_size_fine): 138 | cropsize_m = cropsize_voxels_fine * voxel_size_fine 139 | 140 | assert np.all(cropsize_voxels_fine % 4 == 0) 141 | cropsize_voxels_coarse = cropsize_voxels_fine // 4 142 | voxel_size_coarse = voxel_size_fine * 4 143 | 144 | ncrops = np.ceil((maxbound - minbound) / cropsize_m).astype(int) 145 | x = np.arange(ncrops[0], dtype=np.int32) * cropsize_voxels_coarse[0] 146 | y = np.arange(ncrops[1], dtype=np.int32) * cropsize_voxels_coarse[1] 147 | z = np.arange(ncrops[2], dtype=np.int32) * cropsize_voxels_coarse[2] 148 | yy, xx, zz = np.meshgrid(y, x, z) 149 | tile_origin_inds = np.c_[xx.flatten(), yy.flatten(), zz.flatten()] 150 | 151 | x = np.arange(0, cropsize_voxels_coarse[0], dtype=np.int32) 152 | y = np.arange(0, cropsize_voxels_coarse[1], dtype=np.int32) 153 | z = np.arange(0, cropsize_voxels_coarse[2], dtype=np.int32) 154 | yy, xx, zz = np.meshgrid(y, x, z) 155 | base_voxel_inds = np.c_[xx.flatten(), yy.flatten(), zz.flatten()] 156 | 157 | tiles = [] 158 | for origin_ind in tile_origin_inds: 159 | origin = origin_ind * voxel_size_coarse + minbound 160 | tile = { 161 | "origin_ind": origin_ind, 162 | "origin": origin.astype(np.float32), 163 | "maxbound_ind": origin_ind + cropsize_voxels_coarse, 164 | "voxel_inds": torch.from_numpy(base_voxel_inds + origin_ind), 165 | "voxel_coords": torch.from_numpy( 166 | base_voxel_inds * voxel_size_coarse + origin 167 | ).float(), 168 | "voxel_features": torch.empty( 169 | (len(base_voxel_inds), 0), dtype=torch.float32 170 | ), 171 | "voxel_logits": torch.empty((len(base_voxel_inds), 0), dtype=torch.float32), 172 | } 173 | tiles.append(tile) 174 | return tiles 175 | 176 | 177 | def frame_selection(tiles, pose, intr, imheight, imwidth, n_imgs, rmin_deg, tmin, SRlist, rgb_imgfiles, CVDict): 178 | sparsified_frame_inds = np.array(utils.remove_redundant(pose, rmin_deg, tmin)) 179 | 180 | if SRlist is not None: 181 | SRlist_inds = [] 182 | for frame_ind in sparsified_frame_inds: 183 | if '0' + rgb_imgfiles[frame_ind][-9:-4] in SRlist: 184 | SRlist_inds.append(frame_ind) 185 | if len(sparsified_frame_inds) != len(SRlist_inds): 186 | print('!!!!!!!!!!!!!!!!!', len(sparsified_frame_inds), len(SRlist_inds), scene_name) 187 | sparsified_frame_inds = np.array(SRlist_inds) 188 | 189 | if CVDict is not None: 190 | SRlist_inds = [] 191 | for frame_ind in sparsified_frame_inds: 192 | if '0' + rgb_imgfiles[frame_ind][-9:-4] in CVDict: 193 | SRlist_inds.append(frame_ind) 194 | if len(sparsified_frame_inds) != len(SRlist_inds): 195 | print('!!!!!!!!!!!!!!!!!', len(sparsified_frame_inds), len(SRlist_inds), scene_name) 196 | sparsified_frame_inds = np.array(SRlist_inds) 197 | 198 | if len(sparsified_frame_inds) < n_imgs: 199 | print('@@@@@@@@@@@@@@@@@', scene_name, len(sparsified_frame_inds)) 200 | # after redundant frame removal we can end up with too few frames-- 201 | # add some back in 202 | avail_inds = list(set(np.arange(len(pose))) - set(sparsified_frame_inds)) 203 | n_needed = n_imgs - len(sparsified_frame_inds) 204 | extra_inds = np.random.choice(avail_inds, size=n_needed, replace=False) 205 | selected_frame_inds = np.concatenate((sparsified_frame_inds, extra_inds)) 206 | else: 207 | selected_frame_inds = sparsified_frame_inds 208 | 209 | for i, tile in enumerate(tiles): 210 | if len(selected_frame_inds) > n_imgs: 211 | sample_pts = tile["voxel_coords"].numpy() 212 | cur_frame_inds, score = utils.frame_selection( 213 | pose[selected_frame_inds], 214 | intr, 215 | imwidth, 216 | imheight, 217 | sample_pts, 218 | tmin, 219 | rmin_deg, 220 | n_imgs, 221 | ) 222 | tile["frame_inds"] = selected_frame_inds[cur_frame_inds] 223 | else: 224 | tile["frame_inds"] = selected_frame_inds 225 | return tiles 226 | 227 | 228 | def get_img_feats(cvrecon, imheight, imwidth, proj_mats, rgb_imgfiles, cam_positions, SRfeat, scene_name): 229 | imsize = np.array([imheight, imwidth]) 230 | dims = { 231 | "coarse": imsize // 16, 232 | "medium": imsize // 8, 233 | "fine": imsize // 4, 234 | } 235 | feats_2d = { 236 | "coarse": torch.empty( 237 | (1, len(rgb_imgfiles), 80, *dims["coarse"]), dtype=torch.float16 238 | ), 239 | "medium": torch.empty( 240 | (1, len(rgb_imgfiles), 40, *dims["medium"]), dtype=torch.float16 241 | ), 242 | "fine": torch.empty( 243 | (1, len(rgb_imgfiles), 24, *dims["fine"]), dtype=torch.float16 244 | ), 245 | } 246 | cam_positions = torch.from_numpy(cam_positions).cuda()[None] 247 | for i in range(len(rgb_imgfiles)): 248 | rgb_img = data.load_rgb_imgs([rgb_imgfiles[i]], imheight, imwidth) 249 | rgb_img = torch.from_numpy(rgb_img).cuda()[None] 250 | cur_proj_mats = {k: v[:, i, None] for k, v in proj_mats.items()} 251 | if SRfeat: 252 | SRfeat0, SRfeat1, SRfeat2 = data.load_SRfeats(scene_name, ['0'+rgb_imgfiles[i][-9:-4]]) 253 | SRfeat0 = torch.from_numpy(SRfeat0).cuda()[None] 254 | SRfeat1 = torch.from_numpy(SRfeat1).cuda()[None] 255 | SRfeat2 = torch.from_numpy(SRfeat2).cuda()[None] 256 | cur_feats_2d = model.cvrecon.get_SR_feats(SRfeat0, SRfeat1, SRfeat2, cur_proj_mats, cam_positions[:, i, None]) 257 | else: 258 | cur_feats_2d = model.cvrecon.get_img_feats(rgb_img, cur_proj_mats, cam_positions[:, i, None]) 259 | 260 | for resname in feats_2d: 261 | feats_2d[resname][0, i] = cur_feats_2d[resname][0, 0].cpu() 262 | return feats_2d 263 | 264 | 265 | def construct_cv(model, cur_feats, ref_feats, intr, pose, ref_pose, n_imgs, rgb_imgfiles): 266 | cvs = [] 267 | cv_masks = [] 268 | 269 | k = np.eye(4, dtype=np.float32) 270 | k[:3, :3] = intr 271 | k[0] = k[0] * 0.125 272 | k[1] = k[1] * 0.125 273 | invK = torch.from_numpy(np.linalg.inv(k)).unsqueeze(0).cuda() 274 | k = torch.from_numpy(k).cuda() 275 | 276 | src_K = k.unsqueeze(0).unsqueeze(0).repeat(1, 7, 1, 1) # [1, 7, 4, 4] 277 | min_depth = torch.tensor(0.25).type_as(src_K).view(1, 1, 1, 1) 278 | max_depth = torch.tensor(5.0).type_as(src_K).view(1, 1, 1, 1) 279 | 280 | inv_pose = torch.from_numpy(np.linalg.inv(pose)).cuda() 281 | inv_ref_pose = torch.from_numpy(np.linalg.inv(ref_pose)).cuda().view(n_imgs, 7, 4, 4) 282 | pose = torch.from_numpy(pose).cuda() 283 | ref_pose = torch.from_numpy(ref_pose).cuda().view(n_imgs, 7, 4, 4) 284 | 285 | if vis_lowest: 286 | output_path = f'/test/{scene_name}' 287 | if not os.path.exists(output_path): os.mkdir(output_path) 288 | gif = [] 289 | 290 | for i in range(n_imgs): 291 | matching_cur_feats = cur_feats[i].unsqueeze(0) 292 | matching_src_feat = ref_feats[i].unsqueeze(0) 293 | 294 | src_cam_T_world = inv_ref_pose[i].unsqueeze(0) 295 | src_world_T_cam = ref_pose[i].unsqueeze(0) 296 | cur_cam_T_world = inv_pose[i].unsqueeze(0) 297 | cur_world_T_cam = pose[i].unsqueeze(0) 298 | with torch.cuda.amp.autocast(False): 299 | # Compute src_cam_T_cur_cam, a transformation for going from 3D 300 | # coords in current view coordinate frame to source view coords 301 | # coordinate frames. 302 | src_cam_T_cur_cam = src_cam_T_world @ cur_world_T_cam.unsqueeze(1) 303 | 304 | # Compute cur_cam_T_src_cam the opposite of src_cam_T_cur_cam. From 305 | # source view to current view. 306 | cur_cam_T_src_cam = cur_cam_T_world.unsqueeze(1) @ src_world_T_cam 307 | 308 | cost_volume, lowest_cost, _, overall_mask_bhw = model.cvrecon.cost_volume( 309 | cur_feats=matching_cur_feats, 310 | src_feats=matching_src_feat, 311 | src_extrinsics=src_cam_T_cur_cam, 312 | src_poses=cur_cam_T_src_cam, 313 | src_Ks=src_K, 314 | cur_invK=invK, 315 | min_depth=min_depth, 316 | max_depth=max_depth, 317 | return_mask=True, 318 | return_lowest=vis_lowest, 319 | ) 320 | 321 | if vis_lowest: 322 | lowest_cost_3hw = colormap_image(lowest_cost, vmin=5, vmax=0.25) 323 | gif.append(Image.fromarray(np.uint8(lowest_cost_3hw.permute(1,2,0).cpu().detach().numpy() * 255))) 324 | 325 | cvs.append(cost_volume.unsqueeze(1)) 326 | cv_masks.append(overall_mask_bhw.unsqueeze(1)) 327 | 328 | if vis_lowest: 329 | gif[0].save(os.path.join(output_path, 'lowest_cost.gif'), save_all=True,optimize=False, append_images=gif[1:], loop=0) 330 | save_gif(rgb_imgfiles, output_path) 331 | 332 | cvs = torch.cat(cvs, dim=1) # [b, n, c, d, h, w] 333 | cv_masks = torch.cat(cv_masks, dim=1) 334 | if config["cv_overall"]: 335 | ############################### skiped overall feat #################################################### 336 | # overallfeat = cvs[:, :, -1:, ::8, ...].permute(0, 1, 3, 2, 4, 5).expand([-1, -1, -1, cvs.shape[3], -1, -1]) 337 | 338 | # ############################### conv overall feat #################################################### 339 | # overallfeat = cvs[:, :, -1, ...].view([-1] + list(cvs.shape[3:])) 340 | # overallfeat = self.cv_global_encoder(overallfeat).view(list(cvs.shape[:2]) + [8, 1, cvs.shape[-2], cvs.shape[-1]]) 341 | # overallfeat = overallfeat.expand([-1, -1, -1, cvs.shape[3], -1, -1]) 342 | 343 | # ############################### complete overall feat #################################################### 344 | # overallfeat = cvs[:, :, -1:, :, ...].permute(0, 1, 3, 2, 4, 5).expand([-1, -1, -1, cvs.shape[3], -1, -1]) 345 | 346 | # cvs = cvs[:, :, :-1, ...] 347 | # cvs = torch.cat([overallfeat, cvs], dim=2) 348 | pass 349 | return cvs, cv_masks 350 | 351 | 352 | def inference(model, info_file, outfile, n_imgs, cropsize, SRlist=None, scene_name=None, CVDict=None): 353 | rgb_imgfiles, depth_imgfiles, pose, intr = load_scene(info_file) 354 | test_img = imageio.imread(rgb_imgfiles[0]) 355 | imheight, imwidth, _ = test_img.shape 356 | 357 | scene_minbound, scene_maxbound = get_scene_bounds( 358 | pose, intr, imheight, imwidth, frustum_depth=4 359 | ) 360 | 361 | pose_w2c = np.linalg.inv(pose) 362 | tiles = get_tiles( # divide to non-overlapping fragments 363 | scene_minbound, 364 | scene_maxbound, 365 | cropsize_voxels_fine=np.array(cropsize), 366 | voxel_size_fine=0.04, 367 | ) 368 | 369 | # pre-select views for each tile 370 | tiles = frame_selection( 371 | tiles, pose, intr, imheight, imwidth, n_imgs=n_imgs, rmin_deg=15, tmin=0.1, SRlist=SRlist, rgb_imgfiles=rgb_imgfiles, CVDict=CVDict, 372 | ) 373 | 374 | # drop the frames that weren't selected for any tile, re-index the selected frame indicies 375 | selected_frame_inds = np.unique( 376 | np.concatenate([tile["frame_inds"] for tile in tiles]) 377 | ) 378 | 379 | all_frame_inds = np.arange(len(pose)) 380 | frame_reindex = np.full(len(all_frame_inds), 100_000) 381 | frame_reindex[selected_frame_inds] = np.arange(len(selected_frame_inds)) 382 | for tile in tiles: 383 | tile["frame_inds"] = frame_reindex[tile["frame_inds"]] 384 | pose_w2c = pose_w2c[selected_frame_inds] 385 | pose = pose[selected_frame_inds] 386 | rgb_imgfiles = np.array(rgb_imgfiles)[selected_frame_inds] 387 | 388 | if CVDict is not None: 389 | with open(info_file, "r") as f: 390 | info = json.load(f) 391 | ref_pose = [] 392 | ref_img = [] 393 | cv_invalid_mask = np.zeros(len(rgb_imgfiles), dtype=np.int) 394 | frame2id = {'0'+frame["filename_color"][-9:-4]:i for i, frame in enumerate(info["frames"])} 395 | for i, fname in enumerate(rgb_imgfiles.copy()): 396 | if '0' + fname[-9: -4] in CVDict: 397 | for frameid in CVDict['0' + fname[-9: -4]]: 398 | ref_pose.append(np.array(info['frames'][frame2id[frameid]]['pose'], dtype=np.float32)[None,...]) 399 | ref_img.append(info['frames'][frame2id[frameid]]['filename_color']) 400 | else: 401 | print('!!!!!!!!!!!!!! invalid cv at ', '0' + fname[-9: -4]) 402 | cv_invalid_mask[i] = 1 403 | for i in range(7): 404 | ref_pose.append(np.array(info['frames'][frame2id['0'+fname[-9: -4]]]['pose'], dtype=np.float32)[None,...]) 405 | ref_img.append(fname) 406 | ref_pose = np.concatenate(ref_pose) 407 | ref_imgs_paths = np.array(ref_img) 408 | 409 | cur_imgs = [] 410 | for i in range(len(rgb_imgfiles)): 411 | cur_img = data.load_rgb_imgs([rgb_imgfiles[i]], imheight, imwidth) 412 | cur_imgs.append(torch.from_numpy(cur_img).cuda()[None]) 413 | cur_imgs = torch.cat(cur_imgs, dim=0) 414 | cur_feats = model.cvrecon.compute_matching_feats(cur_imgs).squeeze() # [n_imgs, c, h, w] 415 | 416 | ref_imgs = [] 417 | for i in range(len(ref_imgs_paths)): 418 | ref_img = data.load_rgb_imgs([ref_imgs_paths[i]], imheight, imwidth) 419 | ref_imgs.append(torch.from_numpy(ref_img).cuda()) 420 | ref_imgs = torch.cat(ref_imgs, dim=0).view(-1, 7, 3, imheight, imwidth) # [n_imgs, 7, 3, h, w] 421 | ref_feats = model.cvrecon.compute_matching_feats(ref_imgs) # [n_imgs, 7, c, h, w] 422 | 423 | cost_volume, cv_masks = construct_cv(model, cur_feats, ref_feats, intr, pose, ref_pose, len(rgb_imgfiles), rgb_imgfiles) 424 | cost_volume[0][cv_invalid_mask.astype(bool)] = 0 425 | 426 | factors = np.array([1 / 16, 1 / 8, 1 / 4]) 427 | proj_mats = data.get_proj_mats(intr, pose_w2c, factors) 428 | proj_mats = {k: torch.from_numpy(v)[None].cuda() for k, v in proj_mats.items()} 429 | img_feats = get_img_feats( 430 | model, 431 | imheight, 432 | imwidth, 433 | proj_mats, 434 | rgb_imgfiles, 435 | cam_positions=pose[:, :3, 3], 436 | SRfeat=SRlist!=None, 437 | scene_name=scene_name, 438 | ) 439 | for resname, res in model.cvrecon.resolutions.items(): 440 | 441 | # populate feature volume independently for each tile 442 | for tile in tiles: 443 | voxel_coords = tile["voxel_coords"].cuda() 444 | voxel_batch_inds = torch.zeros( 445 | len(voxel_coords), dtype=torch.int64, device="cuda" 446 | ) 447 | 448 | cur_img_feats = img_feats[resname][:, tile["frame_inds"]].cuda() 449 | cur_proj_mats = proj_mats[resname][:, tile["frame_inds"]] 450 | cur_cost_volume = cost_volume[:, tile["frame_inds"]].clone() 451 | featheight, featwidth = img_feats[resname].shape[-2:] 452 | 453 | #################################################### 2dfeat & CV group conv ################################################################ 454 | feat_cha = {'coarse': 80, 'medium': 40, 'fine':24} 455 | cv_dim = 15 - 8 456 | bs = 1 457 | if resname != 'medium': 458 | cur_cost_volume = F.interpolate(cur_cost_volume.view([bs*n_imgs*cv_dim, 64, 60, 80]), [featheight, featwidth]).view([bs, n_imgs, cv_dim, 64, featheight, featwidth]) 459 | cur_img_feats = model.cvrecon.cv_global_encoder[resname](torch.cat([cur_cost_volume[:,:,-1], cur_img_feats], dim=2).view([-1, feat_cha[resname]+64, featheight, featwidth])) 460 | cur_img_feats = cur_img_feats.view([bs, n_imgs, feat_cha[resname]+64, featheight, featwidth]) 461 | overallfeat = cur_img_feats[:, :, :64].unsqueeze(3).expand([-1, -1, -1, 64, -1, -1]) 462 | cur_img_feats = cur_img_feats[:,:,64:] 463 | cur_cost_volume = model.cvrecon.unshared_conv[resname]( 464 | torch.cat([cur_img_feats.unsqueeze(3).expand([-1,-1,-1,64,-1,-1]), cur_cost_volume], dim=2).transpose(2,3).reshape(bs*n_imgs,-1,featheight, featwidth)) 465 | cur_cost_volume = cur_cost_volume.view([bs, n_imgs, 64, 7, featheight, featwidth]).transpose(2,3) 466 | cur_cost_volume = torch.cat([overallfeat, cur_cost_volume], dim=2) 467 | ############################################################################################################################################# 468 | 469 | bp_uv, bp_depth, bp_mask = model.cvrecon.project_voxels( 470 | voxel_coords, 471 | voxel_batch_inds, 472 | cur_proj_mats.transpose(0, 1), 473 | featheight, 474 | featwidth, 475 | ) 476 | bp_data = { 477 | "voxel_batch_inds": voxel_batch_inds, 478 | "bp_uv": bp_uv, 479 | "bp_depth": bp_depth, 480 | "bp_mask": bp_mask, 481 | } 482 | bp_feats, proj_occ_logits = model.cvrecon.back_project_features( 483 | bp_data, 484 | cur_img_feats.transpose(0, 1), 485 | model.cvrecon.mv_fusion[resname], 486 | cur_cost_volume if (CVDict is not None) else None, 487 | cv_masks[:, tile["frame_inds"]] if (CVDict is not None) else None, 488 | ) 489 | bp_feats = model.cvrecon.layer_norms[resname](bp_feats) 490 | 491 | tile["voxel_features"] = torch.cat( 492 | (tile["voxel_features"], bp_feats.cpu(), tile["voxel_logits"]), 493 | dim=-1, 494 | ) 495 | 496 | # combine all tiles into one sparse tensor & run convolution 497 | voxel_inds = torch.cat([tile["voxel_inds"] for tile in tiles], dim=0) 498 | voxel_batch_inds = torch.zeros((len(voxel_inds), 1), dtype=torch.int32) 499 | voxel_features = torchsparse.SparseTensor( 500 | torch.cat([tile["voxel_features"] for tile in tiles], dim=0).cuda(), 501 | torch.cat([voxel_inds, voxel_batch_inds], dim=-1).cuda(), 502 | ) 503 | 504 | voxel_features = model.cvrecon.cnns3d[resname](voxel_features) 505 | voxel_logits = model.cvrecon.output_layers[resname](voxel_features) 506 | 507 | if resname in ["coarse", "medium"]: 508 | # sparsify & upsample 509 | occupancy = voxel_logits.F.squeeze(1) > 0 510 | if not torch.any(occupancy): 511 | raise Exception("um") 512 | voxel_features = model.cvrecon.upsampler.upsample_feats( 513 | voxel_features.F[occupancy] 514 | ) 515 | voxel_inds = model.cvrecon.upsampler.upsample_inds(voxel_logits.C[occupancy]) 516 | voxel_logits = model.cvrecon.upsampler.upsample_feats( 517 | voxel_logits.F[occupancy] 518 | ) 519 | voxel_features = voxel_features.cpu() 520 | voxel_inds = voxel_inds.cpu() 521 | voxel_logits = voxel_logits.cpu() 522 | 523 | # split back up into tiles 524 | for tile in tiles: 525 | tile["origin_ind"] *= 2 526 | tile["maxbound_ind"] *= 2 527 | 528 | tile_voxel_mask = ( 529 | (voxel_inds[:, 0] >= tile["origin_ind"][0]) 530 | & (voxel_inds[:, 1] >= tile["origin_ind"][1]) 531 | & (voxel_inds[:, 2] >= tile["origin_ind"][2]) 532 | & (voxel_inds[:, 0] < tile["maxbound_ind"][0]) 533 | & (voxel_inds[:, 1] < tile["maxbound_ind"][1]) 534 | & (voxel_inds[:, 2] < tile["maxbound_ind"][2]) 535 | ) 536 | 537 | tile["voxel_inds"] = voxel_inds[tile_voxel_mask, :3] 538 | tile["voxel_features"] = voxel_features[tile_voxel_mask] 539 | tile["voxel_logits"] = voxel_logits[tile_voxel_mask] 540 | tile["voxel_coords"] = tile["voxel_inds"] * ( 541 | res / 2 542 | ) + scene_minbound.astype(np.float32) 543 | 544 | tsdf_vol = utils.to_vol( 545 | voxel_logits.C[:, :3].cpu().numpy(), 546 | 1.05 * torch.tanh(voxel_logits.F).squeeze(-1).cpu().numpy(), 547 | ) 548 | mesh = utils.to_mesh( 549 | -tsdf_vol, 550 | voxel_size=0.04, 551 | origin=scene_minbound, 552 | level=0, 553 | mask=~np.isnan(tsdf_vol), 554 | ) 555 | return mesh 556 | 557 | 558 | if __name__ == "__main__": 559 | parser = argparse.ArgumentParser() 560 | parser.add_argument("--ckpt", required=True) 561 | parser.add_argument("--split", default='test', type=str) 562 | parser.add_argument("--outputdir", required=True) 563 | parser.add_argument("--config", required=True) 564 | parser.add_argument("--use-proj-occ", default=True, type=bool) 565 | parser.add_argument("--n-imgs", default=60, type=int) 566 | parser.add_argument("--cropsize", default=96, type=int) 567 | parser.add_argument('--vis-lowest', action='store_true') 568 | args = parser.parse_args() 569 | 570 | pl.seed_everything(0) 571 | 572 | with open(args.config, "r") as f: 573 | config = yaml.safe_load(f) 574 | 575 | cropsize = (args.cropsize, args.cropsize, 48) 576 | 577 | SRfeat = config["SRfeat"] 578 | useCV = config["cost_volume"] 579 | vis_lowest = True if args.vis_lowest else False 580 | if SRfeat: 581 | from collections import defaultdict 582 | SRlists = defaultdict(list) 583 | with open('/data_splits/ScanNetv2/standard_split/{}_eight_view_deepvmvs_dense_for_cvrecon.txt'.format(args.split), 'r') as f: 584 | lines = f.read().splitlines() 585 | for line in lines: 586 | scan_id, frame_id = line.split(" ")[:2] 587 | SRlists[scan_id].append(frame_id) 588 | 589 | if useCV: 590 | from collections import defaultdict 591 | CVDicts = defaultdict(dict) 592 | fname = '/data_splits/ScanNetv2/standard_split/{}_for_cvrecon.txt'.format(args.split) 593 | if args.split == 'test': 594 | fname = '/data_splits/ScanNetv2/standard_split/test_eight_view_deepvmvs_dense_for_cvrecon.txt' 595 | with open(fname, 'r') as f: 596 | lines = f.read().splitlines() 597 | for line in lines: 598 | scan_id, *frame_id = line.split(" ") 599 | CVDicts[scan_id][frame_id[0]] = frame_id[1:] 600 | 601 | with torch.cuda.amp.autocast(): 602 | 603 | info_files = utils.load_info_files(config["scannet_dir"], args.split) 604 | model = load_model(args.ckpt, args.use_proj_occ, config) 605 | for info_file in tqdm.tqdm(info_files): 606 | 607 | scene_name = os.path.basename(os.path.dirname(info_file)) 608 | outdir = os.path.join(args.outputdir, scene_name) 609 | os.makedirs(outdir, exist_ok=True) 610 | outfile = os.path.join(outdir, "prediction.ply") 611 | 612 | # if os.path.exists(outfile): 613 | # print(outfile, 'exists, skipping') 614 | # continue 615 | 616 | # try: 617 | if SRfeat: 618 | mesh = inference(model, info_file, outfile, args.n_imgs, cropsize, SRlists[scene_name], scene_name) 619 | elif useCV: 620 | mesh = inference(model, info_file, outfile, args.n_imgs, cropsize, CVDict=CVDicts[scene_name]) 621 | else: 622 | mesh = inference(model, info_file, outfile, args.n_imgs, cropsize) 623 | o3d.io.write_triangle_mesh(outfile, mesh) 624 | # except Exception as e: 625 | # print(e) 626 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import random 6 | import subprocess 7 | 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | import torch 11 | import yaml 12 | from pytorch_lightning.plugins import DDPPlugin 13 | 14 | from cvrecon import collate, data, lightningmodel, utils 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--config", required=True) 20 | parser.add_argument("--gpus", default=1) 21 | args = parser.parse_args() 22 | 23 | with open(args.config, "r") as f: 24 | config = yaml.safe_load(f) 25 | 26 | pl.seed_everything(config["seed"]) 27 | 28 | if config['wandb_runid'] is not None: 29 | logger = pl.loggers.WandbLogger(project=config["wandb_project_name"], config=config, id=config['wandb_runid'], resume="must") 30 | else: 31 | logger = pl.loggers.WandbLogger(project=config["wandb_project_name"], config=config) 32 | subprocess.call( 33 | [ 34 | "zip", 35 | "-q", 36 | os.path.join(str(logger.experiment.dir), "code.zip"), 37 | "config.yml", 38 | *glob.glob("cvrecon/*.py"), 39 | *glob.glob("scripts/*.py"), 40 | ] 41 | ) 42 | 43 | ckpt_dir = os.path.join(str(logger.experiment.dir), "ckpts") 44 | checkpointer = pl.callbacks.ModelCheckpoint( 45 | save_last=True, 46 | dirpath=ckpt_dir, 47 | filename='{epoch}-{val/voxel_loss_medium:.4f}', 48 | verbose=True, 49 | save_top_k=20, 50 | monitor="val/voxel_loss_medium", 51 | ) 52 | callbacks = [checkpointer, lightningmodel.FineTuning(config["initial_epochs"], config["cost_volume"])] 53 | 54 | if config["use_amp"]: 55 | amp_kwargs = {"precision": 16} 56 | else: 57 | amp_kwargs = {} 58 | 59 | model = lightningmodel.LightningModel(config) 60 | 61 | 62 | trainer = pl.Trainer( 63 | gpus=args.gpus, 64 | logger=logger, 65 | benchmark=True, 66 | max_epochs=config["initial_epochs"] + config["finetune_epochs"] + 300, 67 | check_val_every_n_epoch=5, 68 | detect_anomaly=False, 69 | callbacks=callbacks, 70 | reload_dataloaders_every_n_epochs=1, # a hack so batch size can be adjusted for fine tuning 71 | strategy=DDPPlugin(find_unused_parameters=True), 72 | accumulate_grad_batches=1, 73 | num_sanity_val_steps=1, 74 | **amp_kwargs, 75 | ) 76 | trainer.fit(model, ckpt_path=config["ckpt"]) 77 | -------------------------------------------------------------------------------- /tools/generate_gt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import time 5 | import pickle 6 | import argparse 7 | from tqdm import tqdm 8 | import ray 9 | import glob 10 | import torch.multiprocessing 11 | from tools.simple_loader import * 12 | 13 | from cvrecon.tsdf_fusion import * 14 | 15 | torch.multiprocessing.set_sharing_strategy('file_system') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Fuse ground truth tsdf') 20 | parser.add_argument("--dataset", default='scannet') 21 | parser.add_argument("--data_path", metavar="DIR", 22 | help="path to raw dataset", default='/data/scannet/output/') 23 | parser.add_argument("--save_name", metavar="DIR", 24 | help="file name", default='all_tsdf') 25 | parser.add_argument('--test', action='store_true', 26 | help='prepare the test set') 27 | parser.add_argument('--max_depth', default=3., type=float, 28 | help='mask out large depth values since they are noisy') 29 | parser.add_argument('--num_layers', default=3, type=int) 30 | parser.add_argument('--margin', default=3, type=int) 31 | parser.add_argument('--voxel_size', default=0.04, type=float) 32 | 33 | parser.add_argument('--window_size', default=9, type=int) 34 | parser.add_argument('--min_angle', default=15, type=float) 35 | parser.add_argument('--min_distance', default=0.1, type=float) 36 | 37 | # ray multi processes 38 | parser.add_argument('--n_proc', type=int, default=4, help='#processes launched to process scenes.') 39 | parser.add_argument('--n_gpu', type=int, default=2, help='#number of gpus') 40 | parser.add_argument('--num_workers', type=int, default=4) 41 | parser.add_argument('--loader_num_workers', type=int, default=0) 42 | return parser.parse_args() 43 | 44 | 45 | args = parse_args() 46 | args.save_path = os.path.join(args.data_path, args.save_name) 47 | 48 | 49 | def save_tsdf_full(args, scene_path, cam_intr, depth_list, cam_pose_list, color_list, save_mesh=False): 50 | # ======================================================================================================== # 51 | # (Optional) This is an example of how to compute the 3D bounds 52 | # in world coordinates of the convex hull of all camera view 53 | # frustums in the dataset 54 | # ======================================================================================================== # 55 | vol_bnds = np.zeros((3, 2)) 56 | 57 | n_imgs = len(depth_list.keys()) 58 | if n_imgs > 200: 59 | ind = np.linspace(0, n_imgs - 1, 200).astype(np.int32) 60 | image_id = np.array(list(depth_list.keys()))[ind] 61 | else: 62 | image_id = depth_list.keys() 63 | for id in image_id: 64 | depth_im = depth_list[id] 65 | cam_pose = cam_pose_list[id] 66 | 67 | # Compute camera view frustum and extend convex hull 68 | view_frust_pts = get_view_frustum(depth_im, cam_intr, cam_pose) 69 | vol_bnds[:, 0] = np.minimum(vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)) 70 | vol_bnds[:, 1] = np.maximum(vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)) 71 | # ======================================================================================================== # 72 | 73 | # ======================================================================================================== # 74 | # Integrate 75 | # ======================================================================================================== # 76 | # Initialize voxel volume 77 | print("Initializing voxel volume...") 78 | tsdf_vol_list = [] 79 | for l in range(args.num_layers): 80 | tsdf_vol_list.append(TSDFVolume(vol_bnds, voxel_size=args.voxel_size * 2 ** l, margin=args.margin)) 81 | 82 | # Loop through RGB-D images and fuse them together 83 | t0_elapse = time.time() 84 | for id in depth_list.keys(): 85 | if id % 100 == 0: 86 | print("{}: Fusing frame {}/{}".format(scene_path, str(id), str(n_imgs))) 87 | depth_im = depth_list[id] 88 | cam_pose = cam_pose_list[id] 89 | if len(color_list) == 0: 90 | color_image = None 91 | else: 92 | color_image = color_list[id] 93 | 94 | # Integrate observation into voxel volume (assume color aligned with depth) 95 | for l in range(args.num_layers): 96 | tsdf_vol_list[l].integrate(color_image, depth_im, cam_intr, cam_pose, obs_weight=1.) 97 | 98 | fps = n_imgs / (time.time() - t0_elapse) 99 | print("Average FPS: {:.2f}".format(fps)) 100 | 101 | tsdf_info = { 102 | 'vol_origin': tsdf_vol_list[0]._vol_origin, 103 | 'voxel_size': tsdf_vol_list[0]._voxel_size, 104 | } 105 | tsdf_path = os.path.join(args.save_path, scene_path) 106 | if not os.path.exists(tsdf_path): 107 | os.makedirs(tsdf_path) 108 | 109 | with open(os.path.join(args.save_path, scene_path, 'tsdf_info.pkl'), 'wb') as f: 110 | pickle.dump(tsdf_info, f) 111 | 112 | for l in range(args.num_layers): 113 | tsdf_vol, color_vol, weight_vol = tsdf_vol_list[l].get_volume() 114 | ################################################################################################################### 115 | tsdf_vol[tsdf_vol < 1] *= -1 116 | tsdf_vol[(tsdf_vol == 1) & (weight_vol > 0)] = -1 117 | ################################################################################################################### 118 | np.savez_compressed(os.path.join(args.save_path, scene_path, 'full_tsdf_layer{}'.format(str(l))), tsdf_vol) 119 | 120 | if save_mesh: 121 | for l in range(args.num_layers): 122 | print("Saving mesh to mesh{}.ply...".format(str(l))) 123 | verts, faces, norms, colors = tsdf_vol_list[l].get_mesh() 124 | 125 | meshwrite(os.path.join(args.save_path, scene_path, 'mesh_layer{}.ply'.format(str(l))), verts, faces, norms, 126 | colors) 127 | 128 | 129 | 130 | def save_fragment_pkl(args, scene, cam_intr, depth_list, cam_pose_list): 131 | fragments = [] 132 | print('segment: process scene {}'.format(scene)) 133 | 134 | # gather pose 135 | vol_bnds = np.zeros((3, 2)) 136 | vol_bnds[:, 0] = np.inf 137 | vol_bnds[:, 1] = -np.inf 138 | 139 | all_ids = [] 140 | ids = [] 141 | all_bnds = [] 142 | count = 0 143 | last_pose = None 144 | for id in depth_list.keys(): 145 | depth_im = depth_list[id] 146 | cam_pose = cam_pose_list[id] 147 | 148 | if count == 0: 149 | ids.append(id) 150 | vol_bnds = np.zeros((3, 2)) 151 | vol_bnds[:, 0] = np.inf 152 | vol_bnds[:, 1] = -np.inf 153 | last_pose = cam_pose 154 | # Compute camera view frustum and extend convex hull 155 | view_frust_pts = get_view_frustum(depth_im, cam_intr, cam_pose) 156 | vol_bnds[:, 0] = np.minimum(vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)) 157 | vol_bnds[:, 1] = np.maximum(vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)) 158 | count += 1 159 | else: 160 | angle = np.arccos( 161 | ((np.linalg.inv(cam_pose[:3, :3]) @ last_pose[:3, :3] @ np.array([0, 0, 1]).T) * np.array( 162 | [0, 0, 1])).sum()) 163 | dis = np.linalg.norm(cam_pose[:3, 3] - last_pose[:3, 3]) 164 | if angle > (args.min_angle / 180) * np.pi or dis > args.min_distance: 165 | ids.append(id) 166 | last_pose = cam_pose 167 | # Compute camera view frustum and extend convex hull 168 | view_frust_pts = get_view_frustum(depth_im, cam_intr, cam_pose) 169 | vol_bnds[:, 0] = np.minimum(vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)) 170 | vol_bnds[:, 1] = np.maximum(vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)) 171 | count += 1 172 | if count == args.window_size: 173 | all_ids.append(ids) 174 | all_bnds.append(vol_bnds) 175 | ids = [] 176 | count = 0 177 | 178 | with open(os.path.join(args.save_path, scene, 'tsdf_info.pkl'), 'rb') as f: 179 | tsdf_info = pickle.load(f) 180 | 181 | # save fragments 182 | for i, bnds in enumerate(all_bnds): 183 | if not os.path.exists(os.path.join(args.save_path, scene, 'fragments', str(i))): 184 | os.makedirs(os.path.join(args.save_path, scene, 'fragments', str(i))) 185 | fragments.append({ 186 | 'scene': scene, 187 | 'fragment_id': i, 188 | 'image_ids': all_ids[i], 189 | 'vol_origin': tsdf_info['vol_origin'], 190 | 'voxel_size': tsdf_info['voxel_size'], 191 | }) 192 | 193 | with open(os.path.join(args.save_path, scene, 'fragments.pkl'), 'wb') as f: 194 | pickle.dump(fragments, f) 195 | 196 | 197 | @ray.remote(num_cpus=args.num_workers + 1, num_gpus=(1 / args.n_proc)) 198 | def process_with_single_worker(args, scannet_files): 199 | for scene in tqdm(scannet_files): 200 | if os.path.exists(os.path.join(args.save_path, scene, 'fragments.pkl')): 201 | continue 202 | print('read from disk') 203 | 204 | depth_all = {} 205 | cam_pose_all = {} 206 | color_all = {} 207 | 208 | if args.dataset == 'scannet': 209 | n_imgs = len(glob.glob(os.path.join(args.data_path, scene, 'sensor_data', '*.color.jpg'))) 210 | intrinsic_dir = os.path.join(args.data_path, scene, 'intrinsic', 'intrinsic_depth.txt') 211 | cam_intr = np.loadtxt(intrinsic_dir, delimiter=' ')[:3, :3] 212 | dataset = ScanNetDataset(n_imgs, scene, args.data_path, args.max_depth) 213 | 214 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, collate_fn=collate_fn, 215 | batch_sampler=None, num_workers=args.loader_num_workers) 216 | 217 | for id, (cam_pose, depth_im, _) in enumerate(dataloader): 218 | if id % 100 == 0: 219 | print("{}: read frame {}/{}".format(scene, str(id), str(n_imgs))) 220 | 221 | if cam_pose[0][0] == np.inf or cam_pose[0][0] == -np.inf or cam_pose[0][0] == np.nan: 222 | continue 223 | depth_all.update({id: depth_im}) 224 | cam_pose_all.update({id: cam_pose}) 225 | # color_all.update({id: color_image}) 226 | 227 | save_tsdf_full(args, scene, cam_intr, depth_all, cam_pose_all, color_all, save_mesh=False) 228 | save_fragment_pkl(args, scene, cam_intr, depth_all, cam_pose_all) 229 | 230 | 231 | def split_list(_list, n): 232 | assert len(_list) >= n 233 | ret = [[] for _ in range(n)] 234 | for idx, item in enumerate(_list): 235 | ret[idx % n].append(item) 236 | return ret 237 | 238 | 239 | def generate_pkl(args): 240 | all_scenes = sorted(os.listdir(args.save_path)) 241 | # todo: fix for both train/val/test 242 | if not args.test: 243 | splits = ['train', 'val'] 244 | else: 245 | splits = ['test'] 246 | for split in splits: 247 | fragments = [] 248 | with open(os.path.join(args.save_path, 'splits', 'scannetv2_{}.txt'.format(split))) as f: 249 | split_files = f.readlines() 250 | for scene in all_scenes: 251 | if 'scene' not in scene: 252 | continue 253 | if scene + '\n' in split_files: 254 | with open(os.path.join(args.save_path, scene, 'fragments.pkl'), 'rb') as f: 255 | frag_scene = pickle.load(f) 256 | fragments.extend(frag_scene) 257 | 258 | with open(os.path.join(args.save_path, 'fragments_{}.pkl'.format(split)), 'wb') as f: 259 | pickle.dump(fragments, f) 260 | 261 | 262 | if __name__ == "__main__": 263 | all_proc = args.n_proc * args.n_gpu 264 | 265 | ray.init(num_cpus=all_proc * (args.num_workers + 1), num_gpus=args.n_gpu) 266 | 267 | if args.dataset == 'scannet': 268 | if not args.test: 269 | args.data_path = os.path.join(args.data_path, 'scans') 270 | else: 271 | args.data_path = os.path.join(args.data_path, 'scans_test') 272 | files = sorted(os.listdir(args.data_path)) 273 | else: 274 | raise NameError('error!') 275 | 276 | files = split_list(files, all_proc) 277 | 278 | ray_worker_ids = [] 279 | for w_idx in range(all_proc): 280 | ray_worker_ids.append(process_with_single_worker.remote(args, files[w_idx])) 281 | 282 | results = ray.get(ray_worker_ids) 283 | -------------------------------------------------------------------------------- /tools/preprocess_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import cv2 5 | import json 6 | import tqdm 7 | import argparse 8 | import shutil 9 | 10 | 11 | def process_color_image(color, depth, K_color, K_depth): 12 | old_height, old_width = np.shape(color)[0:2] 13 | new_height, new_width = np.shape(depth) 14 | 15 | x = np.linspace(0, new_width - 1, num=new_width) 16 | y = np.linspace(0, new_height - 1, num=new_height) 17 | ones = np.ones(shape=(new_height, new_width)) 18 | x_grid, y_grid = np.meshgrid(x, y) 19 | warp_grid = np.stack((x_grid, y_grid, ones), axis=-1) 20 | warp_grid = torch.from_numpy(warp_grid).float() 21 | warp_grid = warp_grid.view(-1, 3).t().unsqueeze(0) 22 | 23 | H = K_color.dot(np.linalg.inv(K_depth)) 24 | H = torch.from_numpy(H).float().unsqueeze(0) 25 | 26 | width_normalizer = old_width / 2.0 27 | height_normalizer = old_height / 2.0 28 | 29 | warping = H.bmm(warp_grid).transpose(dim0=1, dim1=2) 30 | warping = warping[:, :, 0:2] / (warping[:, :, 2].unsqueeze(-1) + 1e-8) 31 | warping = warping.view(1, new_height, new_width, 2) 32 | warping[:, :, :, 0] = (warping[:, :, :, 0] - width_normalizer) / width_normalizer 33 | warping[:, :, :, 1] = (warping[:, :, :, 1] - height_normalizer) / height_normalizer 34 | 35 | image = torch.from_numpy(np.transpose(color, axes=(2, 0, 1))).float().unsqueeze(0) 36 | 37 | warped_image = torch.nn.functional.grid_sample(input=image, 38 | grid=warping, 39 | mode='nearest', 40 | padding_mode='zeros', 41 | align_corners=True) 42 | 43 | warped_image = warped_image.squeeze(0).numpy().astype(np.uint8) 44 | warped_image = np.transpose(warped_image, axes=(1, 2, 0)) 45 | return warped_image 46 | 47 | 48 | def process_scene(scene_dir_src, scene_dir_dst): 49 | scene_name = os.path.basename(scene_dir_src) 50 | data = { 51 | 'scene': scene_name, 52 | 'path': scene_dir_dst, 53 | 'frames': [] 54 | } 55 | 56 | if not os.path.exists(scene_dir_dst): 57 | os.makedirs(scene_dir_dst) 58 | color_dir_dst = os.path.join(scene_dir_dst, 'color') 59 | if not os.path.exists(color_dir_dst): 60 | os.makedirs(color_dir_dst) 61 | depth_dir_dst = os.path.join(scene_dir_dst, 'depth') 62 | if not os.path.exists(depth_dir_dst): 63 | os.makedirs(depth_dir_dst) 64 | 65 | # copy ground truth mesh to new folder 66 | gt_mesh_src = os.path.join(scene_dir_src, '{}_vh_clean_2.ply'.format(scene_name)) 67 | gt_mesh_dst = os.path.join(scene_dir_dst, '{}_vh_clean_2.ply'.format(scene_name)) 68 | shutil.copy(gt_mesh_src, gt_mesh_dst) 69 | data['gt_mesh'] = gt_mesh_dst 70 | 71 | K_color = np.loadtxt(os.path.join(scene_dir_src, 'intrinsic', 'intrinsic_color.txt'))[:3, :3] 72 | K_depth = np.loadtxt(os.path.join(scene_dir_src, 'intrinsic', 'intrinsic_depth.txt'))[:3, :3] 73 | data['intrinsics'] = K_depth.tolist() 74 | 75 | frames = sorted([f for f in os.listdir(os.path.join(scene_dir_src, 'sensor_data')) 76 | if f.endswith('.color.jpg')], key=lambda x: int(x.split('.')[0][-6:])) 77 | 78 | for frame in tqdm.tqdm(frames): 79 | frame_id = int(frame.split('.')[0][-6:]) 80 | fname_color_src = os.path.join(scene_dir_src, 'sensor_data', frame) 81 | fname_depth_src = os.path.join(scene_dir_src, 'sensor_data', 'frame-{0:06d}.depth.png'.format(frame_id)) 82 | fname_color_dst = os.path.join(scene_dir_dst, 'color', '{}.jpg'.format(frame_id).zfill(9)) 83 | fname_depth_dst = os.path.join(scene_dir_dst, 'depth', '{}.png'.format(frame_id).zfill(9)) 84 | 85 | color = cv2.imread(fname_color_src) 86 | depth = cv2.imread(fname_depth_src, cv2.IMREAD_ANYDEPTH) 87 | P = np.loadtxt(os.path.join(scene_dir_src, 'sensor_data', 'frame-{0:06d}.pose.txt'.format(frame_id))) 88 | 89 | if not np.all(np.isfinite(P)): # skip invalid poses 90 | continue 91 | 92 | if color.shape[:2] != depth.shape[:2]: # avoid resizing twice 93 | color = process_color_image(color, depth, K_color, K_depth) 94 | cv2.imwrite(fname_color_dst, color) 95 | 96 | elif not os.path.exists(fname_color_dst): 97 | cv2.imwrite(fname_color_dst, color) 98 | 99 | if not os.path.exists(fname_depth_dst): 100 | cv2.imwrite(fname_depth_dst, depth) 101 | 102 | frame = { 103 | 'filename_color': fname_color_dst, 104 | 'filename_depth': fname_depth_dst, 105 | 'pose': P.tolist() 106 | } 107 | data['frames'].append(frame) 108 | json.dump(data, open(os.path.join(scene_dir_dst, 'info.json'), 'w')) 109 | return 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--src', type=str) 115 | parser.add_argument('--dst', type=str) 116 | args = parser.parse_args() 117 | 118 | if not os.path.exists(args.dst): 119 | os.makedirs(args.dst) 120 | 121 | train_txt = os.path.join(args.dst, 'scannetv2_train.txt') 122 | val_txt = os.path.join(args.dst, 'scannetv2_val.txt') 123 | test_txt = os.path.join(args.dst, 'scannetv2_test.txt') 124 | 125 | if not os.path.exists(train_txt): 126 | shutil.copy(os.path.join(args.src, 'scannetv2_train.txt'), train_txt) 127 | if not os.path.exists(val_txt): 128 | shutil.copy(os.path.join(args.src, 'scannetv2_val.txt'), val_txt) 129 | if not os.path.exists(test_txt): 130 | shutil.copy(os.path.join(args.src, 'scannetv2_test.txt'), test_txt) 131 | 132 | with open(train_txt, 'r') as fp: 133 | train_scenes = [f.strip() for f in fp.readlines()] 134 | with open(val_txt, 'r') as fp: 135 | val_scenes = [f.strip() for f in fp.readlines()] 136 | with open(test_txt, 'r') as fp: 137 | test_scenes = ([f.strip() for f in fp.readlines()]) 138 | 139 | test_dir_src = os.path.join(args.src, 'scans_test') 140 | test_dir_dst = os.path.join(args.dst, 'scans_test') 141 | trainval_dir_src = os.path.join(args.src, 'scans') 142 | trainval_dir_dst = os.path.join(args.dst, 'scans') 143 | 144 | for i, scene_name in enumerate(train_scenes): 145 | print('{} / {}: {}'.format(i+1, 146 | len(train_scenes)+len(val_scenes)+len(test_scenes), scene_name)) 147 | scene_dir_src = os.path.join(trainval_dir_src, scene_name) 148 | scene_dir_dst = os.path.join(trainval_dir_dst, scene_name) 149 | process_scene(scene_dir_src, scene_dir_dst) 150 | 151 | for i, scene_name in enumerate(val_scenes): 152 | print('{} / {}: {}'.format(i+1+len(train_scenes), 153 | len(train_scenes)+len(val_scenes)+len(test_scenes), scene_name)) 154 | scene_dir_src = os.path.join(trainval_dir_src, scene_name) 155 | scene_dir_dst = os.path.join(trainval_dir_dst, scene_name) 156 | process_scene(scene_dir_src, scene_dir_dst) 157 | 158 | for i, scene_name in enumerate(test_scenes): 159 | print('{} / {}: {}'.format(i+1+len(train_scenes)+len(val_scenes), 160 | len(train_scenes)+len(val_scenes)+len(test_scenes), scene_name)) 161 | scene_dir_src = os.path.join(test_dir_src, scene_name) 162 | scene_dir_dst = os.path.join(test_dir_dst, scene_name) 163 | process_scene(scene_dir_src, scene_dir_dst) 164 | -------------------------------------------------------------------------------- /tools/simple_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import cv2 5 | 6 | 7 | def collate_fn(list_data): 8 | cam_pose, depth_im, _ = list_data 9 | # Concatenate all lists 10 | return cam_pose, depth_im, _ 11 | 12 | 13 | class ScanNetDataset(torch.utils.data.Dataset): 14 | """Pytorch Dataset for a single scene. getitem loads individual frames""" 15 | 16 | def __init__(self, n_imgs, scene, data_path, max_depth, id_list=None): 17 | """ 18 | Args: 19 | """ 20 | self.n_imgs = n_imgs 21 | self.scene = scene 22 | self.data_path = data_path 23 | self.max_depth = max_depth 24 | if id_list is None: 25 | self.id_list = [i for i in range(n_imgs)] 26 | else: 27 | self.id_list = id_list 28 | 29 | def __len__(self): 30 | return self.n_imgs 31 | 32 | def __getitem__(self, id): 33 | """ 34 | Returns: 35 | dict of meta data and images for a single frame 36 | """ 37 | id = self.id_list[id] 38 | cam_pose = np.loadtxt(os.path.join(self.data_path, self.scene, "sensor_data", f"frame-{id:06d}.pose.txt"), delimiter=' ') 39 | 40 | # Read depth image and camera pose 41 | depth_im = cv2.imread(os.path.join(self.data_path, self.scene, "sensor_data", f"frame-{id:06d}.depth.png"), -1).astype( 42 | np.float32) 43 | depth_im /= 1000. # depth is saved in 16-bit PNG in millimeters 44 | depth_im[depth_im > self.max_depth] = 0 45 | 46 | # Read RGB image 47 | # print(os.path.join(self.data_path, self.scene, "sensor_data", f"frame-{id:06d}.color.jpg")) 48 | color_image = cv2.cvtColor(cv2.imread(os.path.join(self.data_path, self.scene, "sensor_data", f"frame-{id:06d}.color.jpg")), 49 | cv2.COLOR_BGR2RGB) 50 | color_image = cv2.resize(color_image, (depth_im.shape[1], depth_im.shape[0]), interpolation=cv2.INTER_AREA) 51 | 52 | return cam_pose, depth_im, color_image 53 | --------------------------------------------------------------------------------