├── src ├── __init__.py ├── models │ ├── fusion │ │ ├── __init__.py │ │ ├── embedder.py │ │ └── utils.py │ ├── __init__.py │ ├── models.py │ ├── tcnn_config.json │ ├── model_utils.py │ ├── modules.py │ └── common.py ├── utils │ ├── import_utils.py │ ├── shapenet_helper.py │ ├── torchvision_utils.py │ ├── mesh_helper.py │ ├── fusion_utils.py │ ├── hydra_utils.py │ ├── scannet_helper.py │ ├── common.py │ ├── sample_utils.py │ ├── pangolin_helper.py │ ├── vis_utils.py │ ├── o3d_helper.py │ ├── motion_utils.py │ └── pointnet_utils.py ├── datasets │ ├── __init__.py │ ├── datasets.py │ ├── base_dataset.py │ ├── sampler.py │ ├── fusion_pointnet_dataset.py │ └── arkitscene_dataset.py ├── scripts │ ├── postprocess_meshes.py │ ├── run_inference_on_scene3d.py │ ├── run_inference_on_arkit.py │ ├── run_inference_on_icl_nuim.py │ ├── compute_chamfer.py │ ├── run_inference_on_scannet.py │ ├── evaluate_bnvf.py │ ├── generate_fusion_data_arkit.py │ ├── generate_fusion_data_scene3d.py │ ├── run_rgbd_intergration.py │ ├── generate_fusion_data_scannet.py │ └── generate_fusion_data_icl_nuim.py ├── test.py ├── train.py └── run_e2e.py ├── pretrained ├── pointnet.ckpt └── pointnet_tcnn.ckpt ├── .gitignore ├── configs ├── dataset │ ├── fusion_pointnet_dataset.yaml │ ├── default_dataset.yaml │ ├── fusion_inference_dataset_arkit.yaml │ ├── fusion_refiner_dataset.yaml │ ├── fusion_refiner_scannet_dataset.yaml │ ├── fusion_dataset.yaml │ └── fusion_inference_dataset.yaml ├── trainer │ └── default_trainer.yaml ├── optimizer │ └── adam.yaml ├── callbacks │ └── default_callbacks.yaml ├── model │ ├── fusion_model.yaml │ ├── fusion_refiner_model.yaml │ ├── fusion_pointnet_refiner.yaml │ └── fusion_pointnet_model.yaml └── config.yaml ├── LICENSE ├── Dockerfile └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/fusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained/pointnet.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likojack/bnv_fusion/HEAD/pretrained/pointnet.ckpt -------------------------------------------------------------------------------- /pretrained/pointnet_tcnn.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likojack/bnv_fusion/HEAD/pretrained/pointnet_tcnn.ckpt -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import register, get_model 2 | from .fusion.fusion_refiner import LitFusionRefiner 3 | from .fusion.local_point_fusion import LitFusionPointNet -------------------------------------------------------------------------------- /src/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | 4 | def import_from(module, name): 5 | """ import "name" from "module" 6 | """ 7 | 8 | return getattr(import_module(module), name) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /DeepSDF/ 2 | /exps/ 3 | /data/ 4 | /src/investigate 5 | **/__pycache__ 6 | /test_sdf/ 7 | /logs/ 8 | /.vscode/ 9 | /test_*.py 10 | /render_visualization 11 | /render_out 12 | /animations 13 | /final_animations -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import register, get_dataset 2 | from .fusion_dataset import FusionRefinerDataset 3 | from .fusion_dataset import FusionRefinerScanNetDataset 4 | from .fusion_inference_dataset import FusionInferenceDataset -------------------------------------------------------------------------------- /configs/dataset/fusion_pointnet_dataset.yaml: -------------------------------------------------------------------------------- 1 | data_dir: ${data_dir} # data_dir is specified in config.yaml 2 | 3 | subdomain: "local_shapes" 4 | 5 | name: "fusion_pointnet_dataset" 6 | 7 | n_local_samples: 64 8 | train_batch_size: 100 9 | eval_batch_size: 1 10 | shuffle: True 11 | num_workers: 4 12 | -------------------------------------------------------------------------------- /src/models/models.py: -------------------------------------------------------------------------------- 1 | MODELS = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | MODELS[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def get_model(cfg, **kwargs): 12 | model = MODELS[cfg.model.name](cfg, **kwargs) 13 | return model 14 | -------------------------------------------------------------------------------- /src/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | datasets = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | datasets[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def get_dataset(cfg, subset): 12 | dataset = datasets[cfg.dataset.name](cfg, subset) 13 | return dataset 14 | -------------------------------------------------------------------------------- /configs/dataset/default_dataset.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | _target_: src.datasets.shapenet.ShapeNetDataset 4 | 5 | data_dir: ${data_dir} # data_dir is specified in config.yaml 6 | 7 | subdomain: "ShapeNet" 8 | 9 | name: "shapenet_hierarchical" 10 | num_sdf_samples: 20000 11 | 12 | batch_size: 4 13 | shuffle: True 14 | num_workers: 4 -------------------------------------------------------------------------------- /configs/trainer/default_trainer.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # set `1` to train on GPU, `0` to train on CPU only 4 | gpus: 1 5 | 6 | seed: 12345 7 | 8 | min_epochs: 1 9 | max_epochs: 10 10 | 11 | weights_summary: null 12 | progress_bar_refresh_rate: 10 13 | check_val_every_n_epoch: 10 14 | 15 | terminate_on_nan: True 16 | fast_dev_run: False 17 | 18 | checkpoint: null 19 | weight_only: False 20 | 21 | dense_volume: False 22 | post_process: True -------------------------------------------------------------------------------- /configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_package_: torch.optim 2 | _class_: Adam 3 | eps: 1e-08 4 | weight_decay: 0 5 | betas: [0.9, 0.999] 6 | 7 | lr: 8 | name: base_lr 9 | initial: 0.001 10 | interval: epoch 11 | scheduler: StepLR 12 | lr_scheduler: 13 | step_size: 20000 14 | gamma: 0.5 15 | last_epoch: -1 16 | 17 | # _convert_ is hydra syntax needed to make betas parameter work, learn more here: 18 | # https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#parameter-conversion-strategies 19 | _convert_: "partial" -------------------------------------------------------------------------------- /configs/dataset/fusion_inference_dataset_arkit.yaml: -------------------------------------------------------------------------------- 1 | data_dir: ${data_dir} # data_dir is specified in config.yaml 2 | 3 | subdomain: "fusion" 4 | 5 | name: "fusion_inference_dataset_arkit" 6 | img_res: [192, 256] 7 | num_pixels: 5000 8 | scan_id: "OVERWRITE_THIS" 9 | skip_images: 1 10 | sample_shift: 0 11 | confidence_level: 2 12 | 13 | downsample_scale: 1. 14 | downsample_mode: null # or sparse 15 | train_batch_size: 1 16 | eval_batch_size: 1 17 | max_eval_imgs: 1 18 | shuffle: True 19 | num_workers: 8 20 | 21 | # for end-to-end 22 | depth_scale: 1000. -------------------------------------------------------------------------------- /configs/dataset/fusion_refiner_dataset.yaml: -------------------------------------------------------------------------------- 1 | data_dir: ${data_dir} # data_dir is specified in config.yaml 2 | 3 | subdomain: "fusion" 4 | 5 | name: "fusion_refiner_dataset" 6 | img_res: [480, 640] 7 | num_pixels: 5000 8 | scan_id: 0/475 9 | num_images: 300 10 | skip_images: 10 11 | sample_shift: 0 12 | depth_scale: 1000. 13 | 14 | downsample_scale: 0. 15 | downsample_mode: null # or sparse 16 | train_batch_size: 1 17 | eval_batch_size: 1 18 | max_eval_imgs: 1 19 | shuffle: True 20 | num_workers: 4 21 | first_k: 2 22 | max_neighbor_images: 0 23 | 24 | out_root: null -------------------------------------------------------------------------------- /configs/dataset/fusion_refiner_scannet_dataset.yaml: -------------------------------------------------------------------------------- 1 | data_dir: ${data_dir} # data_dir is specified in config.yaml 2 | 3 | subdomain: "ScanNet" 4 | 5 | name: "fusion_refiner_scannet_dataset" 6 | img_res: [480, 640] 7 | num_pixels: 5000 8 | scan_id: scene0575_00 9 | num_images: 300 10 | skip_images: 10 11 | depth_scale: 1000. 12 | 13 | downsample_scale: 0.5 14 | downsample_mode: null # or sparse 15 | train_batch_size: 1 16 | eval_batch_size: 1 17 | max_eval_imgs: 1 18 | shuffle: True 19 | num_workers: 4 20 | first_k: 2 21 | sample_shift: 0 22 | 23 | out_root: null -------------------------------------------------------------------------------- /configs/dataset/fusion_dataset.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | _target_: src.datasets.idr_dataset.IDRDataset 4 | 5 | data_dir: ${data_dir} # data_dir is specified in config.yaml 6 | 7 | subdomain: "fusion" 8 | 9 | name: "fusion_dataset" 10 | img_res: [240, 320] 11 | train_cameras: False 12 | num_pixels: 1024 13 | scan_id: 0/2 14 | num_images: 300 15 | skip_images: 10 16 | sample_shift: 0 17 | 18 | downsample_scale: 0. 19 | downsample_mode: null # or sparse 20 | train_batch_size: 1 21 | eval_batch_size: 1 22 | max_eval_imgs: 1 23 | shuffle: True 24 | num_workers: 4 25 | first_k: 2 26 | max_neighbor_images: 5 27 | -------------------------------------------------------------------------------- /configs/dataset/fusion_inference_dataset.yaml: -------------------------------------------------------------------------------- 1 | data_dir: ${data_dir} # data_dir is specified in config.yaml 2 | 3 | subdomain: "fusion" 4 | 5 | name: "fusion_inference_dataset" 6 | img_res: [480, 640] 7 | train_cameras: False 8 | num_pixels: 5000 9 | scan_id: "scene3d/lounge" 10 | num_images: 300 11 | skip_images: 10 12 | sample_shift: 0 13 | 14 | downsample_scale: 1. 15 | downsample_mode: null # or sparse 16 | train_batch_size: 1 17 | eval_batch_size: 1 18 | max_eval_imgs: 1 19 | shuffle: True 20 | num_workers: 8 21 | first_k: 2 22 | max_neighbor_images: 5 23 | 24 | out_root: null 25 | 26 | # for end-to-end 27 | depth_scale: 1000. -------------------------------------------------------------------------------- /src/models/tcnn_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "loss": { 3 | "otype": "RelativeL2" 4 | }, 5 | "optimizer": { 6 | "otype": "Adam", 7 | "learning_rate": 1e-2, 8 | "beta1": 0.9, 9 | "beta2": 0.99, 10 | "epsilon": 1e-15, 11 | "l2_reg": 1e-6 12 | }, 13 | "encoding": { 14 | // "otype": "HashGrid", 15 | // "n_levels": 16, 16 | // "n_features_per_level": 2, 17 | // "log2_hashmap_size": 15, 18 | // "base_resolution": 16, 19 | // "per_level_scale": 1.5 20 | "otype": "Identity", 21 | "scale": 1.0, 22 | "offset": 0.0 23 | }, 24 | "network": { 25 | "otype": "FullyFusedMLP", 26 | "activation": "ReLU", 27 | "output_activation": "None", 28 | "n_neurons": 64, 29 | "n_hidden_layers": 3, 30 | } 31 | } -------------------------------------------------------------------------------- /src/utils/shapenet_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | 5 | def read_pose(img): 6 | """ see render_depths.py for translation and intrinsic 7 | """ 8 | img = img[:-1] 9 | x_rot, y_rot = [float(f) for f in img.split("_")] 10 | T_wo = np.eye(4) 11 | T_wo[2, 3] = -1 12 | rot_mat_0 = Rotation.from_euler( 13 | "y", y_rot, degrees=True).as_matrix() 14 | rot_mat_1 = Rotation.from_euler( 15 | "x", x_rot, degrees=True).as_matrix() 16 | T_wo[:3, :3] = rot_mat_1 @ rot_mat_0 17 | intr_mat = np.eye(3) 18 | intr_mat[0, 0] = 128 19 | intr_mat[1, 1] = 128 20 | intr_mat[0, 2] = 128 21 | intr_mat[1, 2] = 128 22 | T_ow = np.linalg.inv(T_wo) 23 | return T_ow.astype(np.float32), intr_mat.astype(np.float32) 24 | -------------------------------------------------------------------------------- /configs/callbacks/default_callbacks.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/acc" # name of the logged metric which determines when model is improving 4 | save_top_k: 1 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | mode: "max" # can be "max" or "min" 7 | verbose: False 8 | dirpath: 'checkpoints/' 9 | filename: '{epoch:02d}' 10 | 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: "val/acc" # name of the logged metric which determines when model is improving 15 | patience: 100 # how many epochs of not improving until training stops 16 | mode: "max" # can be "max" or "min" 17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement -------------------------------------------------------------------------------- /configs/model/fusion_model.yaml: -------------------------------------------------------------------------------- 1 | _class_: LitFusionNet 2 | name: "lit_fusion_net" 3 | 4 | feature_vector_size: 16 5 | train_ray_splits: 200 6 | voxel_size: 0.04 7 | 8 | update_network: 9 | depth: 2 10 | contraction: 32 11 | n_input_channels: 4 12 | double_factor: 1.5 13 | 14 | ray_tracer: 15 | ray_max_dist: 5 16 | 17 | nerf: 18 | hidden_size: 256 19 | num_layers: 4 20 | num_encoding_fn_xyz: 1 21 | num_encoding_fn_dir: 6 22 | include_input_xyz: True 23 | include_input_dir: True 24 | 25 | refine_net: 26 | n_input_channels: 1 27 | n_output_channels: 1 28 | contraction: 64 29 | depth: 1 30 | loss: 31 | name: combined 32 | lambda_unc: 0.03 33 | crop_fraction: 0. 34 | vmin: 0. 35 | vmax: 10. 36 | weight_scale: 1. 37 | limit: 10. 38 | 39 | loss: 40 | rgb_loss: 0. 41 | l1_loss: 0. 42 | depth_bce_loss: 1. 43 | zero_level_loss: 0.1 44 | mask_loss: 1. 45 | reg_loss: 0.01 46 | refine_depth_loss: 1. -------------------------------------------------------------------------------- /configs/model/fusion_refiner_model.yaml: -------------------------------------------------------------------------------- 1 | _class_: LitFusionRefiner 2 | name: "lit_fusion_refiner" 3 | 4 | feature_vector_size: 8 5 | train_ray_splits: 500 6 | voxel_size: 0.02 7 | min_pts_in_grid: 8 8 | sdf_delta_weight: 0.1 9 | 10 | use_refine: False 11 | use_pretrained: True 12 | freeze_pretrained_weights: True 13 | 14 | pretrained_model: /home/kejie/repository/fast_sdf/remote_test_sdf.ckpt 15 | 16 | volume_dir: /home/kejie/repository/fast_sdf/logs/test/2021-09-18/11-18-27/plots 17 | 18 | ray_tracer: 19 | ray_max_dist: 5 20 | truncated_units: 10 21 | 22 | nerf: 23 | hidden_size: 256 24 | num_layers: 4 25 | num_encoding_fn_xyz: 1 26 | num_encoding_fn_dir: 6 27 | include_input_xyz: True 28 | include_input_dir: True 29 | xyz_agnostic: False 30 | interpolate_decode: True 31 | global_coords: False 32 | 33 | loss: 34 | rgb_loss: 0. 35 | l1_loss: 0. 36 | depth_bce_loss: 1. 37 | sdf_delta_loss: 1. 38 | zero_level_loss: 0. 39 | mask_loss: 0. 40 | reg_loss: 0. 41 | -------------------------------------------------------------------------------- /configs/model/fusion_pointnet_refiner.yaml: -------------------------------------------------------------------------------- 1 | _class_: LitFusionRefiner 2 | name: "lit_fusion_pointnet_refiner" 3 | 4 | feature_vector_size: 8 5 | feature_resolution: 128 6 | train_split: 10000 7 | train_ray_splits: 200 8 | voxel_size: 0.025 9 | bound_max: [1, 1, 1] 10 | bound_min: [-1, -1, -1] 11 | min_pts_in_grid: 0 12 | n_levels: 1 13 | 14 | use_pretrained: True 15 | use_pretrained_weights: True 16 | global_coords: True 17 | training_global: True 18 | 19 | pretrained_model: /home/kejie/repository/fast_sdf/remote_test_sdf.ckpt 20 | volume_dir: /home/kejie/repository/fast_sdf/logs/test/2021-09-18/11-18-27/plots 21 | 22 | ray_tracer: 23 | ray_max_dist: 5 24 | 25 | point_net: 26 | in_channels: 6 27 | 28 | nerf: 29 | hidden_size: 256 30 | num_layers: 4 31 | num_encoding_fn_xyz: 1 32 | num_encoding_fn_dir: 6 33 | include_input_xyz: True 34 | include_input_dir: True 35 | 36 | loss: 37 | rgb_loss: 0. 38 | l1_loss: 0. 39 | depth_bce_loss: 1. 40 | zero_level_loss: 0. 41 | mask_loss: 0. 42 | reg_loss: 0. 43 | -------------------------------------------------------------------------------- /src/scripts/postprocess_meshes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import subprocess 5 | 6 | from src.utils.common import get_file_paths 7 | 8 | 9 | def main(): 10 | arg_parser = argparse.ArgumentParser() 11 | arg_parser.add_argument("--input_dir", required=True) 12 | arg_parser.add_argument("--file_type", required=True) 13 | args = arg_parser.parse_args() 14 | 15 | paths = get_file_paths(args.input_dir, file_type=args.file_type) 16 | 17 | for path in paths: 18 | """ 19 | meshlabserver -i $IN_PATH -o $OUT_PATH -s /home/kejie/repository/fast_sdf/clean_mesh.mlx 20 | """ 21 | print(path) 22 | out_path = path[:-4] + "_post_process.ply" 23 | commands = f"meshlabserver -i {path} -o {out_path} -s ./clean_mesh.mlx" 24 | commands = commands.split(" ") 25 | try: 26 | subprocess.run(commands, check=True) 27 | except subprocess.CalledProcessError: 28 | import pdb 29 | pdb.set_trace() 30 | 31 | 32 | if __name__ == "__main__": 33 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 kejieli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/model/fusion_pointnet_model.yaml: -------------------------------------------------------------------------------- 1 | _class_: LitFusionPointnet 2 | name: "lit_fusion_pointnet" 3 | 4 | feature_vector_size: 8 5 | voxel_size: 0.01 6 | train_split: 10000 7 | train_ray_splits: 1000 8 | 9 | tiny_cuda: True 10 | tcnn_config: ./src/models/tcnn_config.json 11 | 12 | training_global: False 13 | global_coords: False 14 | interpolate_decode: True # use interpolation when decoding points. 15 | 16 | # for training local embedding only 17 | bound_max: [1, 1, 1] 18 | bound_min: [-1, -1, -1] 19 | 20 | min_pts_in_grid: 8 21 | 22 | point_net: 23 | in_channels: 6 24 | 25 | nerf: 26 | hidden_size: 256 27 | num_layers: 4 28 | num_encoding_fn_xyz: 1 29 | num_encoding_fn_dir: 6 30 | include_input_xyz: True 31 | include_input_dir: True 32 | interpolate_decode: True 33 | global_coords: False 34 | xyz_agnostic: False 35 | 36 | loss: 37 | bce_loss: 1. 38 | reg_loss: 0.001 39 | # for end-to-end 40 | depth_bce_loss: 1. 41 | 42 | ray_tracer: 43 | ray_max_dist: 3 44 | truncated_units: 10 45 | 46 | # for end-to-end 47 | sdf_delta_weight: 0.1 48 | optim_interval: 100 49 | mode: eval 50 | freeze_pretrained_weights: True 51 | pretrained_model: /home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt 52 | depth_scale: 1000 -------------------------------------------------------------------------------- /src/utils/torchvision_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | import matplotlib as mpl 7 | import matplotlib.cm as cm 8 | 9 | 10 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET): 11 | """ 12 | depth: (H, W) 13 | """ 14 | x = depth.astype(np.uint8) 15 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 16 | x_ = T.ToTensor()(x_) # (3, H, W) 17 | return x_ 18 | 19 | 20 | def visualize_prob(prob, cmap=cv2.COLORMAP_BONE): 21 | """ 22 | prob: (H, W) 0~1 23 | """ 24 | x = (255*prob).astype(np.uint8) 25 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 26 | x_ = T.ToTensor()(x_) # (3, H, W) 27 | return x_ 28 | 29 | 30 | def depth_visualizer(data, min_depth, max_depth): 31 | """ 32 | Args: 33 | data (HxW): depth data 34 | Returns: 35 | vis_data (HxWx3): depth visualization (RGB) 36 | """ 37 | 38 | mask = np.logical_and(data > min_depth, data < max_depth) 39 | inv_depth = 1 / (data + 1e-6) 40 | vmax = np.percentile(1/(data[mask]+1e-6), 90) 41 | normalizer = mpl.colors.Normalize(vmin=inv_depth.min(), vmax=vmax) 42 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 43 | vis_data = (mapper.to_rgba(inv_depth)[:, :, :3] * 255).astype(np.uint8) 44 | return vis_data -------------------------------------------------------------------------------- /src/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from src.utils.import_utils import import_from 2 | 3 | 4 | def set_optimizer_and_lr(cfg, parameters): 5 | """ 6 | 7 | Args: 8 | config 9 | parameters to be optimized 10 | 11 | Return: 12 | A tuple of an optimizer and a learning rate dict 13 | a learning rate dict: 14 | { 15 | 'scheduler': lr_scheduler, # The LR scheduler instance (required) 16 | 'interval': 'epoch', # The unit of the scheduler's step size 17 | 'frequency': 1, # The frequency of the scheduler 18 | 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler 19 | 'monitor': 'val_loss', # Metric for ReduceLROnPlateau to monitor 20 | 'strict': True, # Whether to crash the training if `monitor` is not found 21 | 'name': None, # Custom name for LearningRateMonitor to use 22 | } 23 | """ 24 | 25 | optimizer = import_from( 26 | module=cfg.optimizer._target_package_, 27 | name=cfg.optimizer._class_ 28 | )(parameters, lr=cfg.optimizer.lr.initial) 29 | lr_scheduler = import_from( 30 | module="torch.optim.lr_scheduler", 31 | name=cfg.optimizer.lr.scheduler 32 | )( 33 | optimizer=optimizer, 34 | **cfg.optimizer.lr_scheduler 35 | ) 36 | 37 | return optimizer, lr_scheduler 38 | -------------------------------------------------------------------------------- /src/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | from torch.utils.data import Dataset 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class DatasetBase(Dataset): 11 | """ Abstract base dataset class 12 | """ 13 | def __init__(self, cfg, subset): 14 | self.dataset_dir = cfg.dataset.data_dir 15 | self.dataset_name = cfg.dataset.name 16 | if cfg.dataset.categories is None: 17 | categories = os.listdir(cfg.dataset.data_dir) 18 | categories = [c for c in categories 19 | if osp.isdir(osp.join(cfg.dataset.data_dir, c))] 20 | else: 21 | categories = cfg.dataset.categories 22 | 23 | self.file_list = [] 24 | for c_idx, c in enumerate(categories): 25 | subpath = osp.join(cfg.dataset.data_dir, "splits", c) 26 | if not osp.isdir(subpath): 27 | logger.warning('Category %s does not exist in dataset.' % c) 28 | 29 | split_file = osp.join(subpath, subset + '.lst') 30 | with open(split_file, 'r') as f: 31 | models_c = f.read().split('\n') 32 | 33 | self.file_list += [ 34 | {'category': c, 'model': m} 35 | for m in models_c 36 | ] 37 | 38 | def __len__(self): 39 | return len(self.file_list) 40 | 41 | def __getitem__(self): 42 | raise NotImplementedError() 43 | -------------------------------------------------------------------------------- /src/utils/mesh_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | 5 | def as_mesh(scene_or_mesh): 6 | """ 7 | Convert a possible scene to a mesh. 8 | 9 | If conversion occurs, the returned mesh has only vertex and face data. 10 | """ 11 | if isinstance(scene_or_mesh, trimesh.Scene): 12 | if len(scene_or_mesh.geometry) == 0: 13 | mesh = None # empty scene 14 | else: 15 | # we lose texture information here 16 | mesh = trimesh.util.concatenate( 17 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) 18 | for g in scene_or_mesh.geometry.values())) 19 | else: 20 | assert(isinstance(scene_or_mesh, trimesh.Trimesh)) 21 | mesh = scene_or_mesh 22 | return mesh 23 | 24 | 25 | def merge_meshes(mesh_list): 26 | vertices = [] 27 | faces = [] 28 | face_offset = 0 29 | for mesh in mesh_list: 30 | num_vertices = len(mesh.vertices) 31 | vertices.append(mesh.vertices) 32 | faces.append(mesh.faces + face_offset) 33 | face_offset += num_vertices 34 | vertices = np.concatenate(vertices, axis=0) 35 | faces = np.concatenate(faces, axis=0) 36 | merged_mesh = trimesh.Trimesh(vertices=vertices, faces=faces) 37 | return merged_mesh 38 | 39 | 40 | def scale_to_unit_sphere(mesh): 41 | if isinstance(mesh, trimesh.Scene): 42 | mesh = mesh.dump().sum() 43 | 44 | vertices = mesh.vertices - mesh.bounding_box.centroid 45 | distances = np.linalg.norm(vertices, axis=1) 46 | vertices /= np.max(distances) 47 | 48 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces) 49 | 50 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # specify here default training configuration 2 | defaults: 3 | - trainer: default_trainer.yaml 4 | - model: default_model.yaml 5 | - optimizer: adam.yaml 6 | - dataset: default_dataset.yaml 7 | # - callbacks: default_callbacks.yaml # set this to null if you don't want to use callbacks 8 | - loggers: null # set logger here or use command line (e.g. `python run.py logger=wandb`) 9 | 10 | # enable color logging 11 | # - override hydra/hydra_logging: colorlog 12 | # - override hydra/job_logging: colorlog 13 | 14 | disable_warnings: False 15 | debug: False 16 | 17 | # path to original working directory 18 | # hydra hijacks working directory by changing it to the current log directory, 19 | # so it's useful to have this path as a special variable 20 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 21 | work_dir: ${hydra:runtime.cwd} 22 | # path to folder with data 23 | data_dir: ${work_dir}/data/ 24 | 25 | device_type: "cuda" 26 | 27 | hydra: 28 | # output paths for hydra logs 29 | run: 30 | dir: logs/${hydra.job.name}/ 31 | sweep: 32 | dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} 33 | subdir: ${hydra.job.num} 34 | 35 | job: 36 | # you can set here environment variables that are universal for all users 37 | # for system specific variables (like data paths) it's better to use .env file! 38 | env_set: 39 | # currently there are some issues with running sweeps alongside wandb 40 | # https://github.com/wandb/client/issues/1314 41 | # this env var fixes that 42 | WANDB_START_METHOD: thread -------------------------------------------------------------------------------- /src/models/fusion/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ 4 | 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | embed_fns = [] 12 | d = self.kwargs['input_dims'] 13 | out_dim = 0 14 | if self.kwargs['include_input']: 15 | embed_fns.append(lambda x: x) 16 | out_dim += d 17 | 18 | max_freq = self.kwargs['max_freq_log2'] 19 | N_freqs = self.kwargs['num_freqs'] 20 | 21 | if self.kwargs['log_sampling']: 22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 25 | 26 | for freq in freq_bands: 27 | for p_fn in self.kwargs['periodic_fns']: 28 | embed_fns.append(lambda x, p_fn=p_fn, 29 | freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | def get_embedder(multires): 39 | embed_kwargs = { 40 | 'include_input': True, 41 | 'input_dims': 3, 42 | 'max_freq_log2': multires-1, 43 | 'num_freqs': multires, 44 | 'log_sampling': True, 45 | 'periodic_fns': [torch.sin, torch.cos], 46 | } 47 | 48 | embedder_obj = Embedder(**embed_kwargs) 49 | def embed(x, eo=embedder_obj): return eo.embed(x) 50 | return embed, embedder_obj.out_dim 51 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks import LearningRateMonitor 6 | from pytorch_lightning import seed_everything 7 | import hydra 8 | from omegaconf import DictConfig 9 | 10 | from src.models.models import MODELS 11 | from src.models.models import get_model 12 | from src.datasets import datasets 13 | import src.utils.hydra_utils as hydra_utils 14 | from src.utils.common import override_weights 15 | 16 | 17 | log = hydra_utils.get_logger(__name__) 18 | 19 | 20 | @hydra.main(config_path="../configs/", config_name="config.yaml") 21 | def main(config: DictConfig): 22 | 23 | if "seed" in config.trainer: 24 | seed_everything(config.trainer.seed) 25 | 26 | hydra_utils.extras(config) 27 | hydra_utils.print_config(config, resolve=True) 28 | 29 | # setup dataset 30 | log.info("initializing dataset") 31 | test_dataset = datasets.get_dataset(config, "test") 32 | test_loader = DataLoader( 33 | test_dataset, 34 | batch_size=config.dataset.eval_batch_size, 35 | shuffle=False, 36 | num_workers=config.dataset.num_workers, 37 | collate_fn=test_dataset.collate_fn if hasattr(test_dataset, "collate_fn") else None 38 | ) 39 | 40 | # setup model 41 | log.info("initializing model") 42 | model = MODELS[config.model.name].load_from_checkpoint( 43 | config.trainer.checkpoint, 44 | **{ 45 | "cfg": config 46 | } 47 | ) 48 | 49 | # start training 50 | trainer = pl.Trainer( 51 | gpus=config.trainer.gpus, 52 | ) 53 | trainer.test(model, test_dataloaders=test_loader) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /src/utils/fusion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def groupby_reduce(sample_indexer: torch.Tensor, sample_values: torch.Tensor, op: str = "max"): 5 | """ 6 | Group-By and Reduce sample_values according to their indices, the reduction operation is defined in `op`. 7 | :param sample_indexer: (N,). An index, must start from 0 and go to the (max-1), can be obtained using torch.unique. 8 | :param sample_values: (N, L) 9 | :param op: have to be in 'max', 'mean' 10 | :return: reduced values: (C, L) 11 | """ 12 | C = sample_indexer.max() + 1 13 | n_samples = sample_indexer.size(0) 14 | 15 | assert n_samples == sample_values.size(0), "Indexer and Values must agree on sample count!" 16 | 17 | sample_values = sample_values.contiguous() 18 | sample_indexer = sample_indexer.contiguous() 19 | if op == 'mean': 20 | from src.ext import groupby_sum 21 | values_sum, values_count = groupby_sum(sample_values, sample_indexer, C) 22 | return values_sum / values_count.unsqueeze(-1) 23 | elif op == 'sum': 24 | from src.ext import groupby_sum 25 | values_sum, _ = groupby_sum(sample_values, sample_indexer, C) 26 | return values_sum 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def get_samples(r: int, device: torch.device, a: float = 0.0, b: float = None): 32 | """ 33 | Get samples within a cube, the voxel size is (b-a)/(r-1). range is from [a, b] 34 | :param r: num samples 35 | :param a: bound min 36 | :param b: bound max 37 | :return: (r*r*r, 3) 38 | """ 39 | overall_index = torch.arange(0, r ** 3, 1, device=device, dtype=torch.long) 40 | r = int(r) 41 | 42 | if b is None: 43 | b = 1. - 1. / r 44 | 45 | vsize = (b - a) / (r - 1) 46 | samples = torch.zeros(r ** 3, 3, device=device, dtype=torch.float32) 47 | samples[:, 0] = (overall_index // (r * r)) * vsize + a 48 | samples[:, 1] = ((overall_index // r) % r) * vsize + a 49 | samples[:, 2] = (overall_index % r) * vsize + a 50 | 51 | return samples -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.1-cudnn8-devel-ubuntu18.04 2 | 3 | RUN apt-get update && yes|apt-get upgrade 4 | 5 | RUN apt-get install -y wget bzip2 6 | RUN apt-get -y install sudo 7 | 8 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 9 | ENV PATH /opt/conda/bin:$PATH 10 | 11 | RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata 12 | 13 | RUN apt-get update --fix-missing && \ 14 | apt-get install -y wget bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 git mercurial subversion && \ 15 | apt-get clean 16 | 17 | RUN wget --quiet https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh -O ~/anaconda.sh && \ 18 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \ 19 | rm ~/anaconda.sh && \ 20 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 21 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 22 | echo "conda activate base" >> ~/.bashrc && \ 23 | find /opt/conda/ -follow -type f -name '*.a' -delete && \ 24 | find /opt/conda/ -follow -type f -name '*.js.map' -delete && \ 25 | /opt/conda/bin/conda clean -afy 26 | 27 | CMD [ "/bin/bash" ] 28 | 29 | RUN conda install pip 30 | 31 | # install pytorch related 32 | RUN conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.1 -c pytorch 33 | RUN pip install torch-scatter -f https://data.pyg.org/whl/torch-1.8.0+cu101.html 34 | 35 | # install some extra packages 36 | # RUN conda install -c conda-forge kornia 37 | RUN pip install -U kornia==0.5.0 38 | RUN pip install pytorch-lightning==1.2.6 39 | RUN pip install hydra-core==1.1.1 40 | RUN pip install multidict 41 | RUN pip install pyquaternion 42 | RUN pip install pillow 43 | RUN pip install rich 44 | RUN pip install opencv-python 45 | RUN pip install -U scikit-learn 46 | RUN conda install -c conda-forge trimesh 47 | RUN pip install opencv-python 48 | RUN apt-get install ffmpeg libsm6 libxext6 -y 49 | RUN pip install numpy-quaternion 50 | RUN pip install plyfile 51 | RUN pip install --upgrade PyMCubes 52 | RUN python3 -m pip install --no-cache-dir --upgrade open3d==0.14.1 --ignore-installed PyYAML -------------------------------------------------------------------------------- /src/utils/hydra_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import List, Sequence 4 | from omegaconf import DictConfig, OmegaConf 5 | from rich import print 6 | from rich.syntax import Syntax 7 | from rich.tree import Tree 8 | from pytorch_lightning.utilities import rank_zero_only 9 | 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | def get_logger(name=__name__, level=logging.INFO): 15 | """Initializes python logger.""" 16 | 17 | logger = logging.getLogger(name) 18 | logger.setLevel(level) 19 | 20 | # this ensures all logging levels get marked with the rank zero decorator 21 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 22 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 23 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 24 | 25 | return logger 26 | 27 | 28 | def extras(config: DictConfig) -> None: 29 | if config.get("disable_warnings"): 30 | log.info(f"Disabling python warnings! ") 31 | warnings.filterwarnings("ignore") 32 | 33 | if config.get("debug"): 34 | log.info("Running in debug mode! ") 35 | config.trainer.fast_dev_run = True 36 | 37 | # force debugger friendly configuration if 38 | if config.trainer.get("fast_dev_run"): 39 | log.info("Forcing debugger friendly configuration! ") 40 | # Debuggers don't like GPUs or multiprocessing 41 | if config.trainer.get("gpus"): 42 | config.trainer.gpus = 0 43 | if config.dataset.get("num_workers"): 44 | config.dataset.num_workers = 0 45 | 46 | @rank_zero_only 47 | def print_config( 48 | config: DictConfig, 49 | fields: Sequence[str] = ( 50 | "trainer", 51 | "model", 52 | "optimizer", 53 | "dataset" 54 | ), 55 | resolve: bool = True, 56 | ) -> None: 57 | """Prints content of DictConfig using Rich library and its tree structure. 58 | Args: 59 | config (DictConfig): Config. 60 | fields (Sequence[str], optional): Determines which main fields from config will be printed 61 | and in what order. 62 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 63 | """ 64 | 65 | style = "dim" 66 | tree = Tree(f":gear: CONFIG", style=style, guide_style=style) 67 | 68 | for field in fields: 69 | branch = tree.add(field, style=style, guide_style=style) 70 | 71 | config_section = config.get(field) 72 | branch_content = str(config_section) 73 | if isinstance(config_section, DictConfig): 74 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 75 | 76 | branch.add(Syntax(branch_content, "yaml")) 77 | 78 | print(tree) -------------------------------------------------------------------------------- /src/scripts/run_inference_on_scene3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import subprocess 4 | import trimesh 5 | 6 | import src.utils.geometry as geometry 7 | import src.utils.scannet_helper as scannet_helper 8 | 9 | sequences = ["lounge", "cactusgarden", "stonewall", "copyroom", "burghers"] 10 | 11 | for seq in sequences: 12 | for skip in [10]: 13 | for shift in range(0, 15, 5): 14 | exp_name = f"scene3d_{skip}_{shift}" 15 | out_root = f"/home/kejie/repository/fast_sdf/logs/test/{exp_name}/" 16 | commands = ( 17 | "python src/test.py model=fusion_pointnet_model dataset=fusion_inference_dataset " 18 | "dataset.num_workers=0 " 19 | "dataset.downsample_scale=1 " 20 | "model.ray_tracer.ray_max_dist=3 " 21 | "model.voxel_size=0.01 " 22 | "model.min_pts_in_grid=8 " 23 | "trainer.checkpoint=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 24 | ) 25 | commands = commands.split(" ") 26 | commands += [f"dataset.scan_id=scene3d/{seq}"] 27 | commands += [f"dataset.out_root={out_root}"] 28 | commands += [f"dataset.skip_images={skip}"] 29 | commands += [f"dataset.sample_shift={shift}"] 30 | 31 | try: 32 | subprocess.run(commands, check=True) 33 | print(f"finish {seq}") 34 | except subprocess.CalledProcessError: 35 | import pdb 36 | pdb.set_trace() 37 | 38 | commands = ( 39 | "python src/train.py model=fusion_refiner_model dataset=fusion_refiner_dataset " 40 | "dataset.num_workers=4 " 41 | "dataset.downsample_scale=1 " 42 | "model.ray_tracer.ray_max_dist=3 " 43 | "model.voxel_size=0.01 " 44 | "model.min_pts_in_grid=8 " 45 | "trainer.max_epochs=30 " 46 | "trainer.check_val_every_n_epoch=10 " 47 | "dataset.num_pixels=5000 " 48 | "model.train_ray_splits=2500 " 49 | "model.sdf_delta_weight=0.1 " 50 | "model.pretrained_model=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 51 | ) 52 | commands = commands.split(" ") 53 | commands += [f"dataset.scan_id=scene3d/{seq}"] 54 | 55 | volume_path = f"/home/kejie/repository/fast_sdf/logs/test/{exp_name}/scene3d/{seq}" 56 | commands += [f"model.volume_dir={volume_path}"] 57 | 58 | commands += [f"dataset.skip_images={skip}"] 59 | commands += [f"dataset.sample_shift={shift}"] 60 | 61 | out_root = f"/home/kejie/repository/fast_sdf/logs/train/{exp_name}/" 62 | commands += [f"dataset.out_root={out_root}"] 63 | try: 64 | subprocess.run(commands, check=True) 65 | print(f"finish {seq}") 66 | except subprocess.CalledProcessError: 67 | import pdb 68 | pdb.set_trace() -------------------------------------------------------------------------------- /src/scripts/run_inference_on_arkit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import subprocess 4 | import trimesh 5 | 6 | import src.utils.geometry as geometry 7 | import src.utils.scannet_helper as scannet_helper 8 | 9 | sequences = ["41048190", "41048265", "41159630", "47670149", "47431043"] 10 | for seq in sequences: 11 | for skip in [1]: 12 | for shift in range(1): 13 | out_root = f"/home/kejie/repository/fast_sdf/logs/test/arkit/" 14 | commands = ( 15 | "python src/test.py model=fusion_pointnet_model dataset=fusion_inference_dataset " 16 | "dataset.num_workers=0 " 17 | "dataset.downsample_scale=1 " 18 | "model.ray_tracer.ray_max_dist=3 " 19 | "model.voxel_size=0.02 " 20 | "model.min_pts_in_grid=8 " 21 | "trainer.checkpoint=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 22 | ) 23 | commands = commands.split(" ") 24 | commands += [f"dataset.scan_id=arkit/{seq}"] 25 | commands += [f"dataset.out_root={out_root}"] 26 | commands += [f"dataset.skip_images={skip}"] 27 | # commands += [f"dataset.sample_shift={shift}"] 28 | 29 | try: 30 | subprocess.run(commands, check=True) 31 | print(f"finish {seq}") 32 | except subprocess.CalledProcessError: 33 | import pdb 34 | pdb.set_trace() 35 | 36 | commands = ( 37 | "python src/train.py model=fusion_refiner_model dataset=fusion_refiner_dataset " 38 | "dataset.num_workers=4 " 39 | "dataset.downsample_scale=1 " 40 | "dataset.num_pixels=5000 " 41 | "model.ray_tracer.ray_max_dist=3 " 42 | "model.voxel_size=0.02 " 43 | "model.min_pts_in_grid=8 " 44 | "dataset.img_res=[192,256] " 45 | "trainer.max_epochs=10 " 46 | "trainer.check_val_every_n_epoch=5 " 47 | "trainer.dense_volume=False " 48 | "model.sdf_delta_weight=0. " 49 | "model.train_ray_splits=2500 " 50 | "model.pretrained_model=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 51 | ) 52 | commands = commands.split(" ") 53 | commands += [f"dataset.scan_id=arkit/{seq}"] 54 | 55 | volume_path = f"/home/kejie/repository/fast_sdf/logs/test/arkit/arkit/{seq}" 56 | commands += [f"model.volume_dir={volume_path}"] 57 | 58 | commands += [f"dataset.skip_images={skip}"] 59 | # commands += [f"dataset.sample_shift={shift}"] 60 | 61 | out_root = f"/home/kejie/repository/fast_sdf/logs/train/arkit/" 62 | commands += [f"dataset.out_root={out_root}"] 63 | try: 64 | subprocess.run(commands, check=True) 65 | print(f"finish {seq}") 66 | except subprocess.CalledProcessError: 67 | import pdb 68 | pdb.set_trace() -------------------------------------------------------------------------------- /src/scripts/run_inference_on_icl_nuim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import subprocess 4 | import trimesh 5 | 6 | import src.utils.geometry as geometry 7 | import src.utils.scannet_helper as scannet_helper 8 | 9 | sequences = ["livingroom1_noise", "livingroom2_noise", "office1_noise", "office2_noise"] 10 | 11 | for seq in sequences: 12 | for skip in [10]: 13 | for shift in range(0, 15, 5): 14 | exp_name = f"icl_nuim_{skip}_{shift}" 15 | out_root = f"/home/kejie/repository/fast_sdf/logs/test/{exp_name}/" 16 | commands = ( 17 | "python src/test.py model=fusion_pointnet_model dataset=fusion_inference_dataset " 18 | "dataset.num_workers=0 " 19 | "dataset.downsample_scale=1 " 20 | "model.ray_tracer.ray_max_dist=5 " 21 | "model.voxel_size=0.02 " 22 | "model.min_pts_in_grid=8 " 23 | "trainer.checkpoint=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 24 | ) 25 | commands = commands.split(" ") 26 | commands += [f"dataset.scan_id=icl_nuim/{seq}"] 27 | commands += [f"dataset.out_root={out_root}"] 28 | commands += [f"dataset.skip_images={skip}"] 29 | commands += [f"dataset.sample_shift={shift}"] 30 | 31 | try: 32 | subprocess.run(commands, check=True) 33 | print(f"finish {seq}") 34 | except subprocess.CalledProcessError: 35 | import pdb 36 | pdb.set_trace() 37 | 38 | commands = ( 39 | "python src/train.py model=fusion_refiner_model dataset=fusion_refiner_dataset " 40 | "dataset.num_workers=4 " 41 | "dataset.downsample_scale=1 " 42 | "model.ray_tracer.ray_max_dist=5 " 43 | "model.voxel_size=0.02 " 44 | "model.min_pts_in_grid=8 " 45 | "trainer.max_epochs=30 " 46 | "trainer.check_val_every_n_epoch=10 " 47 | "dataset.num_pixels=5000 " 48 | "model.train_ray_splits=2500 " 49 | "model.sdf_delta_weight=1 " 50 | "model.pretrained_model=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 51 | ) 52 | commands = commands.split(" ") 53 | commands += [f"dataset.scan_id=icl_nuim/{seq}"] 54 | 55 | volume_path = f"/home/kejie/repository/fast_sdf/logs/test/{exp_name}/icl_nuim/{seq}" 56 | commands += [f"model.volume_dir={volume_path}"] 57 | 58 | commands += [f"dataset.skip_images={skip}"] 59 | commands += [f"dataset.sample_shift={shift}"] 60 | 61 | out_root = f"/home/kejie/repository/fast_sdf/logs/train/{exp_name}/" 62 | commands += [f"dataset.out_root={out_root}"] 63 | try: 64 | subprocess.run(commands, check=True) 65 | print(f"finish {seq}") 66 | except subprocess.CalledProcessError: 67 | import pdb 68 | pdb.set_trace() -------------------------------------------------------------------------------- /src/scripts/compute_chamfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import open3d as o3d 3 | import trimesh 4 | import numpy as np 5 | from sklearn.neighbors import NearestNeighbors 6 | 7 | import src.utils.o3d_helper as o3d_helper 8 | 9 | 10 | def visualize_errors(distances, gt_pts, max_dist=0.05, file_out=None): 11 | import matplotlib.pyplot as plt 12 | cmap = plt.cm.get_cmap("plasma") 13 | distances = np.clip(distances, a_min=0, a_max=max_dist) 14 | colors = cmap(distances / max_dist) 15 | mesh = trimesh.Trimesh( 16 | vertices=gt_pts, 17 | faces=np.zeros_like(gt_pts), 18 | process=False 19 | ) 20 | mesh.visual.vertex_colors = (colors * 255).astype(np.uint8) 21 | if file_out is not None: 22 | mesh.export(file_out) 23 | pts_o3d = o3d_helper.np2pc(mesh.vertices, mesh.visual.vertex_colors[:, :3] / 255.) 24 | # mesh_o3d = o3d_helper.trimesh2o3d(mesh) 25 | o3d.visualization.draw_geometries([pts_o3d]) 26 | 27 | 28 | args_parser = argparse.ArgumentParser() 29 | args_parser.add_argument("--pred") 30 | args_parser.add_argument("--gt") 31 | args_parser.add_argument("--vertice_only", action="store_true") 32 | args_parser.add_argument("--compute_normal", action="store_true") 33 | args = args_parser.parse_args() 34 | 35 | 36 | pred_mesh = trimesh.load(args.pred) 37 | gt_mesh = trimesh.load(args.gt) 38 | n_samples = 100000 39 | threshold = 0.025 40 | if args.vertice_only: 41 | gt_points = np.random.permutation(gt_mesh.vertices)[:n_samples] 42 | else: 43 | gt_points, gt_face_id = trimesh.sample.sample_surface(gt_mesh, count=n_samples) 44 | gt_normal = gt_mesh.face_normals[gt_face_id] 45 | gt_face = gt_mesh.faces[gt_face_id] 46 | pred_points, pred_face_id = trimesh.sample.sample_surface(pred_mesh, count=n_samples) 47 | pred_normal = pred_mesh.face_normals[pred_face_id] 48 | pred_face = pred_mesh.faces[pred_face_id] 49 | 50 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(gt_points) 51 | distances, indices = nbrs.kneighbors(pred_points) 52 | # distances = np.clip(distances, a_min=0, a_max=0.05) 53 | pred_gt_dist = np.mean(distances) 54 | print("pred -> gt: ", pred_gt_dist) 55 | precision = np.sum(distances < threshold) / len(distances) 56 | print(f"precision @ {threshold}:", precision) 57 | # pred_mesh_out = os.path.join( 58 | # "/".join(args.pred.split("/")[:-1]), 59 | # "pred_error.ply" 60 | # ) 61 | # visualize_errors(distances[:, 0], pred_points, file_out=pred_mesh_out) 62 | 63 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pred_points) 64 | distances, indices = nbrs.kneighbors(gt_points) 65 | gt_pred_dist = np.mean(distances) 66 | print("gt -> pred: ", gt_pred_dist) 67 | recall = np.sum(distances < threshold) / len(distances) 68 | print(f"recall @ {threshold}:", recall) 69 | F1 = 2 * precision * recall / (precision + recall) 70 | print("F1: ", F1) 71 | print("{:.3f}/{:.4f}/{:.3f}/{:.4f}/{:.4f}".format(pred_gt_dist, precision, gt_pred_dist, recall, F1)) 72 | pred_normal = pred_normal[indices[:, 0]] 73 | if args.compute_normal: 74 | assert not args.vertice_only 75 | print(np.mean(np.sum(gt_normal * pred_normal, axis=-1))) 76 | 77 | # gt_mesh_out = os.path.join( 78 | # "/".join(args.pred.split("/")[:-1]), 79 | # "gt_error.ply" 80 | # ) 81 | # visualize_errors(distances[:, 0], gt_points, max_dist=0.05, file_out=gt_mesh_out) 82 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks import LearningRateMonitor 6 | from pytorch_lightning import seed_everything 7 | import hydra 8 | from omegaconf import DictConfig 9 | 10 | 11 | from src.models.models import get_model 12 | from src.datasets import datasets 13 | import src.utils.hydra_utils as hydra_utils 14 | from src.utils.common import override_weights 15 | 16 | 17 | log = hydra_utils.get_logger(__name__) 18 | 19 | 20 | @hydra.main(config_path="../configs/", config_name="config.yaml") 21 | def main(config: DictConfig): 22 | 23 | if "seed" in config.trainer: 24 | seed_everything(config.trainer.seed) 25 | 26 | hydra_utils.extras(config) 27 | hydra_utils.print_config(config, resolve=True) 28 | 29 | 30 | # setup dataset 31 | log.info("initializing dataset") 32 | train_dataset = datasets.get_dataset(config, "train") 33 | train_loader = DataLoader( 34 | train_dataset, 35 | batch_size=config.dataset.train_batch_size, 36 | shuffle=config.dataset.shuffle, 37 | num_workers=config.dataset.num_workers, 38 | collate_fn=train_dataset.collate_fn if hasattr(train_dataset, "collate_fn") else None 39 | ) 40 | val_dataset = datasets.get_dataset(config, "val") 41 | val_loader = DataLoader( 42 | val_dataset, 43 | batch_size=config.dataset.eval_batch_size, 44 | shuffle=False, 45 | num_workers=config.dataset.num_workers, 46 | collate_fn=val_dataset.collate_fn if hasattr(val_dataset, "collate_fn") else None 47 | ) 48 | 49 | # setup model 50 | log.info("initializing model") 51 | model_dynamic_cfg = { 52 | "num_samples": len(train_dataset), 53 | } 54 | if hasattr(train_dataset, "dimensions"): 55 | model_dynamic_cfg = { 56 | "dimensions": train_dataset.dimensions 57 | } 58 | model = get_model(config, **model_dynamic_cfg) 59 | 60 | log.info("setup checkpoint callback") 61 | # setup checkpoint callback 62 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 63 | dirpath=None, # the path is set by hydra in config.yaml TODO: add flexibility? 64 | monitor="val_loss", # TODO: add monitor 65 | save_top_k=50, 66 | period=1, 67 | save_last=True 68 | ) 69 | lr_monitor = LearningRateMonitor(logging_interval='step') 70 | 71 | if config.trainer.weight_only: 72 | checkpoint = None 73 | else: 74 | checkpoint = config.trainer.checkpoint 75 | 76 | # start training 77 | trainer = pl.Trainer( 78 | gpus=config.trainer.gpus, 79 | max_epochs=config.trainer.max_epochs, 80 | callbacks=[checkpoint_callback, lr_monitor], 81 | log_every_n_steps=5, 82 | resume_from_checkpoint=checkpoint, 83 | check_val_every_n_epoch=config.trainer.check_val_every_n_epoch, 84 | precision=16, 85 | ) 86 | # load pretrained data 87 | if (config.trainer.checkpoint is not None) and config.trainer.weight_only: 88 | pretrained_weights = torch.load( 89 | config.trainer.checkpoint 90 | )['state_dict'] 91 | override_weights( 92 | model, pretrained_weights, keys=['decoder'] 93 | ) 94 | trainer.fit(model, train_loader, val_loader) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /src/scripts/run_inference_on_scannet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import subprocess 4 | import trimesh 5 | 6 | import src.utils.geometry as geometry 7 | import src.utils.scannet_helper as scannet_helper 8 | 9 | with open("/home/kejie/Datasets/ScanNet/server/Data/ScanNet/ScanNet/Tasks/Benchmark/scannetv2_val.txt", "r") as f: 10 | sequences = f.read().splitlines() 11 | 12 | sequences = sorted(sequences) 13 | 14 | # "scene0647_00" out of memory 15 | 16 | # run till scene0693_00 17 | 18 | for seq in sequences[286:]: 19 | for skip in [10]: 20 | for shift in range(1): 21 | exp_name = f"scannet_{skip}_{shift}" 22 | out_root = f"/home/kejie/repository/fast_sdf/logs/test/{exp_name}/" 23 | commands = ( 24 | "python src/test.py model=fusion_pointnet_model dataset=fusion_inference_scannet_dataset " 25 | "dataset.num_workers=0 " 26 | "dataset.downsample_scale=1 " 27 | "model.ray_tracer.ray_max_dist=5 " 28 | "model.voxel_size=0.02 " 29 | "model.min_pts_in_grid=8 " 30 | "trainer.checkpoint=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 31 | ) 32 | commands = commands.split(" ") 33 | commands += [f"dataset.scan_id={seq}"] 34 | commands += [f"dataset.out_root={out_root}"] 35 | commands += [f"dataset.skip_images={skip}"] 36 | # commands += [f"dataset.sample_shift={shift}"] 37 | 38 | try: 39 | subprocess.run(commands, check=True) 40 | print(f"finish {seq}") 41 | except subprocess.CalledProcessError: 42 | import pdb 43 | pdb.set_trace() 44 | 45 | commands = ( 46 | "python src/train.py model=fusion_refiner_model dataset=fusion_refiner_scannet_dataset " 47 | "dataset.num_workers=4 " 48 | "dataset.downsample_scale=1 " 49 | "model.ray_tracer.ray_max_dist=5 " 50 | "model.voxel_size=0.02 " 51 | "model.min_pts_in_grid=8 " 52 | "trainer.max_epochs=20 " 53 | "trainer.check_val_every_n_epoch=10 " 54 | "dataset.num_pixels=5000 " 55 | "model.train_ray_splits=2500 " 56 | "model.sdf_delta_weight=0.1 " 57 | "model.pretrained_model=/home/kejie/repository/fast_sdf/logs/train/2021-10-21/22-37-03/lightning_logs/version_0/checkpoints/last.ckpt" 58 | ) 59 | commands = commands.split(" ") 60 | commands += [f"dataset.scan_id={seq}"] 61 | 62 | volume_path = f"/home/kejie/repository/fast_sdf/logs/test/{exp_name}/{seq}" 63 | commands += [f"model.volume_dir={volume_path}"] 64 | 65 | commands += [f"dataset.skip_images={skip}"] 66 | # commands += [f"dataset.sample_shift={shift}"] 67 | 68 | out_root = f"/home/kejie/repository/fast_sdf/logs/train/{exp_name}/" 69 | commands += [f"dataset.out_root={out_root}"] 70 | try: 71 | subprocess.run(commands, check=True) 72 | print(f"finish {seq}") 73 | except subprocess.CalledProcessError: 74 | import pdb 75 | pdb.set_trace() -------------------------------------------------------------------------------- /src/scripts/evaluate_bnvf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import open3d as o3d 4 | import trimesh 5 | import numpy as np 6 | from sklearn.neighbors import NearestNeighbors 7 | 8 | 9 | def evaluate(pred_points, gt_points, threshold, out, verbose=False): 10 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(gt_points) 11 | distances, indices = nbrs.kneighbors(pred_points) 12 | pred_gt_dist = np.mean(distances) 13 | precision = np.sum(distances < threshold) / len(distances) 14 | 15 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(pred_points) 16 | distances, indices = nbrs.kneighbors(gt_points) 17 | gt_pred_dist = np.mean(distances) 18 | recall = np.sum(distances < threshold) / len(distances) 19 | F1 = 2 * precision * recall / (precision + recall) 20 | if verbose: 21 | print("pred -> gt: ", pred_gt_dist) 22 | print(f"precision @ {threshold}:", precision) 23 | print("gt -> pred: ", gt_pred_dist) 24 | print(f"recall @ {threshold}:", recall) 25 | print("F1: ", F1) 26 | print("{:.3f}/{:.4f}/{:.3f}/{:.4f}/{:.4f}".format(pred_gt_dist, precision, gt_pred_dist, recall, F1)) 27 | out['pred_gt'].append(pred_gt_dist) 28 | out['accuracy'].append(precision) 29 | out['gt_pred'].append(gt_pred_dist) 30 | out['recall'].append(recall) 31 | out['F1'].append(F1) 32 | 33 | 34 | def main(): 35 | arg_parser = argparse.ArgumentParser() 36 | arg_parser.add_argument("--pred_dir", required=True) 37 | arg_parser.add_argument("--gt_dir", required=True) 38 | arg_parser.add_argument("--file_name", required=True) 39 | args = arg_parser.parse_args() 40 | 41 | gt_dir = args.gt_dir 42 | pred_dir = args.pred_dir 43 | sequences = sorted(os.listdir(pred_dir)) 44 | 45 | thresholds = [0.025] 46 | n_samples = 100000 47 | 48 | out = { 49 | "pred_gt": [], 50 | "accuracy": [], 51 | "gt_pred": [], 52 | "recall": [], 53 | "F1": [] 54 | } 55 | for threshold in thresholds: 56 | for seq in sequences: 57 | print(f"{seq}:") 58 | gt_path = os.path.join(gt_dir, seq, "gt_mesh.ply") 59 | gt_mesh = o3d.io.read_triangle_mesh(gt_path) 60 | # gt_pts_poisson = np.asarray(gt_mesh.sample_points_poisson_disk(n_samples).points) 61 | gt_pts_uniform, _ = trimesh.sample.sample_surface(trimesh.load(gt_path), count=n_samples) 62 | # gt_pts_uniform = np.asarray(gt_mesh.sample_points_uniformly(n_samples).points) 63 | pred_seq_dir = os.path.join(pred_dir, seq) 64 | pred_file = [f for f in os.listdir(pred_seq_dir) if args.file_name in f] 65 | if len(pred_file) != 1: 66 | continue 67 | pred_path = os.path.join(pred_seq_dir, pred_file[0]) 68 | pred_mesh = o3d.io.read_triangle_mesh(pred_path) 69 | # pred_pts_poisson = np.asarray(pred_mesh.sample_points_poisson_disk(n_samples).points) 70 | pred_pts_uniform, _ = trimesh.sample.sample_surface(trimesh.load(pred_path), count=n_samples) 71 | # pred_pts_uniform = np.asarray(pred_mesh.sample_points_uniformly(n_samples).points) 72 | evaluate(pred_pts_uniform, gt_pts_uniform, threshold, out) 73 | # evaluate(pred_pts_poisson, gt_pts_poisson, threshold, out) 74 | print("sequence result:") 75 | print(out["pred_gt"][-1], out["accuracy"][-1], out["gt_pred"][-1], out["recall"][-1], out['F1'][-1]) 76 | print("average result:") 77 | print(np.mean(out["pred_gt"]), np.mean(out["accuracy"]), np.mean(out["gt_pred"]), np.mean(out["recall"]), np.mean(out['F1'])) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() -------------------------------------------------------------------------------- /src/utils/scannet_helper.py: -------------------------------------------------------------------------------- 1 | ''' Ref: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts ''' 2 | import os 3 | import sys 4 | import json 5 | import csv 6 | import numpy as np 7 | import quaternion 8 | from plyfile import PlyData 9 | 10 | import src.utils.geometry as geo_utils 11 | 12 | 13 | def read_meta_file(meta_file): 14 | lines = open(meta_file).readlines() 15 | for line in lines: 16 | if 'axisAlignment' in line: 17 | axis_align_matrix = [float(x) \ 18 | for x in line.rstrip().strip('axisAlignment = ').split(' ')] 19 | break 20 | axis_align_matrix = np.array(axis_align_matrix).reshape((4,4)) 21 | return axis_align_matrix 22 | 23 | 24 | def read_intrinsic(file_path): 25 | with open(file_path, "r") as f: 26 | intrinsic = np.asarray( 27 | [list(map(lambda x: float(x), f.split())) for f in f.read().splitlines()] 28 | ) 29 | return intrinsic 30 | 31 | 32 | def read_extrinsic(file_path): 33 | with open(file_path, "r") as f: 34 | T_cam_scan = np.linalg.inv( 35 | np.asarray( 36 | [list(map(lambda x: float(x), f.split())) for f in f.read().splitlines()] 37 | ) 38 | ) 39 | return T_cam_scan 40 | 41 | 42 | def read_mesh_vertices(filename): 43 | assert os.path.isfile(filename) 44 | with open(filename, 'rb') as f: 45 | plydata = PlyData.read(f) 46 | num_verts = plydata['vertex'].count 47 | vertices = np.zeros(shape=[num_verts, 3], dtype=np.float32) 48 | vertices[:,0] = plydata['vertex'].data['x'] 49 | vertices[:,1] = plydata['vertex'].data['y'] 50 | vertices[:,2] = plydata['vertex'].data['z'] 51 | return vertices 52 | 53 | 54 | def read_aggregation(filename): 55 | assert os.path.isfile(filename) 56 | object_id_to_segs = {} 57 | object_id_to_class = {} 58 | with open(filename) as f: 59 | data = json.load(f) 60 | num_objects = len(data['segGroups']) 61 | for i in range(num_objects): 62 | object_id = data['segGroups'][i]['objectId'] + 1 # instance ids should be 1-indexed 63 | label = data['segGroups'][i]['label'] 64 | segs = data['segGroups'][i]['segments'] 65 | object_id_to_segs[object_id] = segs 66 | object_id_to_class[object_id] = label 67 | return object_id_to_segs, object_id_to_class 68 | 69 | 70 | def read_segmentation(filename): 71 | assert os.path.isfile(filename) 72 | seg_to_verts = {} 73 | with open(filename) as f: 74 | data = json.load(f) 75 | num_verts = len(data['segIndices']) 76 | for i in range(num_verts): 77 | seg_id = data['segIndices'][i] 78 | if seg_id in seg_to_verts: 79 | seg_to_verts[seg_id].append(i) 80 | else: 81 | seg_to_verts[seg_id] = [i] 82 | return seg_to_verts, num_verts 83 | 84 | 85 | def get_cam_azi(T_wc): 86 | """ get camera azimuth rotation. Assume z axis is upright 87 | """ 88 | cam_orientation = np.array([[0, 0, 1], [0, 0, 0]]) 89 | cam_orientation = (geo_utils.get_homogeneous(cam_orientation) @ T_wc.T)[:, :3] 90 | cam_orientation = cam_orientation[0] - cam_orientation[1] 91 | cam_orientation[2] = 0 92 | cam_orientation = cam_orientation / np.linalg.norm(cam_orientation) 93 | theta = np.arctan2(cam_orientation[1], cam_orientation[0]) 94 | return theta 95 | 96 | 97 | def make_M_from_tqs(t, q, s): 98 | q = np.quaternion(q[0], q[1], q[2], q[3]) 99 | T = np.eye(4) 100 | T[0:3, 3] = t 101 | R = np.eye(4) 102 | R[0:3, 0:3] = quaternion.as_rotation_matrix(q) 103 | S = np.eye(4) 104 | S[0:3, 0:3] = np.diag(s) 105 | 106 | M = T.dot(R).dot(S) 107 | return M -------------------------------------------------------------------------------- /src/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch_scatter import scatter_mean 5 | 6 | 7 | class SampleManager(): 8 | def __init__(self, img_h, img_w, patch_size=40): 9 | self.img_length = img_h * img_w 10 | self.img_h = img_h 11 | self.img_w = img_w 12 | self.caching_error_map = torch.zeros((img_h, img_w)) - 1. 13 | self.uniform_sample_ids = torch.randperm(self.img_length) 14 | self.num_uniform_samples = 0 15 | 16 | self.weighted_error_map = torch.ones((img_h, img_w)) * 5 17 | 18 | self.patch_size = patch_size 19 | self.coarse_img_h = int(img_h / patch_size) 20 | self.coarse_img_w = int(img_w / patch_size) 21 | self.weighted_error_map_coarse = torch.ones( 22 | (self.coarse_img_h, self.coarse_img_w)) * 5 23 | 24 | def reset(self): 25 | self.caching_error_map = torch.zeros((self.img_h, self.img_w)) - 1 26 | self.uniform_sample_ids = torch.randperm(self.img_length) 27 | self.num_uniform_samples = 0 28 | 29 | def log_error(self, uv, errors): 30 | """ store error in the cached error map 31 | 32 | Args: 33 | uv: [n, 2] 34 | errors: [n] 35 | """ 36 | 37 | # make sure we are not overwriting 38 | assert np.all(self.caching_error_map[uv[:, 1], uv[:, 0]] == -1) 39 | self.caching_error_map[uv[:, 1], uv[:, 0]] = errors 40 | 41 | def log_weighted_error(self, uv, errors): 42 | self.weighted_error_map[uv[:, 1], uv[:, 0]] = errors 43 | 44 | coarse_uv = torch.floor(uv / self.patch_size).long() 45 | uv_1d = coarse_uv[:, 1] * self.coarse_img_w + coarse_uv[:, 0] 46 | unique_flat_ids, pinds, pcounts = torch.unique( 47 | uv_1d, return_inverse=True, return_counts=True) 48 | unique_y = unique_flat_ids // self.coarse_img_w 49 | unique_x = unique_flat_ids % self.coarse_img_w 50 | assert torch.max(unique_y) < self.coarse_img_h 51 | assert torch.max(unique_x) < self.coarse_img_w 52 | unique_uv = torch.stack([unique_x, unique_y], dim=-1) 53 | 54 | error_mean = scatter_mean(errors, pinds) 55 | self.weighted_error_map_coarse[unique_uv[:, 1], unique_uv[:, 0]] = error_mean 56 | 57 | def uniform_sample(self, num_samples): 58 | if self.num_uniform_samples + num_samples > self.img_length: 59 | print("reset sampler: {}".format(self.num_uniform_samples)) 60 | self.reset() 61 | sampled_ids = self.uniform_sample_ids[ 62 | self.num_uniform_samples: self.num_uniform_samples+num_samples] 63 | self.num_uniform_samples += num_samples 64 | return sampled_ids 65 | 66 | def weighted_sample(self, num_samples): 67 | """ sampling weighted by the error_maps 68 | """ 69 | weights = self.weighted_error_map_coarse / torch.sum(self.weighted_error_map_coarse) + 1e-5 70 | 71 | coarse_uv_1d = torch.multinomial( 72 | weights.reshape(-1), num_samples=num_samples, replacement=True 73 | ) # [num_samples] 74 | coarse_y = coarse_uv_1d // self.coarse_img_w 75 | coarse_x = coarse_uv_1d % self.coarse_img_w 76 | rand_x = torch.randint(low=0, high=self.patch_size, size=(num_samples,)) 77 | rand_y = torch.randint(low=0, high=self.patch_size, size=(num_samples,)) 78 | 79 | x = coarse_x * self.patch_size + rand_x 80 | y = coarse_y * self.patch_size + rand_y 81 | 82 | uv_1d = y * self.img_w + x 83 | return uv_1d 84 | 85 | # uv = torch.zeros((num_samples, 2)) 86 | 87 | # # y == uv_1d // img_w 88 | # uv[:, 1] = uv_1d // self.img_w 89 | # assert torch.max(uv[:, 1]) < self.img_h 90 | 91 | # uv[:, 0] = uv_1d - uv[:, 1] * self.img_w 92 | # assert torch.max(uv[:, 0]) < self.img_w 93 | 94 | # return uv 95 | -------------------------------------------------------------------------------- /src/utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import torch 7 | import time 8 | import skimage 9 | from skimage import transform 10 | 11 | 12 | class Timer: 13 | def __init__(self, names): 14 | self.times = {n: 0 for n in names} 15 | self.t0 = {n: 0 for n in names} 16 | 17 | def start(self, name): 18 | self.t0[name] = time.time() 19 | 20 | def log(self, name): 21 | self.times[name] += time.time() - self.t0[name] 22 | 23 | 24 | def to_cuda(in_dict): 25 | for k in in_dict: 26 | if isinstance(in_dict[k], torch.Tensor): 27 | in_dict[k] = in_dict[k].to("cuda") 28 | 29 | 30 | def to_cpu(in_dict): 31 | for k in in_dict: 32 | if isinstance(in_dict[k], torch.Tensor): 33 | in_dict[k] = in_dict[k].cpu() 34 | 35 | 36 | def override_weights(model, pretrained_weights, keys): 37 | """ 38 | Args: 39 | model: pytorch nn module 40 | pretrained_weights: OrderedDict of state_dict 41 | keys: a list of keyword. the weights to be overrided if matched 42 | """ 43 | 44 | pretrained_dict = {} 45 | for model_key in model.state_dict().keys(): 46 | if any([(key in model_key) for key in keys]): 47 | if model_key not in pretrained_weights: 48 | print(f"[warning]: {model_key} not in pretrained weight") 49 | continue 50 | pretrained_dict[model_key] = pretrained_weights[model_key] 51 | model.load_state_dict(pretrained_dict, strict=False) 52 | 53 | 54 | def get_file_paths(dir, file_type=None): 55 | names = sorted(os.listdir(dir)) 56 | out = [] 57 | for n in names: 58 | if os.path.isdir(osp.join(dir, n)): 59 | paths = get_file_paths(osp.join(dir, n), file_type) 60 | out.extend(paths) 61 | else: 62 | if file_type is not None: 63 | if n.endswith(file_type): 64 | out.append(osp.join(dir, n)) 65 | else: 66 | out.append(osp.join(dir, n)) 67 | return out 68 | 69 | 70 | def inverse_sigmoid(x): 71 | return np.log(x) - np.log(1-x) 72 | 73 | 74 | def load_rgb(path, downsample_scale=0): 75 | img = imageio.imread(path) 76 | img = skimage.img_as_float32(img) 77 | if downsample_scale > 0: 78 | img = transform.rescale(img, (downsample_scale, downsample_scale, 1)) 79 | # pixel values between [-1,1] 80 | img -= 0.5 81 | img *= 2. 82 | img = img.transpose(2, 0, 1) 83 | return img 84 | 85 | 86 | def load_depth( 87 | path, 88 | downsample_scale, 89 | downsample_mode="dense", 90 | max_depth=None, 91 | add_noise=False 92 | ): 93 | depth = cv2.imread(path, -1) / 1000. 94 | if downsample_scale > 0: 95 | img_h, img_w = depth.shape 96 | if downsample_mode == "dense": 97 | reduced_w = int(img_w * downsample_scale) 98 | reduced_h = int(img_h * downsample_scale) 99 | depth = cv2.resize( 100 | depth, 101 | dsize=(reduced_w, reduced_h), 102 | interpolation=cv2.INTER_NEAREST 103 | ) 104 | else: 105 | assert downsample_mode == "sparse" 106 | downsample_mask = np.zeros_like(depth) 107 | interval = int(1 / downsample_scale) 108 | downsample_mask[::interval, ::interval] = 1 109 | depth = depth * downsample_mask 110 | mask = depth > 0 111 | if max_depth is not None: 112 | mask *= depth < max_depth 113 | depth = depth * mask 114 | if add_noise: 115 | noise_depth = noise_simulator.simulate(depth) 116 | # noise_depth = add_depth_noise(depth) 117 | noise_depth = noise_depth * mask 118 | return depth, noise_depth, mask 119 | else: 120 | return depth, depth, mask 121 | -------------------------------------------------------------------------------- /src/models/modules.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | MODULES = {} 8 | 9 | 10 | def register(name): 11 | def decorator(cls): 12 | MODULES[name] = cls 13 | return cls 14 | return decorator 15 | 16 | 17 | def get_modules(cfg, **kwargs): 18 | model = MODULES[cfg.name](cfg, **kwargs) 19 | return model 20 | 21 | 22 | def positional_encoding( 23 | tensor, num_encoding_functions=6, include_input=True, log_sampling=True 24 | ) -> torch.Tensor: 25 | r"""Apply positional encoding to the input. 26 | Args: 27 | tensor (torch.Tensor): Input tensor to be positionally encoded. 28 | encoding_size (optional, int): Number of encoding functions used to 29 | compute a positional encoding (default: 6). 30 | include_input (optional, bool): Whether or not to include the input 31 | in the positional encoding (default: True). 32 | Returns: 33 | (torch.Tensor): Positional encoding of the input tensor. 34 | """ 35 | # TESTED 36 | # Trivially, the input tensor is added to the positional encoding. 37 | if num_encoding_functions == 0: 38 | return tensor 39 | encoding = [tensor] if include_input else [] 40 | frequency_bands = None 41 | if log_sampling: 42 | frequency_bands = 2.0 ** torch.linspace( 43 | 0.0, 44 | num_encoding_functions - 1, 45 | num_encoding_functions, 46 | dtype=tensor.dtype, 47 | device=tensor.device, 48 | ) 49 | else: 50 | frequency_bands = torch.linspace( 51 | 2.0 ** 0.0, 52 | 2.0 ** (num_encoding_functions - 1), 53 | num_encoding_functions, 54 | dtype=tensor.dtype, 55 | device=tensor.device, 56 | ) 57 | 58 | for freq in frequency_bands: 59 | for func in [torch.sin, torch.cos]: 60 | encoding.append(func(tensor * freq)) 61 | 62 | # Special case, for no positional encoding 63 | if len(encoding) == 1: 64 | return encoding[0] 65 | else: 66 | return torch.cat(encoding, dim=-1) 67 | 68 | 69 | def get_norm_layer(layer_type='inst'): 70 | if layer_type == 'batch': 71 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 72 | elif layer_type == 'batch3d': 73 | norm_layer = functools.partial(nn.BatchNorm3d, affine=True) 74 | elif layer_type == 'inst': 75 | norm_layer = functools.partial( 76 | nn.InstanceNorm2d, affine=False, track_running_stats=False 77 | ) 78 | elif layer_type == 'inst3d': 79 | norm_layer = functools.partial( 80 | nn.InstanceNorm3d, affine=False, track_running_stats=False 81 | ) 82 | elif layer_type == 'none': 83 | norm_layer = None 84 | else: 85 | raise NotImplementedError( 86 | f'normalization layer {layer_type} is not found' 87 | ) 88 | return norm_layer 89 | 90 | 91 | def maxpool(x, dim=-1, keepdim=False): 92 | out, _ = x.max(dim=dim, keepdim=keepdim) 93 | return out 94 | 95 | 96 | def deconvBlock(input_nc, output_nc, bias, norm_layer=None, nl='relu'): 97 | layers = [nn.ConvTranspose3d(input_nc, output_nc, 4, 2, 1, bias=bias)] 98 | 99 | if norm_layer is not None: 100 | layers += [norm_layer(output_nc)] 101 | if nl == 'relu': 102 | layers += [nn.ReLU(True)] 103 | elif nl == 'lrelu': 104 | layers += [nn.LeakyReLU(0.2, inplace=True)] 105 | else: 106 | raise NotImplementedError('NL layer {} is not implemented' % nl) 107 | return nn.Sequential(*layers) 108 | 109 | 110 | if __name__ == "__main__": 111 | import torch 112 | from easydict import EasyDict as edict 113 | nz = 64 114 | config = edict(bias=False, res=64, nz=nz, ngf=32, max_nf=8, norm="batch3d") 115 | model = Deconv3DDecoder(config) 116 | z = torch.zeros((1, 64, 1, 1, 1)) 117 | layers_out, out = model(z) 118 | -------------------------------------------------------------------------------- /src/utils/sample_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def random_face(mesh: torch.Tensor, num_samples: int, distrib=None): 5 | """Return an area weighted random sample of faces and their normals from the mesh. 6 | 7 | Args: 8 | mesh (torch.Tensor): #F, 3, 3 array of vertices 9 | num_samples (int): num of samples to return 10 | distrib: distribution to use. By default, area-weighted distribution is used. 11 | """ 12 | 13 | if distrib is None: 14 | distrib = area_weighted_distribution(mesh) 15 | 16 | normals = per_face_normals(mesh) 17 | 18 | idx = distrib.sample([num_samples]) 19 | 20 | return mesh[idx], normals[idx] 21 | 22 | 23 | def per_face_normals(mesh: torch.Tensor): 24 | """Compute normals per face. 25 | 26 | Args: 27 | mesh (torch.Tensor): #F, 3, 3 array of vertices 28 | """ 29 | 30 | vec_a = mesh[:, 0] - mesh[:, 1] 31 | vec_b = mesh[:, 1] - mesh[:, 2] 32 | normals = torch.cross(vec_a, vec_b) 33 | return normals 34 | 35 | 36 | def area_weighted_distribution( 37 | mesh: torch.Tensor, normals: torch.Tensor = None 38 | ): 39 | """Construct discrete area weighted distribution over triangle mesh. 40 | 41 | Args: 42 | mesh (torch.Tensor): #F, 3, 3 array of vertices 43 | normals (torch.Tensor): normals (if precomputed) 44 | eps (float): epsilon 45 | """ 46 | 47 | if normals is None: 48 | normals = per_face_normals(mesh) 49 | areas = torch.norm(normals, p=2, dim=1) * 0.5 50 | areas /= torch.sum(areas) + 1e-10 51 | 52 | # Discrete PDF over triangles 53 | return torch.distributions.Categorical(areas.view(-1)) 54 | 55 | 56 | def sample_near_surface(mesh: torch.Tensor, num_samples: int, distrib=None): 57 | """Sample points near the mesh surface. 58 | 59 | Args: 60 | mesh (torch.Tensor): triangle mesh 61 | num_samples (int): number of surface samples 62 | distrib: distribution to use. By default, area-weighted distribution is used 63 | """ 64 | if distrib is None: 65 | distrib = area_weighted_distribution(mesh) 66 | samples = sample_surface(mesh, num_samples, distrib)[0] 67 | samples += torch.randn_like(samples) * 0.01 68 | return samples 69 | 70 | 71 | def sample_uniform(num_samples: int): 72 | """Sample uniformly in [-1,1] bounding volume. 73 | 74 | Args: 75 | num_samples(int) : number of points to sample 76 | """ 77 | return torch.rand(num_samples, 3) * 2.0 - 1.0 78 | 79 | 80 | def sample_surface( 81 | mesh: torch.Tensor, 82 | num_samples: int, 83 | distrib=None, 84 | ): 85 | """Sample points and their normals on mesh surface. 86 | 87 | Args: 88 | mesh (torch.Tensor): triangle mesh 89 | num_samples (int): number of surface samples 90 | distrib: distribution to use. By default, area-weighted distribution is used 91 | """ 92 | if distrib is None: 93 | distrib = area_weighted_distribution(mesh) 94 | 95 | # Select faces & sample their surface 96 | f, normals = random_face(mesh, num_samples, distrib) 97 | 98 | u = torch.sqrt(torch.rand(num_samples)).to(mesh.device).unsqueeze(-1) 99 | v = torch.rand(num_samples).to(mesh.device).unsqueeze(-1) 100 | 101 | samples = (1 - u) * f[:, 0, :] + (u * (1 - v)) * f[:, 1, :] + u * v * f[:, 2, :] 102 | 103 | return samples, normals 104 | 105 | 106 | def point_sample(mesh: torch.Tensor, techniques: list, num_samples: int): 107 | """Sample points from a mesh. 108 | 109 | Args: 110 | mesh (torch.Tensor): #F, 3, 3 array of vertices 111 | techniques (list[str]): list of techniques to sample with 112 | num_samples (int): points to sample per technique 113 | """ 114 | if 'trace' in techniques or 'near' in techniques: 115 | # Precompute face distribution 116 | distrib = area_weighted_distribution(mesh) 117 | 118 | samples = [] 119 | for technique in techniques: 120 | if technique == 'trace': 121 | samples.append(sample_surface(mesh, num_samples, distrib)[0]) 122 | elif technique == 'near': 123 | samples.append(sample_near_surface(mesh, num_samples, distrib)) 124 | elif technique == 'rand': 125 | samples.append(sample_uniform(num_samples).to(mesh.device)) 126 | samples = torch.cat(samples, dim=0) 127 | return samples -------------------------------------------------------------------------------- /src/scripts/generate_fusion_data_arkit.py: -------------------------------------------------------------------------------- 1 | """ 2 | What idr needs: 3 | cameras.npz 4 | scale_mat_{i} 5 | world_mat_{i} 6 | image 7 | mask 8 | """ 9 | import argparse 10 | import os 11 | import sys 12 | from tqdm import tqdm 13 | import open3d as o3d 14 | import trimesh 15 | 16 | import cv2 17 | import numpy as np 18 | 19 | from src.utils.geometry import get_homogeneous, depth2xyz 20 | from src.datasets.arkitscene_dataset import get_association 21 | from src.utils.common import load_depth 22 | 23 | 24 | def make_dir(dir_): 25 | if not os.path.exists(dir_): 26 | os.makedirs(dir_) 27 | 28 | 29 | def read_cam_traj(path, n_imgs): 30 | T_wcs = [] 31 | start_line = 1 32 | end_line = 5 33 | with open(path, "r") as f: 34 | lines = f.read().splitlines() 35 | assert len(lines) / 5 == n_imgs 36 | for i in range(n_imgs): 37 | if "\t" in lines[start_line]: 38 | line = [ 39 | line.split("\t") for line in lines[start_line:end_line] 40 | ] 41 | else: 42 | line = [ 43 | [l for l in line.split(" ") if len(l) > 0] for line in lines[start_line:end_line] 44 | ] 45 | T_wc = np.asarray(line).astype(np.float32) 46 | start_line += 5 47 | end_line += 5 48 | T_wcs.append(T_wc) 49 | return T_wcs 50 | 51 | 52 | ROOT_DIR = "/home/kejie/Datasets_ssd/raw/Training/" 53 | out_base_dir = "/home/kejie/repository/fast_sdf/data/fusion/arkit" 54 | 55 | arg_parser = argparse.ArgumentParser() 56 | arg_parser.add_argument("--seq") 57 | args = arg_parser.parse_args() 58 | seq_names = [args.seq] 59 | 60 | for name in tqdm(seq_names): 61 | 62 | gt_mesh_path = os.path.join(ROOT_DIR, name, f"{name}_3dod_mesh.ply") 63 | gt_mesh = trimesh.load(gt_mesh_path) 64 | max_pts = np.max(gt_mesh.vertices, axis=0) 65 | min_pts = np.min(gt_mesh.vertices, axis=0) 66 | center = (min_pts + max_pts) / 2 67 | dimensions = max_pts - min_pts 68 | axis_align_mat = np.eye(4) 69 | axis_align_mat[:3, 3] = -center 70 | base_dir = os.path.join(ROOT_DIR, name) 71 | frames = get_association(base_dir, name) 72 | n_imgs = len(frames) 73 | out_dir = os.path.join(out_base_dir, name) 74 | out_rgb_dir = os.path.join(out_dir, "image") 75 | out_mask_dir = os.path.join(out_dir, "mask") 76 | out_depth_dir = os.path.join(out_dir, "depth") 77 | out_pose_dir = os.path.join(out_dir, "pose") 78 | make_dir(out_dir) 79 | make_dir(out_rgb_dir) 80 | make_dir(out_mask_dir) 81 | make_dir(out_depth_dir) 82 | make_dir(out_pose_dir) 83 | 84 | gt_mesh.vertices = (axis_align_mat @ get_homogeneous(gt_mesh.vertices).T)[:3, :].T 85 | gt_mesh.export(os.path.join(out_dir, "gt_mesh.ply")) 86 | # get the 3D bounding box of the scene 87 | cameras_new = {} 88 | skip = 0 89 | i = 0 90 | for _ in range(0, n_imgs): 91 | if not (os.path.exists(frames[i]['rgb_path']) and os.path.exists(frames[i]['depth_path'])): 92 | skip += 1 93 | continue 94 | 95 | rgb = cv2.imread(frames[i]['rgb_path'], -1) 96 | depth_map, _, _ = load_depth(frames[i]['depth_path'], 1, max_depth=1000.) 97 | mask = cv2.imread(frames[i]['confidence_path'], -1) 98 | depth_map = depth_map * (mask >= 2) 99 | ind_y, ind_x = np.nonzero(depth_map != 0) 100 | mask = (depth_map > 0).astype(np.float32) 101 | img_h, img_w = mask.shape 102 | T_wc = np.linalg.inv(frames[i]['T_cw']) 103 | T_wc = axis_align_mat @ T_wc 104 | out_rgb_path = os.path.join(out_rgb_dir, f"{i}.jpg") 105 | cv2.imwrite(out_rgb_path, rgb) 106 | out_mask_path = os.path.join(out_mask_dir, f"{i}.png") 107 | cv2.imwrite(out_mask_path, mask.astype(np.uint8)*255) 108 | out_depth_path = os.path.join(out_depth_dir, f"{i}.png") 109 | cv2.imwrite(out_depth_path, (depth_map * 1000).astype(np.uint16)) 110 | 111 | intr_mat = np.eye(4) 112 | intr_mat[:3, :3] = frames[i]['intr_mat'] 113 | cameras_new['intr_mat_%d'%i] = intr_mat 114 | cameras_new['T_wc_%d'%i] = T_wc 115 | intr_path = os.path.join(out_pose_dir, f"intr_mat_{i}.txt") 116 | with open(intr_path, "w") as f: 117 | f.write(" ".join([str(t) for t in intr_mat.reshape(-1)])) 118 | extr_path = os.path.join(out_pose_dir, f"T_wc_{i}.txt") 119 | with open(extr_path, "w") as f: 120 | f.write(" ".join([str(t) for t in T_wc.reshape(-1)])) 121 | i += 1 122 | print(f"skip {skip} images") 123 | cameras_new['dimensions'] = dimensions 124 | np.savez('{0}/{1}.npz'.format(out_dir, "cameras"), **cameras_new) 125 | dimension_path = os.path.join(out_pose_dir, "dimensions.txt") 126 | with open(dimension_path, "w") as f: 127 | f.write(" ".join([str(t) for t in dimensions.reshape(-1)])) 128 | 129 | -------------------------------------------------------------------------------- /src/scripts/generate_fusion_data_scene3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | What idr needs: 3 | cameras.npz 4 | scale_mat_{i} 5 | world_mat_{i} 6 | image 7 | mask 8 | """ 9 | import os 10 | import sys 11 | from tqdm import tqdm 12 | import open3d as o3d 13 | import trimesh 14 | 15 | import cv2 16 | import numpy as np 17 | 18 | from src.datasets.scenenet import SceneNet 19 | from src.utils.geometry import get_homogeneous, depth2xyz 20 | 21 | 22 | def make_dir(dir_): 23 | if not os.path.exists(dir_): 24 | os.makedirs(dir_) 25 | 26 | 27 | def read_cam_traj(path, n_imgs): 28 | T_wcs = [] 29 | start_line = 1 30 | end_line = 5 31 | with open(path, "r") as f: 32 | lines = f.read().splitlines() 33 | assert len(lines) / 5 == n_imgs 34 | for i in range(n_imgs): 35 | if "\t" in lines[start_line]: 36 | line = [ 37 | line.split("\t") for line in lines[start_line:end_line] 38 | ] 39 | else: 40 | line = [ 41 | [l for l in line.split(" ") if len(l) > 0] for line in lines[start_line:end_line] 42 | ] 43 | T_wc = np.asarray(line).astype(np.float32) 44 | start_line += 5 45 | end_line += 5 46 | T_wcs.append(T_wc) 47 | return T_wcs 48 | 49 | 50 | ROOT_DIR = "/home/kejie/repository/fast_sdf/data/scene3d/" 51 | out_base_dir = "/home/kejie/repository/fast_sdf/data/fusion/scene3d" 52 | intr_mat = np.eye(3) 53 | intr_mat[0, 0] = 525. 54 | intr_mat[0, 2] = 319.5 55 | intr_mat[1, 1] = 525. 56 | intr_mat[1, 2] = 239.5 57 | 58 | seq_names = ["lounge", "stonewall", "copyroom", "cactusgarden", "burghers"] 59 | for name in tqdm(seq_names): 60 | gt_mesh_path = os.path.join(ROOT_DIR, name, f"{name}.ply") 61 | gt_mesh = trimesh.load(gt_mesh_path) 62 | max_pts = np.max(gt_mesh.vertices, axis=0) 63 | min_pts = np.min(gt_mesh.vertices, axis=0) 64 | center = (min_pts + max_pts) / 2 65 | dimensions = max_pts - min_pts 66 | axis_align_mat = np.eye(4) 67 | axis_align_mat[:3, 3] = -center 68 | 69 | 70 | in_rgb_dir = os.path.join(ROOT_DIR, name, f"{name}_png", "color") 71 | in_depth_dir = os.path.join(ROOT_DIR, name, f"{name}_png", "depth") 72 | in_cam_traj_path = os.path.join(ROOT_DIR, name, f"{name}_trajectory.log") 73 | n_imgs = len(os.listdir(in_rgb_dir)) 74 | T_wcs = read_cam_traj(in_cam_traj_path, n_imgs) 75 | 76 | out_dir = os.path.join(out_base_dir, name) 77 | out_rgb_dir = os.path.join(out_dir, "image") 78 | out_mask_dir = os.path.join(out_dir, "mask") 79 | out_depth_dir = os.path.join(out_dir, "depth") 80 | out_pose_dir = os.path.join(out_dir, "pose") 81 | make_dir(out_dir) 82 | make_dir(out_rgb_dir) 83 | make_dir(out_mask_dir) 84 | make_dir(out_depth_dir) 85 | make_dir(out_pose_dir) 86 | 87 | gt_mesh.vertices = (axis_align_mat @ get_homogeneous(gt_mesh.vertices).T)[:3, :].T 88 | gt_mesh.export(os.path.join(out_dir, "gt_mesh.ply")) 89 | # get the 3D bounding box of the scene 90 | cameras_new = {} 91 | for i in range(0, n_imgs): 92 | rgb = cv2.imread( 93 | os.path.join(in_rgb_dir, "{:06d}.png".format(i+1)), 94 | -1 95 | ) 96 | depth_map = cv2.imread( 97 | os.path.join(in_depth_dir, "{:06d}.png".format(i+1)), 98 | -1 99 | ) / 1000. 100 | ind_y, ind_x = np.nonzero(depth_map != 0) 101 | mask = (depth_map > 0).astype(np.float32) 102 | img_h, img_w = mask.shape 103 | pts = depth2xyz(depth_map, intr_mat)[ind_y, ind_x, :] 104 | T_wc = T_wcs[i] 105 | T_wc = axis_align_mat @ T_wc 106 | out_rgb_path = os.path.join(out_rgb_dir, f"{i}.jpg") 107 | cv2.imwrite(out_rgb_path, rgb[:, :, ::-1]) 108 | out_mask_path = os.path.join(out_mask_dir, f"{i}.png") 109 | cv2.imwrite(out_mask_path, mask.astype(np.uint8)*255) 110 | out_depth_path = os.path.join(out_depth_dir, f"{i}.png") 111 | cv2.imwrite(out_depth_path, (depth_map * 1000).astype(np.uint16)) 112 | 113 | _intr_mat = np.eye(4) 114 | _intr_mat[:3, :3] = intr_mat 115 | cameras_new['intr_mat_%d'%i] = _intr_mat 116 | cameras_new['T_wc_%d'%i] = T_wc 117 | intr_path = os.path.join(out_pose_dir, f"intr_mat_{i}.txt") 118 | with open(intr_path, "w") as f: 119 | f.write(" ".join([str(t) for t in intr_mat.reshape(-1)])) 120 | extr_path = os.path.join(out_pose_dir, f"T_wc_{i}.txt") 121 | with open(extr_path, "w") as f: 122 | f.write(" ".join([str(t) for t in T_wc.reshape(-1)])) 123 | 124 | cameras_new['dimensions'] = dimensions 125 | np.savez('{0}/{1}.npz'.format(out_dir, "cameras"), **cameras_new) 126 | dimension_path = os.path.join(out_pose_dir, "dimensions.txt") 127 | with open(dimension_path, "w") as f: 128 | f.write(" ".join([str(t) for t in dimensions.reshape(-1)])) 129 | 130 | -------------------------------------------------------------------------------- /src/scripts/run_rgbd_intergration.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import open3d as o3d 4 | import numpy as np 5 | import trimesh 6 | from tqdm import tqdm 7 | 8 | from src.utils.common import load_depth, load_rgb 9 | 10 | 11 | def read_pose(path): 12 | with open(path, "r") as f: 13 | line = f.read().splitlines()[0].split(" ") 14 | pose = np.asarray([float(t) for t in line]) 15 | nrow = int(np.sqrt(len(pose))) 16 | return pose.reshape(nrow, nrow).astype(np.float32) 17 | 18 | 19 | if __name__ == "__main__": 20 | arg_parser = argparse.ArgumentParser() 21 | arg_parser.add_argument("--scan_id", required=True) 22 | arg_parser.add_argument("--dataset_name", required=True) 23 | arg_parser.add_argument("--skip", type=int, required=True) 24 | arg_parser.add_argument("--voxel", type=float, required=True) 25 | arg_parser.add_argument("--max_depth", type=float, required=True) 26 | arg_parser.add_argument("--shift", type=int, required=True) 27 | args = arg_parser.parse_args() 28 | 29 | root_dir = "/home/kejie/repository/fast_sdf/data/fusion/" 30 | 31 | voxel_size = args.voxel 32 | scale = args.skip 33 | render_path = f"{args.dataset_name}_{args.skip}_{args.shift}" 34 | max_depth = args.max_depth 35 | shift = args.shift 36 | sdf_trunc = min(voxel_size * 5, 0.05) 37 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 38 | voxel_length=voxel_size, 39 | sdf_trunc=sdf_trunc, 40 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8) 41 | seq_dir = os.path.join(root_dir, args.dataset_name, args.scan_id) 42 | img_dir = os.path.join(seq_dir, "image") 43 | depth_dir = os.path.join(seq_dir, "depth") 44 | mask_dir = os.path.join(seq_dir, "mask") 45 | pose_dir = os.path.join(seq_dir, "pose") 46 | 47 | n_imgs = len(os.listdir(img_dir)) 48 | test_imgs = np.arange(shift, n_imgs, scale) 49 | cameras = np.load(os.path.join(seq_dir, "cameras.npz")) 50 | intr_mat = cameras['intr_mat_0'] 51 | dimension = cameras['dimensions'] / 2. 52 | downsample_scale = 1 53 | intr_mat[:2, :3] = intr_mat[:2, :3] * downsample_scale 54 | for img in tqdm(test_imgs): 55 | rgb_path = os.path.join(img_dir, f"{img}.jpg") 56 | rgb = load_rgb(rgb_path, downsample_scale).transpose(1, 2, 0) 57 | rgb = (rgb / 2 + 0.5) * 255. 58 | T_wc_path = os.path.join( 59 | pose_dir, f"T_wc_{img}.txt" 60 | ) 61 | T_wc = read_pose(T_wc_path) 62 | depth_path = os.path.join(depth_dir, f"{img}.png") 63 | depth_map, _, mask = load_depth( 64 | depth_path, downsample_scale, max_depth=max_depth) 65 | img_h, img_w = depth_map.shape 66 | depth_map = depth_map * (depth_map < 5) 67 | color = o3d.geometry.Image(rgb.astype(np.uint8)) 68 | depth = o3d.geometry.Image((depth_map * 1000).astype(np.uint16)) 69 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 70 | color, depth, depth_trunc=max_depth, convert_rgb_to_intensity=False) 71 | intrinsic = o3d.camera.PinholeCameraIntrinsic() 72 | img_h, img_w = depth_map.shape 73 | intrinsic.set_intrinsics( 74 | width=img_w, 75 | height=img_h, 76 | fx=intr_mat[0, 0], 77 | fy=intr_mat[1, 1], 78 | cx=intr_mat[0, 2], 79 | cy=intr_mat[1, 2], 80 | ) 81 | # T_wc = cameras[f"T_wc_{img}"] 82 | T_cw = np.linalg.inv(T_wc) 83 | volume.integrate( 84 | rgbd, 85 | intrinsic, 86 | T_cw 87 | ) 88 | print("Extract a triangle mesh from the volume and visualize it.") 89 | 90 | mesh_o3d = volume.extract_triangle_mesh() 91 | mesh_o3d.compute_vertex_normals() 92 | out_dir = f"./logs/tsdf_fusion/{render_path}/{args.dataset_name}/{args.scan_id}" 93 | if not os.path.exists(out_dir): 94 | os.makedirs(out_dir) 95 | mesh = trimesh.Trimesh( 96 | vertices=np.asarray(mesh_o3d.vertices), 97 | faces=np.asarray(mesh_o3d.triangles), 98 | vertex_normals=np.asarray(mesh_o3d.vertex_normals) 99 | ) 100 | mesh = trimesh.intersections.slice_mesh_plane( 101 | mesh, np.array([1, 0, 0]), np.array([-dimension[0], 0, 0])) 102 | mesh = trimesh.intersections.slice_mesh_plane( 103 | mesh, np.array([-1, 0, 0]), np.array([dimension[0], 0, 0])) 104 | mesh = trimesh.intersections.slice_mesh_plane( 105 | mesh, np.array([0, 1, 0]), np.array([0, -dimension[1], 0])) 106 | mesh = trimesh.intersections.slice_mesh_plane( 107 | mesh, np.array([0, -1, 0]), np.array([0, dimension[1], 0])) 108 | mesh = trimesh.intersections.slice_mesh_plane( 109 | mesh, np.array([0, 0, 1]), np.array([0, 0, -dimension[2]])) 110 | mesh = trimesh.intersections.slice_mesh_plane( 111 | mesh, np.array([0, 0, -1]), np.array([0, 0, dimension[2]])) 112 | mesh.export( 113 | os.path.join(out_dir, f"scene_scale_{scale}_voxel_size_{int(voxel_size*1000)}_max_depth_{int(max_depth*10)}_shift_{int(shift)}.ply") 114 | ) 115 | -------------------------------------------------------------------------------- /src/scripts/generate_fusion_data_scannet.py: -------------------------------------------------------------------------------- 1 | """ 2 | What idr needs: 3 | cameras.npz 4 | scale_mat_{i} 5 | world_mat_{i} 6 | image 7 | mask 8 | """ 9 | import os 10 | import cv2 11 | import numpy as np 12 | from tqdm import tqdm 13 | import trimesh 14 | import sys 15 | 16 | # from src.datasets.scenenet import SceneNet 17 | from src.utils.geometry import get_homogeneous, depth2xyz 18 | import src.utils.scannet_helper as scannet_helper 19 | 20 | 21 | def recenter(vertices): 22 | min_ = np.min(vertices, axis=0) 23 | max_ = np.max(vertices, axis=0) 24 | center = (max_ + min_) / 2. 25 | vertices = vertices - center[None, :] 26 | return vertices 27 | 28 | 29 | def make_dir(dir_): 30 | if not os.path.exists(dir_): 31 | os.makedirs(dir_) 32 | 33 | 34 | ROOT_DIR = "/home/kejie/repository/fast_sdf/data/ScanNet/" 35 | RENDER_PATH = "scene0575_00" 36 | out_base_dir = "/home/kejie/repository/fast_sdf/data/fusion/ScanNet" 37 | 38 | DEPTH_SCALE = 1000. 39 | SKIP_IMAGES = 1 40 | 41 | out_dir = os.path.join(out_base_dir, RENDER_PATH) 42 | out_rgb_dir = os.path.join(out_dir, "image") 43 | out_mask_dir = os.path.join(out_dir, "mask") 44 | out_depth_dir = os.path.join(out_dir, "depth") 45 | out_pose_dir = os.path.join(out_dir, "pose") 46 | make_dir(out_dir) 47 | make_dir(out_rgb_dir) 48 | make_dir(out_mask_dir) 49 | make_dir(out_depth_dir) 50 | make_dir(out_pose_dir) 51 | 52 | seq_dir = os.path.join(ROOT_DIR, RENDER_PATH, "frames") 53 | img_dir = os.path.join(seq_dir, "color") 54 | depth_dir = os.path.join(seq_dir, "depth") 55 | pose_dir = os.path.join(seq_dir, "pose") 56 | 57 | img_names = [f.split(".")[0] for f in os.listdir(img_dir)] 58 | img_names = sorted(img_names, key=lambda a: int(a)) 59 | n_imgs = len(img_names) 60 | intrinsic_path = os.path.join(seq_dir, "intrinsic", "intrinsic_depth.txt") 61 | min_pts = [] 62 | max_pts = [] 63 | 64 | cameras = {} 65 | axis_align_mat = scannet_helper.read_meta_file( 66 | os.path.join(os.path.join(ROOT_DIR, RENDER_PATH, RENDER_PATH + ".txt")) 67 | ) 68 | gt_mesh = trimesh.load(os.path.join(ROOT_DIR, RENDER_PATH, f"{RENDER_PATH}_vh_clean_2.ply")) 69 | vertices = gt_mesh.vertices 70 | vertices = (axis_align_mat @ get_homogeneous(vertices).T)[:3, :].T 71 | gt_mesh.vertices = vertices 72 | 73 | # get the 3D bounding box of the scene 74 | used_id = 0 75 | for i in range(0, n_imgs, SKIP_IMAGES): 76 | rgb = cv2.imread( 77 | os.path.join(img_dir, img_names[i] + ".jpg"), -1)[:, :, ::-1] 78 | depth = cv2.imread( 79 | os.path.join(depth_dir, img_names[i] + ".png"), -1) / DEPTH_SCALE 80 | mask = (depth > 0).astype(np.float32) 81 | img_h, img_w = mask.shape 82 | y, x = np.nonzero(depth) 83 | valid_pixels = np.stack([x, y], axis=-1) 84 | img_h, img_w = depth.shape 85 | rgb = cv2.resize(rgb, (img_w, img_h)) 86 | 87 | intr_mat = scannet_helper.read_intrinsic(intrinsic_path) 88 | T_cw_old = scannet_helper.read_extrinsic( 89 | os.path.join(pose_dir, img_names[i] + ".txt")) 90 | if np.isnan(T_cw_old).any(): 91 | continue 92 | T_wc = np.linalg.inv(T_cw_old) 93 | T_wc = axis_align_mat @ T_wc 94 | cameras['intr_mat_%d'%used_id] = intr_mat 95 | cameras['T_wc_%d'%used_id] = T_wc 96 | 97 | # pts_c = depth2xyz(depth, intr_mat) 98 | # pts_c = pts_c[valid_pixels[:, 1], valid_pixels[:, 0], :].reshape(-1, 3) 99 | # pts_w = (T_wc @ get_homogeneous(pts_c).T)[:3, :].T 100 | # _min = np.min(pts_w, axis=0) 101 | # _max = np.max(pts_w, axis=0) 102 | # min_pts.append(_min) 103 | # max_pts.append(_max) 104 | 105 | out_rgb_path = os.path.join(out_rgb_dir, f"{used_id}.jpg") 106 | cv2.imwrite(out_rgb_path, rgb[:, :, ::-1]) 107 | out_mask_path = os.path.join(out_mask_dir, f"{used_id}.png") 108 | cv2.imwrite(out_mask_path, mask.astype(np.uint8)*255) 109 | out_depth_path = os.path.join(out_depth_dir, f"{used_id}.png") 110 | cv2.imwrite(out_depth_path, (depth * 1000).astype(np.uint16)) 111 | used_id += 1 112 | n_imgs = used_id 113 | min_pts = np.min(np.stack(gt_mesh.vertices, axis=0), axis=0) 114 | max_pts = np.max(np.stack(gt_mesh.vertices, axis=0), axis=0) 115 | center = (min_pts + max_pts) / 2 116 | dimensions = max_pts - min_pts 117 | axis_align_mat = np.eye(4) 118 | axis_align_mat[:3, 3] = -center 119 | 120 | used_id = 0 121 | cameras_new = {} 122 | for i in range(n_imgs): 123 | T_wc = axis_align_mat @ cameras['T_wc_%d'%used_id] 124 | extr_path = os.path.join(out_pose_dir, f"T_wc_{used_id}.txt") 125 | with open(extr_path, "w") as f: 126 | f.write(" ".join([str(t) for t in T_wc.reshape(-1)])) 127 | cameras_new['T_wc_%d'%used_id] = T_wc 128 | cameras_new['intr_mat_%d'%used_id] = cameras['intr_mat_%d'%used_id] 129 | intr_path = os.path.join(out_pose_dir, f"intr_mat_{used_id}.txt") 130 | with open(intr_path, "w") as f: 131 | f.write(" ".join([str(t) for t in intr_mat.reshape(-1)])) 132 | used_id += 1 133 | cameras_new['dimensions'] = dimensions 134 | np.savez('{0}/{1}.npz'.format(out_dir, "cameras"), **cameras_new) 135 | 136 | dimension_path = os.path.join(out_pose_dir, "dimensions.txt") 137 | with open(dimension_path, "w") as f: 138 | f.write(" ".join([str(t) for t in dimensions.reshape(-1)])) 139 | -------------------------------------------------------------------------------- /src/utils/pangolin_helper.py: -------------------------------------------------------------------------------- 1 | import pangolin 2 | import OpenGL.GL as gl 3 | import numpy as np 4 | 5 | 6 | def init_panel(): 7 | pangolin.ParseVarsFile('app.cfg') 8 | 9 | pangolin.CreateWindowAndBind('Main', 640, 480) 10 | gl.glEnable(gl.GL_DEPTH_TEST) 11 | 12 | scam = pangolin.OpenGlRenderState( 13 | pangolin.ProjectionMatrix(640, 480, 420, 420, 320, 240, 0.1, 1000), 14 | pangolin.ModelViewLookAt(0, 0.5, -3, 0, 0, 0, pangolin.AxisDirection.AxisY)) 15 | handler3d = pangolin.Handler3D(scam) 16 | 17 | dcam = pangolin.CreateDisplay() 18 | dcam.SetBounds(0.0, 1.0, 180/640., 1.0, -640.0/480.0) 19 | # dcam.SetBounds(pangolin.Attach(0.0), pangolin.Attach(1.0), 20 | # pangolin.Attach.Pix(180), pangolin.Attach(1.0), -640.0/480.0) 21 | 22 | dcam.SetHandler(pangolin.Handler3D(scam)) 23 | 24 | panel = pangolin.CreatePanel('ui') 25 | panel.SetBounds(0.0, 1.0, 0.0, 180/640.) 26 | return scam, dcam 27 | 28 | 29 | def draw_3d_points(display, scene_camera, points, colors, pt_size=5): 30 | """display 3d point cloud 31 | 32 | Args: 33 | display (Pangolin object): display window 34 | scene_camera (Pangolin object): scene camera 35 | points (np.ndarray): [N, 3] 3D point position 36 | colors (np.ndarray): [N, 3] color for each point, range [0, 1] 37 | pt_size (int): point size 38 | Returns: 39 | None 40 | """ 41 | display.Activate(scene_camera) 42 | gl.glPointSize(pt_size) 43 | gl.glColor3f(1.0, 0.0, 0.0) 44 | pangolin.DrawPoints(points, colors) 45 | 46 | 47 | def draw_3d_box(display, scene_camera, rgb_color, t_wo, dimensions, line_width=3, kitti=False, alpha=1): 48 | """display object 3d bounding box 49 | 50 | Args: 51 | display (Pangolin object): display window 52 | scene_camera (Pangolin object): scene camera 53 | rgb_color (np.ndarray): [3, ] box color, range [0, 1] 54 | t_wo (np.ndarray): [4, 4] transform matrix from object to world coordinate 55 | dimensions (np.ndarray): [3] object dimension 56 | line_width: line with of bounding box 57 | kitti: if use kitti format 58 | kitti use a weird format for bbox_3d and translation 59 | the translation is not the object bounding box center, but the (x_mean, y_min, z_mean) 60 | and although the dimension is for h, w, l, it doesn't align with the coordinate order. 61 | h is for y, w is for z, and l is for x. So we need to do two things to align with 62 | representation. (1) switch the order of dimensions, (2) minus t_y so that it is in the 63 | translation is in the object center. 64 | Returns: 65 | None 66 | """ 67 | if kitti: 68 | t_wo = deepcopy(t_wo) 69 | t_wo[1, 3] -= dimensions[1] / 2. 70 | display.Activate(scene_camera) 71 | gl.glLineWidth(line_width) 72 | gl.glColor4f(rgb_color[0], rgb_color[1], rgb_color[2], alpha) 73 | pangolin.DrawBoxes([t_wo], dimensions[None, :]) 74 | 75 | 76 | def draw_mesh(display, scene_camera, points, faces, normals, alpha_value=0.6): 77 | """display mesh 78 | 79 | Args: 80 | display (Pangolin object): display window 81 | scene_camera (Pangolin object): scene camera 82 | points (np.ndarray): [N, 3] 3D point position 83 | faces (np.ndarray): [NUM_F, 3] mesh faces 84 | normals (np.ndarray): [N, 4] vertex normal 85 | colors (np.ndarray): [N, 3] color for each point, range [0, 1] 86 | Returns: 87 | None 88 | """ 89 | alpha = np.ones((len(normals), 1)) * alpha_value 90 | smoothed_normals = np.concatenate([normals, alpha], axis=1) 91 | display.Activate(scene_camera) 92 | pangolin.DrawMesh(points, faces.astype(np.int32), smoothed_normals) 93 | 94 | 95 | def draw_line(display, scene_camera, points, colors, line_width=5, alpha=1): 96 | """display line in 3D display 97 | 98 | Args: 99 | display (Pangolin object): display window 100 | scene_camera (Pangolin object): scene camera 101 | points (np.ndarray): [2, 3] points positions on the line 102 | colors (np.ndarray): [1, 3] lien color, range [0, 1] 103 | Returns: 104 | None 105 | """ 106 | 107 | display.Activate(scene_camera) 108 | gl.glLineWidth(line_width) 109 | gl.glColor4f(colors[0], colors[1], colors[2], colors[3]) 110 | pangolin.DrawLine(points) 111 | 112 | 113 | def draw_lines(display, scene_camera, lines, colors, line_width=5): 114 | """display line in 3D display 115 | 116 | Args: 117 | display (Pangolin object): display window 118 | scene_camera (Pangolin object): scene camera 119 | lines (np.ndarray): [n_lines, 2, 3] points positions on the line 120 | colors (np.ndarray): [3] lien color, range [0, 1] 121 | Returns: 122 | None 123 | """ 124 | 125 | display.Activate(scene_camera) 126 | for line in lines: 127 | gl.glLineWidth(line_width) 128 | gl.glColor3f(colors[0], colors[1], colors[2]) 129 | pangolin.DrawLine(line) 130 | 131 | 132 | def draw_image(rgb, texture, display): 133 | """display image 134 | 135 | Args: 136 | rgb (np.ndarray): [H, W, 3] rgb image 137 | texture (Pangolin object): see pangolin 138 | display (Pangolin object): display camera 139 | Returns: 140 | None 141 | """ 142 | texture.Upload(rgb, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) 143 | display.Activate() 144 | gl.glColor3f(1.0, 1.0, 1.0) 145 | texture.RenderToViewport() -------------------------------------------------------------------------------- /src/models/fusion/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.measure import marching_cubes_lewiner 3 | from tqdm import tqdm 4 | import trimesh 5 | import torch 6 | 7 | 8 | 9 | def decode_feature_grid( 10 | nerf, 11 | volume, 12 | weight_mask, 13 | num_hits, 14 | sdf_delta, 15 | min_coords, 16 | max_coords, 17 | volume_resolution, 18 | voxel_size, 19 | step_size=0.25, 20 | batch_size=500, 21 | level=0., 22 | path=None 23 | ): 24 | device = volume.device 25 | occupied_voxels = torch.nonzero(num_hits[0][0]).cpu().numpy() 26 | assert step_size <= 1 27 | all_vertices = [] 28 | all_faces = [] 29 | last_face_id = 0 30 | min_sdf = [] 31 | max_sdf = [] 32 | for i in tqdm(range(0, len(occupied_voxels), batch_size)): 33 | origin = occupied_voxels[i:i+batch_size] 34 | n_batches = len(origin) 35 | range_ = np.arange(0, 1+step_size, step_size) - 0.5 36 | spacing = [range_[1] - range_[0]] * 3 37 | voxel_coords = np.stack( 38 | np.meshgrid(range_, range_, range_, indexing="ij"), 39 | axis=-1 40 | ) 41 | voxel_coords = np.tile(voxel_coords, (n_batches, 1, 1, 1, 1)) 42 | voxel_coords += origin[:, None, None, None, :] 43 | voxel_coords = torch.from_numpy( 44 | voxel_coords).float().to(device) 45 | voxel_pts = voxel_coords * voxel_size + min_coords 46 | H, W, D = voxel_pts.shape[1:4] 47 | voxel_pts = voxel_pts.reshape(1, n_batches, -1, 3) 48 | dirs = torch.zeros_like(voxel_pts) 49 | pts_and_dirs = torch.cat([voxel_pts, dirs], dim=-1) 50 | out, _ = nerf( 51 | pts_and_dirs, 52 | volume, 53 | weight_mask, 54 | sdf_delta, 55 | voxel_size, 56 | volume_resolution, 57 | min_coords, 58 | max_coords, 59 | active_voxels=None, 60 | ) 61 | sdf = out[0, :, :, -1].reshape(n_batches, H, W, D) 62 | sdf = sdf.detach().cpu().numpy() 63 | min_sdf.append(np.min(sdf)) 64 | max_sdf.append(np.max(sdf)) 65 | for j in range(n_batches): 66 | if np.max(sdf[j]) > level and np.min(sdf[j]) < level: 67 | verts, faces, normals, values = \ 68 | marching_cubes_lewiner( 69 | sdf[j], 70 | level=level, 71 | spacing=spacing 72 | ) 73 | verts += origin[j] - 0.5 74 | all_vertices.append(verts) 75 | all_faces.append(faces + last_face_id) 76 | last_face_id += np.max(faces) + 1 77 | print(np.min(min_sdf)) 78 | print(np.max(max_sdf)) 79 | 80 | if len(all_vertices) == 0: 81 | return None 82 | final_vertices = np.concatenate(all_vertices, axis=0) 83 | final_faces = np.concatenate(all_faces, axis=0) 84 | final_vertices = final_vertices * voxel_size + min_coords.cpu().numpy() 85 | # all_normals = np.concatenate(all_normals, axis=0) 86 | mesh = trimesh.Trimesh( 87 | vertices=final_vertices, 88 | faces=final_faces, 89 | # vertex_normals=all_normals, 90 | process=False 91 | ) 92 | if path is None: 93 | return mesh 94 | else: 95 | mesh.export(path) 96 | 97 | 98 | def get_neighbors(points): 99 | """ 100 | args: voxel_coordinates: [b, n_steps, n_samples, 3] 101 | """ 102 | return torch.stack([ 103 | torch.stack( 104 | [ 105 | torch.floor(points[:, :, :, 0]), 106 | torch.floor(points[:, :, :, 1]), 107 | torch.floor(points[:, :, :, 2]) 108 | ], 109 | dim=-1 110 | ), 111 | torch.stack( 112 | [ 113 | torch.ceil(points[:, :, :, 0]), 114 | torch.floor(points[:, :, :, 1]), 115 | torch.floor(points[:, :, :, 2]) 116 | ], 117 | dim=-1 118 | ), 119 | torch.stack( 120 | [ 121 | torch.floor(points[:, :, :, 0]), 122 | torch.ceil(points[:, :, :, 1]), 123 | torch.floor(points[:, :, :, 2]) 124 | ], 125 | dim=-1 126 | ), 127 | torch.stack( 128 | [ 129 | torch.floor(points[:, :, :, 0]), 130 | torch.floor(points[:, :, :, 1]), 131 | torch.ceil(points[:, :, :, 2]) 132 | ], 133 | dim=-1 134 | ), 135 | torch.stack( 136 | [ 137 | torch.ceil(points[:, :, :, 0]), 138 | torch.ceil(points[:, :, :, 1]), 139 | torch.floor(points[:, :, :, 2]) 140 | ], 141 | dim=-1 142 | ), 143 | torch.stack( 144 | [ 145 | torch.ceil(points[:, :, :, 0]), 146 | torch.floor(points[:, :, :, 1]), 147 | torch.ceil(points[:, :, :, 2]) 148 | ], 149 | dim=-1 150 | ), 151 | torch.stack( 152 | [ 153 | torch.floor(points[:, :, :, 0]), 154 | torch.ceil(points[:, :, :, 1]), 155 | torch.ceil(points[:, :, :, 2]) 156 | ], 157 | dim=-1 158 | ), 159 | torch.stack( 160 | [ 161 | torch.ceil(points[:, :, :, 0]), 162 | torch.ceil(points[:, :, :, 1]), 163 | torch.ceil(points[:, :, :, 2]) 164 | ], 165 | dim=-1 166 | ), 167 | ], dim=1) 168 | -------------------------------------------------------------------------------- /src/scripts/generate_fusion_data_icl_nuim.py: -------------------------------------------------------------------------------- 1 | """ 2 | What idr needs: 3 | cameras.npz 4 | scale_mat_{i} 5 | world_mat_{i} 6 | image 7 | mask 8 | """ 9 | import argparse 10 | import os 11 | import open3d as o3d 12 | import sys 13 | from tqdm import tqdm 14 | import trimesh 15 | from scipy.spatial.transform import Rotation 16 | import cv2 17 | import numpy as np 18 | 19 | # from src.datasets.scenenet import SceneNet 20 | from src.utils.geometry import get_homogeneous, depth2xyz 21 | import src.utils.scannet_helper as scannet_helper 22 | 23 | 24 | def make_dir(dir_): 25 | if not os.path.exists(dir_): 26 | os.makedirs(dir_) 27 | 28 | 29 | def read_cam_traj(path, n_imgs): 30 | T_wcs = [] 31 | start_line = 1 32 | end_line = 5 33 | with open(path, "r") as f: 34 | lines = f.read().splitlines() 35 | assert len(lines) / 5 == n_imgs 36 | for i in range(n_imgs): 37 | T_wc = np.asarray([line.split(" ") for line in lines[start_line:end_line]]).astype(np.float32) 38 | start_line += 5 39 | end_line += 5 40 | T_wcs.append(T_wc) 41 | return T_wcs 42 | 43 | 44 | arg_parser = argparse.ArgumentParser() 45 | arg_parser.add_argument("--noise", action="store_true") 46 | arg_parser.add_argument("--scan_id", required=True) 47 | args = arg_parser.parse_args() 48 | 49 | 50 | ROOT_DIR = "/home/kejie/repository/fast_sdf/data/icl_nuim/" 51 | RENDER_PATH = args.scan_id 52 | GT_MESH_PATH = os.path.join(ROOT_DIR, RENDER_PATH[:-1] + ".ply") 53 | # gt_mesh = o3d.io.read_triangle_mesh(GT_MESH_PATH) 54 | gt_mesh = trimesh.load(GT_MESH_PATH) 55 | max_pts = np.max(np.asarray(gt_mesh.vertices), axis=0) 56 | min_pts = np.min(np.asarray(gt_mesh.vertices), axis=0) 57 | center = (min_pts + max_pts) / 2 58 | dimensions = max_pts - min_pts 59 | axis_align_mat = np.eye(4) 60 | axis_align_mat[:3, 3] = -center 61 | 62 | out_base_dir = "/home/kejie/repository/fast_sdf/data/fusion/icl_nuim" 63 | DEPTH_SCALE = 1000. 64 | SKIP_IMAGES = 1 65 | intr_mat = np.eye(3) 66 | intr_mat[0, 0] = 525 67 | intr_mat[0, 2] = 319.5 68 | intr_mat[1, 1] = 525. 69 | intr_mat[1, 2] = 239.5 70 | 71 | if args.noise: 72 | out_dir = os.path.join(out_base_dir, f"{RENDER_PATH}_noise") 73 | else: 74 | out_dir = os.path.join(out_base_dir, f"{RENDER_PATH}") 75 | 76 | gt_mesh.vertices = (axis_align_mat @ get_homogeneous(np.asarray(gt_mesh.vertices)).T)[:3, :].T 77 | gt_mesh.export(os.path.join(out_dir, "gt_mesh.ply")) 78 | 79 | out_rgb_dir = os.path.join(out_dir, "image") 80 | out_mask_dir = os.path.join(out_dir, "mask") 81 | out_depth_dir = os.path.join(out_dir, "depth") 82 | out_pose_dir = os.path.join(out_dir, "pose") 83 | make_dir(out_dir) 84 | make_dir(out_rgb_dir) 85 | make_dir(out_mask_dir) 86 | make_dir(out_depth_dir) 87 | make_dir(out_pose_dir) 88 | 89 | seq_dir = os.path.join(ROOT_DIR, RENDER_PATH) 90 | img_dir = os.path.join(seq_dir, f"{RENDER_PATH}-color") 91 | if args.noise: 92 | depth_dir = os.path.join(seq_dir, f"{RENDER_PATH}-depth-simulated") 93 | else: 94 | depth_dir = os.path.join(seq_dir, f"{RENDER_PATH}-depth-clean") 95 | 96 | pose_path = os.path.join(seq_dir, "pose.txt") 97 | 98 | img_names = [f.split(".")[0] for f in os.listdir(img_dir)] 99 | img_names = sorted(img_names, key=lambda a: int(a)) 100 | n_imgs = len(img_names) 101 | 102 | T_wcs = read_cam_traj(pose_path, n_imgs) 103 | min_pts = [] 104 | max_pts = [] 105 | 106 | cameras = {} 107 | # get the 3D bounding box of the scene 108 | used_id = 0 109 | 110 | pts_o3d = [] 111 | cameras_new = {} 112 | for i in range(0, n_imgs, SKIP_IMAGES): 113 | rgb = cv2.imread( 114 | os.path.join(img_dir, img_names[i] + ".jpg"), -1)[:, :, ::-1] 115 | depth = cv2.imread( 116 | os.path.join(depth_dir, img_names[i] + ".png"), -1) / DEPTH_SCALE 117 | mask = (depth > 0).astype(np.float32) 118 | img_h, img_w = mask.shape 119 | y, x = np.nonzero(depth) 120 | valid_pixels = np.stack([x, y], axis=-1) 121 | img_h, img_w = depth.shape 122 | rgb = cv2.resize(rgb, (img_w, img_h)) 123 | 124 | T_wc = T_wcs[i] 125 | T_wc = axis_align_mat @ T_wc 126 | cameras['intr_mat_%d'%used_id] = intr_mat 127 | cameras['T_wc_%d'%used_id] = T_wc 128 | 129 | pts_c = depth2xyz(depth, intr_mat) 130 | pts_c = pts_c[valid_pixels[:, 1], valid_pixels[:, 0], :].reshape(-1, 3) 131 | pts_w = (T_wc @ get_homogeneous(pts_c).T)[:3, :].T 132 | _min = np.min(pts_w, axis=0) 133 | _max = np.max(pts_w, axis=0) 134 | min_pts.append(_min) 135 | max_pts.append(_max) 136 | 137 | out_rgb_path = os.path.join(out_rgb_dir, f"{used_id}.jpg") 138 | cv2.imwrite(out_rgb_path, rgb[:, :, ::-1]) 139 | out_mask_path = os.path.join(out_mask_dir, f"{used_id}.png") 140 | cv2.imwrite(out_mask_path, mask.astype(np.uint8)*255) 141 | out_depth_path = os.path.join(out_depth_dir, f"{used_id}.png") 142 | cv2.imwrite(out_depth_path, (depth * 1000).astype(np.uint16)) 143 | 144 | extr_path = os.path.join(out_pose_dir, f"T_wc_{used_id}.txt") 145 | with open(extr_path, "w") as f: 146 | f.write(" ".join([str(t) for t in T_wc.reshape(-1)])) 147 | cameras_new['T_wc_%d'%used_id] = T_wc 148 | cameras_new['intr_mat_%d'%used_id] = cameras['intr_mat_%d'%used_id] 149 | intr_path = os.path.join(out_pose_dir, f"intr_mat_{used_id}.txt") 150 | with open(intr_path, "w") as f: 151 | f.write(" ".join([str(t) for t in intr_mat.reshape(-1)])) 152 | used_id += 1 153 | 154 | cameras_new['dimensions'] = dimensions 155 | np.savez('{0}/{1}.npz'.format(out_dir, "cameras"), **cameras_new) 156 | 157 | dimension_path = os.path.join(out_pose_dir, "dimensions.txt") 158 | with open(dimension_path, "w") as f: 159 | f.write(" ".join([str(t) for t in dimensions.reshape(-1)])) -------------------------------------------------------------------------------- /src/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import matplotlib.cm 4 | from src.utils.motion_utils import Isometry 5 | 6 | 7 | def pointcloud(pc, color: np.ndarray = None, normal: np.ndarray = None): 8 | if isinstance(pc, o3d.geometry.PointCloud): 9 | if pc.has_normals() and normal is None: 10 | normal = np.asarray(pc.normals) 11 | if pc.has_colors() and color is None: 12 | color = np.asarray(pc.colors) 13 | pc = np.asarray(pc.points) 14 | 15 | assert pc.shape[1] == 3 and len(pc.shape) == 2, f"Point cloud is of size {pc.shape} and cannot be displayed!" 16 | point_cloud = o3d.geometry.PointCloud() 17 | point_cloud.points = o3d.utility.Vector3dVector(pc) 18 | if color is not None: 19 | assert color.shape[0] == pc.shape[0], f"Point and color must have same size {color.shape[0]}, {pc.shape[0]}" 20 | point_cloud.colors = o3d.utility.Vector3dVector(color) 21 | if normal is not None: 22 | point_cloud.normals = o3d.utility.Vector3dVector(normal) 23 | 24 | return point_cloud 25 | 26 | 27 | def frame(transform: Isometry = Isometry(), size=1.0): 28 | frame_obj = o3d.geometry.TriangleMesh.create_coordinate_frame(size=size) 29 | frame_obj.transform(transform.matrix) 30 | return frame_obj 31 | 32 | 33 | def merged_linesets(lineset_list: list): 34 | merged_points = [] 35 | merged_inds = [] 36 | merged_colors = [] 37 | point_acc_ind = 0 38 | for ls in lineset_list: 39 | merged_points.append(np.asarray(ls.points)) 40 | merged_inds.append(np.asarray(ls.lines) + point_acc_ind) 41 | if ls.has_colors(): 42 | merged_colors.append(np.asarray(ls.colors)) 43 | else: 44 | merged_colors.append(np.zeros((len(ls.lines), 3))) 45 | point_acc_ind += len(ls.points) 46 | 47 | geom = o3d.geometry.LineSet( 48 | points=o3d.utility.Vector3dVector(np.vstack(merged_points)), 49 | lines=o3d.utility.Vector2iVector(np.vstack(merged_inds)) 50 | ) 51 | geom.colors = o3d.utility.Vector3dVector(np.vstack(merged_colors)) 52 | return geom 53 | 54 | 55 | def trajectory(traj1: list, traj2: list = None, ucid: int = -1): 56 | if len(traj1) > 0 and isinstance(traj1[0], Isometry): 57 | traj1 = [t.t for t in traj1] 58 | if traj2 and isinstance(traj2[0], Isometry): 59 | traj2 = [t.t for t in traj2] 60 | 61 | traj1_lineset = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(np.asarray(traj1)), 62 | lines=o3d.utility.Vector2iVector(np.vstack((np.arange(0, len(traj1) - 1), 63 | np.arange(1, len(traj1)))).T)) 64 | if ucid != -1: 65 | color_map = np.asarray(matplotlib.cm.get_cmap('tab10').colors) 66 | traj1_lineset.paint_uniform_color(color_map[ucid % 10]) 67 | 68 | if traj2 is not None: 69 | assert len(traj1) == len(traj2) 70 | traj2_lineset = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(np.asarray(traj2)), 71 | lines=o3d.utility.Vector2iVector(np.vstack((np.arange(0, len(traj2) - 1), 72 | np.arange(1, len(traj2)))).T)) 73 | traj_diff = o3d.geometry.LineSet( 74 | points=o3d.utility.Vector3dVector(np.vstack((np.asarray(traj1), np.asarray(traj2)))), 75 | lines=o3d.utility.Vector2iVector(np.arange(2 * len(traj1)).reshape((2, len(traj1))).T)) 76 | traj_diff.colors = o3d.utility.Vector3dVector(np.array([[1.0, 0.0, 0.0]]).repeat(len(traj_diff.lines), axis=0)) 77 | 78 | traj1_lineset = merged_linesets([traj1_lineset, traj2_lineset, traj_diff]) 79 | return traj1_lineset 80 | 81 | 82 | def camera(T_wc, wh_ratio: float = 4.0 / 3.0, scale: float = 1.0, fovx: float = 90.0, 83 | color_id: int = -1): 84 | pw = np.tan(np.deg2rad(fovx / 2.)) * scale 85 | ph = pw / wh_ratio 86 | all_points = np.asarray([ 87 | [0.0, 0.0, 0.0], 88 | [pw, ph, scale], 89 | [pw, -ph, scale], 90 | [-pw, ph, scale], 91 | [-pw, -ph, scale], 92 | ]) 93 | line_indices = np.asarray([ 94 | [0, 1], [0, 2], [0, 3], [0, 4], 95 | [1, 2], [1, 3], [3, 4], [2, 4] 96 | ]) 97 | geom = o3d.geometry.LineSet( 98 | points=o3d.utility.Vector3dVector(all_points), 99 | lines=o3d.utility.Vector2iVector(line_indices)) 100 | 101 | if color_id == -1: 102 | my_color = np.zeros((3,)) 103 | else: 104 | my_color = np.asarray(matplotlib.cm.get_cmap('tab10').colors)[color_id, :3] 105 | geom.colors = o3d.utility.Vector3dVector(np.repeat(np.expand_dims(my_color, 0), line_indices.shape[0], 0)) 106 | 107 | geom.transform(T_wc) 108 | return geom 109 | 110 | 111 | def wireframe_bbox(extent_min=None, extent_max=None, color_id=-1): 112 | if extent_min is None: 113 | extent_min = [0.0, 0.0, 0.0] 114 | if extent_max is None: 115 | extent_max = [1.0, 1.0, 1.0] 116 | 117 | if color_id == -1: 118 | my_color = np.zeros((3,)) 119 | else: 120 | my_color = np.asarray(matplotlib.cm.get_cmap('tab10').colors)[color_id, :3] 121 | 122 | all_points = np.asarray([ 123 | [extent_min[0], extent_min[1], extent_min[2]], 124 | [extent_min[0], extent_min[1], extent_max[2]], 125 | [extent_min[0], extent_max[1], extent_min[2]], 126 | [extent_min[0], extent_max[1], extent_max[2]], 127 | [extent_max[0], extent_min[1], extent_min[2]], 128 | [extent_max[0], extent_min[1], extent_max[2]], 129 | [extent_max[0], extent_max[1], extent_min[2]], 130 | [extent_max[0], extent_max[1], extent_max[2]], 131 | ]) 132 | line_indices = np.asarray([ 133 | [0, 1], [2, 3], [4, 5], [6, 7], 134 | [0, 4], [1, 5], [2, 6], [3, 7], 135 | [0, 2], [4, 6], [1, 3], [5, 7] 136 | ]) 137 | geom = o3d.geometry.LineSet( 138 | points=o3d.utility.Vector3dVector(all_points), 139 | lines=o3d.utility.Vector2iVector(line_indices)) 140 | geom.colors = o3d.utility.Vector3dVector(np.repeat(np.expand_dims(my_color, 0), line_indices.shape[0], 0)) 141 | 142 | return geom 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

BNV-Fusion: Dense 3D Reconstruction using Bi-level Neural Volume Fusion

4 |

5 | Kejie Li 6 | ~ 7 | Yansong Tang 8 | ~ 9 | Victor Adrian Prisacariu 10 | ~ 11 | Philip H.S. Torr 12 |

13 |

14 | 15 | ## BNV-Fusion ([Video](https://www.youtube.com/watch?v=ptx5vtQ9SvM) | [Paper](https://arxiv.org/pdf/2204.01139.pdf)) 16 | 17 | This repo implements the CVPR 2022 paper [Bi-level Neural Volume Fusion (BNV-Fusion)](https://arxiv.org/abs/2204.01139). BNV-Fusion leverages recent advances in neural implicit representations and neural rendering for dense 3D reconstruction. The keys to BNV-Fusion are 1) a sparse voxel grid of local shape codes to model surface geometry; 2) a well-designed bi-level fusion mechanism to integrate raw depth observations to the implicit grid efficiently and effectively. As a result, BNV-Fusin can run at a relatively **high frame rate** (2-5 frames per second on a desktop GPU) and reconstruct the 3D environment with **high accuracy**, where fine details missed by recent neural implicit based methods or traditional TSDF-Fusion are captured by BNV-Fusion. 18 | 19 | ## Requirements 20 | 21 | Setup anaconda environment using the following command: 22 | 23 | ` 24 | conda env create -f environment.yml -p CONDA_DIR/envs/bnv_fusion (CONDA_DIR is the folder where anaconda is installed) 25 | ` 26 | 27 | You will need to the [torch-scatter](https://github.com/rusty1s/pytorch_scatter) additionally since conda doesn't seem to handle this package particularly well. 28 | 29 | 30 | Alternatively, you can build a docker image using the DockerFile provided (Work in progress. We can't get the Open3D working within the docker image. Any help is appreciated!). 31 | 32 | [IMPORTANT] Setup the PYTHONPATH before running the code: 33 | 34 | ` 35 | export PYTHONPATH=$PYTHONPATH:$PWD 36 | ` 37 | 38 | If you don't want to run this command everytime using a new terminal, you can also setup an alias in Bash to setup PYTHONPATH and activate the environment at one go as follows: 39 | 40 | ` 41 | alias bnv_fusion="export PYTHONPATH=$PYTHONPATH:PROJECT_DIR;conda activate bnv_fusion" 42 | ` 43 | 44 | PROJECT_DIR is the root directory of this repo. 45 | 46 | 47 | **New: Running with sequences captured by iPhone/iPad** 48 | ------ 49 | We are happy to share that you can run BNV-Fusion reasonably easily on any sequences you captured using an iOS device with a lidar sensor (e.g., iPhone 12/13 Pro, iPad Pro). The instructions are as follows: 50 | 1. Download the [3D scanner app](https://apps.apple.com/us/app/3d-scanner-app/id1419913995) to an iOS device. 51 | 2. You can then capture a sequence using this app. 52 | 3. After recoding, you need to transfer the raw data (e.g., depth images, camera poses) to a desktop with a GPU. To do this, tap "scans" at the bottom left of the app. Select "Share Model" after clicking the "..." button. There are various formats you can use to share the data, but what we need is the raw data, so select "All Data". You can then choose your favorite way, such as google drive, for sharing. 53 | 4. After you unpack the data at your desktop, run BNV-Fusion using the following command: 54 | ``` 55 | python src/run_e2e.py model=fusion_pointnet_model dataset=fusion_inference_dataset_arkit trainer.checkpoint=$PWD/pretrained/pointnet_tcnn.ckpt 'dataset.scan_id="xxxxx"' dataset.data_dir=yyyyy model.tcnn_config=$PWD/src/models/tcnn_config.json 56 | ``` 57 | Obviously, you need to specify the scan_id and where you hold the data. You should be able to see the reconstruction provided by BNV-Fusion after this step. Hope you have fun! 58 | 59 | 60 | ## Datasets and pretrained models 61 | We tested BVF-Fusion on three datasets: 3D scene, ICL-NUIM, and ScanNet. Please go to the respective dataset repos to download data. 62 | After downloading the data, run preprocessing scripts: 63 | ``` 64 | python src/script/generate_fusion_data_{DATASET}.py 65 | ``` 66 | 67 | Instead of downloading the those datasets, we also provide some preprocessed data for one of the sequences in 3D scene in this link for quickly trying out BNV-Fusion. We can download the preprocessed data [here](https://drive.google.com/file/d/1nmdkK-mMpxebAO1MriCD_UbpLwbXYxah/view?usp=sharing). 68 | 69 | You can also run the following command at the project root dir: 70 | ``` 71 | mkdir -p data/fusion/scene3d 72 | cd data/fusion/scene3d/ 73 | pip install gdown (if gdown was not installed) 74 | gdown https://drive.google.com/uc?id=1nmdkK-mMpxebAO1MriCD_UbpLwbXYxah 75 | unzip lounge.zip && rm lounge.zip 76 | ``` 77 | 78 | 79 | ## Running 80 | 86 | 87 | To process a sequence, use the following command: 88 | ``` 89 | python src/run_e2e.py model=fusion_pointnet_model dataset=fusion_inference_dataset dataset.scan_id="scene3d/lounge" trainer.checkpoint=$PWD/pretrained/pointnet_tcnn.ckpt model.tcnn_config=$PWD/src/models/tcnn_config.json model.mode="demo" 90 | ``` 91 | 92 | ## Evaluation 93 | The results and GT meshes are availalbe here: https://drive.google.com/drive/folders/1gzsOIuCrj7ydX2-XXULQ61KjtITipYb5?usp=sharing 94 | 95 | After downloading the data, you can run evaluation using the ```evaluate_bnvf.py```. 96 | 97 | ## Training the embedding (optional) 98 | Instead of using the pretrained model provided, you can also train the local embedding yourself by running the following command 99 | ``` 100 | python src/train.py model=fusion_pointnet_modeldataset=fusion_pointnet_dataset model.voxel_size=0.01 model.min_pts_in_grid=8 model.train_ray_splits=1000 model.tcnn_config=$PWD/src/models/tcnn_config.json 101 | ``` 102 | 103 | ## FAQ 104 | - **Do I have to have a rough mesh, as requested [here](https://github.com/likojack/bnv_fusion/blob/9178e8c36743d6bf9a7828087553d365f50a6d7f/src/datasets/fusion_inference_dataset.py#L253), when running with my own data?** 105 | 106 | No, We only use the mesh to determin the dimensions of the sceen to be reconstructed. You can manually set the boundary if you know the dimensions. 107 | 108 | - **How to set an appropriate voxel size?** 109 | 110 | The reconstruction quality apparently depends on the voxel size. If the voxel size is too small, there won't be enough points within each local region for the local embedding. If it is too large, the system fail to recover fine details. Therefore, we select the ideal voxel size based on the number of 3D points in a voxel. You will get a statistic on 3D points used in the local embedding after running system (see [here](https://github.com/likojack/bnv_fusion/blob/9178e8c36743d6bf9a7828087553d365f50a6d7f/src/models/sparse_volume.py#L515)). Empirically, we found out that the voxel size satisfying the following requirements gives better results: 1) the ```min``` is larger than 4, and 2) the ```mean``` is ideally larger than 8. 111 | 112 | ## Citation 113 | If you find our code or paper useful, please cite 114 | ```bibtex 115 | @inproceedings{li2022bnv, 116 | author = {Li, Kejie and Tang, Yansong and Prisacariu, Victor Adrian and Torr, Philip HS}, 117 | title = {BNV-Fusion: Dense 3D Reconstruction using Bi-level Neural Volume Fusion}, 118 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 119 | year = {2022} 120 | } 121 | ``` 122 | -------------------------------------------------------------------------------- /src/datasets/fusion_pointnet_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import pickle 6 | import trimesh 7 | from kornia.geometry.depth import depth_to_normals, depth_to_3d 8 | 9 | from src.datasets import register 10 | import src.utils.geometry as geometry 11 | from src.utils.shapenet_helper import read_pose 12 | 13 | 14 | @register("fusion_pointnet_dataset") 15 | class FusionPointNetDataset(torch.utils.data.Dataset): 16 | def __init__(self, cfg, stage): 17 | super().__init__() 18 | self.n_local_samples = cfg.dataset.n_local_samples 19 | self.voxel_size = cfg.model.voxel_size 20 | self.stage = stage 21 | self.max_depth = cfg.model.ray_tracer.ray_max_dist 22 | self.data_root_dir = os.path.join( 23 | cfg.dataset.data_dir, 24 | cfg.dataset.subdomain, 25 | ) 26 | # cats = ["03001627", "03636649"] 27 | cats = ["03001627_noise", "03636649_noise"] 28 | 29 | seq_dirs = [os.path.join(self.data_root_dir, f) for f in cats] 30 | file_paths = [] 31 | 32 | if stage == "test": 33 | self.data_root_dir = f"/home/kejie/repository/fast_sdf/data/rendering/03636649" 34 | seqs = sorted(os.listdir(self.data_root_dir))[:10] 35 | seqs = ["1a5ebc8575a4e5edcc901650bbbbb0b5"] 36 | for seq in seqs: 37 | file_names = os.listdir(os.path.join(self.data_root_dir, seq)) 38 | file_paths.append( 39 | [os.path.join(self.data_root_dir, seq, f) for f in file_names]) 40 | elif stage == "val": 41 | for seq_dir in seq_dirs: 42 | seqs = sorted(os.listdir(seq_dir))[:10] 43 | for seq in seqs: 44 | file_names = os.listdir(os.path.join(seq_dir, seq)) 45 | file_names = sorted(file_names) 46 | file_paths.append( 47 | [os.path.join(seq_dir, seq, f) for f in file_names]) 48 | else: # stage == "train" 49 | for seq_dir in seq_dirs: 50 | seqs = sorted(os.listdir(seq_dir))[10:] 51 | for seq in seqs: 52 | file_names = os.listdir(os.path.join(seq_dir, seq)) 53 | file_names = sorted(file_names) 54 | file_paths.extend( 55 | [os.path.join(seq_dir, seq, f) for f in file_names]) 56 | self.file_paths = file_paths 57 | 58 | def __len__(self): 59 | return len(self.file_paths) 60 | 61 | def _resize_input_pts(self, pts): 62 | """ pts: [N, 3] 63 | """ 64 | pts = torch.from_numpy(np.asarray(pts)).float() 65 | if len(pts) < self.n_local_samples: 66 | inds = torch.randint(0, len(pts), size=(self.n_local_samples,)) 67 | pts = pts[inds] 68 | pts = pts.numpy() 69 | pts = np.random.permutation(pts)[:self.n_local_samples] 70 | return pts 71 | 72 | def __getitem__(self, idx): 73 | if self.stage == "test": 74 | max_depth = self.max_depth 75 | img_path = self.file_paths[idx][0] 76 | img_name = img_path.split("/")[-1].split(".")[0] 77 | instance_name = img_path.split("/")[-2] 78 | T_wc, intr_mat = read_pose(img_name) 79 | gt_depth = cv2.imread(img_path, -1) / 5000. 80 | mask = np.logical_and((gt_depth > 0), (gt_depth < max_depth)).reshape(-1) 81 | gt_depth[gt_depth > max_depth] = 0 82 | # compute the gt normal 83 | gt_normal = depth_to_normals( 84 | torch.from_numpy(gt_depth).unsqueeze(0).unsqueeze(0), 85 | torch.from_numpy(intr_mat).unsqueeze(0) 86 | )[0].permute(1, 2, 0).numpy() 87 | 88 | gt_xyz_map = depth_to_3d( 89 | torch.from_numpy(gt_depth).unsqueeze(0).unsqueeze(0), 90 | torch.from_numpy(intr_mat).unsqueeze(0) 91 | )[0].permute(1, 2, 0).numpy() 92 | 93 | gt_pts_c = gt_xyz_map.reshape(-1, 3)[mask] 94 | gt_pts_c[:, 2] *= -1 95 | gt_pts_c[:, 1] *= -1 96 | gt_pts_w = (T_wc @ geometry.get_homogeneous(gt_pts_c).T)[:3, :].T 97 | min_ = np.min(gt_pts_w, axis=0) 98 | max_ = np.max(gt_pts_w, axis=0) 99 | center = (min_ + max_) / 2 100 | gt_pts_w = gt_pts_w - center[None, :] 101 | gt_normal_w = (T_wc[:3, :3] @ gt_normal.reshape(-1, 3).T).T 102 | gt_normal_w = gt_normal_w[mask] 103 | gt_normal_w = gt_normal_w * -1 104 | input_pts = np.concatenate( 105 | [gt_pts_w, gt_normal_w], 106 | axis=-1 107 | ) 108 | min_ = np.min(gt_pts_w, axis=0) - 2 * self.voxel_size 109 | max_ = np.max(gt_pts_w, axis=0) + 2 * self.voxel_size 110 | n_xyz = np.ceil((max_ - min_) / self.voxel_size).astype(int).tolist() 111 | bound_min = min_ 112 | bound_max = bound_min + self.voxel_size * np.asarray(n_xyz) 113 | # input_pts = [] 114 | # sample_centers = [] 115 | # sample_ids = np.random.permutation( 116 | # np.arange(len(gt_pts_w)))[:500] 117 | # for ind, sample_id in enumerate(sample_ids): 118 | # dist = np.sqrt(np.sum( 119 | # (gt_pts_w[sample_id:sample_id+1] - gt_pts_w) ** 2, 120 | # axis=-1) 121 | # ) 122 | # valid_neighbors = dist < self.voxel_size * 2 123 | # neighbor_ids = np.random.permutation( 124 | # np.nonzero(valid_neighbors)[0] 125 | # )[:128] 126 | # neighbor_pts = gt_pts_w[neighbor_ids] 127 | # neighbor_normals = gt_normal_w[neighbor_ids] 128 | # sample_pts = np.concatenate( 129 | # [neighbor_pts, neighbor_normals], 130 | # axis=-1 131 | # ) 132 | # sample_pts = self._resize_input_pts(sample_pts) 133 | # sample_center = gt_pts_w[sample_id:sample_id+1] 134 | # sample_pts[:, :3] = sample_pts[:, :3] - sample_center 135 | # input_pts.append(sample_pts) 136 | # sample_centers.append(sample_center) 137 | # input_pts = np.stack(input_pts, axis=0) 138 | # sample_centers = np.concatenate(sample_centers, axis=0) 139 | 140 | return { 141 | "scene_id": instance_name, 142 | "input_pts": input_pts, 143 | # "sample_center": sample_centers, 144 | "bound_min": bound_min, 145 | "bound_max": bound_max, 146 | } 147 | elif self.stage == "val": 148 | file_paths = self.file_paths[idx] 149 | # randomly select 500 local patches of a shape 150 | sample_paths = np.random.permutation(file_paths)[:500] 151 | batch_input_pts = [] 152 | batch_center = [] 153 | batch_training_pts = [] 154 | batch_gt = [] 155 | for p in sample_paths: 156 | with open(p, "rb") as f: 157 | data = pickle.load(f) 158 | if len(data['input_pts']) < 16: 159 | continue 160 | input_pts = self._resize_input_pts(data['input_pts']) 161 | batch_input_pts.append(input_pts) 162 | batch_center.append(data['center'][0]) 163 | batch_training_pts.append(np.asarray(data['training_pts'])) 164 | batch_gt.append(np.asarray(data['gt_sdf'])) 165 | batch_input_pts = np.stack(batch_input_pts, axis=0) 166 | batch_center = np.stack(batch_center, axis=0) 167 | batch_training_pts = np.stack(batch_training_pts, axis=0) 168 | batch_gt = np.stack(batch_gt, axis=0) 169 | data = { 170 | "scene_id": p.split("/")[-2], 171 | "input_pts": batch_input_pts, 172 | 'sample_center': batch_center, 173 | "training_pts": batch_training_pts, 174 | "gt": batch_gt 175 | } 176 | else: 177 | file_path = self.file_paths[idx] 178 | with open(file_path, "rb") as f: 179 | data = pickle.load(f) 180 | input_pts = self._resize_input_pts(data['input_pts']) 181 | # DEBUG: 182 | # import open3d as o3d 183 | # import src.utils.o3d_helper as o3d_helper 184 | # points = data['training_pts'] 185 | # sdfs = data['gt_sdf'] 186 | # surface_pts = data['input_pts'] 187 | # visual_list = [ 188 | # o3d_helper.np2pc(surface_pts[:, :3], surface_pts[:, -3:] / 0.5 + 0.5) 189 | # ] 190 | # color = np.zeros_like(points) 191 | # color[sdfs > 0, 0] = 1 192 | # color[sdfs <= 0, 1] = 1 193 | # visual_list.append(o3d_helper.np2pc(points, color)) 194 | # o3d.visualization.draw_geometries(visual_list) 195 | data = { 196 | "scene_id": file_path.split("/")[-1], 197 | "input_pts": input_pts, 198 | 'sample_center': data['center'][0], 199 | "training_pts": np.asarray(data['training_pts']), 200 | "gt": np.asarray(data['gt_sdf']) 201 | } 202 | 203 | return data 204 | -------------------------------------------------------------------------------- /src/utils/o3d_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import open3d.core as o3c 4 | import trimesh 5 | import torch 6 | import os 7 | 8 | 9 | from src.utils.geometry import get_homogeneous 10 | 11 | 12 | def align_vector_to_another(a=np.array([0, 0, 1]), b=np.array([1, 0, 0])): 13 | """ 14 | Aligns vector a to vector b with axis angle rotation 15 | """ 16 | if np.array_equal(a, b): 17 | return None, None 18 | if np.sum(b + a) == 0: # if b is possite to a 19 | b += 1e-3 20 | axis_ = np.cross(a, b) 21 | axis_ = axis_ / (np.linalg.norm(axis_)) 22 | angle = np.arccos(np.dot(a, b)) 23 | return axis_, angle 24 | 25 | 26 | def normalized(a, axis=-1, order=2): 27 | """Normalizes a numpy array of points""" 28 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 29 | l2[l2 == 0] = 1 30 | return a / np.expand_dims(l2, axis), l2 31 | 32 | 33 | class LineMesh(object): 34 | def __init__(self, points, lines=None, colors=[0, 1, 0], radius=0.15): 35 | """Creates a line represented as sequence of cylinder triangular meshes 36 | 37 | Arguments: 38 | points {ndarray} -- Numpy array of ponts Nx3. 39 | 40 | Keyword Arguments: 41 | lines {list[list] or None} -- List of point index pairs denoting 42 | line segments. If None, implicit lines from ordered pairwise 43 | points. (default: {None}) 44 | colors {list} -- list of colors, or single color of the line 45 | (default: {[0, 1, 0]}) 46 | radius {float} -- radius of cylinder (default: {0.15}) 47 | """ 48 | self.points = np.array(points) 49 | self.lines = np.array(lines) if lines is not None else \ 50 | self.lines_from_ordered_points(self.points) 51 | self.colors = np.array(colors) 52 | self.radius = radius 53 | self.cylinder_segments = [] 54 | 55 | self.create_line_mesh() 56 | 57 | @staticmethod 58 | def lines_from_ordered_points(points): 59 | lines = [[i, i + 1] for i in range(0, points.shape[0] - 1, 1)] 60 | return np.array(lines) 61 | 62 | def create_line_mesh(self): 63 | first_points = self.points[self.lines[:, 0], :] 64 | second_points = self.points[self.lines[:, 1], :] 65 | line_segments = second_points - first_points 66 | line_segments_unit, line_lengths = normalized(line_segments) 67 | 68 | z_axis = np.array([0, 0, 1]) 69 | # Create triangular mesh cylinder segments of line 70 | for i in range(line_segments_unit.shape[0]): 71 | line_segment = line_segments_unit[i, :] 72 | line_length = line_lengths[i] 73 | # get axis angle rotation to allign cylinder with line segment 74 | axis, angle = align_vector_to_another(z_axis, line_segment) 75 | # Get translation vector 76 | translation = first_points[i, :] + line_segment * line_length * 0.5 77 | # create cylinder and apply transformations 78 | cylinder_segment = o3d.geometry.TriangleMesh.create_cylinder( 79 | self.radius, line_length) 80 | cylinder_segment = cylinder_segment.translate( 81 | translation, relative=False) 82 | if axis is not None: 83 | axis_a = axis * angle 84 | cylinder_segment = cylinder_segment.rotate( 85 | R=o3d.geometry.get_rotation_matrix_from_axis_angle(axis_a), 86 | center=True 87 | ) 88 | # color cylinder 89 | color = self.colors if self.colors.ndim == 1 else self.colors[i, :] 90 | cylinder_segment.paint_uniform_color(color) 91 | 92 | self.cylinder_segments.append(cylinder_segment) 93 | 94 | def add_line(self, vis): 95 | """Adds this line to the visualizer""" 96 | for cylinder in self.cylinder_segments: 97 | vis.add_geometry(cylinder) 98 | 99 | def remove_line(self, vis): 100 | """Removes this line from the visualizer""" 101 | for cylinder in self.cylinder_segments: 102 | vis.remove_geometry(cylinder) 103 | 104 | 105 | def lineset_from_pc(point_cloud, colors, orders=None): 106 | """ open3d lineset from numpy point cloud 107 | 108 | Args: 109 | point_cloud ([N, 3] np.ndarray): corner points of a 3D bounding box 110 | colors ([1, 3] np.ndarray): color of the lineset 111 | orders (): reorder the point cloud to build a valid 3D bbox 112 | 113 | Returns: 114 | line_set (open3d.geometry.Lineset) 115 | """ 116 | # vertex order is consistent with get_corner_pts() in Object class 117 | if orders is None: 118 | lines = [ 119 | [0, 1], 120 | [1, 2], 121 | [2, 3], 122 | [3, 0], 123 | [4, 5], 124 | [5, 6], 125 | [6, 7], 126 | [7, 4], 127 | [0, 4], 128 | [1, 5], 129 | [2, 6], 130 | [3, 7], 131 | ] 132 | else: 133 | lines = orders 134 | colors_tmp = np.zeros((len(lines), 3)) 135 | colors_tmp += colors 136 | line_set = o3d.geometry.LineSet( 137 | points=o3d.utility.Vector3dVector(point_cloud), 138 | lines=o3d.utility.Vector2iVector(lines), 139 | ) 140 | line_set.colors = o3d.utility.Vector3dVector(colors_tmp) 141 | return line_set 142 | 143 | 144 | def linemesh_from_pc(point_cloud, colors, orders=None): 145 | if orders is None: 146 | lines = [ 147 | [0, 1], 148 | [1, 2], 149 | [2, 3], 150 | [3, 0], 151 | [4, 5], 152 | [5, 6], 153 | [6, 7], 154 | [7, 4], 155 | [0, 4], 156 | [1, 5], 157 | [2, 6], 158 | [3, 7], 159 | ] 160 | else: 161 | lines = orders 162 | 163 | colors_tmp = np.zeros((len(lines), 3)) 164 | colors_tmp += colors 165 | 166 | line_mesh = LineMesh(point_cloud, lines, colors_tmp, radius=0.02) 167 | return line_mesh.cylinder_segments 168 | 169 | 170 | def load_scene_mesh(path, trans_mat=None, open_3d=True): 171 | scene_mesh = trimesh.load(path) 172 | if trans_mat is not None: 173 | scene_mesh.vertices = np.dot(get_homogeneous( 174 | scene_mesh.vertices), trans_mat.T)[:, :3] 175 | if open_3d: 176 | scene_mesh_o3d = trimesh2o3d(scene_mesh) 177 | return scene_mesh_o3d 178 | else: 179 | return scene_mesh 180 | 181 | 182 | def trimesh2o3d(mesh): 183 | mesh_o3d = o3d.geometry.TriangleMesh() 184 | mesh_o3d.vertices = o3d.utility.Vector3dVector(mesh.vertices) 185 | mesh_o3d.triangles = o3d.utility.Vector3iVector(mesh.faces) 186 | mesh_o3d.compute_vertex_normals() 187 | if mesh.visual.vertex_colors is not None: 188 | mesh_o3d.vertex_colors = o3d.utility.Vector3dVector( 189 | mesh.visual.vertex_colors[:, :3] / 255. 190 | ) 191 | return mesh_o3d 192 | 193 | 194 | def np2pc(points, colors=None): 195 | """ convert numpy colors point cloud to o3d point cloud 196 | 197 | Args: 198 | points (np.ndarray): [n_pts, 3] 199 | colors (np.ndarray): [n_pts, 3] 200 | Return: 201 | pts_o3d (o3d.geometry.PointCloud) 202 | """ 203 | pts_o3d = o3d.geometry.PointCloud() 204 | pts_o3d.points = o3d.utility.Vector3dVector(points) 205 | if colors is not None: 206 | pts_o3d.colors = o3d.utility.Vector3dVector(colors) 207 | return pts_o3d 208 | 209 | 210 | def mesh2o3d(vertices, faces, normals=None, colors=None): 211 | mesh = trimesh.Trimesh( 212 | vertices=vertices, 213 | faces=faces, 214 | vertex_normals=normals, 215 | vertex_colors=colors 216 | ) 217 | return trimesh2o3d(mesh) 218 | 219 | 220 | def post_process_mesh(mesh, vertex_threshold=0.005, surface_threshold=0.1): 221 | """ merge close vertices and remove small connected components 222 | 223 | Args: 224 | mesh (trimesh.Trimesh): input trimesh 225 | 226 | Returns: 227 | _type_: _description_ 228 | """ 229 | mesh_o3d = trimesh2o3d(mesh) 230 | mesh_o3d.merge_close_vertices(vertex_threshold).remove_degenerate_triangles().remove_duplicated_triangles().remove_duplicated_vertices() 231 | mesh_o3d = mesh_o3d.filter_smooth_simple(number_of_iterations=1) 232 | # component_ids, component_nums, component_surfaces = mesh_o3d.cluster_connected_triangles() 233 | # remove_componenets = np.asarray(component_nums)[np.asarray(component_surfaces) < surface_threshold] 234 | # remove_mask = [c in remove_componenets for c in component_ids] 235 | # mesh_o3d.remove_triangles_by_mask(remove_mask) 236 | # mesh_o3d.remove_unreferenced_vertices() 237 | mesh = trimesh.Trimesh( 238 | vertices=np.asarray(mesh_o3d.vertices), 239 | faces=np.asarray(mesh_o3d.triangles), 240 | ) 241 | return mesh 242 | 243 | class TSDFFusion: 244 | def __init__(self, voxel_size=0.01): 245 | self.volume = o3d.pipelines.integration.ScalableTSDFVolume( 246 | voxel_length=voxel_size, 247 | sdf_trunc=voxel_size*5, 248 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8) 249 | 250 | def integrate(self, depth, color, T_wc, intr_mat): 251 | """integrate new RGBD frame 252 | 253 | Args: 254 | depth (np.ndarray): [h,w] in range[0,255] 255 | color (np.ndarray): [h,w,3] in meters 256 | T_wc (np.ndarray): [4,4] 257 | intr_mat (np.ndarray): [3,3] or [4,4] 258 | 259 | """ 260 | img_h, img_w = depth.shape 261 | color = o3d.geometry.Image(color.astype(np.uint8)) 262 | depth = o3d.geometry.Image((depth * 1000).astype(np.uint16)) 263 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 264 | color, depth, depth_trunc=10000000.0, convert_rgb_to_intensity=False) 265 | intrinsic = o3d.camera.PinholeCameraIntrinsic() 266 | intrinsic.set_intrinsics( 267 | width=img_w, 268 | height=img_h, 269 | fx=intr_mat[0, 0], 270 | fy=intr_mat[1, 1], 271 | cx=intr_mat[0, 2], 272 | cy=intr_mat[1, 2], 273 | ) 274 | T_cw = np.linalg.inv(T_wc) 275 | self.volume.integrate(rgbd, intrinsic, T_cw) 276 | 277 | def marching_cube(self, path=None): 278 | print("Extract a triangle mesh from the volume and visualize it.") 279 | mesh_o3d = self.volume.extract_triangle_mesh() 280 | mesh_o3d.compute_vertex_normals() 281 | mesh = trimesh.Trimesh( 282 | vertices=np.asarray(mesh_o3d.vertices), # / dimension, 283 | faces=np.asarray(mesh_o3d.triangles), 284 | vertex_normals=np.asarray(mesh_o3d.vertex_normals) 285 | ) 286 | if path is not None: 287 | dir_ = "/".join(path.split("/")[:-1]) 288 | if not os.path.exists(dir_): 289 | os.mkdir(dir_) 290 | mesh.export(path) 291 | return mesh 292 | 293 | 294 | if __name__ == "__main__": 295 | mesh = trimesh.load("/home/kejie/repository/bnv_fusion/logs/run_e2e/scene0000_00/3999.ply") 296 | mesh = post_process_mesh(mesh) 297 | -------------------------------------------------------------------------------- /src/utils/motion_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyquaternion import Quaternion 3 | 4 | 5 | def so3_vee(Phi): 6 | if Phi.ndim < 3: 7 | Phi = np.expand_dims(Phi, axis=0) 8 | 9 | if Phi.shape[1:3] != (3, 3): 10 | raise ValueError("Phi must have shape ({},{}) or (N,{},{})".format(3, 3, 3, 3)) 11 | 12 | phi = np.empty([Phi.shape[0], 3]) 13 | phi[:, 0] = Phi[:, 2, 1] 14 | phi[:, 1] = Phi[:, 0, 2] 15 | phi[:, 2] = Phi[:, 1, 0] 16 | return np.squeeze(phi) 17 | 18 | 19 | def so3_wedge(phi): 20 | phi = np.atleast_2d(phi) 21 | if phi.shape[1] != 3: 22 | raise ValueError( 23 | "phi must have shape ({},) or (N,{})".format(3, 3)) 24 | 25 | Phi = np.zeros([phi.shape[0], 3, 3]) 26 | Phi[:, 0, 1] = -phi[:, 2] 27 | Phi[:, 1, 0] = phi[:, 2] 28 | Phi[:, 0, 2] = phi[:, 1] 29 | Phi[:, 2, 0] = -phi[:, 1] 30 | Phi[:, 1, 2] = -phi[:, 0] 31 | Phi[:, 2, 1] = phi[:, 0] 32 | return np.squeeze(Phi) 33 | 34 | 35 | def so3_log(matrix): 36 | cos_angle = 0.5 * np.trace(matrix) - 0.5 37 | cos_angle = np.clip(cos_angle, -1., 1.) 38 | angle = np.arccos(cos_angle) 39 | if np.isclose(angle, 0.): 40 | return so3_vee(matrix - np.identity(3)) 41 | else: 42 | return so3_vee((0.5 * angle / np.sin(angle)) * (matrix - matrix.T)) 43 | 44 | 45 | def so3_left_jacobian(phi): 46 | angle = np.linalg.norm(phi) 47 | 48 | if np.isclose(angle, 0.): 49 | return np.identity(3) + 0.5 * so3_wedge(phi) 50 | 51 | axis = phi / angle 52 | s = np.sin(angle) 53 | c = np.cos(angle) 54 | 55 | return (s / angle) * np.identity(3) + \ 56 | (1 - s / angle) * np.outer(axis, axis) + \ 57 | ((1 - c) / angle) * so3_wedge(axis) 58 | 59 | 60 | def se3_curlywedge(xi): 61 | xi = np.atleast_2d(xi) 62 | 63 | Psi = np.zeros([xi.shape[0], 6, 6]) 64 | Psi[:, 0:3, 0:3] = so3_wedge(xi[:, 3:6]) 65 | Psi[:, 0:3, 3:6] = so3_wedge(xi[:, 0:3]) 66 | Psi[:, 3:6, 3:6] = Psi[:, 0:3, 0:3] 67 | 68 | return np.squeeze(Psi) 69 | 70 | 71 | def se3_left_jacobian_Q_matrix(xi): 72 | rho = xi[0:3] # translation part 73 | phi = xi[3:6] # rotation part 74 | 75 | rx = so3_wedge(rho) 76 | px = so3_wedge(phi) 77 | 78 | ph = np.linalg.norm(phi) 79 | ph2 = ph * ph 80 | ph3 = ph2 * ph 81 | ph4 = ph3 * ph 82 | ph5 = ph4 * ph 83 | 84 | cph = np.cos(ph) 85 | sph = np.sin(ph) 86 | 87 | m1 = 0.5 88 | m2 = (ph - sph) / ph3 89 | m3 = (0.5 * ph2 + cph - 1.) / ph4 90 | m4 = (ph - 1.5 * sph + 0.5 * ph * cph) / ph5 91 | 92 | t1 = rx 93 | t2 = px.dot(rx) + rx.dot(px) + px.dot(rx).dot(px) 94 | t3 = px.dot(px).dot(rx) + rx.dot(px).dot(px) - 3. * px.dot(rx).dot(px) 95 | t4 = px.dot(rx).dot(px).dot(px) + px.dot(px).dot(rx).dot(px) 96 | 97 | return m1 * t1 + m2 * t2 + m3 * t3 + m4 * t4 98 | 99 | 100 | def se3_left_jacobian(xi): 101 | rho = xi[0:3] # translation part 102 | phi = xi[3:6] # rotation part 103 | 104 | # Near |phi|==0, use first order Taylor expansion 105 | if np.isclose(np.linalg.norm(phi), 0.): 106 | return np.identity(6) + 0.5 * se3_curlywedge(xi) 107 | 108 | so3_jac = so3_left_jacobian(phi) 109 | Q_mat = se3_left_jacobian_Q_matrix(xi) 110 | 111 | jac = np.zeros([6, 6]) 112 | jac[0:3, 0:3] = so3_jac 113 | jac[0:3, 3:6] = Q_mat 114 | jac[3:6, 3:6] = so3_jac 115 | 116 | return jac 117 | 118 | 119 | def se3_inv_left_jacobian(xi): 120 | rho = xi[0:3] # translation part 121 | phi = xi[3:6] # rotation part 122 | 123 | # Near |phi|==0, use first order Taylor expansion 124 | if np.isclose(np.linalg.norm(phi), 0.): 125 | return np.identity(6) - 0.5 * se3_curlywedge(xi) 126 | 127 | so3_inv_jac = so3_inv_left_jacobian(phi) 128 | Q_mat = se3_left_jacobian_Q_matrix(xi) 129 | 130 | jac = np.zeros([6, 6]) 131 | jac[0:3, 0:3] = so3_inv_jac 132 | jac[0:3, 3:6] = -so3_inv_jac.dot(Q_mat).dot(so3_inv_jac) 133 | jac[3:6, 3:6] = so3_inv_jac 134 | 135 | return jac 136 | 137 | 138 | def so3_inv_left_jacobian(phi): 139 | angle = np.linalg.norm(phi) 140 | 141 | if np.isclose(angle, 0.): 142 | return np.identity(3) - 0.5 * so3_wedge(phi) 143 | 144 | axis = phi / angle 145 | half_angle = 0.5 * angle 146 | cot_half_angle = 1. / np.tan(half_angle) 147 | 148 | return half_angle * cot_half_angle * np.identity(3) + \ 149 | (1 - half_angle * cot_half_angle) * np.outer(axis, axis) - \ 150 | half_angle * so3_wedge(axis) 151 | 152 | 153 | def project_orthogonal(rot): 154 | u, s, vh = np.linalg.svd(rot, full_matrices=True, compute_uv=True) 155 | rot = u @ vh 156 | if np.linalg.det(rot) < 0: 157 | u[:, 2] = -u[:, 2] 158 | rot = u @ vh 159 | return rot 160 | 161 | 162 | class Isometry: 163 | GL_POST_MULT = Quaternion(degrees=180.0, axis=[1.0, 0.0, 0.0]) 164 | 165 | def __init__(self, q=None, t=None): 166 | if q is None: 167 | q = Quaternion() 168 | if t is None: 169 | t = np.zeros(3) 170 | if not isinstance(t, np.ndarray): 171 | t = np.asarray(t) 172 | assert t.shape[0] == 3 and t.ndim == 1 173 | self.q = q 174 | self.t = t 175 | 176 | def __repr__(self): 177 | return f"Isometry: t = {self.t}, q = {self.q}" 178 | 179 | @property 180 | def rotation(self): 181 | return Isometry(q=self.q) 182 | 183 | @property 184 | def matrix(self): 185 | mat = self.q.transformation_matrix 186 | mat[0:3, 3] = self.t 187 | return mat 188 | 189 | @staticmethod 190 | def from_matrix(mat, t_component=None, ortho=False): 191 | assert isinstance(mat, np.ndarray) 192 | if t_component is None: 193 | assert mat.shape == (4, 4) 194 | if ortho: 195 | mat[:3, :3] = project_orthogonal(mat[:3, :3]) 196 | return Isometry(q=Quaternion(matrix=mat), t=mat[:3, 3]) 197 | else: 198 | assert mat.shape == (3, 3) 199 | assert t_component.shape == (3,) 200 | if ortho: 201 | mat = project_orthogonal(mat) 202 | return Isometry(q=Quaternion(matrix=mat), t=t_component) 203 | 204 | @staticmethod 205 | def from_twist(xi: np.ndarray): 206 | rho = xi[:3] 207 | phi = xi[3:6] 208 | iso = Isometry.from_so3_exp(phi) 209 | iso.t = so3_left_jacobian(phi) @ rho 210 | return iso 211 | 212 | @staticmethod 213 | def from_so3_exp(phi: np.ndarray): 214 | angle = np.linalg.norm(phi) 215 | 216 | # Near phi==0, use first order Taylor expansion 217 | if np.isclose(angle, 0.): 218 | return Isometry(q=Quaternion(matrix=np.identity(3) + so3_wedge(phi))) 219 | 220 | axis = phi / angle 221 | s = np.sin(angle) 222 | c = np.cos(angle) 223 | 224 | rot_mat = (c * np.identity(3) + 225 | (1 - c) * np.outer(axis, axis) + 226 | s * so3_wedge(axis)) 227 | return Isometry(q=Quaternion(matrix=rot_mat)) 228 | 229 | @property 230 | def continuous_repr(self): 231 | rot = self.q.rotation_matrix[:, 0:2].T.flatten() # (6,) 232 | return np.concatenate([rot, self.t]) # (9,) 233 | 234 | @staticmethod 235 | def from_continuous_repr(rep, gs=True): 236 | if isinstance(rep, list): 237 | rep = np.asarray(rep) 238 | assert isinstance(rep, np.ndarray) 239 | assert rep.shape == (9,) 240 | # For rotation, use Gram-Schmidt orthogonalization 241 | col1 = rep[0:3] 242 | col2 = rep[3:6] 243 | if gs: 244 | col1 /= np.linalg.norm(col1) 245 | col2 = col2 - np.dot(col1, col2) * col1 246 | col2 /= np.linalg.norm(col2) 247 | col3 = np.cross(col1, col2) 248 | return Isometry(q=Quaternion(matrix=np.column_stack([col1, col2, col3])), t=rep[6:9]) 249 | 250 | @property 251 | def full_repr(self): 252 | rot = self.q.rotation_matrix.T.flatten() 253 | return np.concatenate([rot, self.t]) 254 | 255 | @staticmethod 256 | def from_full_repr(rep, ortho=False): 257 | assert isinstance(rep, np.ndarray) 258 | assert rep.shape == (12,) 259 | rot = rep[0:9].reshape(3, 3).T 260 | if ortho: 261 | rot = project_orthogonal(rot) 262 | return Isometry(q=Quaternion(matrix=rot), t=rep[9:12]) 263 | 264 | def torch_matrices(self, device): 265 | import torch 266 | return torch.from_numpy(self.q.rotation_matrix).to(device).float(), \ 267 | torch.from_numpy(self.t).to(device).float() 268 | 269 | @staticmethod 270 | def random(): 271 | return Isometry(q=Quaternion.random(), t=np.random.random((3,))) 272 | 273 | def inv(self): 274 | qinv = self.q.inverse 275 | return Isometry(q=qinv, t=-(qinv.rotate(self.t))) 276 | 277 | def dot(self, right): 278 | return Isometry(q=(self.q * right.q), t=(self.q.rotate(right.t) + self.t)) 279 | 280 | def to_gl_camera(self): 281 | return Isometry(q=(self.q * self.GL_POST_MULT), t=self.t) 282 | 283 | @staticmethod 284 | def look_at(source: np.ndarray, target: np.ndarray, up: np.ndarray = None): 285 | z_dir = target - source 286 | z_dir /= np.linalg.norm(z_dir) 287 | if up is None: 288 | up = np.asarray([0.0, 1.0, 0.0]) 289 | if np.linalg.norm(np.cross(z_dir, up)) < 1e-6: 290 | up = np.asarray([1.0, 0.0, 0.0]) 291 | else: 292 | up /= np.linalg.norm(up) 293 | x_dir = np.cross(z_dir, up) 294 | x_dir /= np.linalg.norm(x_dir) 295 | y_dir = np.cross(z_dir, x_dir) 296 | R = np.column_stack([x_dir, y_dir, z_dir]) 297 | return Isometry(q=Quaternion(matrix=R), t=source) 298 | 299 | def adjoint_matrix(self): 300 | R = self.q.rotation_matrix 301 | twR = so3_wedge(self.t) @ R 302 | adjoint = np.zeros((6, 6)) 303 | adjoint[0:3, 0:3] = R 304 | adjoint[3:6, 3:6] = R 305 | adjoint[0:3, 3:6] = twR 306 | return adjoint 307 | 308 | def log(self): 309 | phi = so3_log(self.q.rotation_matrix) 310 | rho = so3_inv_left_jacobian(phi) @ self.t 311 | return np.hstack([rho, phi]) 312 | 313 | def tangent(self, prev_iso, next_iso): 314 | t = 0.5 * (next_iso.t - prev_iso.t) 315 | l1 = Quaternion.log((self.q.inverse * prev_iso.q).normalised) 316 | l2 = Quaternion.log((self.q.inverse * next_iso.q).normalised) 317 | e = Quaternion() 318 | e.q = -0.25 * (l1.q + l2.q) 319 | e = self.q * Quaternion.exp(e) 320 | return Isometry(t=t, q=e) 321 | 322 | def __matmul__(self, other): 323 | # "@" operator: other can be (N,3) or (3,). 324 | if hasattr(other, "device"): # Torch tensor 325 | assert other.ndim == 2 and other.size(1) == 3 # (N,3) 326 | th_R, th_t = self.torch_matrices(other.device) 327 | return other @ th_R.t() + th_t.unsqueeze(0) 328 | if isinstance(other, Isometry): 329 | return self.dot(other) 330 | if type(other) != np.ndarray or other.ndim == 1: 331 | return self.q.rotate(other) + self.t 332 | else: 333 | return other @ self.q.rotation_matrix.T + self.t[np.newaxis, :] 334 | 335 | @staticmethod 336 | def interpolate(source, target, alpha): 337 | iquat = Quaternion.slerp(source.q, target.q, alpha) 338 | it = source.t * (1 - alpha) + target.t * alpha 339 | return Isometry(q=iquat, t=it) 340 | -------------------------------------------------------------------------------- /src/datasets/arkitscene_dataset.py: -------------------------------------------------------------------------------- 1 | """The highres depth is not error free. Some pixels are totally out-of-bound. 2 | """ 3 | 4 | 5 | import numpy as np 6 | import os 7 | import os.path as osp 8 | from scipy.spatial.transform import Rotation 9 | 10 | 11 | def read_extr(info): 12 | rot = np.asarray(info[1:4]).astype(np.float32) 13 | T_cw = np.eye(4) 14 | rot_mat = Rotation.from_rotvec(rot).as_matrix() 15 | T_cw[:3, :3] = rot_mat 16 | trans = np.asarray(info[4:7]) 17 | T_cw[:3, 3] = trans 18 | return T_cw 19 | 20 | 21 | def read_intr(path): 22 | with open(path, "r") as f: 23 | intr_info = f.read().split(" ") 24 | intr_mat = np.eye(3) 25 | intr_mat[0][0] = intr_info[2] 26 | intr_mat[1][1] = intr_info[3] 27 | intr_mat[0][2] = intr_info[4] 28 | intr_mat[1][2] = intr_info[5] 29 | return intr_mat 30 | 31 | 32 | def get_frame_from_highres_time_stamps(dir, highres_dir, seq_name, poses, time_stamps): 33 | """ 34 | get frame given the input time stamp in the highres folder. 35 | The pose is linearly interpolate if not given in the pose traj. 36 | 37 | Args: 38 | poses (_type_): _description_ 39 | time_stamps (_type_): _description_ 40 | """ 41 | 42 | img_dir = osp.join(dir, "lowres_wide") 43 | depth_dir = osp.join(dir, "lowres_depth") 44 | confidence_dir = osp.join(dir, "confidence") 45 | high_res_img_dir = osp.join(highres_dir, "wide") 46 | high_res_depth_dir = osp.join(highres_dir, "highres_depth") 47 | 48 | time_stamps = np.asarray(time_stamps) 49 | src_time_stamps = np.asarray([float(p) for p in poses.keys()]) 50 | time_diffs =np.abs(time_stamps[:, None] - src_time_stamps[None, :]) 51 | match_ids = np.argsort(time_diffs, axis=-1) 52 | highres_frames = [] 53 | counter = 0 54 | for i, time_stamp in enumerate(time_stamps): 55 | # calculate the weights for pose interpolation 56 | weights = np.abs(time_stamp - src_time_stamps[match_ids[i, :2]]) 57 | weights = weights / np.sum(weights) 58 | weights = 1 - weights 59 | if weights[0] == 1: 60 | counter += 1 61 | 62 | assert np.abs(np.sum(weights) - 1) <= 1e-6 63 | T_cw = np.eye(4) 64 | id_0 = "{:.3f}".format(src_time_stamps[match_ids[i][0]]) 65 | id_1 = "{:.3f}".format(src_time_stamps[match_ids[i][1]]) 66 | p0 = poses[id_0] 67 | p1 = poses[id_1] 68 | quat = weights[0] * Rotation.from_matrix(p0[:3, :3]).as_quat() + \ 69 | weights[1] * Rotation.from_matrix(p1[:3, :3]).as_quat() 70 | rot_mat = Rotation.from_quat(quat).as_matrix() 71 | T_cw[:3, :3] = rot_mat 72 | trans = weights[0] * p0[:3, 3] + weights[1] * p1[:3, 3] 73 | T_cw[:3, 3] = trans 74 | time_stamp = "{:.3f}".format(time_stamp) 75 | depth_name = f"{seq_name}_{time_stamp}.png" 76 | rgb_name = f"{seq_name}_{time_stamp}.png" 77 | high_res_rgb_name = f"{seq_name}_{time_stamp}.png" 78 | if (not os.path.exists(osp.join(img_dir, rgb_name))) or \ 79 | (not os.path.exists(osp.join(depth_dir, depth_name))): 80 | 81 | print("lowres image not available") 82 | assert False 83 | 84 | intr_path = osp.join(dir, "lowres_wide_intrinsics", f"{seq_name}_{time_stamp}.pincam") 85 | intr_mat = read_intr(intr_path) 86 | intr_path = osp.join(dir, "wide_intrinsics", f"{seq_name}_{time_stamp}.pincam") 87 | highres_intr_mat = read_intr(intr_path) 88 | 89 | highres_frame = { 90 | "confidence_path": osp.join(confidence_dir, depth_name), 91 | "rgb_path": osp.join(img_dir, rgb_name), 92 | "depth_path": osp.join(depth_dir, depth_name), 93 | "high_res_rgb_path": osp.join(high_res_img_dir, high_res_rgb_name), 94 | "high_res_depth_path": osp.join(high_res_depth_dir, depth_name), 95 | "T_cw": T_cw, 96 | "intr_mat": intr_mat, 97 | "high_res_intr_mat": highres_intr_mat, 98 | "time_stamp": time_stamp 99 | } 100 | highres_frames.append(highres_frame) 101 | print(counter) 102 | return highres_frames 103 | 104 | 105 | def read_poses(path): 106 | with open(path, "r") as f: 107 | lines = f.read().splitlines() 108 | 109 | poses = {} 110 | 111 | for l in lines: 112 | info = l.split(" ") 113 | time_stamp = "{:.3f}".format(round(float(info[0]), 3)) 114 | T_cw = read_extr(info) 115 | poses[time_stamp] = T_cw 116 | return poses 117 | 118 | 119 | def get_association(dir, poses, seq_name): 120 | img_dir = osp.join(dir, "lowres_wide") 121 | high_res_img_dir = osp.join(dir, "vga_wide") 122 | depth_dir = osp.join(dir, "lowres_depth") 123 | confidence_dir = osp.join(dir, "confidence") 124 | 125 | frames = [] 126 | available_rgbs = [] 127 | available_rgbs = np.asarray([float(f.split("_")[1][:-4]) for f in os.listdir(img_dir)]) 128 | available_depths = np.asarray([float(f.split("_")[1][:-4]) for f in os.listdir(depth_dir)]) 129 | n_skipped_imgs = 0 130 | for time_stamp in poses: 131 | depth_name = f"{seq_name}_{time_stamp}.png" 132 | rgb_name = f"{seq_name}_{time_stamp}.png" 133 | high_res_rgb_name = f"{seq_name}_{time_stamp}.png" 134 | if (not os.path.exists(osp.join(img_dir, rgb_name))) or \ 135 | (not os.path.exists(osp.join(depth_dir, depth_name))) or \ 136 | (not os.path.exists(osp.join(high_res_img_dir, high_res_rgb_name))): 137 | 138 | n_skipped_imgs += 1 139 | continue 140 | 141 | T_cw = poses[time_stamp] 142 | intr_path = osp.join(dir, "lowres_wide_intrinsics", f"{seq_name}_{time_stamp}.pincam") 143 | intr_mat = read_intr(intr_path) 144 | high_res_intr_mat = read_intr( 145 | osp.join(dir, "vga_wide_intrinsics", f"{seq_name}_{time_stamp}.pincam") 146 | ) 147 | frame = { 148 | "confidence_path": osp.join(confidence_dir, depth_name), 149 | "rgb_path": osp.join(img_dir, rgb_name), 150 | "high_res_rgb_path": osp.join(img_dir, high_res_rgb_name), 151 | "depth_path": osp.join(depth_dir, depth_name), 152 | "T_cw": T_cw, 153 | "intr_mat": intr_mat, 154 | "high_res_intr_mat": high_res_intr_mat, 155 | "time_stamp": time_stamp 156 | } 157 | frames.append(frame) 158 | print(f"{len(poses)} poses, {len(available_rgbs)} rgb, {len(available_depths)} depths") 159 | print(f"skipped {n_skipped_imgs}/{len(available_rgbs)} due to missing poses") 160 | return frames 161 | 162 | 163 | def get_pose_by_time_stamp(frames, time_stamp): 164 | poses = [f['T_cw'] for f in frames] 165 | time_stamps = [f['time_stamp'] for f in frames] 166 | time_diff = np.abs(time_stamps - time_stamp) 167 | ids = np.argsort(time_diff) 168 | 169 | 170 | def get_frame_by_time_stamp(frames, time_stamp): 171 | for frame in frames: 172 | if frame['time_stamp'] == time_stamp: 173 | return frame 174 | print("no frame found") 175 | 176 | 177 | def get_nearby_frames(frames, time_stamp, N=10, min_angle=15, min_distance=0.1): 178 | """ 179 | get nearby frames given a time stamp. The frames should have 180 | enough parallax and covisibility. 181 | 182 | Args: 183 | frames (_type_): _description_ 184 | time_stamp (_type_): _description_ 185 | N (int, optional): _description_. Defaults to 10. 186 | """ 187 | 188 | n = 0 189 | i = 0 190 | times = np.asarray([f['time_stamp'] for f in frames]) 191 | time_diffs = np.abs(time_stamp - times) 192 | ref_frame = get_frame_by_time_stamp(frames, time_stamp) 193 | assert ref_frame is not None 194 | 195 | ids = np.argsort(time_diffs) 196 | pose_ref = np.linalg.inv(ref_frame['T_cw']) 197 | nearby_frames = [] 198 | while n < N: 199 | source_frame = frames[ids[i]] 200 | T_wc = np.linalg.inv(source_frame['T_cw']) 201 | angle = np.arccos( 202 | ((np.linalg.inv(T_wc[:3, :3]) @ pose_ref[:3, :3] @ np.array([0, 0, 1]).T) * np.array( 203 | [0, 0, 1])).sum()) 204 | dis = np.linalg.norm(T_wc[:3, 3] - pose_ref[:3, 3]) 205 | if angle > (min_angle / 180) * np.pi or dis > min_distance: 206 | nearby_frames.append(source_frame) 207 | n += 1 208 | i += 1 209 | return nearby_frames 210 | 211 | 212 | def get_high_res_time_stamp(in_dir): 213 | img_dir = osp.join(in_dir, "highres_depth") 214 | return sorted([float(f.split("_")[1][:-4]) for f in os.listdir(img_dir)]) 215 | 216 | 217 | class ARKitSceneDataset(torch.utils.data.Dataset): 218 | def __init__(self) -> None: 219 | super().__init__() 220 | 221 | def __len__(self): 222 | 223 | def __getitem__(self, idx): 224 | 225 | 226 | if __name__ == "__main__": 227 | import matplotlib.pyplot as plt 228 | import cv2 229 | import open3d as o3d 230 | import src.utils.o3d_helper as o3d_helper 231 | import src.utils.geometry as geometry 232 | 233 | 234 | # seq_ids = os.listdir("./data/fusion/arkit") 235 | seq_ids = ["41048190"] 236 | for seq_id in seq_ids: 237 | root_dir = f"/home/kejie/Datasets_ssd/raw/Training/{seq_id}" 238 | upsample_dir = f"/home/kejie/Datasets_ssd/raw/Training/{seq_id}" 239 | traj_path = osp.join(root_dir, "lowres_wide.traj") 240 | poses = read_poses(traj_path) 241 | frames = get_association(root_dir, poses, seq_id) 242 | highres_time_stamps = get_high_res_time_stamp(upsample_dir) 243 | training_pairs = [] 244 | highres_frames = get_frame_from_highres_time_stamps(root_dir, upsample_dir, seq_id, poses, highres_time_stamps) 245 | 246 | fusion = o3d_helper.TSDFFusion() 247 | pts_list = [] 248 | for frame in highres_frames: 249 | print(frame['high_res_depth_path']) 250 | depth = cv2.imread(frame['high_res_depth_path'], -1) / 1000. 251 | mask = (depth > 0).astype(np.float32) 252 | # color = cv2.imread(frame['high_res_rgb_path'], -1) 253 | # T_wc = np.linalg.inv(frame['T_cw']) 254 | # intr_mat = frame['high_res_intr_mat'] 255 | 256 | depth = cv2.imread(frame['depth_path'], -1) / 1000. 257 | confidence = cv2.imread(frame["confidence_path"], -1) 258 | depth = depth * (confidence >= 2) 259 | mask = cv2.resize(mask, dsize=(depth.shape[1], depth.shape[0])) 260 | depth = depth * mask 261 | color = cv2.imread(frame['rgb_path'], -1) 262 | T_wc = np.linalg.inv(frame['T_cw']) 263 | intr_mat = frame['intr_mat'] 264 | 265 | fusion.integrate(depth, color, T_wc, intr_mat) 266 | 267 | # xyz = geometry.depth2xyz(depth, intr_mat).reshape(-1, 3) 268 | # xyz = (T_wc @ geometry.get_homogeneous(xyz).T)[:3, :].T 269 | # pts_list.append(o3d_helper.np2pc(xyz)) 270 | # o3d.visualization.draw_geometries(pts_list) 271 | # lowres_depth = cv2.imread(frame['depth_path'], -1) / 1000. 272 | # mask = cv2.imread(frame["confidence_path"], -1) 273 | # lowres_depth = lowres_depth * (mask >= 2) 274 | # _, axes = plt.subplots(1, 2) 275 | # axes[0].imshow(depth) 276 | # axes[1].imshow(lowres_depth) 277 | # plt.show() 278 | mesh = fusion.marching_cube("./low_res.ply") 279 | o3d.visualization.draw_geometries([o3d_helper.trimesh2o3d(mesh)]) 280 | 281 | # for time_stamp in highres_time_stamps: 282 | # ref_frame = get_frame_by_time_stamp(frames, time_stamp) 283 | # if ref_frame is not None: 284 | # source_frames = get_nearby_frames(frames, time_stamp) 285 | # training_pairs.append([source_frames, ref_frame]) 286 | # # fig, axes = plt.subplots(3, 5) 287 | # # src_color = cv2.imread(ref_frame['high_res_rgb_path'], -1)[:, :, ::-1] 288 | # # axes[0][2].imshow(src_color) 289 | # # print(ref_frame['time_stamp']) 290 | # # for i in range(len(source_frames)): 291 | # # print(source_frames[i]['time_stamp']) 292 | # # src_color = cv2.imread(source_frames[i]['high_res_rgb_path'], -1)[:, :, ::-1] 293 | # # y = i // 5 294 | # # x = i - y * 5 295 | # # axes[y+1][x].imshow(src_color) 296 | # # plt.show() 297 | # print(f"{len(training_pairs)}/{len(highres_time_stamps)}") 298 | 299 | -------------------------------------------------------------------------------- /src/run_e2e.py: -------------------------------------------------------------------------------- 1 | from weakref import KeyedRef 2 | import hydra 3 | import numpy as np 4 | import os 5 | from omegaconf import DictConfig 6 | import torch 7 | import time 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from pytorch_lightning import seed_everything 11 | 12 | from src.datasets import datasets 13 | from src.datasets.fusion_inference_dataset import IterableInferenceDataset 14 | import src.utils.o3d_helper as o3d_helper 15 | import src.utils.hydra_utils as hydra_utils 16 | import src.utils.voxel_utils as voxel_utils 17 | from src.models.fusion.local_point_fusion import LitFusionPointNet 18 | from src.models.sparse_volume import SparseVolume 19 | from src.utils.render_utils import calculate_loss 20 | from src.utils.common import to_cuda, Timer 21 | import third_parties.fusion as fusion 22 | 23 | 24 | log = hydra_utils.get_logger(__name__) 25 | 26 | 27 | class NeuralMap: 28 | def __init__( 29 | self, 30 | dimensions, 31 | config, 32 | pointnet, 33 | working_dir, 34 | ): 35 | if "/" in config.dataset.scan_id: 36 | self.dataset_name, self.scan_id = config.dataset.scan_id.split("/") 37 | else: 38 | self.scan_id = config.dataset.scan_id 39 | min_coords, max_coords, n_xyz = voxel_utils.get_world_range( 40 | dimensions, config.model.voxel_size) 41 | self.pointnet = pointnet 42 | self.working_dir = working_dir 43 | self.config = config 44 | self.volume = SparseVolume( 45 | config.model.feature_vector_size, 46 | config.model.voxel_size, 47 | dimensions, 48 | config.model.min_pts_in_grid) 49 | self.bound_min = torch.from_numpy(min_coords).to("cuda").float() 50 | self.bound_max = torch.from_numpy(max_coords).to("cuda").float() 51 | self.voxel_size = config.model.voxel_size 52 | self.n_xyz = n_xyz 53 | self.dimensions = dimensions 54 | self.sampling_size = config.dataset.num_pixels 55 | self.train_ray_splits = config.model.train_ray_splits 56 | self.ray_max_dist = config.model.ray_tracer.ray_max_dist 57 | self.truncated_units = config.model.ray_tracer.truncated_units 58 | self.truncated_dist = min(self.truncated_units * self.voxel_size * 0.5, 0.1) 59 | self.depth_scale = 1000. 60 | self.sdf_delta = None 61 | self.skip_images = config.dataset.skip_images 62 | self.tsdf_voxel_size = 0.025 63 | self.sdf_delta_weight = config.model.sdf_delta_weight 64 | min_coords, max_coords, n_xyz = voxel_utils.get_world_range( 65 | dimensions, self.tsdf_voxel_size) 66 | vol_bnds = np.zeros((3,2)) 67 | vol_bnds[:, 0] = min_coords 68 | vol_bnds[:, 1] = max_coords 69 | self.tsdf_vol = fusion.TSDFVolume( 70 | vol_bnds, 71 | voxel_size=self.tsdf_voxel_size) 72 | 73 | self.frames = [] 74 | self.iterable_dataset = IterableInferenceDataset( 75 | self.frames, self.ray_max_dist, self.bound_min.cpu(), 76 | self.bound_max.cpu(), self.n_xyz, self.sampling_size, config.dataset.confidence_level) 77 | 78 | def integrate(self, frame): 79 | if len(frame['input_pts']) == 0: 80 | return None 81 | with torch.no_grad(): 82 | # local-level fusion 83 | fine_feats, fine_weights, _, fine_coords, fine_n_pts = self.pointnet.encode_pointcloud( 84 | frame['input_pts'], # [1, N, 6] 85 | self.volume.n_xyz, 86 | self.volume.min_coords, 87 | self.volume.max_coords, 88 | self.volume.voxel_size, 89 | return_dense=self.pointnet.dense_volume 90 | ) 91 | if fine_feats is None: 92 | return None 93 | self.volume.track_n_pts(fine_n_pts) 94 | self.pointnet._integrate( 95 | self.volume, 96 | fine_coords, 97 | fine_feats, 98 | fine_weights) 99 | # tsdf fusion 100 | rgbd = frame['rgbd'].cpu().numpy() 101 | depth_map = rgbd[0, -1, :, :] 102 | rgb = (rgbd[0, :3, :, :].transpose(1, 2, 0) + 0.5) * 255. 103 | # depth_map = rgbd[0, 3, :, :] 104 | self.tsdf_vol.integrate( 105 | rgb, # [h, w, 3], [0, 255] 106 | depth_map, # [h, w], metric depth 107 | frame['intr_mat'].cpu().numpy()[0], 108 | frame["T_wc"].cpu().numpy()[0], 109 | obs_weight=1.) 110 | 111 | def optimize(self, n_iters, last_frame): 112 | self.volume.to_tensor() 113 | tsdf_delta = self.prepare_tsdf_volume() 114 | self.volume.features = torch.nn.Parameter(self.volume.features) 115 | self.iterable_dataset.n_iters = n_iters 116 | self.iterable_dataset.last_frame = last_frame 117 | loader = torch.utils.data.DataLoader(self.iterable_dataset, batch_size=None, num_workers=4) 118 | optimizer = torch.optim.Adam([self.volume.features], lr=0.001) 119 | for rays in tqdm(loader): 120 | optimizer.zero_grad() 121 | if torch.isnan(rays['T_wc']).any(): 122 | continue 123 | to_cuda(rays) 124 | batch_loss = {} 125 | n_rays = rays['uv'].shape[1] 126 | n_splits = n_rays / self.train_ray_splits 127 | for i, indx in enumerate(torch.split( 128 | torch.arange(n_rays).cuda(), self.train_ray_splits, dim=0 129 | )): 130 | ray_splits = { 131 | "uv": torch.index_select(rays['uv'], 1, indx), 132 | "rgb": torch.index_select(rays['rgb'], 1, indx), 133 | "gt_pts": torch.index_select(rays['gt_pts'], 1, indx), 134 | "mask": torch.index_select(rays['mask'], 1, indx), 135 | "neighbor_pts": torch.index_select(rays['neighbor_pts'], 1, indx), 136 | "neighbor_masks": torch.index_select(rays['neighbor_masks'], 1, indx), 137 | "T_wc": rays['T_wc'], 138 | "intr_mat": rays['intr_mat']} 139 | split_loss_out = calculate_loss( 140 | self.volume, 141 | ray_splits, 142 | self.pointnet.nerf, 143 | truncated_units=self.truncated_units, 144 | truncated_dist=self.truncated_dist, 145 | ray_max_dist=self.ray_max_dist, 146 | sdf_delta=tsdf_delta) 147 | loss_for_backward = 0 148 | for k in split_loss_out: 149 | if k[0] != "_": 150 | loss_for_backward += split_loss_out[k] 151 | if k not in batch_loss: 152 | batch_loss[k] = split_loss_out[k] 153 | else: 154 | batch_loss[k] += split_loss_out[k] 155 | loss_for_backward.backward() 156 | optimizer.step() 157 | # store optimized features back to the sparse_volume 158 | self.volume.insert( 159 | self.volume.active_coordinates, 160 | self.volume.features, 161 | self.volume.weights, 162 | self.volume.num_hits) 163 | 164 | def extract_mesh(self): 165 | sdf_delta = self.prepare_tsdf_volume() 166 | surface_pts, mesh = self.volume.meshlize(self.pointnet.nerf, sdf_delta) 167 | return mesh 168 | 169 | def prepare_tsdf_volume(self): 170 | tsdf_volume, _ = self.tsdf_vol.get_volume() 171 | tsdf_volume = tsdf_volume * (self.tsdf_voxel_size * 5) 172 | tsdf_volume = torch.from_numpy(tsdf_volume).to(self.pointnet.device).float().unsqueeze(0).unsqueeze(0) 173 | resized_tsdf_volume = tsdf_volume 174 | # resized_tsdf_volume = F.interpolate( 175 | # tsdf_volume, 176 | # size=( 177 | # self.n_xyz[0], 178 | # self.n_xyz[1], 179 | # self.n_xyz[2] 180 | # ), 181 | # mode="trilinear", 182 | # align_corners=True) 183 | resized_tsdf_volume = torch.clip( 184 | resized_tsdf_volume, min=-self.truncated_dist, max=self.truncated_dist) 185 | resized_tsdf_volume *= self.sdf_delta_weight 186 | return resized_tsdf_volume 187 | 188 | def save(self): 189 | # save tsdf volume 190 | tsdf_out_path = os.path.join(self.working_dir, self.scan_id + ".npy") 191 | tsdf_vol, _ = self.tsdf_vol.get_volume() 192 | tsdf_vol = tsdf_vol * (self.tsdf_voxel_size * 5) 193 | np.save(tsdf_out_path, tsdf_vol) 194 | self.volume.save(os.path.join(self.working_dir, "final")) 195 | 196 | def track_memory(): 197 | div_GB = 1024 * 1024 * 1024 198 | print("GPU status:") 199 | print(f"allocated: {torch.cuda.memory_allocated() / div_GB} GB") 200 | print(f"max allocated: {torch.cuda.max_memory_allocated() / div_GB} GB") 201 | print(f"reserved: {torch.cuda.memory_reserved() / div_GB} GB") 202 | print(f"max reserved: {torch.cuda.max_memory_reserved() / div_GB} GB") 203 | 204 | 205 | @hydra.main(config_path="../configs/", config_name="config.yaml") 206 | def main(config: DictConfig): 207 | 208 | if "seed" in config.trainer: 209 | seed_everything(config.trainer.seed) 210 | 211 | hydra_utils.extras(config) 212 | hydra_utils.print_config(config, resolve=True) 213 | 214 | # setup dataset 215 | log.info("initializing dataset") 216 | val_dataset = datasets.get_dataset(config, "val") 217 | val_loader = DataLoader( 218 | val_dataset, 219 | batch_size=config.dataset.eval_batch_size, 220 | shuffle=False, 221 | num_workers=config.dataset.num_workers, 222 | collate_fn=val_dataset.collate_fn if hasattr(val_dataset, "collate_fn") else None 223 | ) 224 | 225 | plots_dir = os.path.join(os.getcwd(), config.dataset.scan_id) 226 | if not os.path.exists(plots_dir): 227 | os.makedirs(plots_dir) 228 | 229 | # setup model 230 | log.info("initializing model") 231 | pointnet_model = LitFusionPointNet(config) 232 | pretrained_weights = torch.load(config.trainer.checkpoint) 233 | pointnet_model.load_state_dict(pretrained_weights['state_dict']) 234 | pointnet_model.eval() 235 | pointnet_model.cuda() 236 | pointnet_model.freeze() 237 | neural_map = NeuralMap( 238 | val_dataset.dimensions, 239 | config, 240 | pointnet_model, 241 | working_dir=plots_dir) 242 | timer = Timer(["local", "global"]) 243 | for idx, data in enumerate(tqdm(val_loader)): 244 | # LOCAL FUSION: 245 | # integrate information from the new frame to the feature volume 246 | frame, _ = data 247 | for k in frame.keys(): 248 | if isinstance(frame[k], torch.Tensor): 249 | frame[k] = frame[k].cuda().float() 250 | timer.start("local") 251 | neural_map.integrate(frame) 252 | timer.log("local") 253 | if torch.isnan(frame['T_wc']).any(): 254 | continue 255 | meta_frame = { 256 | "frame_id": frame["frame_id"], 257 | "scan_id": frame["scene_id"], 258 | "T_wc": frame["T_wc"].clone().cpu(), 259 | "intr_mat": frame["intr_mat"].clone().cpu(), 260 | "img_path": frame['img_path'][0], 261 | "depth_path": frame['depth_path'][0], 262 | } 263 | if "mask_path" in frame: 264 | meta_frame['mask_path'] = frame['mask_path'][0] 265 | del frame 266 | neural_map.frames.append(meta_frame) 267 | # clear memory for open3d hashmap 268 | if (idx+1) % 2 == 0: 269 | torch.cuda.empty_cache() 270 | if config.model.mode == "demo": 271 | if (idx) % config.model.optim_interval == 0: 272 | last_frame = max(0, len(neural_map.frames) - config.model.optim_interval) 273 | n_iters = min(len(neural_map.frames), config.model.optim_interval) * neural_map.skip_images 274 | timer.start("global") 275 | neural_map.optimize(n_iters=n_iters, last_frame=last_frame) 276 | timer.log("global") 277 | mesh = neural_map.extract_mesh() 278 | mesh = o3d_helper.post_process_mesh(mesh) 279 | mesh_out_path = os.path.join(neural_map.working_dir, f"{idx}.ply") 280 | mesh.export(mesh_out_path) 281 | neural_map.volume.to_tensor() 282 | mesh = neural_map.extract_mesh() 283 | mesh.export(os.path.join(neural_map.working_dir, "before_optim.ply")) 284 | global_steps = int(len(neural_map.frames) * neural_map.skip_images) 285 | global_steps = global_steps * 2 if config.model.mode != "demo" else global_steps 286 | timer.start("global") 287 | neural_map.optimize(n_iters=global_steps, last_frame=-1) 288 | timer.log("global") 289 | for n in ["local", "global"]: 290 | print(f"speed on {n} fusion: {global_steps / timer.times[n]} fps") 291 | 292 | mesh = neural_map.extract_mesh() 293 | mesh = o3d_helper.post_process_mesh(mesh, vertex_threshold=neural_map.voxel_size / 4) 294 | mesh_out_path = os.path.join(neural_map.working_dir, "final.ply") 295 | mesh.export(mesh_out_path) 296 | neural_map.save() 297 | 298 | 299 | if __name__ == "__main__": 300 | main() -------------------------------------------------------------------------------- /src/models/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def chamfer_distance_naive(points1, points2): 6 | ''' Naive implementation of the Chamfer distance. 7 | Args: 8 | points1 (numpy array): first point set 9 | points2 (numpy array): second point set 10 | ''' 11 | assert(points1.size() == points2.size()) 12 | batch_size, T, _ = points1.size() 13 | 14 | points1 = points1.view(batch_size, T, 1, 3) 15 | points2 = points2.view(batch_size, 1, T, 3) 16 | 17 | distances = (points1 - points2).pow(2).sum(-1) 18 | 19 | chamfer1 = distances.min(dim=1)[0].mean(dim=1) 20 | chamfer2 = distances.min(dim=2)[0].mean(dim=1) 21 | 22 | chamfer = chamfer1 + chamfer2 23 | return chamfer 24 | 25 | 26 | def compute_iou(occ1, occ2): 27 | ''' Computes the Intersection over Union (IoU) value for two sets of 28 | occupancy values. 29 | 30 | Args: 31 | occ1 (tensor): first set of occupancy values 32 | occ2 (tensor): second set of occupancy values 33 | ''' 34 | occ1 = np.asarray(occ1) 35 | occ2 = np.asarray(occ2) 36 | 37 | # Put all data in second dimension 38 | # Also works for 1-dimensional data 39 | if occ1.ndim >= 2: 40 | occ1 = occ1.reshape(occ1.shape[0], -1) 41 | if occ2.ndim >= 2: 42 | occ2 = occ2.reshape(occ2.shape[0], -1) 43 | 44 | # Convert to boolean values 45 | occ1 = (occ1 >= 0.5) 46 | occ2 = (occ2 >= 0.5) 47 | 48 | # Compute IOU 49 | area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) 50 | area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) 51 | 52 | iou = (area_intersect / area_union) 53 | 54 | return iou 55 | 56 | 57 | def coordinate2index(x, reso, coord_type='2d'): 58 | ''' Normalize coordinate to [0, 1] for unit cube experiments. 59 | Corresponds to our 3D model 60 | 61 | Args: 62 | x (tensor): coordinate 63 | reso (int): defined resolution 64 | coord_type (str): coordinate type 65 | ''' 66 | x = (x * reso).long() 67 | if coord_type == '2d': # plane 68 | index = x[:, :, 0] + reso * x[:, :, 1] 69 | elif coord_type == '3d': # grid 70 | index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) 71 | index = index[:, None, :] 72 | return index 73 | 74 | 75 | def coordinate2index_rectangle(x, res, coord_type='2d'): 76 | ''' Normalize coordinate to [0, 1] for unit cube experiments. 77 | Corresponds to our 3D model 78 | 79 | Args: 80 | x (tensor): coordinate 81 | reso (int): defined resolution 82 | coord_type (str): coordinate type 83 | ''' 84 | res_x, res_y, res_z = res 85 | index = x[:, :, 0] * res_y * res_z + x[:, :, 1] * res_z + x[:, :, 2] 86 | index = index[:, None, :] 87 | return index 88 | 89 | 90 | def get_neighbors(points): 91 | """ 92 | Get the aabb of points 93 | """ 94 | return torch.stack([ 95 | torch.stack( 96 | [ 97 | torch.floor(points[:, :, 0]), 98 | torch.floor(points[:, :, 1]), 99 | torch.floor(points[:, :, 2]) 100 | ], 101 | dim=-1 102 | ), 103 | torch.stack( 104 | [ 105 | torch.ceil(points[:, :, 0]), 106 | torch.floor(points[:, :, 1]), 107 | torch.floor(points[:, :, 2]) 108 | ], 109 | dim=-1 110 | ), 111 | torch.stack( 112 | [ 113 | torch.floor(points[:, :, 0]), 114 | torch.ceil(points[:, :, 1]), 115 | torch.floor(points[:, :, 2]) 116 | ], 117 | dim=-1 118 | ), 119 | torch.stack( 120 | [ 121 | torch.floor(points[:, :, 0]), 122 | torch.floor(points[:, :, 1]), 123 | torch.ceil(points[:, :, 2]) 124 | ], 125 | dim=-1 126 | ), 127 | torch.stack( 128 | [ 129 | torch.floor(points[:, :, 0]), 130 | torch.ceil(points[:, :, 1]), 131 | torch.ceil(points[:, :, 2]) 132 | ], 133 | dim=-1 134 | ), 135 | torch.stack( 136 | [ 137 | torch.ceil(points[:, :, 0]), 138 | torch.floor(points[:, :, 1]), 139 | torch.ceil(points[:, :, 2]) 140 | ], 141 | dim=-1 142 | ), 143 | torch.stack( 144 | [ 145 | torch.ceil(points[:, :, 0]), 146 | torch.ceil(points[:, :, 1]), 147 | torch.floor(points[:, :, 2]) 148 | ], 149 | dim=-1 150 | ), 151 | torch.stack( 152 | [ 153 | torch.ceil(points[:, :, 0]), 154 | torch.ceil(points[:, :, 1]), 155 | torch.ceil(points[:, :, 2]) 156 | ], 157 | dim=-1 158 | ), 159 | ], dim=0) 160 | 161 | 162 | def recenter(points, grid_resolution): 163 | """ 164 | Args: 165 | points: [B, N, 3] point position in 3D in [-1, 1] 166 | grid_resolution: the resolution of the feature volume 167 | 168 | Returns: 169 | local_coordinates [B, 8, N, 3] coordinate wrt neighboring points 170 | indices: [B, 8, N, 3]: feature grid indices of neighboring points 171 | """ 172 | 173 | # convert points from [-1, 1] to the voxel grid coordinate 174 | points = (points + 1) / 2 * grid_resolution 175 | 176 | # get neighbouring points 177 | indices = get_neighbors(points) 178 | 179 | # calculate relative coordinate 180 | local_coordinates = points.unsqueeze(0).repeat((8, 1, 1, 1)) \ 181 | - indices 182 | 183 | # rescale the coordinates to [-1, 1] 184 | local_coordinates = local_coordinates # / grid_resolution * 2 185 | 186 | # [8, B, N, 3] -> [B, 8, N, 3] 187 | local_coordinates = local_coordinates.permute(1, 0, 2, 3) 188 | indices = indices.permute(1, 0, 2, 3) 189 | 190 | return local_coordinates, indices.int() 191 | 192 | 193 | def get_neighbors_new(points, resolution): 194 | """ 195 | Get the aabb of points 196 | """ 197 | return torch.stack([ 198 | torch.stack( 199 | [ 200 | torch.floor(points[:, :, 0] / resolution), 201 | torch.floor(points[:, :, 1] / resolution), 202 | torch.floor(points[:, :, 2] / resolution) 203 | ], 204 | dim=-1 205 | ), 206 | torch.stack( 207 | [ 208 | torch.ceil(points[:, :, 0] / resolution), 209 | torch.floor(points[:, :, 1] / resolution), 210 | torch.floor(points[:, :, 2] / resolution) 211 | ], 212 | dim=-1 213 | ), 214 | torch.stack( 215 | [ 216 | torch.floor(points[:, :, 0] / resolution), 217 | torch.ceil(points[:, :, 1] / resolution), 218 | torch.floor(points[:, :, 2] / resolution) 219 | ], 220 | dim=-1 221 | ), 222 | torch.stack( 223 | [ 224 | torch.floor(points[:, :, 0] / resolution), 225 | torch.floor(points[:, :, 1] / resolution), 226 | torch.ceil(points[:, :, 2] / resolution) 227 | ], 228 | dim=-1 229 | ), 230 | torch.stack( 231 | [ 232 | torch.floor(points[:, :, 0] / resolution), 233 | torch.ceil(points[:, :, 1] / resolution), 234 | torch.ceil(points[:, :, 2] / resolution) 235 | ], 236 | dim=-1 237 | ), 238 | torch.stack( 239 | [ 240 | torch.ceil(points[:, :, 0] / resolution), 241 | torch.floor(points[:, :, 1] / resolution), 242 | torch.ceil(points[:, :, 2] / resolution) 243 | ], 244 | dim=-1 245 | ), 246 | torch.stack( 247 | [ 248 | torch.ceil(points[:, :, 0] / resolution), 249 | torch.ceil(points[:, :, 1] / resolution), 250 | torch.floor(points[:, :, 2] / resolution) 251 | ], 252 | dim=-1 253 | ), 254 | torch.stack( 255 | [ 256 | torch.ceil(points[:, :, 0] / resolution), 257 | torch.ceil(points[:, :, 1] / resolution), 258 | torch.ceil(points[:, :, 2] / resolution) 259 | ], 260 | dim=-1 261 | ), 262 | ], dim=0) 263 | 264 | 265 | def recenter_new(points, grid_step_size, resolution): 266 | """ 267 | Args: 268 | points: [B, N, 3] point position in 3D in [-1, 1] 269 | grid_resolution: the resolution of the feature volume 270 | 271 | Returns: 272 | local_coordinates [B, 8, N, 3] coordinate wrt neighboring points 273 | indices: [B, 8, N, 3]: feature grid indices of neighboring points 274 | """ 275 | 276 | # convert points from [-1, 1] to the voxel grid coordinate 277 | points = (points + 1) / 2 * (resolution - grid_step_size) 278 | 279 | # get neighbouring points 280 | indices = get_neighbors_new(points, grid_step_size) 281 | 282 | neighbor_centers = indices * grid_step_size 283 | 284 | # calculate relative coordinate 285 | local_coordinates = points.unsqueeze(0).repeat((8, 1, 1, 1)) \ 286 | - neighbor_centers 287 | 288 | # rescale the coordinates to [-1, 1] 289 | local_coordinates = local_coordinates / grid_step_size 290 | 291 | # [8, B, N, 3] -> [B, 8, N, 3] 292 | local_coordinates = local_coordinates.permute(1, 0, 2, 3) 293 | indices = indices.permute(1, 0, 2, 3) 294 | 295 | return local_coordinates, indices.int() 296 | 297 | 298 | def index_w_border(grid, batch_ind, x, y, z, D, H, W): 299 | batch_size, feat_dim = grid.shape[:2] 300 | n_pts = x.shape[1] 301 | x = torch.clamp(x, 0, D-1) # [B, N] 302 | y = torch.clamp(y, 0, H-1) # [B, N] 303 | z = torch.clamp(z, 0, W-1) 304 | 305 | out = grid[batch_ind, :, x.flatten(), y.flatten(), z.flatten()].reshape( 306 | batch_size, n_pts, feat_dim) 307 | return out 308 | 309 | 310 | def bilinear_interpolate_torch(im, xy): 311 | """ grid_sample in pytorch can't have second order derivative 312 | 313 | Argument: 314 | im: (B, C, H, W) 315 | xy: (B, N, 2) in [-1, 1] 316 | 317 | Return: 318 | result: (B, C, N) 319 | """ 320 | 321 | batch_size, feat_dim, img_d, img_h, img_w = im.shape 322 | n_pts = xy.shape[1] 323 | assert batch_size == xy.shape[0] 324 | 325 | xy[:, :, 0] = (xy[:, :, 0] + 1) / 2 * (img_w-1) 326 | xy[:, :, 1] = (xy[:, :, 1] + 1) / 2 * (img_h-1) 327 | xy[:, :, 2] = (xy[:, :, 2] + 1) / 2 * (img_d-1) 328 | 329 | x = xy[:, :, 2] 330 | y = xy[:, :, 1] 331 | z = xy[:, :, 0] 332 | dtype = x.type() 333 | 334 | x0 = torch.floor(x).long() 335 | x1 = x0 + 1 336 | y0 = torch.floor(y).long() 337 | y1 = y0 + 1 338 | z0 = torch.floor(z).long() 339 | z1 = z0 + 1 340 | 341 | batch_ind = torch.arange(0, batch_size).unsqueeze(-1).to(im.device) 342 | batch_ind = batch_ind.repeat(1, n_pts).flatten() 343 | 344 | I000 = index_w_border(im, batch_ind, x0, y0, z0, img_d, img_h, img_w) 345 | I010 = index_w_border(im, batch_ind, x0, y1, z0, img_d, img_h, img_w) 346 | I100 = index_w_border(im, batch_ind, x1, y0, z0, img_d, img_h, img_w) 347 | I110 = index_w_border(im, batch_ind, x1, y1, z0, img_d, img_h, img_w) 348 | I001 = index_w_border(im, batch_ind, x0, y0, z1, img_d, img_h, img_w) 349 | I011 = index_w_border(im, batch_ind, x0, y1, z1, img_d, img_h, img_w) 350 | I101 = index_w_border(im, batch_ind, x1, y0, z1, img_d, img_h, img_w) 351 | I111 = index_w_border(im, batch_ind, x1, y1, z1, img_d, img_h, img_w) 352 | 353 | x1_weight = (x1.type(dtype)-x) 354 | y1_weight = (y1.type(dtype)-y) 355 | z1_weight = (z1.type(dtype)-z) 356 | x0_weight = (x-x0.type(dtype)) 357 | y0_weight = (y-y0.type(dtype)) 358 | z0_weight = (z-z0.type(dtype)) 359 | 360 | w000 = x1_weight * y1_weight * z1_weight 361 | w000 = w000.unsqueeze(-1) 362 | w010 = x1_weight * y0_weight * z1_weight 363 | w010 = w010.unsqueeze(-1) 364 | w100 = x0_weight * y1_weight * z1_weight 365 | w100 = w100.unsqueeze(-1) 366 | w110 = x0_weight * y0_weight * z1_weight 367 | w110 = w110.unsqueeze(-1) 368 | w001 = x1_weight * y1_weight * z0_weight 369 | w001 = w001.unsqueeze(-1) 370 | w011 = x1_weight * y0_weight * z0_weight 371 | w011 = w011.unsqueeze(-1) 372 | w101 = x0_weight * y1_weight * z0_weight 373 | w101 = w101.unsqueeze(-1) 374 | w111 = x0_weight * y0_weight * z0_weight 375 | w111 = w111.unsqueeze(-1) 376 | 377 | return I000*w000 + I010*w010 + I100*w100 + I110*w110 + I001*w001 + I011*w011 + I101*w101 + I111*w111 378 | 379 | 380 | if __name__ == "__main__": 381 | # points = np.array([ 382 | # [ 383 | # [-0.9999, -0.9999, -0.9999], 384 | # [-0.975, -0.99, -0.99], 385 | # [-0.8125, -0.8125, -0.8125], 386 | # [0.99, 0.99, 0.99] 387 | # ] 388 | # ]) 389 | points = np.stack(np.meshgrid(np.arange(16), np.arange( 390 | 16), np.arange(16)), axis=-1).reshape(-1, 3) 391 | points = points / 128 * 2 - 1 392 | recenter_new(torch.from_numpy(points).unsqueeze(0), 393 | grid_step_size=8, resolution=128) 394 | -------------------------------------------------------------------------------- /src/utils/pointnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import commentjson as json 5 | import tinycudann as tcnn 6 | 7 | 8 | def query_ball_point(radius, nsample, xyz, new_xyz, return_dist=False): 9 | """ 10 | Input: 11 | radius: local region radius 12 | nsample: max sample number in local region 13 | xyz: all points, [B, N, 3] 14 | new_xyz: query points, [B, S, 3] 15 | Return: 16 | group_idx: grouped points index, [B, S, nsample] 17 | """ 18 | device = xyz.device 19 | B, N, C = xyz.shape 20 | _, S, _ = new_xyz.shape 21 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 22 | sqrdists = square_distance(new_xyz, xyz) 23 | group_idx[sqrdists > radius ** 2] = N 24 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 25 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 26 | mask = group_idx == N 27 | group_idx[mask] = group_first[mask] 28 | if return_dist: 29 | view_shape = list(group_idx.shape) 30 | view_shape[1:] = [1] * (len(view_shape) - 1) 31 | repeat_shape = list(group_idx.shape) 32 | repeat_shape[0] = 1 33 | batch_indices = torch.arange( 34 | B, dtype=torch.long 35 | ).to(device).view(view_shape).repeat(repeat_shape) 36 | query_indices = torch.arange( 37 | S, dtype=torch.long 38 | ).to(device).view(1, S, 1).repeat(B, 1, nsample) 39 | dists = sqrdists[ 40 | batch_indices.reshape(-1), 41 | query_indices.reshape(-1), 42 | group_idx.reshape(-1) 43 | ] 44 | dists = dists.reshape(B, S, nsample) 45 | return group_idx, dists 46 | return group_idx 47 | 48 | 49 | def square_distance(src, dst): 50 | """ 51 | Calculate Euclid distance between each two points. 52 | 53 | src^T * dst = xn * xm + yn * ym + zn * zm; 54 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 55 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 56 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 57 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 58 | 59 | Input: 60 | src: source points, [B, N, C] 61 | dst: target points, [B, M, C] 62 | Output: 63 | dist: per-point square distance, [B, N, M] 64 | """ 65 | B, N, _ = src.shape 66 | _, M, _ = dst.shape 67 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 68 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 69 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 70 | return dist 71 | 72 | 73 | def farthest_point_sample(xyz, npoint): 74 | """ 75 | Input: 76 | xyz: pointcloud data, [B, N, 3] 77 | npoint: number of samples 78 | Return: 79 | centroids: sampled pointcloud index, [B, npoint] 80 | """ 81 | device = xyz.device 82 | B, N, C = xyz.shape 83 | centroids = torch.zeros(B, npoint, dtype=torch.long, device=device) 84 | distance = torch.ones((B, N), device=device) * 1e10 85 | farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device) 86 | batch_indices = torch.arange(B, dtype=torch.long, device=device) 87 | for i in range(npoint): 88 | centroids[:, i] = farthest 89 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 90 | dist = torch.sum((xyz - centroid) ** 2, -1) 91 | mask = dist < distance 92 | distance[mask] = dist[mask] 93 | farthest = torch.max(distance, -1)[1] 94 | return centroids 95 | 96 | 97 | def index_points(points, idx): 98 | """ 99 | 100 | Input: 101 | points: input points data, [B, N, C] 102 | idx: sample index data, [B, S] 103 | Return: 104 | new_points:, indexed points data, [B, S, C] 105 | """ 106 | device = points.device 107 | B = points.shape[0] 108 | view_shape = list(idx.shape) 109 | view_shape[1:] = [1] * (len(view_shape) - 1) 110 | repeat_shape = list(idx.shape) 111 | repeat_shape[0] = 1 112 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 113 | new_points = points[batch_indices, idx, :] 114 | return new_points 115 | 116 | 117 | class PointNetSetAbstractionMsg(nn.Module): 118 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 119 | super(PointNetSetAbstractionMsg, self).__init__() 120 | self.npoint = npoint 121 | self.radius_list = radius_list 122 | self.nsample_list = nsample_list 123 | self.conv_blocks = nn.ModuleList() 124 | self.bn_blocks = nn.ModuleList() 125 | for i in range(len(mlp_list)): 126 | convs = nn.ModuleList() 127 | bns = nn.ModuleList() 128 | last_channel = in_channel + 3 129 | for out_channel in mlp_list[i]: 130 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 131 | bns.append(nn.BatchNorm2d(out_channel)) 132 | last_channel = out_channel 133 | self.conv_blocks.append(convs) 134 | self.bn_blocks.append(bns) 135 | 136 | def forward(self, xyz, points): 137 | """ 138 | Input: 139 | xyz: input points position data, [B, C, N] 140 | points: input points data, [B, D, N] 141 | Return: 142 | new_xyz: sampled points position data, [B, C, S] 143 | new_points_concat: sample points feature data, [B, D', S] 144 | """ 145 | xyz = xyz.permute(0, 2, 1) 146 | if points is not None: 147 | points = points.permute(0, 2, 1) 148 | 149 | B, N, C = xyz.shape 150 | S = self.npoint 151 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 152 | new_points_list = [] 153 | for i, radius in enumerate(self.radius_list): 154 | K = self.nsample_list[i] 155 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 156 | grouped_xyz = index_points(xyz, group_idx) 157 | grouped_xyz -= new_xyz.view(B, S, 1, C) 158 | if points is not None: 159 | grouped_points = index_points(points, group_idx) 160 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 161 | else: 162 | grouped_points = grouped_xyz 163 | 164 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 165 | for j in range(len(self.conv_blocks[i])): 166 | conv = self.conv_blocks[i][j] 167 | bn = self.bn_blocks[i][j] 168 | grouped_points = F.relu(bn(conv(grouped_points))) 169 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 170 | new_points_list.append(new_points) 171 | 172 | new_xyz = new_xyz.permute(0, 2, 1) 173 | new_points_concat = torch.cat(new_points_list, dim=1) 174 | return new_xyz, new_points_concat 175 | 176 | 177 | class PointNetFeaturePropagation(nn.Module): 178 | def __init__(self, in_channel, mlp): 179 | super(PointNetFeaturePropagation, self).__init__() 180 | self.mlp_convs = nn.ModuleList() 181 | self.mlp_bns = nn.ModuleList() 182 | last_channel = in_channel 183 | for out_channel in mlp: 184 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 185 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 186 | last_channel = out_channel 187 | 188 | def forward(self, xyz1, xyz2, points1, points2): 189 | """ 190 | Input: 191 | xyz1: input points position data, [B, C, N] 192 | xyz2: sampled input points position data, [B, C, S] 193 | points1: input points data, [B, D, N] 194 | points2: input points data, [B, D, S] 195 | Return: 196 | new_points: upsampled points data, [B, D', N] 197 | """ 198 | xyz1 = xyz1.permute(0, 2, 1) 199 | xyz2 = xyz2.permute(0, 2, 1) 200 | 201 | points2 = points2.permute(0, 2, 1) 202 | B, N, C = xyz1.shape 203 | _, S, _ = xyz2.shape 204 | 205 | if S == 1: 206 | interpolated_points = points2.repeat(1, N, 1) 207 | else: 208 | dists = square_distance(xyz1, xyz2) 209 | dists, idx = dists.sort(dim=-1) 210 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 211 | 212 | dist_recip = 1.0 / (dists + 1e-8) 213 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 214 | weight = dist_recip / norm 215 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 216 | 217 | if points1 is not None: 218 | points1 = points1.permute(0, 2, 1) 219 | new_points = torch.cat([points1, interpolated_points], dim=-1) 220 | else: 221 | new_points = interpolated_points 222 | 223 | new_points = new_points.permute(0, 2, 1) 224 | for i, conv in enumerate(self.mlp_convs): 225 | bn = self.mlp_bns[i] 226 | new_points = F.relu(bn(conv(new_points))) 227 | return new_points 228 | 229 | 230 | class PointNetEncoder(nn.Module): 231 | def __init__(self, feat_dims, in_channels, **kwargs): 232 | super(PointNetEncoder, self).__init__() 233 | self.feat_dims = feat_dims 234 | self.conv1 = torch.nn.Conv1d(in_channels, 128, 1) 235 | self.conv2 = torch.nn.Conv1d(128, 128, 1) 236 | self.conv3 = torch.nn.Conv1d(128, 128, 1) 237 | self.conv4 = torch.nn.Conv1d(128, feat_dims, 1) 238 | # self.conv3 = torch.nn.Conv1d(128, feat_dims, 1) 239 | 240 | self.bn1 = nn.BatchNorm1d(128) 241 | self.bn2 = nn.BatchNorm1d(128) 242 | self.bn3 = nn.BatchNorm1d(128) 243 | self.bn4 = nn.BatchNorm1d(feat_dims) 244 | # self.bn3 = nn.BatchNorm1d(feat_dims) 245 | 246 | def forward(self, x, global_feat): 247 | B, D, N = x.size() 248 | trans = x 249 | x = x.transpose(2, 1) 250 | if D > 3: 251 | feature = x[:, :, 3:] 252 | x = x[:, :, :3] 253 | if D > 3: 254 | x = torch.cat([x, feature], dim=2) 255 | x = x.transpose(2, 1) 256 | x = F.relu(self.bn1(self.conv1(x))) 257 | x = F.relu(self.bn2(self.conv2(x))) 258 | x = F.relu(self.bn3(self.conv3(x))) 259 | x = self.bn4(self.conv4(x)) 260 | # x = self.bn3(self.conv3(x)) 261 | if global_feat: 262 | x = torch.mean(x, 2, keepdim=True) 263 | x = x.view(-1, self.feat_dims) 264 | return x # [B, F] 265 | else: 266 | return x # [B, F, N] 267 | 268 | 269 | class tcnnPointNetEncoder(nn.Module): 270 | def __init__(self, feat_dims, in_channels, **kwargs): 271 | super(tcnnPointNetEncoder, self).__init__() 272 | with open(kwargs['tcnn_config']) as config_file: 273 | config = json.load(config_file) 274 | self.model = tcnn.NetworkWithInputEncoding( 275 | n_input_dims=in_channels, 276 | n_output_dims=feat_dims, 277 | encoding_config=config["encoding"], 278 | network_config=config["network"] 279 | ) 280 | self.feat_dims = feat_dims 281 | self.in_channels = in_channels 282 | 283 | def forward(self, x, global_feat): 284 | x = x.transpose(2, 1) 285 | B, N, D = x.size() 286 | x = self.model(x.reshape(-1, self.in_channels)) 287 | x = x.reshape(B, N, self.feat_dims) 288 | x = x.permute(0, 2, 1) # [B, F, N] 289 | if global_feat: 290 | x = torch.mean(x, 2, keepdim=True) 291 | x = x.view(-1, self.feat_dims) 292 | return x # [B, F] 293 | else: 294 | return x # [B, F, N] 295 | 296 | 297 | 298 | class PointNet(nn.Module): 299 | def __init__(self, feat_dims, in_channels): 300 | super(PointNet, self).__init__() 301 | 302 | self.sa1 = PointNetSetAbstractionMsg( 303 | npoint=1024, 304 | radius_list=[0.05, 0.1], 305 | nsample_list=[16, 32], 306 | in_channel=in_channels, 307 | mlp_list=[[16, 16, 32], [32, 32, 64]] 308 | ) 309 | self.sa2 = PointNetSetAbstractionMsg(256, [0.1, 0.2], [16, 32], 32+64, [[64, 64, 128], [64, 96, 128]]) 310 | self.sa3 = PointNetSetAbstractionMsg(64, [0.2, 0.4], [16, 32], 128+128, [[128, 196, 256], [128, 196, 256]]) 311 | self.sa4 = PointNetSetAbstractionMsg(16, [0.4, 0.8], [16, 32], 256+256, [[256, 256, 512], [256, 384, 512]]) 312 | self.fp4 = PointNetFeaturePropagation(512+512+256+256, [256, 256]) 313 | self.fp3 = PointNetFeaturePropagation(128+128+256, [256, 256]) 314 | self.fp2 = PointNetFeaturePropagation(32+64+256, [256, 128]) 315 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, feat_dims]) 316 | 317 | def forward(self, xyz): 318 | l0_points = xyz 319 | l0_xyz = xyz[:,:3,:] 320 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 321 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 322 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 323 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 324 | # import numpy as np 325 | # import open3d as o3d 326 | # import src.utils.o3d_helper as o3d_helper 327 | # def visualize(xyz, _color): 328 | # xyz_np = xyz.permute(0, 2, 1).detach().cpu().numpy()[0] 329 | # color = np.zeros_like(xyz_np) 330 | # color += _color[None, :] 331 | # return o3d_helper.np2pc(xyz_np, color) 332 | # visual_list = [] 333 | # visual_list.append(visualize(l0_xyz, np.array([1, 0, 0]))) 334 | # visual_list.append(visualize(l1_xyz, np.array([0, 1, 0]))) 335 | # visual_list.append(visualize(l2_xyz, np.array([0, 0, 1]))) 336 | # o3d.visualization.draw_geometries(visual_list) 337 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 338 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 339 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 340 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) 341 | 342 | return l0_points --------------------------------------------------------------------------------