├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── SLAM ├── __init__.py ├── eval.py ├── gaussian_pointcloud.py ├── icp.py ├── multiprocess │ ├── mapper.py │ ├── system.py │ └── tracker.py ├── render.py └── utils.py ├── arguments └── __init__.py ├── assets └── teaser.png ├── build_orb.sh ├── configs ├── base.yaml ├── orb_config │ ├── ours.yaml │ ├── tum1.yaml │ ├── tum2.yaml │ └── tum3.yaml ├── ours │ ├── corridor.yaml │ ├── home.yaml │ ├── hotel.yaml │ ├── office.yaml │ └── outside.yaml ├── ours_base.yaml ├── replica │ ├── office0.yaml │ ├── office1.yaml │ ├── office2.yaml │ ├── office3.yaml │ ├── office4.yaml │ ├── room0.yaml │ ├── room1.yaml │ └── room2.yaml ├── replica_base.yaml ├── scannetpp │ ├── 39f36da05b.yaml │ ├── 8b5caf3398.yaml │ ├── b20a261fdf.yaml │ └── f34d532901.yaml ├── scannetpp_base.yaml ├── tum │ ├── dataset │ │ ├── fr1_desk.yaml │ │ ├── fr2_xyz.yaml │ │ └── fr3_office.yaml │ ├── fr1_desk.yaml │ ├── fr2_xyz.yaml │ └── fr3_office.yaml └── tum_base.yaml ├── environment.yaml ├── metric.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── associate.py ├── download_ours.sh ├── download_replica.sh ├── download_tum.sh ├── eval_ate.py ├── parse_scannetpp.py └── parse_scannetpp.sh ├── slam.py ├── slam_mp.py └── utils ├── camera_utils.py ├── config_utils.py ├── general_utils.py ├── graphics_utils.py ├── loss_utils.py ├── monitor.py └── sh_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.vscode 2 | **/__pycache__ 3 | data 4 | output 5 | temp/ 6 | *.ipynb 7 | boost_1_80_0 8 | install 9 | opencv-4.2.0 10 | Pangolin 11 | cuda_utils 12 | diff-gaussian-rasterizer-depth 13 | simple-knn 14 | *.txt 15 | build 16 | lib -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules"] 2 | path = submodules 3 | url = git@github.com:CJAPPLE5/RTG-SLAM-cuda_utils.git 4 | branch = main 5 | [submodule "thirdParty/pybind"] 6 | path = thirdParty/pybind 7 | url = git@github.com:CJAPPLE5/RTG-SLAM-PYBIND.git 8 | branch = main 9 | [submodule "thirdParty/ORB-SLAM2-PYBIND"] 10 | path = thirdParty/ORB-SLAM2-PYBIND 11 | url = git@github.com:CJAPPLE5/RTG-SLAM-BACKEND.git 12 | branch = main 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RTG-SLAM: Real-time 3D Reconstruction at Scale Using Gaussian Splatting 2 | 3 | Zhexi Peng, Tianjia Shao, Liu Yong, Jingke Zhou, Yin Yang, Jingdong Wang, Kun Zhou 4 | ![Teaser image](assets/teaser.png) 5 | 6 | This repository contains the official authors implementation associated with the paper "RTG-SLAM: Real-time 3D Reconstruction 7 | at Scale Using Gaussian Splatting", which can be found [here](https://gapszju.github.io/RTG-SLAM/static/pdfs/RTG-SLAM_arxiv.pdf). 8 | 9 | Abstract: *We present Real-time Gaussian SLAM (RTG-SLAM), a real-time 3D reconstruction system with an RGBD camera for large-scale environments using Gaussian splatting. The system features a compact Gaussian representation and a highly efficient on-the-fly Gaussian optimization scheme. We force each Gaussian to be either opaque or nearly transparent, with the opaque ones fitting the surface and dominant colors, and transparent ones fitting residual colors. By rendering depth in a different way from color rendering, we let a single opaque Gaussian well fit a local surface region without the need of multiple overlapping Gaussians, hence largely reducing the memory and computation cost. For on-the-fly Gaussian optimization, we explicitly add Gaussians for three types of pixels per frame: newly observed, with large color errors, and with large depth errors. We also categorize all Gaussians into stable and unstable ones, where the stable Gaussians are expected to well fit previously observed RGBD images and otherwise unstable. We only optimize the unstable Gaussians and only render the pixels occupied by unstable Gaussians. In this way, both the number of Gaussians to be optimized and pixels to be rendered are largely reduced, and the optimization can be done in real time. We show real-time reconstructions of a variety of large scenes. Compared with the state-of-the-art NeRF-based RGBD SLAM, our system achieves comparable high-quality reconstruction but with around twice the speed and half the memory cost, and shows superior performance in the realism of novel view synthesis and camera tracking accuracy.* 10 | 11 | 12 | ## 1. Installation 13 | 14 | ### 1.1 Clone the Repository 15 | 16 | ``` 17 | git clone --recursive https://github.com/MisEty/RTG-SLAM.git 18 | ``` 19 | 20 | ### 1.2 Python Environment 21 | RTG-SLAM has been tested on python 3.9, CUDA=11.7, pytorch=1.13.1. The simplest way to install all dependences is to use [anaconda](https://www.anaconda.com/) and [pip](https://pypi.org/project/pip/) in the following steps: 22 | 23 | ```bash 24 | conda env create -f environment.yaml 25 | ``` 26 | 27 | ### 1.3 Modified ORB-SLAM2 Python Binding 28 | We have made some changes on ORB-SLAM2 to work with our ICP front-end and you can run this script to install pangolin, opencv, orbslam and boost-python binding. 29 | 30 | ```bash 31 | bash build_orb.sh 32 | ``` 33 | 34 | If you encounted the problem during install pangolin: 35 | 36 | ```bash 37 | xxx/Pangolin/src/video/drivers/ffmpeg.cpp: In function ‘std::__cxx11::string pangolin::FfmpegFmtToString(AVPixelFormat)’: 38 | xxx/Pangolin/src/video/drivers/ffmpeg.cpp:41:41: error: ‘AV_PIX_FMT_XVMC_MPEG2_MC’ was not declared in this scope 39 | ``` 40 | 41 | You can follow this [solution](https://github.com/stevenlovegrove/Pangolin/pull/318/files?diff=split&w=0). 42 | 43 | #### Note 44 | For real data, backend optimization based on ORB-SLAM2 is crucial. Therefore, you need to install the python binding for ORB-SLAM2 according to the steps. We have modified some code based on [ORB_SLAM2-PythonBindings 45 | ](https://github.com/jskinn/ORB_SLAM2-PythonBindings). If you encounter any problem related to compilation, you can refer to [ORB_SLAM2-PythonBindings 46 | ](https://github.com/jskinn/ORB_SLAM2-PythonBindings) to find solutions. Our ICP front-end works well when the depth is accurate so if you only want to test on synthetic dataset like Replica, you don't need to install ORB-SLAM2 python binding. 47 | 48 | 49 | ### 1.4 Test ORB-SLAM2 Python Binding 50 | ```bash 51 | cd thirdParty/pybind/examples 52 | python orbslam_rgbd_tum.py # please set voc_path, association_path ... 53 | python eval_ate.py path_to_groundtruth.txt trajectory.txt --plot PLOT --verbose 54 | ``` 55 | If you encount the problem `SIGSEGV` similar to [raulmur/ORB_SLAM2#844](https://github.com/raulmur/ORB_SLAM2/pull/844) [silencht/SG-SLAM#31](https://github.com/silencht/SG-SLAM/issues/31), you can try to modify 'thirdParty/ORB-SLAM2-PYBIND/CMakeLists.txt' 56 | ``` 57 | #set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE}") 58 | SET(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -O3 -march=native") 59 | #set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") 60 | SET(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native") 61 | ``` 62 | 63 | If the code runs without any error and the trajetory is corret, you can move on to the next step. 64 | 65 | ## 2. Dataset Preparation 66 | ### 2.1 Replica 67 | ``` 68 | bash scripts/download_replica.sh 69 | ``` 70 | ### 2.2 TUM-RGBD 71 | ```bash 72 | bash scripts/download_tum.sh 73 | ``` 74 | And copy config file to data folder. 75 | ```bash 76 | cp configs/tum/dataset/fr1_desk.yaml data/TUM_RGBD/rgbd_dataset_freiburg1_desk/config.yaml 77 | cp configs/tum/dataset/fr2_xyz.yaml data/TUM_RGBD/rgbd_dataset_freiburg2_xyz/config.yaml 78 | cp configs/tum/dataset/fr3_office.yaml data/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household/config.yaml 79 | ``` 80 | 81 | 82 | ### 2.3 ScanNet++ 83 | Please follow [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) to download dataset. And run 84 | ```bash 85 | python scripts/parse_scannetpp.py --data_path scannetpp_path/download/data/8b5caf3398 --output_path data/ScanNetpp/8b5caf3398 86 | ``` 87 | ### 2.4 Ours 88 | You can donwload our dataset from [google cloud](https://drive.google.com/drive/folders/161QHjVTHRCED9WmRWAlOEhJQ_GXxgtn5?usp=sharing). 89 | 90 | 91 | ### 2.5 Dataset 92 | 93 | ``` 94 | |-- data 95 | |-- Ours 96 | |-- hotel 97 | |-- Replica 98 | |-- office0 99 | |-- results 100 | |-- office0.ply 101 | |-- traj.txt 102 | |-- cam_params.json 103 | |-- TUM_RGBD 104 | |-- rgbd_dataset_freiburg1_desk 105 | |-- ScanNetpp 106 | |-- 8b5caf3398 107 | |-- color 108 | |-- depth 109 | |-- intrinsic 110 | |-- pose 111 | |-- mesh_aligned_cull.ply 112 | ``` 113 | 114 | ## 3. Run 115 | ### 3.1 Replica 116 | ```bash 117 | # Single Process: Recommended, More Stable 118 | python slam.py --config ./configs/replica/office0.yaml 119 | # Multi Process: 120 | python slam_mp.py --config ./configs/replica/office0.yaml 121 | ``` 122 | 123 | ### 3.2 TUM-RGBD 124 | ```bash 125 | # Single Process: Recommended, More Stable 126 | python slam.py --config ./configs/tum/fr1_desk.yaml 127 | # Multi Process: 128 | python slam_mp.py --config ./configs/tum/fr1_desk.yaml 129 | ``` 130 | 131 | ### 3.3 ScanNet++ 132 | ```bash 133 | # Single Process: Recommended, More Stable 134 | python slam.py --config ./configs/scannetpp/8b5caf3398.yaml 135 | # Multi Process: 136 | python slam_mp.py --config ./configs/scannetpp/8b5caf3398.yaml 137 | ``` 138 | 139 | ### 3.4 Ours 140 | ```bash 141 | # Single Process: Recommended, More Stable 142 | python slam.py --config ./configs/ours/hotel.yaml 143 | # Multi Process: 144 | python slam_mp.py --config ./configs/ours/hotel.yaml 145 | ``` 146 | 147 | ## 4. Evaluate 148 | You can run metric.py to evaluate the rendering quality on Replica, ScanNet++ and Ours dataset and calculate geometry accuracy on Replica and ScanNet++. 149 | There will be a csv result file in model path. 150 | The tracking accuracy is estimated right after running slam.py. The ate result is in model_path/save_traj. 151 | #### Note 152 | The script selects all images when computing psnr, lpips and ssim. Our method adds Gaussian according to the depth so the performance may decrease in the presence of significant depth noise or invalid depth (such as transparent materials, highly reflective materials, etc.). For fairness, when evaluating novel view synthesis on ScanNet++ in the paper, we manually removed images with large invalid depth areas. 153 | 154 | 155 | ```bash 156 | python metric.py --config config_path \ 157 | # eval the first k frames 158 | ----load_frame k \ 159 | # save pictures 160 | --save_pic 161 | ``` 162 | 163 | ### 4.1 Replica 164 | ```bash 165 | python metric.py --config ./configs/replica/office0.yaml 166 | ``` 167 | ### 4.2 TUM-RGBD 168 | ```bash 169 | python metric.py --config ./configs/tum/fr1_desk.yaml 170 | ``` 171 | ### 4.3 ScanNet++ 172 | ```bash 173 | python metric.py --config ./configs/scannetpp/8b5caf3398.yaml # all novel view images 174 | ``` 175 | ### 4.4 Ours 176 | ```bash 177 | python metric.py --config ./configs/ours/hotel.yaml 178 | ``` 179 | 180 |
181 |
182 |

BibTeX

183 |
@article{peng2024rtgslam,
184 |         author    = {Zhexi Peng and Tianjia Shao and Liu Yong and Jingke Zhou and Yin Yang and Jingdong Wang and Kun Zhou},
185 |         title     = {RTG-SLAM: Real-time 3D Reconstruction at Scale using Gaussian Splatting},
186 |         booktitle  = {ACM SIGGRAPH Conference Proceedings, Denver, CO, United States, July 28 - August 1, 2024},
187 |         year      = {2024},
188 |       }
189 |
190 |
191 | 192 | 193 | ## Acknowledgments 194 | This project is built upon [3DGS](https://github.com/graphdeco-inria/gaussian-splatting). The ORB-SLAM backend is based on [ORB-SLAM2](https://github.com/raulmur/ORB_SLAM2). The ORB-SLAM2 Python Binding is based on [ORB_SLAM2-PythonBindings](https://github.com/jskinn/ORB_SLAM2-PythonBindings). The evaluation script is adopted from [NICE-SLAM](https://github.com/cvg/nice-slam) and [Point-SLAM](https://github.com/eriksandstroem/Point-SLAM). We thank all the authors for their great work. 195 | -------------------------------------------------------------------------------- /SLAM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MisEty/RTG-SLAM/15ac7e3de5bdffd06e651d5a65435a5b1ad82173/SLAM/__init__.py -------------------------------------------------------------------------------- /SLAM/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from scene.cameras import Camera 11 | from SLAM.utils import * 12 | from utils.loss_utils import l1_loss, ssim, psnr 13 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 14 | import trimesh 15 | from pytorch_msssim import ms_ssim 16 | from scipy.spatial import cKDTree as KDTree 17 | from tqdm import tqdm 18 | 19 | def eval_ssim(image_es, image_gt): 20 | return ms_ssim( 21 | image_es.unsqueeze(0), 22 | image_gt.unsqueeze(0), 23 | data_range=1.0, 24 | size_average=True, 25 | ) 26 | 27 | 28 | loss_fn_alex = LearnedPerceptualImagePatchSimilarity( 29 | net_type="alex", normalize=True 30 | ).cuda() 31 | 32 | depth_error_max = 0.08 33 | transmission_max = 0.2 34 | color_hit_weight_max = 1 35 | depth_hit_weight_max = 1 36 | 37 | 38 | def eval_picture( 39 | render_output, 40 | frame: Camera, 41 | save_path, 42 | min_depth, 43 | max_depth, 44 | save_picture 45 | ): 46 | move_to_gpu(frame) 47 | image, depth, normal, index = ( 48 | render_output["render"], 49 | render_output["depth"], 50 | render_output["normal"], 51 | render_output["depth_index_map"], 52 | ) 53 | color_hit_weight, depth_hit_weight, T_map = ( 54 | render_output["color_hit_weight"], 55 | render_output["depth_hit_weight"], 56 | render_output["T_map"], 57 | ) 58 | # check color map 59 | gt_image = frame.original_image 60 | image_error = (gt_image - image).abs() 61 | # check others 62 | psnr_value = psnr(gt_image, image).mean() 63 | ssim_value = eval_ssim(image, gt_image).mean() 64 | lpips_value = loss_fn_alex( 65 | torch.clamp(gt_image.unsqueeze(0), 0.0, 1.0), 66 | torch.clamp(image.unsqueeze(0), 0.0, 1.0), 67 | ).item() 68 | 69 | color_loss = l1_loss(gt_image, image) 70 | 71 | if save_picture: 72 | image_concat = torch.concat([image, gt_image, image_error], dim=-1) 73 | torchvision.utils.save_image( 74 | image_concat, 75 | os.path.join(save_path, "color_compare.jpg"), 76 | ) 77 | 78 | # check depth map 79 | gt_depth = 255.0 * frame.original_depth 80 | valid_range_mask = (gt_depth > min_depth) & (gt_depth < max_depth) 81 | gt_depth[~valid_range_mask] = 0 82 | 83 | depth_error = (gt_depth - depth).abs() 84 | invalid_depth_mask = (index == -1) | (gt_depth == 0) 85 | depth_error[invalid_depth_mask] = 0 86 | 87 | valid_depth_mask = ~invalid_depth_mask 88 | pixel_num = depth.shape[1] * depth.shape[2] 89 | valid_pixel_ratio = valid_depth_mask.sum() / pixel_num 90 | depth_loss = l1_loss(depth[valid_depth_mask], gt_depth[valid_depth_mask]) 91 | 92 | if save_picture: 93 | min_depth = gt_depth[gt_depth > 0].min() 94 | max_depth = gt_depth[gt_depth > 0].max() 95 | colored_depth_render = color_value( 96 | depth, depth == 0, min_depth, max_depth, cv2.COLORMAP_INFERNO 97 | ) 98 | colored_depth_gt = color_value( 99 | gt_depth, gt_depth == 0, min_depth, max_depth, cv2.COLORMAP_INFERNO 100 | ) 101 | colored_depth_error = color_value( 102 | depth_error, invalid_depth_mask, 0.0, depth_error_max 103 | ) 104 | 105 | colored_depth_error = color_value( 106 | depth_error, invalid_depth_mask, 0, 0, cv2.COLORMAP_INFERNO 107 | ) 108 | colored_depth_error[:, (depth == 0)[0]] = 0 109 | depth_concat = torch.concat( 110 | [colored_depth_render, colored_depth_gt, colored_depth_error], dim=-1 111 | ) 112 | torchvision.utils.save_image( 113 | depth_concat, 114 | os.path.join(save_path, "depth_compare.jpg"), 115 | ) 116 | 117 | 118 | if save_picture: 119 | color_weight_color = color_value( 120 | color_hit_weight, None, 0, color_hit_weight_max, cv2.COLORMAP_JET 121 | ) 122 | depth_weight_color = color_value( 123 | depth_hit_weight, None, 0, depth_hit_weight_max, cv2.COLORMAP_JET 124 | ) 125 | T_color = color_value(T_map, None, 0, transmission_max, cv2.COLORMAP_JET) 126 | torchvision.utils.save_image( 127 | torch.concat([color_weight_color, depth_weight_color, T_color], dim=-1), 128 | os.path.join(save_path, "weight_compare.png"), 129 | ) 130 | 131 | normal_loss = torch.tensor(0) 132 | # save log 133 | log_info = "valid pixel ratio={:.2%}, color loss={:.3f}, depth loss={:.3f}cm, normal loss={:.3f}, psnr={:.3f}".format( 134 | valid_pixel_ratio, color_loss, depth_loss * 100, normal_loss, psnr_value 135 | ) 136 | print(log_info) 137 | losses = { 138 | "valid_pixel_ratio": valid_pixel_ratio.item(), 139 | "depth_loss": depth_loss.item(), 140 | "normal_loss": normal_loss.item(), 141 | "psnr": psnr_value.item(), 142 | "ssim": ssim_value.item(), 143 | "lpips": lpips_value, 144 | } 145 | move_to_cpu(frame) 146 | 147 | return losses 148 | 149 | def completion_ratio(gt_points, rec_points, dist_th=0.03): 150 | gen_points_kd_tree = KDTree(rec_points) 151 | distances, _ = gen_points_kd_tree.query(gt_points) 152 | comp_ratio = np.mean((distances < dist_th).astype(np.float32)) 153 | return comp_ratio 154 | 155 | 156 | def accuracy_ratio(gt_points, rec_points, dist_th=0.03): 157 | gt_points_kd_tree = KDTree(gt_points) 158 | distances, _ = gt_points_kd_tree.query(rec_points) 159 | acc_ratio = np.mean((distances < dist_th).astype(np.float32)) 160 | return acc_ratio 161 | 162 | 163 | def accuracy(gt_points, rec_points): 164 | gt_points_kd_tree = KDTree(gt_points) 165 | distances, _ = gt_points_kd_tree.query(rec_points) 166 | acc = np.mean(distances) 167 | return acc 168 | 169 | 170 | def completion(gt_points, rec_points): 171 | gt_points_kd_tree = KDTree(rec_points) 172 | distances, _ = gt_points_kd_tree.query(gt_points) 173 | comp = np.mean(distances) 174 | return comp 175 | 176 | def eval_pcd( 177 | rec_meshfile, gt_meshfile, dist_thres=[0.03], transform=np.eye(4), 178 | sample_nums = 1000000 179 | ): 180 | """ 181 | 3D reconstruction metric. 182 | 183 | """ 184 | mesh_gt = trimesh.load(gt_meshfile, process=False) 185 | bbox = np.zeros([2, 3]) 186 | bbox[0] = mesh_gt.vertices.min(axis=0) - 0.05 187 | bbox[1] = mesh_gt.vertices.max(axis=0) + 0.05 188 | rec_pc = o3d.io.read_point_cloud(rec_meshfile) 189 | rec_pc.transform(transform) 190 | points = np.asarray(rec_pc.points) 191 | P = points.shape[0] 192 | points = points[np.random.choice(P, min(P, sample_nums), replace=False), :] 193 | rec_pc_tri = trimesh.PointCloud(vertices=points) 194 | 195 | gt_pc = trimesh.sample.sample_surface(mesh_gt, sample_nums) 196 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) 197 | print("compute acc") 198 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices) 199 | print("compute comp") 200 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices) 201 | Ps = {} 202 | Rs = {} 203 | Fs = {} 204 | for thre in tqdm(dist_thres): 205 | P = accuracy_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, dist_th=thre) * 100 206 | R = ( 207 | completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, dist_th=thre) 208 | * 100 209 | ) 210 | F1 = 2 * P * R / (P + R) 211 | Ps["P (< {})".format(thre)] = P 212 | Rs["R (< {})".format(thre)] = R 213 | Fs["F1 (< {})".format(thre)] = F1 214 | accuracy_rec *= 100 # convert to cm 215 | completion_rec *= 100 # convert to cm 216 | results = { 217 | "accuracy": accuracy_rec, 218 | "completion": completion_rec, 219 | } 220 | results.update(Ps) 221 | results.update(Rs) 222 | results.update(Fs) 223 | return results 224 | 225 | 226 | def eval_frame( 227 | mapping, 228 | cam, 229 | dir_name, 230 | run_picture=True, 231 | run_pcd=False, 232 | min_depth=0.5, 233 | max_depth=3.0, 234 | pcd_path=None, 235 | gt_mesh_path=None, 236 | dist_threshs=[0.03], 237 | sample_nums=1000000, 238 | pcd_transform=np.eye(4), 239 | save_picture=False, 240 | ): 241 | with torch.no_grad(): 242 | # save render 243 | frame_name = "frame_{:04d}".format(mapping.time) 244 | render_save_path = os.path.join(dir_name, frame_name) 245 | losses = {} 246 | if run_picture: 247 | os.makedirs(render_save_path, exist_ok=True) 248 | with torch.no_grad(): 249 | render_output = mapping.renderer.render( 250 | cam, mapping.global_params 251 | ) 252 | 253 | pic_loss = eval_picture( 254 | render_output, 255 | cam, 256 | render_save_path, 257 | min_depth, 258 | max_depth, 259 | save_picture, 260 | ) 261 | losses.update(pic_loss) 262 | 263 | 264 | if run_pcd and pcd_path is not None and gt_mesh_path is not None: 265 | os.makedirs(render_save_path, exist_ok=True) 266 | pcd_losses = eval_pcd( 267 | pcd_path, 268 | gt_mesh_path, 269 | dist_threshs, 270 | pcd_transform, 271 | sample_nums 272 | ) 273 | losses.update(pcd_losses) 274 | return losses 275 | -------------------------------------------------------------------------------- /SLAM/gaussian_pointcloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | from plyfile import PlyData, PlyElement 3 | from simple_knn._C import distCUDA2 4 | from torch import nn 5 | 6 | from SLAM.utils import * 7 | from utils.general_utils import ( 8 | build_rotation, 9 | build_covariance_from_scaling_rotation, 10 | inverse_sigmoid 11 | ) 12 | from utils.sh_utils import RGB2SH, SH2RGB 13 | 14 | 15 | class GaussianPointCloud(object): 16 | def setup_functions(self): 17 | self.scaling_activation = torch.exp 18 | self.scaling_inverse_activation = torch.log 19 | 20 | self.covariance_activation = build_covariance_from_scaling_rotation 21 | 22 | self.opacity_activation = torch.sigmoid 23 | self.inverse_opacity_activation = inverse_sigmoid 24 | 25 | self.rotation_activation = torch.nn.functional.normalize 26 | 27 | def __init__(self, args) -> None: 28 | # gaussian optimize parameters 29 | self._xyz = devF(torch.empty(0)) 30 | self._features_dc = devF(torch.empty(0)) 31 | self._features_rest = devF(torch.empty(0)) 32 | self._scaling = devF(torch.empty(0)) 33 | self._rotation = devF(torch.empty(0)) 34 | self._opacity = devF(torch.empty(0)) 35 | # map management paramters 36 | self._normal = devF(torch.empty(0)) 37 | self._confidence = devF(torch.empty(0)) 38 | self._add_tick = devI(torch.empty(0)) 39 | # error counter 40 | self._depth_error_counter = devI(torch.empty(0)) 41 | self._color_error_counter = devI(torch.empty(0)) 42 | 43 | self.init_opacity = args.init_opacity 44 | self.scale_factor = args.scale_factor 45 | self.min_radius = args.min_radius 46 | self.max_radius = args.max_radius 47 | self.max_sh_degree = args.max_sh_degree 48 | self.active_sh_degree = args.active_sh_degree 49 | assert self.active_sh_degree <= self.max_sh_degree 50 | self.xyz_factor = devF(torch.tensor(args.xyz_factor)) 51 | self.setup_functions() 52 | 53 | def densify(self, sigma, circle_num, levels): 54 | means3D = self._xyz 55 | normal = self.get_normal 56 | normal = normal.cpu() 57 | plane0, plane1, axis0, axis1 = self.get_plane 58 | plane0 = plane0.cpu() 59 | plane1 = plane1.cpu() 60 | axis0 = axis0.cpu() 61 | axis1 = axis1.cpu() 62 | P = normal.shape[0] 63 | 64 | # generate theta 65 | theta = torch.rand(1, circle_num) * torch.pi * 2 66 | theta = theta.repeat(1, levels * sigma) 67 | 68 | a_random = None 69 | b_random = None 70 | a_random_ = torch.ones(P, circle_num * levels) * axis0 * sigma 71 | b_random_ = torch.ones(P, circle_num * levels) * axis1 * sigma 72 | # normal = normal.repeat([sample_num]) 73 | for level in range(levels): 74 | a_random_[:, level * circle_num : (level + 1) * circle_num] *= ( 75 | level + 0.5 76 | ) / levels 77 | b_random_[:, level * circle_num : (level + 1) * circle_num] *= ( 78 | level + 0.5 79 | ) / levels 80 | 81 | for sigma_ in range(sigma): 82 | if a_random is None: 83 | a_random = a_random_ 84 | b_random = b_random_ 85 | else: 86 | # print(a_random.shape, a_random_.shape) 87 | a_random = torch.concat([a_random, a_random_ + axis0 * sigma_], dim=1) 88 | b_random = torch.concat([b_random, b_random_ + axis1 * sigma_], dim=1) 89 | 90 | x = a_random * torch.cos(theta) 91 | z = b_random * torch.sin(theta) 92 | 93 | xyz = torch.concat( 94 | [x[..., None], torch.zeros_like(x)[..., None], z[..., None]], dim=-1 95 | ).unsqueeze(-1) 96 | rotation = ( 97 | torch.stack([plane0, normal, plane1], dim=-1) 98 | .permute(0, 2, 1) 99 | .unsqueeze(1) 100 | .repeat(1, circle_num * levels * sigma, 1, 1) 101 | ) 102 | 103 | xyz = xyz.cpu() 104 | rotation = rotation.cpu() 105 | xyz = torch.matmul(rotation, xyz).squeeze(-1) 106 | means3D = means3D.cpu() 107 | xyz += means3D.unsqueeze(1).repeat(1, circle_num * levels * sigma, 1) 108 | normal = normal[:, None, :].repeat([1, circle_num * levels * sigma, 1]) 109 | xyz = xyz.reshape(-1, 3) 110 | normal = normal.reshape(-1, 3) 111 | 112 | pcd = o3d.geometry.PointCloud() 113 | pcd.points = o3d.utility.Vector3dVector(xyz.cpu().numpy()) 114 | pcd.normals = o3d.utility.Vector3dVector(normal.cpu().numpy()) 115 | 116 | return pcd 117 | 118 | def load(self, ply_path): 119 | plydata = PlyData.read(ply_path) 120 | xyz = np.stack( 121 | ( 122 | np.asarray(plydata.elements[0]["x"]), 123 | np.asarray(plydata.elements[0]["y"]), 124 | np.asarray(plydata.elements[0]["z"]), 125 | ), 126 | axis=1, 127 | ) 128 | P = xyz.shape[0] 129 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 130 | if "confidence" in plydata.elements[0]: 131 | confidences = np.asarray(plydata.elements[0]["confidence"])[..., np.newaxis] 132 | else: 133 | confidences = np.zeros((P, 1)) 134 | 135 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 136 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 137 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 138 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 139 | extra_f_names = [ 140 | p.name 141 | for p in plydata.elements[0].properties 142 | if p.name.startswith("f_rest_") 143 | ] 144 | extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) 145 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 146 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 147 | for idx, attr_name in enumerate(extra_f_names): 148 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 149 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 150 | features_extra = features_extra.reshape( 151 | (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1) 152 | ) 153 | scale_names = [ 154 | p.name 155 | for p in plydata.elements[0].properties 156 | if p.name.startswith("scale_") 157 | ] 158 | scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) 159 | scales = np.zeros((xyz.shape[0], len(scale_names))) 160 | for idx, attr_name in enumerate(scale_names): 161 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 162 | rot_names = [ 163 | p.name for p in plydata.elements[0].properties if p.name.startswith("rot") 164 | ] 165 | rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) 166 | rots = np.zeros((xyz.shape[0], len(rot_names))) 167 | for idx, attr_name in enumerate(rot_names): 168 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 169 | self._xyz = torch.tensor(xyz, dtype=torch.float, device="cuda") 170 | self._features_dc = ( 171 | torch.tensor(features_dc, dtype=torch.float, device="cuda") 172 | .transpose(1, 2) 173 | .contiguous() 174 | ) 175 | 176 | self._features_rest = ( 177 | torch.tensor(features_extra, dtype=torch.float, device="cuda") 178 | .transpose(1, 2) 179 | .contiguous() 180 | ) 181 | self._opacity = torch.tensor(opacities, dtype=torch.float, device="cuda") 182 | self._scaling = torch.tensor(scales, dtype=torch.float, device="cuda") 183 | self._rotation = torch.tensor(rots, dtype=torch.float, device="cuda") 184 | self._normal = self.get_normal 185 | self._confidence = torch.tensor(confidences, dtype=torch.float, device="cuda") 186 | 187 | self._add_tick = torch.zeros([P, 1], dtype=torch.int32, device="cuda") 188 | self._depth_error_counter = torch.zeros( 189 | [P, 1], dtype=torch.int32, device="cuda" 190 | ) 191 | self._color_error_counter = torch.zeros( 192 | [P, 1], dtype=torch.int32, device="cuda" 193 | ) 194 | 195 | def delete(self, delte_mask): 196 | self._xyz = self._xyz[~delte_mask] 197 | self._features_dc = self._features_dc[~delte_mask] 198 | self._features_rest = self._features_rest[~delte_mask] 199 | self._scaling = self._scaling[~delte_mask] 200 | self._rotation = self._rotation[~delte_mask] 201 | self._opacity = self._opacity[~delte_mask] 202 | self._normal = self._normal[~delte_mask] 203 | self._confidence = self._confidence[~delte_mask] 204 | self._add_tick = self._add_tick[~delte_mask] 205 | self._depth_error_counter = self._depth_error_counter[~delte_mask] 206 | self._color_error_counter = self._color_error_counter[~delte_mask] 207 | 208 | def remove(self, remove_mask): 209 | xyz = self._xyz[remove_mask] 210 | features_dc = self._features_dc[remove_mask] 211 | features_rest = self._features_rest[remove_mask] 212 | scaling = self._scaling[remove_mask] 213 | rotation = self._rotation[remove_mask] 214 | opacity = self._opacity[remove_mask] 215 | normal = self._normal[remove_mask] 216 | confidence = self._confidence[remove_mask] 217 | add_tick = self._add_tick[remove_mask] 218 | depth_error_counter = self._depth_error_counter[remove_mask] 219 | color_error_counter = self._color_error_counter[remove_mask] 220 | 221 | gaussian_params = { 222 | "xyz": xyz, 223 | "features_dc": features_dc, 224 | "features_rest": features_rest, 225 | "scaling": scaling, 226 | "rotation": rotation, 227 | "opacity": opacity, 228 | "normal": normal, 229 | "confidence": confidence, 230 | "add_tick": add_tick, 231 | "depth_error_counter": depth_error_counter, 232 | "color_error_counter": color_error_counter, 233 | } 234 | self.delete(remove_mask) 235 | return gaussian_params 236 | 237 | def detach(self): 238 | self._xyz = self._xyz.detach() 239 | self._features_dc = self._features_dc.detach() 240 | self._features_rest = self._features_rest.detach() 241 | self._scaling = self._scaling.detach() 242 | self._rotation = self._rotation.detach() 243 | self._opacity = self._opacity.detach() 244 | 245 | def parametrize(self, update_args): 246 | self._xyz = nn.Parameter(self._xyz.requires_grad_(True)) 247 | self._features_dc = nn.Parameter(self._features_dc.requires_grad_(True)) 248 | self._features_rest = nn.Parameter(self._features_rest.requires_grad_(True)) 249 | self._scaling = nn.Parameter(self._scaling.requires_grad_(True)) 250 | self._rotation = nn.Parameter(self._rotation.requires_grad_(True)) 251 | self._opacity = nn.Parameter(self._opacity.requires_grad_(True)) 252 | l = [ 253 | { 254 | "params": [self._xyz], 255 | "lr": update_args.position_lr, 256 | "name": "xyz", 257 | }, 258 | { 259 | "params": [self._features_dc], 260 | "lr": update_args.feature_lr, 261 | "name": "f_dc", 262 | }, 263 | { 264 | "params": [self._features_rest], 265 | "lr": update_args.feature_lr / 20.0, 266 | "name": "f_rest", 267 | }, 268 | { 269 | "params": [self._opacity], 270 | "lr": update_args.opacity_lr, 271 | "name": "opacity", 272 | }, 273 | { 274 | "params": [self._scaling], 275 | "lr": update_args.scaling_lr, 276 | "name": "scaling", 277 | }, 278 | { 279 | "params": [self._rotation], 280 | "lr": update_args.rotation_lr, 281 | "name": "rotation", 282 | }, 283 | ] 284 | return l 285 | 286 | def cat(self, paramters): 287 | self._xyz = torch.cat([self._xyz, paramters["xyz"]]) 288 | self._features_dc = torch.cat([self._features_dc, paramters["features_dc"]]) 289 | self._features_rest = torch.cat( 290 | [self._features_rest, paramters["features_rest"]] 291 | ) 292 | self._scaling = torch.cat([self._scaling, paramters["scaling"]]) 293 | self._rotation = torch.cat([self._rotation, paramters["rotation"]], dim=0) 294 | self._opacity = torch.cat([self._opacity, paramters["opacity"]]) 295 | self._confidence = torch.cat([self._confidence, paramters["confidence"]]) 296 | self._normal = self.get_normal 297 | self._add_tick = torch.cat([self._add_tick, paramters["add_tick"]]) 298 | self._depth_error_counter = torch.cat( 299 | [self._depth_error_counter, paramters["depth_error_counter"]] 300 | ) 301 | self._color_error_counter = torch.cat( 302 | [self._color_error_counter, paramters["color_error_counter"]] 303 | ) 304 | 305 | def add_empty_points(self, xyz, normal, color, time): 306 | """ 307 | :param xyz: [N, 3] 308 | :param normal: [N, 3] 309 | :param color: [N, 3] 310 | """ 311 | # preprocess 312 | assert xyz.shape[0] == color.shape[0] and color.shape[0] == normal.shape[0] 313 | if xyz.shape[0] < 1: 314 | return 315 | mag = l2_norm(normal) 316 | normal = normal / (mag + 1e-8) 317 | valid_normal_mask = normal.sum(dim=-1) != 0 318 | xyz = xyz[valid_normal_mask] 319 | normal = normal[valid_normal_mask] 320 | color = color[valid_normal_mask] 321 | points_num = xyz.shape[0] 322 | # compute SH feature 323 | features = devF(torch.zeros((points_num, 3, (self.max_sh_degree + 1) ** 2))) 324 | sh_color = RGB2SH(color) 325 | features[:, :3, 0] = sh_color 326 | features[:, 3:, 1:] = 0.0 327 | # init scale and rot 328 | raw_scales = devF(torch.ones(points_num, 3)) * 1e-6 329 | scales = torch.log(raw_scales) 330 | if ( 331 | self.xyz_factor[0] == 1 332 | and self.xyz_factor[1] == 1 333 | and self.xyz_factor[2] == 1 334 | ): 335 | rots = devF(torch.zeros((points_num, 4))) 336 | rots[:, 0] = 1 337 | else: 338 | z_axis = devF(torch.tensor([0, 0, 1]).repeat(points_num, 1)) 339 | rots = compute_rot(z_axis, normal) 340 | # init opacity 341 | opacities = inverse_sigmoid( 342 | self.init_opacity * devF(torch.ones((points_num, 1))) 343 | ) 344 | # init other flags 345 | confidence = devF(torch.zeros([points_num, 1])) 346 | add_tick = time * devI(torch.ones([points_num, 1])) 347 | 348 | depth_error_counter = devI(torch.zeros([points_num, 1])) 349 | color_error_counter = devI(torch.zeros([points_num, 1])) 350 | 351 | add_params = { 352 | "xyz": xyz, 353 | "features_dc": features[..., 0:1].transpose(1, 2).contiguous(), 354 | "features_rest": features[..., 1:].transpose(1, 2).contiguous(), 355 | "scaling": scales, 356 | "rotation": rots, 357 | "opacity": opacities, 358 | "normal": normal, 359 | "confidence": confidence, 360 | "add_tick": add_tick, 361 | "depth_error_counter": depth_error_counter, 362 | "color_error_counter": color_error_counter, 363 | } 364 | self.cat(add_params) 365 | 366 | def update_geometry(self, extra_xyz, extra_radius): 367 | xyz = self.get_xyz 368 | radius = self.get_radius 369 | points_num = self.get_points_num 370 | if torch.numel(extra_xyz) > 0: 371 | inbbox_mask = bbox_filter(xyz, extra_xyz) 372 | extra_xyz = extra_xyz[inbbox_mask] 373 | extra_radius = extra_radius[inbbox_mask] 374 | total_xyz = torch.cat([xyz, extra_xyz]) 375 | total_radius = torch.cat([radius, extra_radius]) 376 | _, knn_indices = distCUDA2(total_xyz.float().cuda()) 377 | knn_indices = knn_indices[:points_num].long() 378 | dist_0 = ( 379 | torch.norm(xyz - total_xyz[knn_indices[:, 0]], p=2, dim=1) 380 | - 3 * total_radius[knn_indices[:, 0]] 381 | ) 382 | dist_1 = ( 383 | torch.norm(xyz - total_xyz[knn_indices[:, 1]], p=2, dim=1) 384 | - 3 * total_radius[knn_indices[:, 1]] 385 | ) 386 | dist_2 = ( 387 | torch.norm(xyz - total_xyz[knn_indices[:, 2]], p=2, dim=1) 388 | - 3 * total_radius[knn_indices[:, 2]] 389 | ) 390 | invalid_dist_0 = dist_0 < 0 391 | invalid_dist_1 = dist_1 < 0 392 | invalid_dist_2 = dist_2 < 0 393 | 394 | invalid_scale_mask = invalid_dist_0 | invalid_dist_1 | invalid_dist_2 395 | dist2 = (dist_0**2 + dist_1**2 + dist_2**2) / 3 396 | scales = torch.sqrt(dist2) 397 | scales = torch.clip(scales, min=self.min_radius, max=self.max_radius) 398 | if (~invalid_scale_mask).sum() == 0: 399 | self.delete(invalid_scale_mask) 400 | else: 401 | scales = scales[..., None].repeat(1, 3) 402 | factor_scales = self.scale_factor * torch.mul(scales, self.xyz_factor) 403 | log_scales = torch.log(factor_scales) 404 | self._scaling = log_scales 405 | self.delete(invalid_scale_mask) 406 | 407 | def construct_list_of_attributes(self, include_confidence=True): 408 | l = ["x", "y", "z", "nx", "ny", "nz"] 409 | # All channels except the 3 DC 410 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): 411 | l.append("f_dc_{}".format(i)) 412 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): 413 | l.append("f_rest_{}".format(i)) 414 | l.append("opacity") 415 | for i in range(self._scaling.shape[1]): 416 | l.append("scale_{}".format(i)) 417 | for i in range(self._rotation.shape[1]): 418 | l.append("rot_{}".format(i)) 419 | if include_confidence: 420 | l.append("confidence") 421 | return l 422 | 423 | def save_model_ply(self, path, include_confidence=True): 424 | if self.get_points_num == 0: 425 | return 426 | xyz = self._xyz.detach().cpu().numpy() 427 | normals = np.zeros_like(xyz) 428 | f_dc = ( 429 | self._features_dc.detach() 430 | .transpose(1, 2) 431 | .flatten(start_dim=1) 432 | .contiguous() 433 | .cpu() 434 | .numpy() 435 | ) 436 | f_rest = ( 437 | self._features_rest.detach() 438 | .transpose(1, 2) 439 | .flatten(start_dim=1) 440 | .contiguous() 441 | .cpu() 442 | .numpy() 443 | ) 444 | opacities = self._opacity.detach().cpu().numpy() 445 | scale = self._scaling.detach().cpu().numpy() 446 | rotation = self._rotation.detach().cpu().numpy() 447 | confidence = self._confidence.detach().cpu().numpy() 448 | 449 | dtype_full = [ 450 | (attribute, "f4") 451 | for attribute in self.construct_list_of_attributes(include_confidence) 452 | ] 453 | 454 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 455 | if include_confidence: 456 | attributes = np.concatenate( 457 | (xyz, normals, f_dc, f_rest, opacities, scale, rotation, confidence), 458 | axis=1, 459 | ) 460 | else: 461 | attributes = np.concatenate( 462 | (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1 463 | ) 464 | elements[:] = list(map(tuple, attributes)) 465 | el = PlyElement.describe(elements, "vertex") 466 | PlyData([el]).write(path) 467 | 468 | def save_color_ply(self, path): 469 | os.makedirs(os.path.dirname(path), exist_ok=True) 470 | xyz = self._xyz.detach().cpu().numpy() 471 | f_dc = ( 472 | self._features_dc.detach() 473 | .transpose(1, 2) 474 | .flatten(start_dim=1) 475 | .contiguous() 476 | .cpu() 477 | .numpy() 478 | ) 479 | # save color ply 480 | elements = np.empty( 481 | xyz.shape[0], 482 | dtype=[ 483 | ("x", "f4"), 484 | ("y", "f4"), 485 | ("z", "f4"), 486 | ("red", "u1"), 487 | ("green", "u1"), 488 | ("blue", "u1"), 489 | ], 490 | ) 491 | color = SH2RGB(f_dc.reshape(-1, 3)) * 255 492 | attributes = np.concatenate((xyz, color), axis=1) 493 | elements[:] = list(map(tuple, attributes)) 494 | el = PlyElement.describe(elements, "vertex") 495 | file_name = os.path.basename(path) 496 | file_base = os.path.dirname(path) 497 | color_name = file_name.split(".") 498 | color_name = color_name[0] + "_color" + "." + color_name[1] 499 | color_path = os.path.join(file_base, color_name) 500 | PlyData([el]).write(color_path) 501 | 502 | @property 503 | def get_xyz(self): 504 | return self._xyz 505 | 506 | @property 507 | def get_points_num(self): 508 | return self._xyz.shape[0] 509 | 510 | @property 511 | def get_scaling(self): 512 | return self.scaling_activation(self._scaling) 513 | 514 | @property 515 | def get_radius(self): 516 | scales = self.get_scaling 517 | min_length, _ = torch.min(scales, dim=1) 518 | radius = (torch.sum(scales, dim=1) - min_length) / 2 519 | return radius 520 | 521 | @property 522 | def get_rotation(self): 523 | return self.rotation_activation(self._rotation) 524 | 525 | @property 526 | def get_R(self): 527 | return build_rotation(self.rotation_activation(self._rotation)) 528 | 529 | def get_covariance(self, scaling_modifier=1): 530 | return self.covariance_activation( 531 | self.get_scaling, scaling_modifier, self.get_rotation 532 | ) 533 | 534 | @property 535 | def get_xyz(self): 536 | return self._xyz 537 | 538 | @property 539 | def get_normal(self): 540 | scales = self.get_scaling 541 | R = self.get_R 542 | min_indices = torch.argmin(scales, dim=1) 543 | normal = torch.gather( 544 | R.transpose(1, 2), 545 | 1, 546 | min_indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, 3), 547 | ) 548 | normal = normal[:, 0, :] 549 | mag = l2_norm(normal) 550 | return normal / (mag + 1e-8) 551 | 552 | @property 553 | def get_plane(self): 554 | scales = self.get_scaling 555 | R = self.get_R 556 | plane_indices = scales.argsort(dim=1)[:, 1:] 557 | plane0 = torch.gather( 558 | R.transpose(1, 2), 559 | 1, 560 | plane_indices[:, 0].unsqueeze(1).unsqueeze(2).expand(-1, -1, 3), 561 | )[:, 0, :] 562 | plane1 = torch.gather( 563 | R.transpose(1, 2), 564 | 1, 565 | plane_indices[:, 1].unsqueeze(1).unsqueeze(2).expand(-1, -1, 3), 566 | )[:, 0, :] 567 | plane0 = plane0 / (l2_norm(plane0) + 1e-8) 568 | plane1 = plane1 / (l2_norm(plane1) + 1e-8) 569 | axis0 = torch.gather(scales, 1, plane_indices[:, 0:1]) 570 | axis1 = torch.gather(scales, 1, plane_indices[:, 1:]) 571 | return plane0, plane1, axis0, axis1 572 | 573 | @property 574 | def get_features(self): 575 | features_dc = self._features_dc 576 | features_rest = self._features_rest 577 | return torch.cat((features_dc, features_rest), dim=1) 578 | 579 | @property 580 | def get_opacity(self): 581 | return self.opacity_activation(self._opacity) 582 | 583 | @property 584 | def get_color(self): 585 | f_dc = self._features_dc.transpose(1, 2).flatten(start_dim=1).contiguous() 586 | color = SH2RGB(f_dc.reshape(-1, 3)) 587 | return color 588 | 589 | @property 590 | def get_confidence(self): 591 | return self._confidence 592 | 593 | @property 594 | def get_add_tick(self): 595 | return self._add_tick 596 | -------------------------------------------------------------------------------- /SLAM/icp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from SLAM.utils import * 6 | 7 | def point2plane_loss(p_t0, p_t1, n_t0, reduce="mean"): 8 | loss = ((p_t1 - p_t0) * n_t0).sum(dim=-1) 9 | if reduce == "mean": 10 | loss = (loss * loss).mean() 11 | else: 12 | loss = (loss * loss).sum() 13 | return loss 14 | 15 | 16 | class ICP(nn.Module): 17 | def __init__( 18 | self, 19 | max_iter=3, 20 | damping=1e-6, 21 | distance_threshold=0.2, 22 | normal_threshold=20, 23 | verbose=False, 24 | ): 25 | super(ICP, self).__init__() 26 | 27 | self.max_iterations = max_iter 28 | self.distance_threshold = distance_threshold 29 | self.normal_threshold = np.cos(np.deg2rad(normal_threshold)) 30 | self.damping = damping 31 | self.verbose = verbose 32 | 33 | def icp(self, pose10,vertex_t0,vertex_t1,normal_t0,normal_t1, K): 34 | mask0 = (vertex_t0[..., -1] > 0.0) 35 | 36 | for idx in range(self.max_iterations): 37 | # compute residuals 38 | residuals, J_F_p, valid_mask = self.compute_residuals_jacobian( 39 | vertex_t0, vertex_t1, normal_t0, normal_t1, mask0, pose10, K, 40 | self.distance_threshold, self.normal_threshold 41 | ) 42 | 43 | JtWJ = self.compute_jtj(J_F_p) # [B, 6, 6] 44 | JtR = self.compute_jtr(J_F_p, residuals) 45 | pose10 = self.GN_solver(JtWJ, JtR, pose10, damping=self.damping) 46 | H,W = vertex_t0.shape[:2] 47 | valid_ratio = valid_mask.sum() / H / W 48 | return pose10, valid_ratio 49 | 50 | 51 | @staticmethod 52 | def compute_residuals_jacobian(vertex0, vertex1, normal0, normal1, mask0, pose10, K, 53 | distance_threshold, normal_threshold): 54 | """ 55 | :param vertex0: vertex map 0 56 | :param vertex1: vertex map 1 57 | :param normal0: normal map 0 58 | :param normal1: normal map 1 59 | :param mask0: valid mask of template depth image 60 | :param pose10: current estimate of pose10 61 | :param K: intrinsics 62 | :return: residuals and Jacobians 63 | """ 64 | R = pose10[:3, :3] 65 | t = pose10[:3, 3] 66 | H, W, C = vertex0.shape 67 | 68 | rot_vertex0_to1 = (R @ vertex0.view(-1, 3).permute(1, 0)).permute(1, 0).view(H, W, 3) 69 | vertex0_to1 = rot_vertex0_to1 + t[None, None, :] 70 | normal0_to1 = (R @ normal0.view(-1, 3).permute(1, 0)).permute(1, 0).view(H, W, 3) 71 | 72 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 73 | x_, y_, z_ = vertex0_to1[..., 0], vertex0_to1[..., 1], vertex0_to1[..., 2] # [h, w] 74 | u_ = (x_ / z_) * fx + cx # [h, w] 75 | v_ = (y_ / z_) * fy + cy # [h, w] 76 | 77 | inviews = (u_ > 0) & (u_ < W-1) & (v_ > 0) & (v_ < H-1) 78 | # projective data association 79 | r_vertex1 = warp_features(vertex1, u_, v_) # [h, w, 3] 80 | r_normal1 = warp_features(normal1, u_, v_) # [h, w, 3] 81 | mask1 = r_vertex1[..., -1] > 0. 82 | diff = vertex0_to1 - r_vertex1 # [h, w, 3] 83 | 84 | normal_diff_mask = torch.sum(normal0_to1 * r_normal1, dim=-1) > normal_threshold 85 | 86 | # point-to-plane residuals 87 | res = (r_normal1 * diff).sum(dim=-1) # [h, w] 88 | # point-to-plane jacobians 89 | J_trs = r_normal1.view(-1, 3) # [hw, 3] 90 | J_rot = -torch.bmm(J_trs.unsqueeze(dim=1), batch_skew(vertex0_to1.view(-1, 3))).squeeze() # [hw, 3] 91 | 92 | # compose jacobians 93 | J_F_p = torch.cat((J_rot, J_trs), dim=-1).view(H, W, 6) # follow the order of [rot, trs] [hw, 1, 6] 94 | 95 | # occlusion 96 | occ = ~inviews | (diff.norm(p=2, dim=-1) > distance_threshold) 97 | invalid_mask = occ | ~mask0 | ~mask1 | ~normal_diff_mask 98 | J_F_p[invalid_mask] = 0. 99 | res[invalid_mask] = 0. 100 | 101 | res = res.view(-1, 1) # [hw, 1] 102 | J_F_p = J_F_p.view(-1, 1, 6) # [hw, 1, 6] 103 | 104 | return res, J_F_p, ~invalid_mask 105 | 106 | @staticmethod 107 | def compute_jtj(jac): 108 | # J in the dimension of (HW, C, 6) 109 | jacT = jac.transpose(-1, -2) # [HW, 6, C] 110 | jtj = torch.bmm(jacT, jac).sum(0) # [6, 6] 111 | return jtj # [6, 6] 112 | 113 | @staticmethod 114 | def compute_jtr(jac, res): 115 | # J in the dimension of (HW, C, 6) 116 | # res in the dimension of [HW, C] 117 | jacT = jac.transpose(-1, -2) # [HW, 6, C] 118 | jtr = torch.bmm(jacT, res.unsqueeze(-1)).sum(0) # [6, 1] 119 | return jtr # [6, 1] 120 | 121 | @staticmethod 122 | def GN_solver(JtJ, JtR, pose0, damping=1e-6): 123 | # Add a small diagonal damping. Without it, the training becomes quite unstable 124 | # Do not see a clear difference by removing the damping in inference though 125 | Hessian = lev_mar_H(JtJ, damping) 126 | # Hessian = JtJ 127 | updated_pose = forward_update_pose(Hessian, JtR, pose0) 128 | 129 | return updated_pose 130 | 131 | 132 | def warp_features(Feat, u, v, mode="nearest"): 133 | """ 134 | Warp the feature map (F) w.r.t. the grid (u, v). This is the non-batch version 135 | """ 136 | assert len(Feat.shape) == 3 137 | H, W, C = Feat.shape 138 | u_norm = u / ((W - 1) / 2) - 1 # [h, w] 139 | v_norm = v / ((H - 1) / 2) - 1 # [h, w] 140 | uv_grid = torch.cat((u_norm.view(1, H, W, 1), v_norm.view(1, H, W, 1)), dim=-1) 141 | Feat_warped = F.grid_sample( 142 | Feat.unsqueeze(0).permute(0, 3, 1, 2), 143 | uv_grid, 144 | mode=mode, 145 | padding_mode="border", 146 | align_corners=True, 147 | ).squeeze() 148 | return Feat_warped.permute(1, 2, 0) 149 | 150 | 151 | def compute_vertex(depth, K): 152 | H, W = depth.shape 153 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 154 | device = depth.device 155 | 156 | i, j = torch.meshgrid( 157 | torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H) 158 | ) # pytorch's meshgrid has indexing='ij' 159 | i = i.t().to(device) # [h, w] 160 | j = j.t().to(device) # [h, w] 161 | 162 | vertex = ( 163 | torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1).to(device) 164 | * depth[..., None] 165 | ) # [h, w, 3] 166 | return vertex 167 | 168 | 169 | def compute_normal(vertex_map): 170 | """Calculate the normal map from a depth map 171 | :param the input depth image 172 | ----------- 173 | :return the normal map 174 | """ 175 | H, W, C = vertex_map.shape 176 | img_dx, img_dy = feature_gradient(vertex_map, normalize_gradient=False) # [h, w, 3] 177 | 178 | normal = torch.cross(img_dx.view(-1, 3), img_dy.view(-1, 3)) 179 | normal = normal.view(H, W, 3) # [h, w, 3] 180 | 181 | mag = torch.norm(normal, p=2, dim=-1, keepdim=True) 182 | normal = normal / (mag + 1e-8) 183 | 184 | # filter out invalid pixels 185 | depth = vertex_map[:, :, -1] 186 | # 0.5 and 5. 187 | invalid_mask = (depth <= depth.min()) | (depth >= depth.max()) 188 | zero_normal = torch.zeros_like(normal) 189 | normal = torch.where(invalid_mask[..., None], zero_normal, normal) 190 | 191 | return normal 192 | 193 | 194 | def feature_gradient(img, normalize_gradient=True): 195 | """Calculate the gradient on the feature space using Sobel operator 196 | :param the input image 197 | ----------- 198 | :return the gradient of the image in x, y direction 199 | """ 200 | H, W, C = img.shape 201 | # to filter the image equally in each channel 202 | wx = ( 203 | torch.FloatTensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) 204 | .view(1, 1, 3, 3) 205 | .type_as(img) 206 | ) 207 | wy = ( 208 | torch.FloatTensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) 209 | .view(1, 1, 3, 3) 210 | .type_as(img) 211 | ) 212 | 213 | img_permuted = img.permute(2, 0, 1).view(-1, 1, H, W) # [c, 1, h, w] 214 | img_pad = F.pad(img_permuted, (1, 1, 1, 1), mode="replicate") 215 | img_dx = ( 216 | F.conv2d(img_pad, wx, stride=1, padding=0).squeeze().permute(1, 2, 0) 217 | ) # [h, w, c] 218 | img_dy = ( 219 | F.conv2d(img_pad, wy, stride=1, padding=0).squeeze().permute(1, 2, 0) 220 | ) # [h, w, c] 221 | 222 | if normalize_gradient: 223 | mag = torch.sqrt((img_dx**2) + (img_dy**2) + 1e-8) 224 | img_dx = img_dx / mag 225 | img_dy = img_dy / mag 226 | 227 | return img_dx, img_dy # [h, w, c] 228 | 229 | 230 | def batch_skew(w): 231 | """Generate a batch of skew-symmetric matrices. 232 | 233 | function tested in 'test_geometry.py' 234 | 235 | :input 236 | :param skew symmetric matrix entry Bx3 237 | --------- 238 | :return 239 | :param the skew-symmetric matrix Bx3x3 240 | """ 241 | B, D = w.shape 242 | assert D == 3 243 | o = torch.zeros(B).type_as(w) 244 | w0, w1, w2 = w[:, 0], w[:, 1], w[:, 2] 245 | return torch.stack((o, -w2, w1, w2, o, -w0, -w1, w0, o), 1).view(B, 3, 3) 246 | 247 | 248 | def lev_mar_H(JtWJ, damping): 249 | # Add a small diagonal damping. Without it, the training becomes quite unstable 250 | # Do not see a clear difference by removing the damping in inference though 251 | diag_mask = torch.eye(6).to(JtWJ) 252 | diagJtJ = diag_mask * JtWJ 253 | traceJtJ = torch.sum(diagJtJ) 254 | epsilon = (traceJtJ * damping) * diag_mask 255 | Hessian = JtWJ + epsilon 256 | return Hessian 257 | 258 | 259 | def forward_update_pose(H, Rhs, pose): 260 | """ 261 | :param H: 262 | :param Rhs: 263 | :param pose: 264 | :return: 265 | """ 266 | xi = least_square_solve(H, Rhs).squeeze() 267 | pose = exp_se3(xi) @ pose 268 | return pose 269 | 270 | 271 | def exp_se3(xi): 272 | """ 273 | :param x: Cartesian vector of Lie Algebra se(3) 274 | :return: exponential map of x 275 | """ 276 | w = xi[:3].squeeze() # rotation 277 | v = xi[3:6].squeeze() # translation 278 | w_hat = torch.tensor( 279 | [[0.0, -w[2], w[1]], [w[2], 0.0, -w[0]], [-w[1], w[0], 0.0]] 280 | ).to(xi) 281 | w_hat_second = torch.mm(w_hat, w_hat).to(xi) 282 | 283 | theta = torch.norm(w) 284 | theta_2 = theta**2 285 | theta_3 = theta**3 286 | sin_theta = torch.sin(theta) 287 | cos_theta = torch.cos(theta) 288 | eye_3 = torch.eye(3).to(xi) 289 | 290 | eps = 1e-8 291 | 292 | if theta <= eps: 293 | e_w = eye_3 294 | j = eye_3 295 | else: 296 | e_w = ( 297 | eye_3 298 | + w_hat * sin_theta / theta 299 | + w_hat_second * (1.0 - cos_theta) / theta_2 300 | ) 301 | k1 = (1 - cos_theta) / theta_2 302 | k2 = (theta - sin_theta) / theta_3 303 | j = eye_3 + k1 * w_hat + k2 * w_hat_second 304 | 305 | T = torch.eye(4).to(xi) 306 | T[:3, :3] = e_w 307 | T[:3, 3] = torch.mv(j, v) 308 | # T[:3, 3] = v 309 | 310 | return T 311 | 312 | 313 | def invH(H): 314 | """Generate (H+damp)^{-1}, with predicted damping values 315 | :param approximate Hessian matrix JtWJ 316 | ----------- 317 | :return the inverse of Hessian 318 | """ 319 | # GPU is much slower for matrix inverse when the size is small (compare to CPU) 320 | # works (50x faster) than inversing the dense matrix in GPU 321 | if H.is_cuda: 322 | invH = torch.inverse(H.cpu()).cuda() 323 | else: 324 | invH = torch.inverse(H) 325 | return invH 326 | 327 | 328 | def least_square_solve(H, Rhs): 329 | """ 330 | Solve for JTJ @ xi = -JTR 331 | """ 332 | inv_H = invH(H) # [B, 6, 6] square matrix 333 | xi = -inv_H @ Rhs 334 | return xi 335 | 336 | 337 | class ImagePyramids(nn.Module): 338 | """ Construct the pyramids in the image / depth space 339 | """ 340 | def __init__(self, scales, pool='avg'): 341 | super(ImagePyramids, self).__init__() 342 | if pool == 'avg': 343 | self.multiscales = [nn.AvgPool2d(1< self.icp_sample_normal_threshold 405 | depth_filling_mask = ( 406 | ( 407 | torch.abs(render_depth - frame_depth) 408 | > self.icp_sample_distance_threshold 409 | )[..., 0] 410 | | (render_depth == 0)[..., 0] 411 | | (normal_mask) 412 | ) & (frame_depth > 0)[..., 0] 413 | 414 | render_depth[depth_filling_mask] = frame_depth[depth_filling_mask] 415 | self.last_model_depth = render_depth 416 | 417 | def predict_pose(self, frame): 418 | K = frame["K"] 419 | frame_id = frame["frame_id"] 420 | if self.vertex_pyramid_t0 is None: 421 | pose_t1_t0 = np.eye(4) 422 | self.K = K 423 | else: 424 | if self.icp_use_model_depth and frame_id >= self.icp_warmup_frames: 425 | self.vertex_pyramid_t0 = build_vertex_pyramid(self.last_model_depth, self.depth_pyramid_builder, self.K) 426 | self.normal_pyramid_t0 = build_normal_pyramid(self.vertex_pyramid_t0) 427 | 428 | pose_t1_t0 = devF(torch.from_numpy(np.eye(4))) 429 | levels = len(self.icp_downscales) 430 | for level in range(levels): 431 | downscale = self.icp_downscales[level] 432 | K_downscale = K * downscale 433 | K_downscale[2,2] = 1.0 434 | vertex_t0 = self.vertex_pyramid_t0[level] 435 | vertex_t1 = self.vertex_pyramid_t1[level] 436 | normal_t0 = self.normal_pyramid_t0[level] 437 | normal_t1 = self.normal_pyramid_t1[level] 438 | pose_t1_t0, valid_ratio = self.icp_trackers[level].icp( 439 | pose_t1_t0, vertex_t1, vertex_t0, normal_t1, normal_t0, 440 | K_downscale 441 | ) 442 | pose_t1_t0 = pose_t1_t0.cpu().numpy() 443 | pose_t1_t0_pytorch = torch.tensor(pose_t1_t0).cuda().float() 444 | p2ploss = point2plane_loss(self.vertex_pyramid_t0[-1], 445 | self.vertex_pyramid_t1[-1] @ pose_t1_t0_pytorch[:3,:3].T + pose_t1_t0_pytorch[:3, 3], 446 | self.normal_pyramid_t0[-1], 447 | ) 448 | print(p2ploss, valid_ratio) 449 | tracking_success = True 450 | if p2ploss > self.icp_fail_threshold: 451 | tracking_success = False 452 | return pose_t1_t0, tracking_success 453 | -------------------------------------------------------------------------------- /SLAM/multiprocess/system.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.multiprocessing as mp 4 | import os 5 | from SLAM.multiprocess.mapper import MappingProcess 6 | from SLAM.multiprocess.tracker import TrackingProcess 7 | from SLAM.utils import merge_ply 8 | 9 | sleep_time = 0.01 10 | 11 | 12 | class SLAM(object): 13 | def __init__(self, map_params, optimization_params, dataset, args) -> None: 14 | 15 | self.verbose = True 16 | self._end = torch.zeros((2)).int().share_memory_() 17 | self.dataset = dataset 18 | 19 | # strict: mapping : tracker == 1 : sync_tracker2mapper_frames 20 | # loose: tracker frame_id should be: [mapper_frame_id - sync_tracker2mapper_frames, 21 | # mapper_frame_id + sync_tracker2mapper_frames] 22 | # free: there is no sync 23 | self.sync_tracker2mapper_method = map_params.sync_tracker2mapper_method 24 | self.sync_tracker2mapper_frames = map_params.sync_tracker2mapper_frames 25 | 26 | # tracker 2 mapper 27 | self._tracker2mapper_call = mp.Condition() 28 | self._tracker2mapper_frame_queue = mp.Queue() 29 | 30 | # mapper 2 tracker 31 | self._mapper2tracker_call = mp.Condition() 32 | self._mapper2tracker_map_queue = mp.Queue() 33 | 34 | # mapper 2 system 35 | self._mapper2system_call = mp.Condition() 36 | self._mapper2system_requires = [False, False] # tb call, save_model call 37 | self._mapper2system_map_queue = mp.Queue() 38 | self._mapper2system_tb_queue = mp.Queue() 39 | 40 | self.map_process = MappingProcess(args, optimization_params, 41 | self) 42 | self.track_process = TrackingProcess(self, args) 43 | self.save_path = self.map_process.save_path 44 | 45 | 46 | def run(self): 47 | processes = [] 48 | for rank in range(2): 49 | if rank == 0: 50 | print("start mapping process") 51 | p = mp.Process(target=self.mapping, args=(rank, )) 52 | elif rank == 1: 53 | print("start tracking process") 54 | p = mp.Process(target=self.tracking, args=(rank,)) 55 | p.start() 56 | processes.append(p) 57 | while self._end.sum() != 2: 58 | # process save model task 59 | with self._mapper2system_call: 60 | if self._mapper2system_requires.count(True) == 0: 61 | self._mapper2system_call.wait() 62 | 63 | if self._mapper2system_requires[0]: 64 | self._mapper2system_requires[0] = False 65 | 66 | # save model 67 | if self._mapper2system_requires[1]: 68 | while not self._mapper2system_map_queue.empty(): 69 | map_output = self._mapper2system_map_queue.get() 70 | self.save_model(map_output) 71 | del map_output 72 | break 73 | self._mapper2system_requires[1] = False 74 | if self._end[1] == 1: 75 | break 76 | print("system finish") 77 | while not self._mapper2system_map_queue.empty(): 78 | print("delete model") 79 | x = self._mapper2system_map_queue.get() 80 | self.save_model(x) 81 | del x 82 | self.release() 83 | self.track_process.stop() 84 | self.map_process.stop() 85 | print("main finish") 86 | for p in processes: 87 | p.join() 88 | 89 | def tracking(self, rank): 90 | print("start traking") 91 | self.track_process.run() 92 | 93 | def mapping(self, rank): 94 | print("start mapping") 95 | self.map_process.run() 96 | 97 | def release_mp_queue(self, mp_queue): 98 | while not mp_queue.empty(): 99 | x = mp_queue.get() 100 | del x 101 | 102 | def release(self): 103 | self.release_mp_queue(self._tracker2mapper_frame_queue) 104 | self.release_mp_queue(self._mapper2system_map_queue) 105 | self.release_mp_queue(self._mapper2system_tb_queue) 106 | self.release_mp_queue(self._mapper2tracker_map_queue) 107 | 108 | def save_model(self, map_output, save_data=True, save_sibr=True, save_merge=False): 109 | print("save model") 110 | self.pointcloud = map_output["pointcloud"] 111 | self.stable_pointcloud = map_output["stable_pointcloud"] 112 | self.map_time = map_output["time"] 113 | self.map_iter = map_output["iter"] 114 | print("save model:", self.map_time) 115 | frame_name = "frame_{:04d}".format(self.map_time) 116 | frame_save_path = os.path.join(self.save_path, "save_model", frame_name) 117 | os.makedirs(frame_save_path, exist_ok=True) 118 | path = os.path.join( 119 | frame_save_path, 120 | "iter_{:04d}".format(self.map_iter), 121 | ) 122 | if save_data: 123 | self.pointcloud.save_model_ply(path + ".ply", include_confidence=True) 124 | self.stable_pointcloud.save_model_ply( 125 | path + "_stable.ply", include_confidence=True 126 | ) 127 | if save_sibr: 128 | self.pointcloud.save_model_ply(path + "_sibr.ply", include_confidence=False) 129 | self.stable_pointcloud.save_model_ply( 130 | path + "_stable_sibr.ply", include_confidence=False 131 | ) 132 | if save_data and save_merge: 133 | merge_ply( 134 | path + ".ply", 135 | path + "_stable.ply", 136 | path + "_merge.ply", 137 | include_confidence=True, 138 | ) 139 | if save_sibr and save_merge: 140 | merge_ply( 141 | path + "_sibr.ply", 142 | path + "_stable_sibr.ply", 143 | path + "_merge_sibr.ply", 144 | include_confidence=False, 145 | ) 146 | print("save finish") -------------------------------------------------------------------------------- /SLAM/multiprocess/tracker.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import matplotlib.pyplot as plt 3 | from SLAM.gaussian_pointcloud import * 4 | 5 | import torch.multiprocessing as mp 6 | from SLAM.render import Renderer 7 | from collections import defaultdict 8 | from tqdm import tqdm 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | from SLAM.icp import IcpTracker 12 | from threading import Thread 13 | from utils.camera_utils import loadCam 14 | 15 | 16 | def convert_poses(trajs): 17 | poses = [] 18 | stamps = [] 19 | for traj in trajs: 20 | stamp, r00, r01, r02, t0, r10, r11, r12, t1, r20, r21, r22, t2 = traj 21 | pose_ = np.eye(4) 22 | pose_[:3, :3] = np.array([[r00, r01, r02], [r10, r11, r12], [r20, r21, r22]]) 23 | pose_[:3, 3] = np.array([t0, t1, t2]) 24 | poses.append(pose_) 25 | stamps.append(stamp) 26 | return poses, stamps 27 | 28 | 29 | class Tracker(object): 30 | def __init__(self, args): 31 | self.use_gt_pose = args.use_gt_pose 32 | self.mode = args.mode 33 | self.K = None 34 | 35 | self.min_depth = args.min_depth 36 | self.max_depth = args.max_depth 37 | self.depth_filter = args.depth_filter 38 | self.verbose = args.verbose 39 | 40 | self.icp_tracker = IcpTracker(args) 41 | 42 | self.status = defaultdict(bool) 43 | self.pose_gt = [] 44 | self.pose_es = [] 45 | self.timestampes = [] 46 | self.finish = mp.Condition() 47 | 48 | self.icp_success_count = 0 49 | 50 | self.use_orb_backend = args.use_orb_backend 51 | self.orb_vocab_path = args.orb_vocab_path 52 | self.orb_settings_path = args.orb_settings_path 53 | self.orb_backend = None 54 | self.orb_useicp = args.orb_useicp 55 | 56 | self.invalid_confidence_thresh = args.invalid_confidence_thresh 57 | 58 | if self.mode == "single process": 59 | self.initialize_orb() 60 | 61 | def get_new_poses_byid(self, frame_ids): 62 | if self.use_orb_backend and not self.use_gt_pose: 63 | new_poses = convert_poses(self.orb_backend.get_trajectory_points()) 64 | frame_poses = [new_poses[frame_id] for frame_id in frame_ids] 65 | else: 66 | frame_poses = [self.pose_es[frame_id] for frame_id in frame_ids] 67 | return frame_poses 68 | 69 | def get_new_poses(self): 70 | if self.use_orb_backend and not self.use_gt_pose: 71 | new_poses, _ = convert_poses(self.orb_backend.get_trajectory_points()) 72 | else: 73 | new_poses = None 74 | return new_poses 75 | 76 | def save_invalid_traing(self, path): 77 | if np.linalg.norm(self.pose_es[-1][:3, 3] - self.pose_gt[-1][:3, 3]) > 0.15: 78 | if self.track_mode == "icp": 79 | frame_id = len(self.pose_es) 80 | torch.save( 81 | self.icp_tracker.vertex_pyramid_t1, 82 | os.path.join(path, "vertex_pyramid_t1_{}.pt".format(frame_id)), 83 | ) 84 | torch.save( 85 | self.icp_tracker.vertex_pyramid_t0, 86 | os.path.join(path, "vertex_pyramid_t0_{}.pt".format(frame_id)), 87 | ) 88 | torch.save( 89 | self.icp_tracker.normal_pyramid_t1, 90 | os.path.join(path, "normal_pyramid_t1_{}.pt".format(frame_id)), 91 | ) 92 | torch.save( 93 | self.icp_tracker.normal_pyramid_t0, 94 | os.path.join(path, "normal_pyramid_t0_{}.pt".format(frame_id)), 95 | ) 96 | 97 | def map_preprocess(self, frame, frame_id): 98 | depth_map, color_map = ( 99 | frame.original_depth.permute(1, 2, 0) * 255, 100 | frame.original_image.permute(1, 2, 0), 101 | ) # [H, W, C], the image is scaled by 255 in function "PILtoTorch" 102 | depth_map_orb = ( 103 | frame.original_depth.permute(1, 2, 0).cpu().numpy() 104 | * 255 105 | * frame.depth_scale 106 | ).astype(np.uint16) 107 | intrinsic = frame.get_intrinsic 108 | # depth filter 109 | if self.depth_filter: 110 | depth_map_filter = bilateralFilter_torch(depth_map, 5, 2, 2) 111 | else: 112 | depth_map_filter = depth_map 113 | 114 | valid_range_mask = (depth_map_filter > self.min_depth) & (depth_map_filter < self.max_depth) 115 | depth_map_filter[~valid_range_mask] = 0.0 116 | # update depth map 117 | frame.original_depth = depth_map_filter.permute(2, 0, 1) / 255.0 118 | # compute geometry info 119 | vertex_map_c = compute_vertex_map(depth_map_filter, intrinsic) 120 | normal_map_c = compute_normal_map(vertex_map_c) 121 | confidence_map = compute_confidence_map(normal_map_c, intrinsic) 122 | 123 | # confidence_threshold tum: 0.5, others: 0.2 124 | invalid_confidence_mask = ((normal_map_c == 0).all(dim=-1)) | ( 125 | confidence_map < self.invalid_confidence_thresh 126 | )[..., 0] 127 | 128 | depth_map_filter[invalid_confidence_mask] = 0 129 | normal_map_c[invalid_confidence_mask] = 0 130 | vertex_map_c[invalid_confidence_mask] = 0 131 | confidence_map[invalid_confidence_mask] = 0 132 | 133 | color_map_orb = ( 134 | (frame.original_image * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) 135 | ) 136 | 137 | self.update_curr_status( 138 | frame, 139 | frame_id, 140 | depth_map, 141 | depth_map_filter, 142 | vertex_map_c, 143 | normal_map_c, 144 | color_map, 145 | color_map_orb, 146 | depth_map_orb, 147 | intrinsic, 148 | ) 149 | 150 | frame_map = {} 151 | frame_map["depth_map"] = depth_map_filter 152 | frame_map["color_map"] = color_map 153 | frame_map["normal_map_c"] = normal_map_c 154 | frame_map["vertex_map_c"] = vertex_map_c 155 | frame_map["confidence_map"] = confidence_map 156 | frame_map["invalid_confidence_mask"] = invalid_confidence_mask 157 | frame_map["time"] = frame_id 158 | 159 | return frame_map 160 | 161 | def update_curr_status( 162 | self, 163 | frame, 164 | frame_id, 165 | depth_t1, 166 | depth_t1_filter, 167 | vertex_t1, 168 | normal_t1, 169 | color_t1, 170 | color_orb, 171 | depth_orb, 172 | K, 173 | ): 174 | if self.K is None: 175 | self.K = K 176 | self.curr_frame = { 177 | "K": frame.get_intrinsic, 178 | "normal_map": normal_t1, 179 | "depth_map": depth_t1, 180 | "depth_map_filter": depth_t1_filter, 181 | "vertex_map": vertex_t1, 182 | "frame_id": frame_id, 183 | "pose_gt": frame.get_c2w.cpu().numpy(), # 1 184 | "color_map": color_t1, 185 | "timestamp": frame.timestamp, # 1 186 | "color_map_orb": color_orb, # 1 187 | "depth_map_orb": depth_orb, # 1 188 | } 189 | self.icp_tracker.update_curr_status(depth_t1_filter, self.K) 190 | 191 | def update_last_status_v2( 192 | self, frame, render_depth, frame_depth, render_normal, frame_normal 193 | ): 194 | intrinsic = frame.get_intrinsic 195 | normal_mask = ( 196 | 1 - F.cosine_similarity(render_normal, frame_normal, dim=-1) 197 | ) < self.icp_sample_normal_threshold 198 | depth_filling_mask = ( 199 | ( 200 | torch.abs(render_depth - frame_depth) 201 | > self.icp_sample_distance_threshold 202 | )[..., 0] 203 | | (render_depth == 0)[..., 0] 204 | | (normal_mask) 205 | ) & (frame_depth > 0)[..., 0] 206 | 207 | render_depth[depth_filling_mask] = frame_depth[depth_filling_mask] 208 | render_depth[(frame_depth == 0)[..., 0]] = 0 209 | 210 | self.last_model_vertex = compute_vertex_map(render_depth, intrinsic) 211 | self.last_model_normal = compute_normal_map(self.last_model_vertex) 212 | 213 | def update_last_status( 214 | self, 215 | frame, 216 | render_depth, 217 | frame_depth, 218 | render_normal, 219 | frame_normal, 220 | ): 221 | self.icp_tracker.update_last_status( 222 | frame, render_depth, frame_depth, render_normal, frame_normal 223 | ) 224 | 225 | def refine_icp_pose(self, pose_t1_t0, tracking_success): 226 | if tracking_success and self.orb_useicp: 227 | print("success") 228 | self.orb_backend.track_with_icp_pose( 229 | self.curr_frame["color_map_orb"], 230 | self.curr_frame["depth_map_orb"], 231 | pose_t1_t0.astype(np.float32), 232 | self.curr_frame["timestamp"], 233 | ) 234 | time.sleep(0.005) 235 | else: 236 | self.orb_backend.track_with_orb_feature( 237 | self.curr_frame["color_map_orb"], 238 | self.curr_frame["depth_map_orb"], 239 | self.curr_frame["timestamp"], 240 | ) 241 | time.sleep(0.005) 242 | traj_history = self.orb_backend.get_trajectory_points() 243 | pose_es_t1, _ = convert_poses(traj_history[-2:]) 244 | return pose_es_t1[-1] 245 | 246 | def initialize_orb(self): 247 | if not self.use_gt_pose and self.use_orb_backend and self.orb_backend is None: 248 | import orbslam2 249 | print("init orb backend") 250 | self.orb_backend = orbslam2.System( 251 | self.orb_vocab_path, self.orb_settings_path, orbslam2.Sensor.RGBD 252 | ) 253 | self.orb_backend.set_use_viewer(False) 254 | self.orb_backend.initialize(self.orb_useicp) 255 | 256 | def initialize_tracker(self): 257 | if self.use_orb_backend: 258 | self.orb_backend.process_image_rgbd( 259 | self.curr_frame["color_map_orb"], 260 | self.curr_frame["depth_map_orb"], 261 | self.curr_frame["timestamp"], 262 | ) 263 | self.status["initialized"] = True 264 | 265 | def tracking(self, frame, frame_map): 266 | self.pose_gt.append(self.curr_frame["pose_gt"]) 267 | self.timestampes.append(self.curr_frame["timestamp"]) 268 | p2loss = 0 269 | tracking_success = True 270 | if self.use_gt_pose: 271 | pose_t1_w = self.pose_gt[-1] 272 | else: 273 | # initialize 274 | if not self.status["initialized"]: 275 | self.initialize_tracker() 276 | pose_t1_w = np.eye(4) 277 | else: 278 | pose_t1_t0, tracking_success = self.icp_tracker.predict_pose(self.curr_frame) 279 | if self.use_orb_backend: 280 | pose_t1_w = self.refine_icp_pose(pose_t1_t0, tracking_success) 281 | else: 282 | pose_t1_w = self.pose_es[-1] @ pose_t1_t0 283 | 284 | self.icp_tracker.move_last_status() 285 | self.pose_es.append(pose_t1_w) 286 | 287 | frame.updatePose(pose_t1_w) 288 | frame_map["vertex_map_w"] = transform_map( 289 | frame_map["vertex_map_c"], frame.get_c2w 290 | ) 291 | frame_map["normal_map_w"] = transform_map( 292 | frame_map["normal_map_c"], get_rot(frame.get_c2w) 293 | ) 294 | 295 | return tracking_success 296 | 297 | def eval_total_ate(self, pose_es, pose_gt): 298 | ates = [] 299 | for i in tqdm(range(1, len(pose_gt) + 1)): 300 | ates.append(self.eval_ate(pose_es, pose_gt, i)) 301 | ates = np.array(ates) 302 | return ates 303 | 304 | def save_ate_fig(self, ates, save_path, save_name): 305 | plt.plot(range(len(ates)), ates) 306 | plt.ylim(0, max(ates) + 0.1) 307 | plt.title("ate:{}".format(ates[-1])) 308 | plt.savefig(os.path.join(save_path, "{}.png".format(save_name))) 309 | 310 | 311 | def save_keyframe_traj(self, save_file): 312 | if self.use_orb_backend: 313 | poses, stamps = convert_poses(self.orb_backend.get_keyframe_points()) 314 | with open(save_file, "w") as f: 315 | for pose_id, pose_es_ in enumerate(poses): 316 | t = pose_es_[:3, 3] 317 | q = R.from_matrix(pose_es_[:3, :3]) 318 | f.write(str(stamps[pose_id]) + " ") 319 | for i in t.tolist(): 320 | f.write(str(i) + " ") 321 | for i in q.as_quat().tolist(): 322 | f.write(str(i) + " ") 323 | f.write("\n") 324 | 325 | def save_traj_tum(self, save_file): 326 | poses, stamps = convert_poses(self.orb_backend.get_trajectory_points()) 327 | with open(save_file, "w") as f: 328 | for pose_id, pose_es_ in enumerate(self.pose_es): 329 | t = pose_es_[:3, 3] 330 | q = R.from_matrix(pose_es_[:3, :3]) 331 | f.write(str(stamps[pose_id]) + " ") 332 | for i in t.tolist(): 333 | f.write(str(i) + " ") 334 | for i in q.as_quat().tolist(): 335 | f.write(str(i) + " ") 336 | f.write("\n") 337 | 338 | def save_orb_traj_tum(self, save_file): 339 | if self.use_orb_backend: 340 | poses, stamps = convert_poses(self.orb_backend.get_trajectory_points()) 341 | with open(save_file, "w") as f: 342 | for pose_id, pose_es_ in enumerate(poses): 343 | t = pose_es_[:3, 3] 344 | q = R.from_matrix(pose_es_[:3, :3]) 345 | f.write(str(stamps[pose_id]) + " ") 346 | for i in t.tolist(): 347 | f.write(str(i) + " ") 348 | for i in q.as_quat().tolist(): 349 | f.write(str(i) + " ") 350 | f.write("\n") 351 | 352 | def save_traj(self, save_path): 353 | save_traj_path = os.path.join(save_path, "save_traj") 354 | if not self.use_gt_pose and self.use_orb_backend: 355 | traj_history = self.orb_backend.get_trajectory_points() 356 | self.pose_es, _ = convert_poses(traj_history) 357 | pose_es = np.stack(self.pose_es, axis=0) 358 | pose_gt = np.stack(self.pose_gt, axis=0) 359 | ates_ba = self.eval_total_ate(pose_es, pose_gt) 360 | print("ate: ", ates_ba[-1]) 361 | np.save(os.path.join(save_traj_path, "pose_gt.npy"), pose_gt) 362 | np.save(os.path.join(save_traj_path, "pose_es.npy"), pose_es) 363 | self.save_ate_fig(ates_ba, save_traj_path, "ate") 364 | 365 | plt.figure() 366 | plt.plot(pose_es[:, 0, 3], pose_es[:, 1, 3]) 367 | plt.plot(pose_gt[:, 0, 3], pose_gt[:, 1, 3]) 368 | plt.legend(["es", "gt"]) 369 | plt.savefig(os.path.join(save_traj_path, "traj_xy.jpg")) 370 | 371 | if self.use_orb_backend: 372 | self.orb_backend.shutdown() 373 | 374 | def eval_ate(self, pose_es, pose_gt, frame_id=-1): 375 | pose_es = np.stack(pose_es, axis=0)[:frame_id, :3, 3] 376 | pose_gt = np.stack(pose_gt, axis=0)[:frame_id, :3, 3] 377 | ate = eval_ate(pose_gt, pose_es) 378 | return ate 379 | 380 | 381 | class TrackingProcess(Tracker): 382 | def __init__(self, slam, args): 383 | args.icp_use_model_depth = False 384 | super().__init__(args) 385 | 386 | self.args = args 387 | # online scanner 388 | self.use_online_scanner = args.use_online_scanner 389 | self.scanner_finish = False 390 | 391 | # sync mode 392 | self.sync_tracker2mapper_method = slam.sync_tracker2mapper_method 393 | self.sync_tracker2mapper_frames = slam.sync_tracker2mapper_frames 394 | 395 | # tracker2mapper 396 | self._tracker2mapper_call = slam._tracker2mapper_call 397 | self._tracker2mapper_frame_queue = slam._tracker2mapper_frame_queue 398 | 399 | self.mapper_running = True 400 | 401 | # mapper2tracker 402 | self._mapper2tracker_call = slam._mapper2tracker_call 403 | self._mapper2tracker_map_queue = slam._mapper2tracker_map_queue 404 | 405 | self.dataset_cameras = slam.dataset.scene_info.train_cameras 406 | self.map_process = slam.map_process 407 | self._end = slam._end 408 | self.max_fps = args.tracker_max_fps 409 | self.frame_time = 1.0 / self.max_fps 410 | self.frame_id = 0 411 | self.last_mapper_frame_id = 0 412 | 413 | self.last_frame = None 414 | self.last_global_params = None 415 | 416 | self.track_renderer = Renderer(args) 417 | self.save_path = args.save_path 418 | 419 | def map_preprocess_mp(self, frame, frame_id): 420 | self.map_input = super().map_preprocess(frame, frame_id) 421 | 422 | def send_frame_to_mapper(self): 423 | print("tracker send frame {} to mapper".format(self.map_input["time"])) 424 | self._tracker2mapper_call.acquire() 425 | self._tracker2mapper_frame_queue.put(self.map_input) 426 | self.map_process._requests[0] = True 427 | self._tracker2mapper_call.notify() 428 | self._tracker2mapper_call.release() 429 | 430 | def finish_(self): 431 | if self.use_online_scanner: 432 | return self.scanner_finish 433 | else: 434 | return self.frame_id >= len(self.dataset_cameras) 435 | 436 | def getNextFrame(self): 437 | frame_info = self.dataset_cameras[self.frame_id] 438 | frame = loadCam(self.args, self.frame_id, frame_info, 1) 439 | print("get frame: {}".format(self.frame_id)) 440 | self.frame_id += 1 441 | return frame 442 | 443 | 444 | def run(self): 445 | self.time = 0 446 | self.initialize_orb() 447 | 448 | while not self.finish_(): 449 | frame = self.getNextFrame() 450 | if frame is None: 451 | break 452 | frame_id = frame.uid 453 | print("current tracker frame = %d" % self.time) 454 | # update current map 455 | move_to_gpu(frame) 456 | 457 | self.map_preprocess_mp(frame, frame_id) 458 | self.tracking(frame, self.map_input) 459 | self.map_input["frame"] = copy.deepcopy(frame) 460 | self.map_input["frame"] = frame 461 | 462 | self.map_input["poses_new"] = self.get_new_poses() 463 | # send message to mapper 464 | 465 | self.send_frame_to_mapper() 466 | 467 | wait_begin = time.time() 468 | if not self.finish_() and self.mapper_running: 469 | if self.sync_tracker2mapper_method == "strict": 470 | if (frame_id + 1) % self.sync_tracker2mapper_frames == 0: 471 | with self._mapper2tracker_call: 472 | print("wait mapper to wakeup") 473 | print( 474 | "tracker buffer size: {}".format( 475 | self._tracker2mapper_frame_queue.qsize() 476 | ) 477 | ) 478 | self._mapper2tracker_call.wait() 479 | elif self.sync_tracker2mapper_method == "loose": 480 | if ( 481 | frame_id - self.last_mapper_frame_id 482 | ) > self.sync_tracker2mapper_frames: 483 | with self._mapper2tracker_call: 484 | print("wait mapper to wakeup") 485 | self._mapper2tracker_call.wait() 486 | else: 487 | pass 488 | wait_end = time.time() 489 | 490 | self.unpack_map_to_tracker() 491 | self.update_last_mapper_render(frame) 492 | self.update_viewer(frame) 493 | 494 | move_to_cpu(frame) 495 | 496 | self.time += 1 497 | # send a invalid time stamp as end signal 498 | self.map_input["time"] = -1 499 | self.send_frame_to_mapper() 500 | self.save_traj(self.save_path) 501 | self._end[0] = 1 502 | with self.finish: 503 | print("tracker wating finish") 504 | self.finish.wait() 505 | print("track finish") 506 | 507 | def stop(self): 508 | with self.finish: 509 | self.finish.notify() 510 | 511 | def unpack_map_to_tracker(self): 512 | self._mapper2tracker_call.acquire() 513 | while not self._mapper2tracker_map_queue.empty(): 514 | map_info = self._mapper2tracker_map_queue.get() 515 | self.last_frame = map_info["frame"] 516 | self.last_global_params = map_info["global_params"] 517 | self.last_mapper_frame_id = map_info["frame_id"] 518 | print("tracker unpack map {}".format(self.last_mapper_frame_id)) 519 | self._mapper2tracker_call.notify() 520 | self._mapper2tracker_call.release() 521 | 522 | def update_last_mapper_render(self, frame): 523 | pose_t0_w = frame.get_c2w.cpu().numpy() 524 | if self.last_frame is not None: 525 | pose_w_t0 = np.linalg.inv(pose_t0_w) 526 | self.last_frame.update(pose_w_t0[:3, :3].transpose(), pose_w_t0[:3, 3]) 527 | render_output = self.track_renderer.render( 528 | self.last_frame, 529 | self.last_global_params, 530 | None 531 | ) 532 | self.update_last_status( 533 | self.last_frame, 534 | render_output["depth"].permute(1, 2, 0), 535 | self.map_input["depth_map"], 536 | render_output["normal"].permute(1, 2, 0), 537 | self.map_input["normal_map_w"], 538 | ) 539 | -------------------------------------------------------------------------------- /SLAM/render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | 5 | from scene.cameras import Camera 6 | from SLAM.utils import devF, devI 7 | 8 | from diff_gaussian_rasterization_depth import ( 9 | GaussianRasterizationSettings as GaussianRasterizationSettings_depth, 10 | ) 11 | from diff_gaussian_rasterization_depth import ( 12 | GaussianRasterizer as GaussianRasterizer_depth, 13 | ) 14 | 15 | from utils.general_utils import ( 16 | build_covariance_from_scaling_rotation, 17 | inverse_sigmoid 18 | ) 19 | 20 | 21 | class Renderer: 22 | def setup_functions(self): 23 | self.scaling_activation = torch.exp 24 | self.scaling_inverse_activation = torch.log 25 | 26 | self.covariance_activation = build_covariance_from_scaling_rotation 27 | 28 | self.opacity_activation = torch.sigmoid 29 | self.inverse_opacity_activation = inverse_sigmoid 30 | 31 | self.rotation_activation = torch.nn.functional.normalize 32 | 33 | def __init__(self, args): 34 | self.raster_settings = None 35 | self.rasterizer = None 36 | self.bg_color = devF(torch.tensor([0, 0, 0])) 37 | self.renderer_opaque_threshold = args.renderer_opaque_threshold 38 | self.renderer_normal_threshold = np.cos( 39 | np.deg2rad(args.renderer_normal_threshold) 40 | ) 41 | self.scaling_modifier = 1.0 42 | self.renderer_depth_threshold = args.renderer_depth_threshold 43 | self.max_sh_degree = args.max_sh_degree 44 | self.color_sigma = args.color_sigma 45 | if args.active_sh_degree < 0: 46 | self.active_sh_degree = self.max_sh_degree 47 | else: 48 | self.active_sh_degree = args.active_sh_degree 49 | self.setup_functions() 50 | 51 | def get_scaling(self, scaling): 52 | return self.scaling_activation(scaling) 53 | 54 | def get_rotation(self, rotaion): 55 | return self.rotation_activation(rotaion) 56 | 57 | def get_covariance(self, scaling, rotaion, scaling_modifier=1): 58 | return self.covariance_activation(scaling, scaling_modifier, rotaion) 59 | 60 | def render( 61 | self, 62 | viewpoint_camera: Camera, 63 | gaussian_data, 64 | tile_mask=None, 65 | ): 66 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 67 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 68 | self.raster_settings = GaussianRasterizationSettings_depth( 69 | image_height=int(viewpoint_camera.image_height), 70 | image_width=int(viewpoint_camera.image_width), 71 | tanfovx=tanfovx, 72 | tanfovy=tanfovy, 73 | bg=self.bg_color, 74 | scale_modifier=self.scaling_modifier, 75 | viewmatrix=viewpoint_camera.world_view_transform, 76 | projmatrix=viewpoint_camera.full_proj_transform, 77 | sh_degree=self.active_sh_degree, 78 | campos=viewpoint_camera.camera_center, 79 | opaque_threshold=self.renderer_opaque_threshold, 80 | depth_threshold=self.renderer_depth_threshold, 81 | normal_threshold=self.renderer_normal_threshold, 82 | color_sigma=self.color_sigma, 83 | prefiltered=False, 84 | debug=False, 85 | cx=viewpoint_camera.cx, 86 | cy=viewpoint_camera.cy, 87 | T_threshold=0.0001, 88 | ) 89 | self.rasterizer = GaussianRasterizer_depth( 90 | raster_settings=self.raster_settings 91 | ) 92 | 93 | means3D = gaussian_data["xyz"] 94 | opacity = gaussian_data["opacity"] 95 | scales = gaussian_data["scales"] 96 | rotations = gaussian_data["rotations"] 97 | shs = gaussian_data["shs"] 98 | normal = gaussian_data["normal"] 99 | cov3D_precomp = None 100 | colors_precomp = None 101 | if tile_mask is None: 102 | tile_mask = devI( 103 | torch.ones( 104 | (viewpoint_camera.image_height + 15) // 16, 105 | (viewpoint_camera.image_width + 15) // 16, 106 | dtype=torch.int32, 107 | ) 108 | ) 109 | 110 | render_results = self.rasterizer( 111 | means3D=means3D, 112 | opacities=opacity, 113 | shs=shs, 114 | colors_precomp=colors_precomp, 115 | scales=scales, 116 | rotations=rotations, 117 | cov3D_precomp=cov3D_precomp, 118 | normal_w=normal, 119 | tile_mask=tile_mask, 120 | ) 121 | 122 | rendered_image = render_results[0] 123 | rendered_depth = render_results[1] 124 | color_index_map = render_results[2] 125 | depth_index_map = render_results[3] 126 | color_hit_weight = render_results[4] 127 | depth_hit_weight = render_results[5] 128 | T_map = render_results[6] 129 | 130 | render_normal = devF(torch.zeros_like(rendered_image)) 131 | render_normal[:, depth_index_map[0] > -1] = normal[ 132 | depth_index_map[depth_index_map > -1].long() 133 | ].permute(1, 0) 134 | 135 | results = { 136 | "render": rendered_image, 137 | "depth": rendered_depth, 138 | "normal": render_normal, 139 | "color_index_map": color_index_map, 140 | "depth_index_map": depth_index_map, 141 | "color_hit_weight": color_hit_weight, 142 | "depth_hit_weight": depth_hit_weight, 143 | "T_map": T_map, 144 | } 145 | return results 146 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from argparse import ArgumentParser, Namespace 15 | 16 | import numpy as np 17 | 18 | 19 | class GroupParams: 20 | pass 21 | 22 | 23 | class ParamGroup: 24 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 25 | group = parser.add_argument_group(name) 26 | for key, value in vars(self).items(): 27 | shorthand = False 28 | if key.startswith("_"): 29 | shorthand = True 30 | key = key[1:] 31 | t = type(value) 32 | value = value if not fill_none else None 33 | if shorthand: 34 | if t == bool: 35 | group.add_argument( 36 | "--" + key, ("-" + key[0:1]), default=value, action="store_true" 37 | ) 38 | else: 39 | group.add_argument( 40 | "--" + key, ("-" + key[0:1]), default=value, type=t 41 | ) 42 | else: 43 | if t == bool: 44 | group.add_argument("--" + key, default=value, action="store_true") 45 | else: 46 | group.add_argument("--" + key, default=value, type=t) 47 | 48 | def extract(self, args): 49 | group = GroupParams() 50 | for arg in vars(args).items(): 51 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 52 | if "densify_until_iter" == arg[0]: 53 | print(arg[0], arg[1]) 54 | setattr(group, arg[0], arg[1]) 55 | return group 56 | 57 | def extract_dict(self, config): 58 | group = GroupParams() 59 | for k, v in config.items(): 60 | if k in vars(self) or ("_" + k) in vars(self): 61 | setattr(group, k, v) 62 | return group 63 | 64 | 65 | class ModelParams(ParamGroup): 66 | def __init__(self, parser, sentinel=False): 67 | self.sh_degree = 3 68 | self._source_path = "" 69 | self._model_path = "" 70 | self._images = "images" 71 | self._white_background = False 72 | self.data_device = "cuda" 73 | self.eval = False 74 | self.init_mode = "random" 75 | self.frame_num = -1 76 | self.eval_llff = 8 77 | super().__init__(parser, "Loading Parameters", sentinel) 78 | 79 | def extract(self, args): 80 | g = super().extract(args) 81 | g.source_path = os.path.abspath(g.source_path) 82 | return g 83 | 84 | 85 | class PipelineParams(ParamGroup): 86 | def __init__(self, parser): 87 | self.convert_SHs_python = False 88 | self.compute_cov3D_python = False 89 | self.debug = False 90 | self.debug_gaussian_id = -1 91 | self.use_network_gui = False 92 | self.render_mode = "raw" 93 | self.fix_position = False 94 | self.fix_opacity = False 95 | self.init_opacity = 0.5 96 | self.fix_sh = False 97 | self.fix_cov = False 98 | self.fix_density = False 99 | self.opaque_threshold = 0.9 100 | 101 | super().__init__(parser, "Pipeline Parameters") 102 | 103 | 104 | class OptimizationParams(ParamGroup): 105 | def __init__(self, parser): 106 | self.train_iterations = 30_000 107 | self.position_lr = 0.0016 108 | self.feature_lr = 0.0025 109 | self.opacity_lr = 0.05 110 | self.scaling_lr = 0.005 111 | self.rotation_lr = 0.001 112 | 113 | self.color_weight = 0.8 114 | self.depth_weight = 1 115 | self.ssim_weight = 0.2 116 | self.history_weight = 0.1 117 | self.normal_weight = 0.1 118 | super().__init__(parser, "Optimization Parameters") 119 | 120 | 121 | class DatasetParams(ParamGroup): 122 | def __init__(self, parser, sentinel=False): 123 | self._source_path = "" 124 | self._model_path = "" 125 | self._images = "images" 126 | self._resolution = -1 127 | self._white_background = False 128 | self.type = "ours" 129 | self.data_device = "cuda" 130 | self.eval = False 131 | self.init_mode = "random" 132 | self.frame_num = -1 133 | self.frame_start = 0 134 | self.frame_step = 0 135 | self.eval_llff = 8 136 | self.sh_degree = 3 137 | self.preload = False 138 | self.resolution_scales = [1.0] 139 | super().__init__(parser, "Dataset Parameters", sentinel) 140 | 141 | def extract(self, args): 142 | g = super().extract(args) 143 | g.source_path = os.path.abspath(g.source_path) 144 | return g 145 | 146 | 147 | class MapParams(ParamGroup): 148 | def __init__(self, parser, sentinel=False): 149 | self.init_opacity = 0.999 150 | self.max_sh_degree = 4 151 | self.active_sh_degree = -1 152 | self.uniform_sample_num = 5000 153 | self.gaussian_update_iter = 300 154 | self.gaussian_update_frame = 1 155 | self.KNN_num = 15 156 | self.KNN_threshold = 0.005 157 | 158 | self.spatial_lr_scale = 1 159 | self.save_path = "output/slam_test" 160 | self.min_depth = 0 161 | self.max_depth = 0 162 | self.renderer_opaque_threshold = 0.7 163 | self.renderer_normal_threshold = 80 164 | self.renderer_depth_threshold = 1.0 165 | self.render_mode = "ours" 166 | 167 | self.memory_length = 10 168 | self.xyz_factor = [1, 1, 1] 169 | self.use_tensorboard = True 170 | self.add_depth_thres = 0.05 171 | self.add_normal_thres = 0.1 172 | self.add_color_thres = 0.1 173 | self.add_transmission_thres = 0.1 174 | self.transmission_sample_ratio = 0.5 175 | self.error_sample_ratio = 0.3 176 | self.save_step = 1 177 | self.stable_confidence_thres = 200 178 | self.unstable_time_window = 50 179 | self.min_radius = 0.01 180 | self.max_radius = 0.10 181 | self.scale_factor = 0.5 182 | self.color_sigma = 1.0 183 | self.depth_filter = False 184 | self.verbose = False 185 | 186 | 187 | self.keyframe_trans_thes = 0.3 188 | self.keyframe_theta_thes = 20 189 | self.global_keyframe_num = 3 190 | self.sync_tracker2mapper_method = "strict" 191 | self.sync_tracker2mapper_frames = 5 192 | super().__init__(parser, "Map Parameters", sentinel) 193 | 194 | 195 | def get_combined_args(parser: ArgumentParser): 196 | cmdlne_string = sys.argv[1:] 197 | cfgfile_string = "Namespace()" 198 | args_cmdline = parser.parse_args(cmdlne_string) 199 | 200 | try: 201 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 202 | print("Looking for config file in", cfgfilepath) 203 | with open(cfgfilepath) as cfg_file: 204 | print("Config file found: {}".format(cfgfilepath)) 205 | cfgfile_string = cfg_file.read() 206 | except TypeError: 207 | print("Config file not found at") 208 | pass 209 | args_cfgfile = eval(cfgfile_string) 210 | 211 | merged_dict = vars(args_cfgfile).copy() 212 | for k, v in vars(args_cmdline).items(): 213 | if v != None: 214 | merged_dict[k] = v 215 | return Namespace(**merged_dict) 216 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MisEty/RTG-SLAM/15ac7e3de5bdffd06e651d5a65435a5b1ad82173/assets/teaser.png -------------------------------------------------------------------------------- /build_orb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p thirdParty && cd thirdParty 4 | install_path=$(pwd)/install 5 | mkdir -p ${install_path} 6 | 7 | python_prefix=$(python -c "import sys; print(sys.prefix)") 8 | python_include=${python_prefix}/include/python3.9/ 9 | python_lib=${python_prefix}/lib/libpython3.9.so 10 | python_exe=${python_prefix}/bin/python 11 | python_env=${python_prefix}/lib/python3.9/site-packages/ 12 | numpy_include=$(python -c "import numpy; print(numpy.get_include())") 13 | 14 | echo ${python_env} 15 | 16 | # # build pangolin 17 | git clone -b v0.5 https://github.com/stevenlovegrove/Pangolin.git 18 | cd Pangolin 19 | mkdir -p build && cd build 20 | cmake .. -DCMAKE_INSTALL_PREFIX=${install_path} 21 | make install -j 22 | 23 | # build opencv-4.2.0 24 | cd ../../ 25 | wget https://github.com/opencv/opencv/archive/4.2.0.zip 26 | unzip 4.2.0.zip 27 | cd opencv-4.2.0 28 | mkdir -p build && cd build 29 | cmake .. -DCMAKE_INSTALL_PREFIX=${install_path} 30 | make install -j 31 | 32 | opencv_dir=${install_path}/lib/cmake/opencv4 33 | 34 | # build orbslam2 35 | cd ../../ 36 | cd ORB-SLAM2-PYBIND 37 | bash build.sh ${opencv_dir} ${install_path} 38 | cd ../ 39 | 40 | 41 | # # build pybind 42 | # # build boost 43 | wget -t 999 -c https://boostorg.jfrog.io/artifactory/main/release/1.80.0/source/boost_1_80_0.zip 44 | unzip boost_1_80_0.zip 45 | cd boost_1_80_0 46 | ./bootstrap.sh --with-libraries=python --prefix=${install_path} --with-python=${python_exe} 47 | 48 | # # ./b2 49 | ./b2 install --with-python include=${python_include} --prefix=${install_path} 50 | 51 | 52 | # # build orbslam_pybind 53 | cd ../pybind 54 | mkdir -p build && cd build 55 | 56 | cmake .. -DPYTHON_INCLUDE_DIRS=${python_include} \ 57 | -DPYTHON_LIBRARIES=${python_lib} \ 58 | -DPYTHON_EXECUTABLE=${python_exe} \ 59 | -DBoost_INCLUDE_DIRS=${install_path}/include/boost \ 60 | -DBoost_LIBRARIES=${install_path}/lib/libboost_python39.so \ 61 | -DORB_SLAM2_INCLUDE_DIR=${install_path}/include/ORB_SLAM2 \ 62 | -DORB_SLAM2_LIBRARIES=${install_path}/lib/libORB_SLAM2.so \ 63 | -DOpenCV_DIR=${install_path}/lib/cmake/opencv4 \ 64 | -DPangolin_DIR=${install_path}/lib/cmake/Pangolin \ 65 | -DPYTHON_NUMPY_INCLUDE_DIR=${numpy_include} \ 66 | -DCMAKE_INSTALL_PREFIX=${python_env} 67 | 68 | make install -j 69 | 70 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | parent: None 2 | quiet: False 3 | device_list: [0] 4 | save_path: "output/replica_test/debug" 5 | use_tensorboard: True 6 | record_mem: False 7 | verbose: False 8 | mode: "single process" 9 | use_network_viewer: False 10 | use_online_scanner: False 11 | sync_tracker2mapper_method: "strict" 12 | sync_tracker2mapper_frames: 5 13 | 14 | # dataset params: 15 | type: "Replica" 16 | source_path: "" 17 | frame_start: 0 18 | frame_step: 0 19 | frame_num: -1 20 | save_step: 2000 21 | preload: False 22 | resolution : 1 23 | resolution_scales: [1.0] 24 | data_device: "cuda" # only cuda work 25 | eval: False # whether select frames for eval 26 | eval_llff: 2 # the step for select frame tor eavl 27 | init_mode: "depth" 28 | 29 | # gaussian params 30 | active_sh_degree: 3 31 | max_sh_degree: 3 32 | xyz_factor: [1, 1, 0.1] # z should be smallest 33 | init_opacity: 0.99 34 | scale_factor: 1.0 35 | max_radius: 0.05 36 | min_radius: 0.001 37 | 38 | # map preprocess 39 | min_depth: 0.3 40 | max_depth: 5.0 41 | depth_filter: False 42 | invalid_confidence_thresh: 0.2 43 | global_keyframe_num: 3 44 | 45 | # map params 46 | memory_length: 1 47 | uniform_sample_num: 50000 48 | add_transmission_thres: 0.5 49 | transmission_sample_ratio: 1.0 50 | error_sample_ratio: 0.05 51 | add_depth_thres: 0.1 52 | add_color_thres: 0.1 53 | add_normal_thres: 1000 54 | history_merge_max_weight: 0.5 55 | 56 | # state manage 57 | keyframe_trans_thes: 0.3 58 | keyframe_theta_thes: 30 59 | stable_confidence_thres: 500 60 | unstable_time_window: 200 61 | KNN_num: 15 62 | KNN_threshold: -1 63 | 64 | # render params 65 | renderer_opaque_threshold: 0.6 66 | renderer_normal_threshold: 60 67 | renderer_depth_threshold: 1.0 68 | color_sigma: 3.0 69 | render_mode: "ours" # "torch", "ours" 70 | depth_mode: "normal" # 'alpha', "opaque", "normal" 71 | global_opt_top_ratio: 0.4 72 | 73 | 74 | # optimize params: 75 | gaussian_update_iter: 100 76 | gaussian_update_frame: 5 77 | final_global_iter: 10 78 | color_weight: 0.8 79 | depth_weight: 1.0 80 | ssim_weight: 0.2 81 | normal_weight: 0.0 82 | position_lr : 0.001 83 | feature_lr : 0.0005 84 | opacity_lr : 0.000 85 | scaling_lr : 0.004 86 | rotation_lr : 0.001 87 | feature_lr_coef: 1.0 88 | scaling_lr_coef: 1.0 89 | rotation_lr_coef: 1.0 90 | 91 | # ICP 92 | use_gt_pose: True 93 | icp_use_model_depth: False # if False, use dataset depth frame to frame 94 | icp_downscales: [0.25, 0.5, 1.0] 95 | icp_damping: 0.0001 96 | icp_downscale_iters: [5, 5, 5] 97 | icp_distance_threshold: 0.1 # m 98 | icp_normal_threshold: 20 # degree 99 | icp_sample_distance_threshold: 0.01 # m 100 | icp_sample_normal_threshold: 0.01 # cos similarity 101 | icp_warmup_frames: 0 102 | icp_fail_threshold: 0.02 103 | 104 | # orb backend 105 | use_orb_backend: False 106 | use_orb_viewer: False 107 | orb_vocab_path: "thirdParty/ORB-SLAM2-PYBIND/Vocabulary/ORBvoc.txt" 108 | orb_settings_path: "" 109 | tracker_max_fps: 15 110 | orb_useicp: True 111 | 112 | sync_tracker2mapper_method : "strict" 113 | # strict: mapping : tracker == 1 : sync_tracker2mapper_frames 114 | # loose: tracker frame_id should be: [mapper_frame_id - sync_tracker2mapper_frames, 115 | # mapper_frame_id + sync_tracker2mapper_frames] 116 | # free: there is no sync 117 | sync_tracker2mapper_frames : 5 118 | system_verbose: False 119 | tracker_max_fps: 30 120 | 121 | # evaluate 122 | renderer_opaque_threshold_eval: 0.5 123 | pcd_densify: False -------------------------------------------------------------------------------- /configs/orb_config/ours.yaml: -------------------------------------------------------------------------------- 1 | %YAML:1.0 2 | 3 | #-------------------------------------------------------------------------------------------- 4 | # Camera Parameters. Adjust them! 5 | #-------------------------------------------------------------------------------------------- 6 | 7 | # Camera calibration and distortion parameters (OpenCV) 8 | Camera.fx: 605.371 9 | Camera.fy: 605.245 10 | Camera.cx: 635.3085 11 | Camera.cy: 366.510 12 | 13 | 14 | Camera.k1: -0.052405 15 | Camera.k2: -1.758217 16 | Camera.p1: 0.000439 17 | Camera.p2: -0.000216 18 | Camera.k3: 1.028101 19 | Camera.k4: -0.171968 20 | Camera.k5: -1.575789 21 | Camera.k6: 0.954648 22 | 23 | # k1: -0.052405, k2: -1.758217, k3: 1.028101, 24 | # p1: 0.000439, p2: -0.000216, k4: -0.171968, k5: -1.575789, k6: 0.954648 25 | Camera.width: 1280 26 | Camera.height: 720 27 | 28 | # Camera frames per second 29 | Camera.fps: 15.0 30 | 31 | # IR projector baseline times fx (aprox.) 32 | Camera.bf: 40.0 33 | 34 | # Color order of the images (0: BGR, 1: RGB. It is ignored if images are grayscale) 35 | Camera.RGB: 1 36 | 37 | # Close/Far threshold. Baseline times. 38 | ThDepth: 40.0 39 | 40 | # Deptmap values factor 41 | DepthMapFactor: 1000.0 42 | 43 | #-------------------------------------------------------------------------------------------- 44 | # ORB Parameters 45 | #-------------------------------------------------------------------------------------------- 46 | 47 | # ORB Extractor: Number of features per image 48 | ORBextractor.nFeatures: 4000 49 | 50 | # ORB Extractor: Scale factor between levels in the scale pyramid 51 | ORBextractor.scaleFactor: 1.2 52 | 53 | # ORB Extractor: Number of levels in the scale pyramid 54 | ORBextractor.nLevels: 12 55 | 56 | # ORB Extractor: Fast threshold 57 | # Image is divided in a grid. At each cell FAST are extracted imposing a minimum response. 58 | # Firstly we impose iniThFAST. If no corners are detected we impose a lower value minThFAST 59 | # You can lower these values if your images have low contrast 60 | ORBextractor.iniThFAST: 20 61 | ORBextractor.minThFAST: 7 62 | 63 | #-------------------------------------------------------------------------------------------- 64 | # Viewer Parameters 65 | #-------------------------------------------------------------------------------------------- 66 | Viewer.KeyFrameSize: 0.05 67 | Viewer.KeyFrameLineWidth: 1 68 | Viewer.GraphLineWidth: 0.9 69 | Viewer.PointSize:2 70 | Viewer.CameraSize: 0.08 71 | Viewer.CameraLineWidth: 3 72 | Viewer.ViewpointX: 0 73 | Viewer.ViewpointY: -0.7 74 | Viewer.ViewpointZ: -1.8 75 | Viewer.ViewpointF: 500 76 | 77 | -------------------------------------------------------------------------------- /configs/orb_config/tum1.yaml: -------------------------------------------------------------------------------- 1 | %YAML:1.0 2 | 3 | #-------------------------------------------------------------------------------------------- 4 | # Camera Parameters. Adjust them! 5 | #-------------------------------------------------------------------------------------------- 6 | 7 | # Camera calibration and distortion parameters (OpenCV) 8 | Camera.fx: 517.306408 9 | Camera.fy: 516.469215 10 | Camera.cx: 318.643040 11 | Camera.cy: 255.313989 12 | 13 | Camera.k1: 0.262383 14 | Camera.k2: -0.953104 15 | Camera.p1: -0.005358 16 | Camera.p2: 0.002628 17 | Camera.k3: 1.163314 18 | 19 | Camera.width: 640 20 | Camera.height: 480 21 | 22 | # Camera frames per second 23 | Camera.fps: 30.0 24 | 25 | # IR projector baseline times fx (aprox.) 26 | Camera.bf: 40.0 27 | 28 | # Color order of the images (0: BGR, 1: RGB. It is ignored if images are grayscale) 29 | Camera.RGB: 1 30 | 31 | # Close/Far threshold. Baseline times. 32 | ThDepth: 40.0 33 | 34 | # Deptmap values factor 35 | DepthMapFactor: 5000.0 36 | 37 | #-------------------------------------------------------------------------------------------- 38 | # ORB Parameters 39 | #-------------------------------------------------------------------------------------------- 40 | 41 | # ORB Extractor: Number of features per image 42 | ORBextractor.nFeatures: 1000 43 | 44 | # ORB Extractor: Scale factor between levels in the scale pyramid 45 | ORBextractor.scaleFactor: 1.2 46 | 47 | # ORB Extractor: Number of levels in the scale pyramid 48 | ORBextractor.nLevels: 8 49 | 50 | # ORB Extractor: Fast threshold 51 | # Image is divided in a grid. At each cell FAST are extracted imposing a minimum response. 52 | # Firstly we impose iniThFAST. If no corners are detected we impose a lower value minThFAST 53 | # You can lower these values if your images have low contrast 54 | ORBextractor.iniThFAST: 20 55 | ORBextractor.minThFAST: 7 56 | 57 | #-------------------------------------------------------------------------------------------- 58 | # Viewer Parameters 59 | #-------------------------------------------------------------------------------------------- 60 | Viewer.KeyFrameSize: 0.05 61 | Viewer.KeyFrameLineWidth: 1 62 | Viewer.GraphLineWidth: 0.9 63 | Viewer.PointSize:2 64 | Viewer.CameraSize: 0.08 65 | Viewer.CameraLineWidth: 3 66 | Viewer.ViewpointX: 0 67 | Viewer.ViewpointY: -0.7 68 | Viewer.ViewpointZ: -1.8 69 | Viewer.ViewpointF: 500 70 | 71 | -------------------------------------------------------------------------------- /configs/orb_config/tum2.yaml: -------------------------------------------------------------------------------- 1 | %YAML:1.0 2 | 3 | #-------------------------------------------------------------------------------------------- 4 | # Camera Parameters. Adjust them! 5 | #-------------------------------------------------------------------------------------------- 6 | 7 | # Camera calibration and distortion parameters (OpenCV) 8 | Camera.fx: 520.908620 9 | Camera.fy: 521.007327 10 | Camera.cx: 325.141442 11 | Camera.cy: 249.701764 12 | 13 | Camera.k1: 0.231222 14 | Camera.k2: -0.784899 15 | Camera.p1: -0.003257 16 | Camera.p2: -0.000105 17 | Camera.k3: 0.917205 18 | 19 | Camera.width: 640 20 | Camera.height: 480 21 | 22 | # Camera frames per second 23 | Camera.fps: 30.0 24 | 25 | # IR projector baseline times fx (aprox.) 26 | Camera.bf: 40.0 27 | 28 | # Color order of the images (0: BGR, 1: RGB. It is ignored if images are grayscale) 29 | Camera.RGB: 1 30 | 31 | # Close/Far threshold. Baseline times. 32 | ThDepth: 40.0 33 | 34 | # Deptmap values factor 35 | DepthMapFactor: 5208.0 36 | 37 | #-------------------------------------------------------------------------------------------- 38 | # ORB Parameters 39 | #-------------------------------------------------------------------------------------------- 40 | 41 | # ORB Extractor: Number of features per image 42 | ORBextractor.nFeatures: 1000 43 | 44 | # ORB Extractor: Scale factor between levels in the scale pyramid 45 | ORBextractor.scaleFactor: 1.2 46 | 47 | # ORB Extractor: Number of levels in the scale pyramid 48 | ORBextractor.nLevels: 8 49 | 50 | # ORB Extractor: Fast threshold 51 | # Image is divided in a grid. At each cell FAST are extracted imposing a minimum response. 52 | # Firstly we impose iniThFAST. If no corners are detected we impose a lower value minThFAST 53 | # You can lower these values if your images have low contrast 54 | ORBextractor.iniThFAST: 20 55 | ORBextractor.minThFAST: 7 56 | 57 | #-------------------------------------------------------------------------------------------- 58 | # Viewer Parameters 59 | #-------------------------------------------------------------------------------------------- 60 | Viewer.KeyFrameSize: 0.05 61 | Viewer.KeyFrameLineWidth: 1 62 | Viewer.GraphLineWidth: 0.9 63 | Viewer.PointSize:2 64 | Viewer.CameraSize: 0.08 65 | Viewer.CameraLineWidth: 3 66 | Viewer.ViewpointX: 0 67 | Viewer.ViewpointY: -0.7 68 | Viewer.ViewpointZ: -1.8 69 | Viewer.ViewpointF: 500 70 | 71 | -------------------------------------------------------------------------------- /configs/orb_config/tum3.yaml: -------------------------------------------------------------------------------- 1 | %YAML:1.0 2 | 3 | #-------------------------------------------------------------------------------------------- 4 | # Camera Parameters. Adjust them! 5 | #-------------------------------------------------------------------------------------------- 6 | 7 | # Camera calibration and distortion parameters (OpenCV) 8 | Camera.fx: 535.4 9 | Camera.fy: 539.2 10 | Camera.cx: 320.1 11 | Camera.cy: 247.6 12 | 13 | Camera.k1: 0.0 14 | Camera.k2: 0.0 15 | Camera.p1: 0.0 16 | Camera.p2: 0.0 17 | 18 | Camera.width: 640 19 | Camera.height: 480 20 | 21 | # Camera frames per second 22 | Camera.fps: 30.0 23 | 24 | # IR projector baseline times fx (aprox.) 25 | Camera.bf: 40.0 26 | 27 | # Color order of the images (0: BGR, 1: RGB. It is ignored if images are grayscale) 28 | Camera.RGB: 1 29 | 30 | # Close/Far threshold. Baseline times. 31 | ThDepth: 40.0 32 | 33 | # Deptmap values factor 34 | DepthMapFactor: 5000.0 35 | 36 | #-------------------------------------------------------------------------------------------- 37 | # ORB Parameters 38 | #-------------------------------------------------------------------------------------------- 39 | 40 | # ORB Extractor: Number of features per image 41 | ORBextractor.nFeatures: 1000 42 | 43 | # ORB Extractor: Scale factor between levels in the scale pyramid 44 | ORBextractor.scaleFactor: 1.2 45 | 46 | # ORB Extractor: Number of levels in the scale pyramid 47 | ORBextractor.nLevels: 8 48 | 49 | # ORB Extractor: Fast threshold 50 | # Image is divided in a grid. At each cell FAST are extracted imposing a minimum response. 51 | # Firstly we impose iniThFAST. If no corners are detected we impose a lower value minThFAST 52 | # You can lower these values if your images have low contrast 53 | ORBextractor.iniThFAST: 20 54 | ORBextractor.minThFAST: 7 55 | 56 | #-------------------------------------------------------------------------------------------- 57 | # Viewer Parameters 58 | #-------------------------------------------------------------------------------------------- 59 | Viewer.KeyFrameSize: 0.05 60 | Viewer.KeyFrameLineWidth: 1 61 | Viewer.GraphLineWidth: 0.9 62 | Viewer.PointSize:2 63 | Viewer.CameraSize: 0.08 64 | Viewer.CameraLineWidth: 3 65 | Viewer.ViewpointX: 0 66 | Viewer.ViewpointY: -0.7 67 | Viewer.ViewpointZ: -1.8 68 | Viewer.ViewpointF: 500 69 | 70 | -------------------------------------------------------------------------------- /configs/ours/corridor.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/ours_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Ours/corridor" 10 | save_path: "output/dataset/Ours/corridor" 11 | 12 | use_gt_pose: False -------------------------------------------------------------------------------- /configs/ours/home.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/ours_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Ours/home" 10 | save_path: "output/dataset/Ours/home" 11 | 12 | use_gt_pose: False 13 | -------------------------------------------------------------------------------- /configs/ours/hotel.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/ours_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Ours/hotel" 10 | save_path: "output/dataset/Ours/hotel" 11 | 12 | use_gt_pose: False -------------------------------------------------------------------------------- /configs/ours/office.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/ours_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Ours/office" 10 | save_path: "output/dataset/Ours/office" 11 | 12 | use_gt_pose: False -------------------------------------------------------------------------------- /configs/ours/outside.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/ours_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Ours/outside" 10 | save_path: "output/dataset/Ours/outside" 11 | 12 | use_gt_pose: False -------------------------------------------------------------------------------- /configs/ours_base.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/base.yaml" 2 | 3 | save_path: "output/replica_test/debug" 4 | 5 | # map params 6 | uniform_sample_num: 46080 7 | min_depth: 0.3 8 | max_depth: 2.5 9 | KNN_num: 20 10 | KNN_threshold: -1 11 | memory_length: 3 12 | global_keyframe_num: 3 13 | 14 | # map preprocess 15 | invalid_confidence_thresh: 0.1 16 | 17 | # dataset params: 18 | type: "Ours" 19 | source_path: "data/Ours/hotel" 20 | 21 | # state manage 22 | stable_confidence_thres: 250 23 | unstable_time_window: 400 24 | 25 | # optimize params: 26 | gaussian_update_iter: 50 27 | gaussian_update_frame: 8 28 | position_lr : 0.001 29 | feature_lr : 0.001 30 | opacity_lr : 0.000 31 | scaling_lr : 0.002 32 | rotation_lr : 0.001 33 | KNN_threshold: -1 34 | 35 | # track params 36 | use_orb_backend: True 37 | tracker_max_fps: 15 38 | orb_settings_path: "configs/orb_config/ours.yaml" 39 | 40 | use_gt_pose: False 41 | icp_use_model_depth: True # if False, use dataset depth frame to frame 42 | icp_matches_threshold: 0.2 # ratio * valie pixels 43 | icp_normal_threshold: 20 # degree 44 | icp_sample_distance_threshold: 0.01 # m 45 | icp_sample_normal_threshold: 0.01 # cos similarity 46 | icp_warmup_frames: 200 47 | icp_fail_threshold: 0.01 -------------------------------------------------------------------------------- /configs/replica/office0.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/office0" 10 | save_path: "output/dataset/Replica/office0" 11 | -------------------------------------------------------------------------------- /configs/replica/office1.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/office1" 10 | save_path: "output/dataset/Replica/office1" 11 | 12 | -------------------------------------------------------------------------------- /configs/replica/office2.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/office2" 10 | save_path: "output/dataset/Replica/office2" 11 | -------------------------------------------------------------------------------- /configs/replica/office3.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/office3" 10 | save_path: "output/dataset/Replica/office3" 11 | -------------------------------------------------------------------------------- /configs/replica/office4.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/office4" 10 | save_path: "output/dataset/Replica/office4" 11 | -------------------------------------------------------------------------------- /configs/replica/room0.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/room0" 10 | save_path: "output/dataset/Replica/room0" 11 | -------------------------------------------------------------------------------- /configs/replica/room1.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/room1" 10 | save_path: "output/dataset/Replica/room1" 11 | -------------------------------------------------------------------------------- /configs/replica/room2.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/replica_base.yaml" 2 | save_step: 400 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "./data/Replica/room2" 10 | save_path: "output/dataset/Replica/room2" 11 | -------------------------------------------------------------------------------- /configs/replica_base.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/base.yaml" 2 | device_list: [1] 3 | save_path: "output/replica_test/debug" 4 | 5 | # dataset params: 6 | type: "Replica" 7 | source_path: "data/Replica/office0" 8 | 9 | uniform_sample_num: 40800 10 | 11 | # state manage 12 | stable_confidence_thres: 100 13 | unstable_time_window: 120 14 | memory_length: 5 15 | 16 | # optimize params: 17 | gaussian_update_iter: 50 18 | gaussian_update_frame: 6 19 | position_lr : 0.001 20 | feature_lr : 0.0005 21 | opacity_lr : 0.000 22 | scaling_lr : 0.004 23 | rotation_lr : 0.001 24 | final_global_iter: 20 25 | 26 | use_gt_pose: False 27 | icp_use_model_depth: True 28 | 29 | 30 | # track params 31 | icp_use_model_depth: True # if False, use dataset depth frame to frame 32 | icp_normal_threshold: 20 # degree 33 | icp_sample_distance_threshold: 0.01 # m 34 | icp_sample_normal_threshold: 0.01 # cos similarity 35 | 36 | 37 | feature_lr_coef: 4.0 38 | scaling_lr_coef: 4.0 39 | rotation_lr_coef: 4.0 40 | 41 | pcd_densify: True -------------------------------------------------------------------------------- /configs/scannetpp/39f36da05b.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/scannetpp_base.yaml" 2 | save_step: 100 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "data/ScanNetpp/39f36da05b" 10 | save_path: "output/dataset/ScanNetpp/39f36da05b" -------------------------------------------------------------------------------- /configs/scannetpp/8b5caf3398.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/scannetpp_base.yaml" 2 | save_step: 100 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "data/ScanNetpp/8b5caf3398" 10 | save_path: "output/dataset/ScanNetpp/8b5caf3398" 11 | -------------------------------------------------------------------------------- /configs/scannetpp/b20a261fdf.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/scannetpp_base.yaml" 2 | save_step: 100 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "data/ScanNetpp/b20a261fdf" 10 | save_path: "output/dataset/ScanNetpp/b20a261fdf" -------------------------------------------------------------------------------- /configs/scannetpp/f34d532901.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/scannetpp_base.yaml" 2 | save_step: 100 3 | 4 | frame_start: 0 5 | frame_step: 0 6 | frame_num: -1 7 | 8 | # dataset params: 9 | source_path: "data/ScanNetpp/f34d532901" 10 | save_path: "output/dataset/ScanNetpp/f34d532901" -------------------------------------------------------------------------------- /configs/scannetpp_base.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/base.yaml" 2 | 3 | save_path: "output/Scannetpp/debug" 4 | 5 | # map params 6 | uniform_sample_num: 68620 7 | min_depth: 0.3 8 | max_depth: 10.0 9 | 10 | # dataset params: 11 | type: "Scannetpp" 12 | source_path: "data/ScanNetpp/8b5caf3398" 13 | 14 | # state manage 15 | stable_confidence_thres: 400 16 | unstable_time_window: 200 17 | 18 | # optimize params: 19 | gaussian_update_iter: 75 20 | gaussian_update_frame: 3 21 | position_lr : 0.001 22 | feature_lr : 0.0005 23 | opacity_lr : 0.000 24 | scaling_lr : 0.004 25 | rotation_lr : 0.001 26 | KNN_threshold: -1 27 | 28 | use_gt_pose: True 29 | 30 | pcd_densify: True -------------------------------------------------------------------------------- /configs/tum/dataset/fr1_desk.yaml: -------------------------------------------------------------------------------- 1 | H: 480 2 | W: 640 3 | fx: 517.3 4 | fy: 516.5 5 | cx: 318.6 6 | cy: 255.3 7 | crop_edge: 0 8 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] 9 | depth_scale: 5000.0 10 | -------------------------------------------------------------------------------- /configs/tum/dataset/fr2_xyz.yaml: -------------------------------------------------------------------------------- 1 | H: 480 2 | W: 640 3 | fx: 520.9 4 | fy: 521.0 5 | cx: 325.1 6 | cy: 249.7 7 | crop_edge: 0 8 | distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172] 9 | depth_scale: 5208.0 -------------------------------------------------------------------------------- /configs/tum/dataset/fr3_office.yaml: -------------------------------------------------------------------------------- 1 | H: 480 2 | W: 640 3 | fx: 535.4 4 | fy: 539.2 5 | cx: 320.1 6 | cy: 247.6 7 | crop_edge: 0 8 | depth_scale: 5000.0 9 | -------------------------------------------------------------------------------- /configs/tum/fr1_desk.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/tum_base.yaml" 2 | 3 | frame_start: 0 4 | frame_step: 0 5 | frame_num: -1 6 | 7 | # dataset params: 8 | source_path: "data/TUM_RGBD/rgbd_dataset_freiburg1_desk" 9 | save_path: "output/dataset/TUM_RGBD/fr1_desk" 10 | 11 | orb_settings_path: "configs/orb_config/tum1.yaml" 12 | -------------------------------------------------------------------------------- /configs/tum/fr2_xyz.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/tum_base.yaml" 2 | 3 | frame_start: 0 4 | frame_step: 0 5 | frame_num: -1 6 | 7 | # dataset params: 8 | source_path: "data/TUM_RGBD/rgbd_dataset_freiburg2_xyz" 9 | save_path: "output/dataset/TUM_RGBD/fr2_xyz" 10 | 11 | orb_settings_path: "configs/orb_config/tum2.yaml" 12 | -------------------------------------------------------------------------------- /configs/tum/fr3_office.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/tum_base.yaml" 2 | 3 | frame_start: 0 4 | frame_step: 0 5 | frame_num: -1 6 | 7 | # dataset params: 8 | source_path: "data/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household" 9 | save_path: "output/dataset/TUM_RGBD/fr3_office" 10 | 11 | orb_settings_path: "configs/orb_config/tum3.yaml" 12 | -------------------------------------------------------------------------------- /configs/tum_base.yaml: -------------------------------------------------------------------------------- 1 | parent: "configs/base.yaml" 2 | 3 | save_path: "output/tum/debug" 4 | 5 | # dataset params: 6 | type: "TUM" 7 | source_path: "data/TUM_RGBD/rgbd_dataset_freiburg1_desk" 8 | 9 | # state manage 10 | stable_confidence_thres: 200 11 | unstable_time_window: 150 12 | memory_length: 5 13 | 14 | # optimize params: 15 | gaussian_update_iter: 50 16 | gaussian_update_frame: 4 17 | position_lr : 0.001 18 | feature_lr : 0.001 19 | opacity_lr : 0.000 20 | scaling_lr : 0.02 21 | rotation_lr : 0.001 22 | 23 | use_gt_pose: False 24 | use_orb_backend: True 25 | orb_useicp: True 26 | icp_use_model_depth: True 27 | 28 | # track params 29 | use_gt_pose: False 30 | icp_use_model_depth: True # if False, use dataset depth frame to frame 31 | icp_matches_threshold: 0.2 # ratio * valie pixels 32 | icp_normal_threshold: 20 # degree 33 | icp_sample_distance_threshold: 0.01 # m 34 | icp_sample_normal_threshold: 0.01 # cos similarity 35 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: RTG-SLAM 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | - pytorch3d 7 | - defaults 8 | dependencies: 9 | - cudatoolkit=11.7 10 | - plyfile 11 | - python=3.9.0 12 | - pytorch=1.13.1 13 | - torchaudio=0.13.1 14 | - torchvision=0.14.1 15 | - tqdm 16 | - tensorboard 17 | - torchmetrics 18 | - pytorch3d 19 | - pip: 20 | - pytorch_msssim 21 | - trimesh 22 | - scikit-image 23 | - open3d 24 | - pyyaml 25 | - scipy 26 | - opencv-python 27 | - matplotlib 28 | - GPUtil 29 | - submodules/cuda_utils 30 | - submodules/diff-gaussian-rasterizer-depth 31 | - submodules/simple-knn -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | 5 | from utils.config_utils import read_config 6 | 7 | parser = ArgumentParser(description="Eval script parameters") 8 | parser.add_argument("--config", type=str) 9 | parser.add_argument("--load_frame", type=int, default=-1) 10 | parser.add_argument("--eval_frames", type=int, default=-1) 11 | parser.add_argument("--load_iter", nargs="+", type=int, default=[]) 12 | parser.add_argument("--eval_merge", action="store_true") 13 | parser.add_argument("--save_pic", action="store_true") 14 | 15 | eval_args = parser.parse_args() 16 | config_path = eval_args.config 17 | args = read_config(config_path) 18 | # set visible devices 19 | device_list = args.device_list 20 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(device) for device in device_list) 21 | 22 | from utils.camera_utils import loadCam 23 | import pandas as pd 24 | import torch 25 | from tqdm import tqdm 26 | from arguments import DatasetParams, MapParams, OptimizationParams 27 | from scene import Dataset 28 | from SLAM.multiprocess.mapper import Mapping 29 | from SLAM.utils import * 30 | from SLAM.eval import eval_frame 31 | from utils.general_utils import safe_state 32 | 33 | 34 | torch.set_printoptions(4, sci_mode=False) 35 | 36 | 37 | def filter_models(frame_path, eval_merge, load_iter): 38 | if eval_merge: 39 | exclud_ = "stable" 40 | include_ = "merge" 41 | else: 42 | exclud_ = "merge" 43 | include_ = "stable" 44 | total_models = [ 45 | i for i in os.listdir(frame_path) if "sibr" not in i and exclud_ not in i 46 | ] 47 | select_models = [] 48 | if len(load_iter) > 0: 49 | for eval_iter in load_iter: 50 | model_iter = [i for i in total_models if "%04d" % eval_iter in i] 51 | merge_models = [i for i in model_iter if include_ in i] 52 | if len(merge_models) > 0: 53 | select_models.extend(merge_models) 54 | else: 55 | select_models.extend(model_iter) 56 | else: 57 | max_iter = sorted([i[5:9] for i in total_models], reverse=True)[0] 58 | total_models = [i for i in total_models if max_iter in i] 59 | merge_models = [i for i in total_models if include_ in i] 60 | if len(merge_models) > 0: 61 | select_models.extend(merge_models) 62 | else: 63 | select_models.extend(total_models) 64 | return select_models 65 | 66 | 67 | def move_to_gpu(frame): 68 | frame.original_depth = devF(frame.original_depth) 69 | frame.original_image = devF(frame.original_image) 70 | 71 | 72 | def move_to_cpu(frame): 73 | frame.original_depth = frame.original_depth.to("cpu") 74 | frame.original_image = frame.original_image.to("cpu") 75 | 76 | 77 | def read_pose_t0(args): 78 | data_type = args.type 79 | if data_type == "Replica": 80 | pose_t0_c2w = np.loadtxt(os.path.join(args.source_path, "traj.txt"))[0].reshape( 81 | 4, 4 82 | ) 83 | elif data_type == "Scannetpp": 84 | pose_t0_c2w = np.loadtxt(os.path.join(args.source_path, "pose", "0000.txt")).reshape(4,4) 85 | else: 86 | pose_t0_c2w = np.eye(4) 87 | return pose_t0_c2w 88 | 89 | 90 | def main(): 91 | if not os.path.exists(os.path.join(args.save_path, "eval_metric")): 92 | os.system("rm -r {}".format(os.path.join(args.save_path, "eval_metric"))) 93 | 94 | load_iter = eval_args.load_iter 95 | load_frame = eval_args.load_frame 96 | eval_merge = eval_args.eval_merge 97 | eval_frames = eval_args.eval_frames 98 | model_base = os.path.join(args.save_path, "save_model") 99 | frames = [i for i in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, i))] 100 | frames = sorted(frames) 101 | if load_frame < 0: 102 | check_frame = frames[-1] 103 | else: 104 | check_frame = [i for i in frames if "%04d" % load_frame in i][0] 105 | print("check frame: ", check_frame) 106 | if eval_frames < 0: 107 | max_cams = int(check_frame.split("_")[-1]) 108 | else: 109 | max_cams = min(eval_frames, int(check_frame.split("_")[-1])) 110 | optimization_params = OptimizationParams(parser) 111 | dataset_params = DatasetParams(parser, sentinel=True) 112 | map_params = MapParams(parser) 113 | 114 | safe_state(args.quiet) 115 | save_pic = eval_args.save_pic 116 | optimization_params = optimization_params.extract(args) 117 | dataset_params = dataset_params.extract(args) 118 | dataset_params.frame_num = max_cams 119 | map_params = map_params.extract(args) 120 | 121 | # read pose_es 122 | if not args.use_gt_pose: 123 | pose_es = np.load(os.path.join(args.save_path, "save_traj", "pose_es.npy")).reshape( 124 | -1, 4, 4 125 | )[args.frame_start :, ...] 126 | 127 | # Initialize dataset 128 | dataset = Dataset( 129 | dataset_params, 130 | shuffle=False, 131 | resolution_scales=dataset_params.resolution_scales, 132 | ) 133 | 134 | pose_t0_c2w = read_pose_t0(args) 135 | pose_t0_w2c = np.linalg.inv(pose_t0_c2w) 136 | 137 | # evaluate depth map opaque 138 | args.renderer_opaque_threshold = args.renderer_opaque_threshold_eval 139 | pcd_densify = args.pcd_densify 140 | 141 | gaussian_map = Mapping(args) 142 | 143 | frame_id = int(check_frame.split("_")[1]) 144 | gaussian_map.time = frame_id 145 | frame_path = os.path.join(model_base, check_frame) 146 | select_models = filter_models(frame_path, eval_merge, load_iter) 147 | 148 | print("select models", select_models) 149 | select_model = select_models[0] 150 | test_iter = select_model[5:9] 151 | print("test_iter: ", test_iter) 152 | 153 | select_ply = os.path.join(frame_path, select_model) 154 | gaussian_map.pointcloud.load(select_ply) 155 | 156 | if pcd_densify: 157 | pcd_path = os.path.join(model_base, "pcd_densify.ply") 158 | if not os.path.exists(pcd_path): 159 | pcd_path = select_ply 160 | else: 161 | pcd_path = select_ply 162 | 163 | print("geometry eval ply: ", pcd_path) 164 | 165 | gaussian_map.iter = int(test_iter) 166 | total_loss = [] 167 | run_pcd = False 168 | for cam_id, frame_info in tqdm( 169 | enumerate(dataset.scene_info.train_cameras), 170 | desc="Evaluating", 171 | total=len(dataset.scene_info.train_cameras), 172 | ): 173 | test_frame = loadCam( 174 | dataset_params, 175 | frame_id, 176 | frame_info, 177 | dataset_params.resolution_scales[0], 178 | ) 179 | move_to_gpu(test_frame) 180 | if not args.use_gt_pose: 181 | test_frame.updatePose(pose_es[cam_id]) 182 | gaussian_map.time = cam_id 183 | if cam_id == len(dataset.scene_info.train_cameras) - 1: 184 | run_pcd = True 185 | move_to_gpu(test_frame) 186 | 187 | losses = eval_frame( 188 | gaussian_map, 189 | test_frame, 190 | os.path.join(gaussian_map.save_path, "eval_metric"), 191 | run_picture=True, 192 | run_pcd=run_pcd, 193 | min_depth=args.min_depth, 194 | max_depth=args.max_depth, 195 | pcd_path=pcd_path, 196 | gt_mesh_path=dataset.mesh_path, 197 | dist_threshs=[0.03], 198 | sample_nums=1000000, 199 | pcd_transform=pose_t0_c2w, 200 | save_picture= save_pic, 201 | ) 202 | 203 | losses["frame"] = gaussian_map.time 204 | losses["iter"] = gaussian_map.iter 205 | total_loss.append(losses) 206 | move_to_cpu(test_frame) 207 | 208 | 209 | df = pd.DataFrame(total_loss) 210 | print(df.mean()) 211 | mean_row = df.mean().to_frame().T 212 | mean_row["frame"] = "mean" 213 | df = pd.concat([df, mean_row], ignore_index=True) 214 | df.to_csv( 215 | os.path.join( 216 | args.save_path, 217 | "statis_frame_{}_iter_{}.csv".format(frame_id, test_iter), 218 | ) 219 | ) 220 | 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from arguments import DatasetParams 13 | from scene.dataset_readers import sceneLoadTypeCallbacks 14 | 15 | 16 | class Dataset: 17 | def __init__( 18 | self, 19 | args: DatasetParams, 20 | shuffle=True, 21 | resolution_scales=[1.0], 22 | ): 23 | self.train_cameras = {} 24 | self.test_cameras = {} 25 | if args.type == "TUM": 26 | print("Assuming TUM data set!") 27 | scene_info = sceneLoadTypeCallbacks["Tum"]( 28 | args.source_path, 29 | args.eval, 30 | args.eval_llff, 31 | args.frame_start, 32 | args.frame_num, 33 | args.frame_step, 34 | ) 35 | elif args.type == "Replica": 36 | print("Assuming Replica data set!") 37 | scene_info = sceneLoadTypeCallbacks["Replica"]( 38 | args.source_path, 39 | args.eval, 40 | args.eval_llff, 41 | args.frame_start, 42 | args.frame_num, 43 | args.frame_step, 44 | ) 45 | elif args.type == "Ours": 46 | print("Assuming Ours dataset!") 47 | scene_info = sceneLoadTypeCallbacks["ours"]( 48 | args.source_path, 49 | args.eval, 50 | args.eval_llff, 51 | args.frame_start, 52 | args.frame_num, 53 | args.frame_step, 54 | ) 55 | elif args.type == "Scannetpp": 56 | print("Assuming Scannetpp dataset!") 57 | scene_info = sceneLoadTypeCallbacks["Scannetpp"]( 58 | args.source_path, 59 | args.eval, 60 | args.eval_llff, 61 | args.frame_start, 62 | args.frame_num, 63 | args.frame_step, 64 | isscannetpp=True 65 | ) 66 | else: 67 | print("scene dataset path:", args.source_path) 68 | assert False, "Could not recognize scene type!" 69 | 70 | self.cameras_extent = scene_info.nerf_normalization["radius"] 71 | self.mesh_path = scene_info.mesh_path 72 | self.scene_info = scene_info 73 | 74 | 75 | def getTrainCameras(self, scale=1.0): 76 | return self.train_cameras[scale] 77 | 78 | def getTestCameras(self, scale=1.0): 79 | return self.test_cameras[scale] 80 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | from SLAM.utils import downscale_img 16 | 17 | from utils.graphics_utils import fov2focal, getProjectionMatrix, getWorld2View2 18 | from utils.general_utils import devF 19 | 20 | 21 | class Camera(nn.Module): 22 | def __init__( 23 | self, 24 | colmap_id, 25 | R, 26 | T, 27 | FoVx, 28 | FoVy, 29 | image, 30 | depth, 31 | gt_alpha_mask, 32 | image_name, 33 | uid, 34 | trans=np.array([0.0, 0.0, 0.0]), 35 | scale=1.0, 36 | pose_gt=np.eye(4), 37 | cx=-1, 38 | cy=-1, 39 | timestamp=0, 40 | depth_scale=1.0, 41 | preload=True, 42 | data_device="cuda", 43 | ): 44 | super(Camera, self).__init__() 45 | 46 | self.uid = uid 47 | self.colmap_id = colmap_id 48 | self.R = R 49 | self.T = T 50 | self.FoVx = FoVx 51 | self.FoVy = FoVy 52 | self.image_name = image_name 53 | self.preload = preload 54 | self.timestamp = timestamp 55 | self.depth_scale = depth_scale 56 | try: 57 | self.data_device = torch.device(data_device) 58 | except Exception as e: 59 | print(e) 60 | print( 61 | f"[Warning] Custom device {data_device} failed, fallback to default cuda device" 62 | ) 63 | self.data_device = torch.device("cuda") 64 | 65 | if not self.preload: 66 | self.data_device = torch.device("cpu") 67 | 68 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 69 | 70 | self.image_width = self.original_image.shape[2] 71 | self.image_height = self.original_image.shape[1] 72 | 73 | if depth is not None: 74 | self.original_depth = depth.to(self.data_device) 75 | else: 76 | self.original_depth = torch.ones(1, self.image_height, self.image_width).to( 77 | self.data_device 78 | ) 79 | 80 | if gt_alpha_mask is not None: 81 | self.original_image *= gt_alpha_mask.to(self.data_device) 82 | self.original_depth *= gt_alpha_mask.to(self.data_device) 83 | else: 84 | self.original_image *= torch.ones( 85 | (1, self.image_height, self.image_width), device=self.data_device 86 | ) 87 | self.original_depth *= torch.ones( 88 | (1, self.image_height, self.image_width), device=self.data_device 89 | ) 90 | self.zfar = 100.0 91 | self.znear = 0.01 92 | 93 | self.trans = trans 94 | self.scale = scale 95 | 96 | self.world_view_transform = ( 97 | torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 98 | ) 99 | self.projection_matrix = ( 100 | getProjectionMatrix( 101 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 102 | ) 103 | .transpose(0, 1) 104 | .cuda() 105 | ) 106 | self.full_proj_transform = ( 107 | self.world_view_transform.unsqueeze(0).bmm( 108 | self.projection_matrix.unsqueeze(0) 109 | ) 110 | ).squeeze(0) 111 | self.camera_center = self.world_view_transform.inverse()[3, :3] 112 | 113 | # for evaluation, unchange 114 | self.pose_gt = pose_gt 115 | self.cx = cx 116 | self.cy = cy 117 | 118 | self.world_view_transform.share_memory_() 119 | self.full_proj_transform.share_memory_() 120 | 121 | def updatePose(self, pose_c2w): 122 | pose_w2c = np.linalg.inv(pose_c2w) 123 | self.update(pose_w2c[:3, :3].transpose(), pose_w2c[:3, 3]) 124 | 125 | def update(self, R, T): 126 | self.R = R 127 | self.T = T 128 | self.world_view_transform = ( 129 | torch.tensor(getWorld2View2(R, T, self.trans, self.scale)) 130 | .transpose(0, 1) 131 | .cuda() 132 | ) 133 | self.full_proj_transform = ( 134 | self.world_view_transform.unsqueeze(0).bmm( 135 | self.projection_matrix.unsqueeze(0) 136 | ) 137 | ).squeeze(0) 138 | 139 | def get_w2c(self): 140 | return self.world_view_transform.transpose(0, 1) 141 | 142 | @property 143 | def get_c2w(self): 144 | return self.world_view_transform.transpose(0, 1).inverse() 145 | 146 | # TODO: only work for Repulica dataset, need to add load local depth intrinsic for ScanNet 147 | @property 148 | def get_intrinsic(self): 149 | w, h = self.image_width, self.image_height 150 | fx, fy = fov2focal(self.FoVx, w), fov2focal(self.FoVy, h) 151 | cx = self.cx if self.cx > 0 else w / 2 152 | cy = self.cy if self.cy > 0 else h / 2 153 | intrinstic = devF(torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])) 154 | return intrinstic 155 | 156 | def get_focal_length(self): 157 | w, h = self.image_width, self.image_height 158 | fx, fy = fov2focal(self.FoVx, w), fov2focal(self.FoVy, h) 159 | return (fx + fy) / 2.0 160 | 161 | def get_uv(self, xyz_w): 162 | intrinsic = self.get_intrinsic 163 | w2c = self.get_w2c() 164 | xyz_c = xyz_w @ w2c[:3, :3].T + w2c[:3, 3] 165 | uv = xyz_c @ intrinsic.T 166 | uv = uv[:, :2] / uv[:, 2:] 167 | uv = uv.long() 168 | return uv 169 | 170 | def move_to_cpu_clone(self): 171 | new_cam = Camera( 172 | colmap_id=self.colmap_id, 173 | R=self.R, 174 | T=self.T, 175 | FoVx=self.FoVx, 176 | FoVy=self.FoVy, 177 | image=self.original_image.detach(), 178 | depth=self.original_depth.detach(), 179 | gt_alpha_mask=None, 180 | image_name=self.image_name, 181 | uid=self.uid, 182 | data_device=self.data_device, 183 | pose_gt=self.pose_gt, 184 | cx=self.cx, 185 | cy=self.cy, 186 | timestamp=self.timestamp, 187 | preload=self.preload, 188 | depth_scale=self.depth_scale, 189 | ) 190 | new_cam.original_depth = new_cam.original_depth.to("cpu") 191 | new_cam.original_image = new_cam.original_image.to("cpu") 192 | return new_cam 193 | 194 | 195 | class MiniCam: 196 | def __init__( 197 | self, 198 | width, 199 | height, 200 | fovy, 201 | fovx, 202 | znear, 203 | zfar, 204 | world_view_transform, 205 | full_proj_transform, 206 | ): 207 | self.image_width = width 208 | self.image_height = height 209 | self.FoVy = fovy 210 | self.FoVx = fovx 211 | self.znear = znear 212 | self.zfar = zfar 213 | self.world_view_transform = world_view_transform 214 | self.full_proj_transform = full_proj_transform 215 | view_inv = torch.inverse(self.world_view_transform) 216 | self.camera_center = view_inv[3][:3] 217 | self.cx = -1 218 | self.cy = -1 219 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import collections 13 | import struct 14 | 15 | import numpy as np 16 | 17 | CameraModel = collections.namedtuple( 18 | "CameraModel", ["model_id", "model_name", "num_params"] 19 | ) 20 | Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) 21 | BaseImage = collections.namedtuple( 22 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] 23 | ) 24 | Point3D = collections.namedtuple( 25 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] 26 | ) 27 | CAMERA_MODELS = { 28 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 29 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 30 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 31 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 32 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 33 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 34 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 35 | CameraModel(model_id=7, model_name="FOV", num_params=5), 36 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 37 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 38 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 39 | } 40 | CAMERA_MODEL_IDS = dict( 41 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] 42 | ) 43 | CAMERA_MODEL_NAMES = dict( 44 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] 45 | ) 46 | 47 | 48 | def qvec2rotmat(qvec): 49 | return np.array( 50 | [ 51 | [ 52 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 53 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 54 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 55 | ], 56 | [ 57 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 58 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 59 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 60 | ], 61 | [ 62 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 63 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 64 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 65 | ], 66 | ] 67 | ) 68 | 69 | 70 | def rotmat2qvec(R): 71 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 72 | K = ( 73 | np.array( 74 | [ 75 | [Rxx - Ryy - Rzz, 0, 0, 0], 76 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 77 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 78 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], 79 | ] 80 | ) 81 | / 3.0 82 | ) 83 | eigvals, eigvecs = np.linalg.eigh(K) 84 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 85 | if qvec[0] < 0: 86 | qvec *= -1 87 | return qvec 88 | 89 | 90 | class Image(BaseImage): 91 | def qvec2rotmat(self): 92 | return qvec2rotmat(self.qvec) 93 | 94 | 95 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 96 | """Read and unpack the next bytes from a binary file. 97 | :param fid: 98 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 99 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 100 | :param endian_character: Any of {@, =, <, >, !} 101 | :return: Tuple of read and unpacked values. 102 | """ 103 | data = fid.read(num_bytes) 104 | return struct.unpack(endian_character + format_char_sequence, data) 105 | 106 | 107 | def read_points3D_text(path): 108 | """ 109 | see: src/base/reconstruction.cc 110 | void Reconstruction::ReadPoints3DText(const std::string& path) 111 | void Reconstruction::WritePoints3DText(const std::string& path) 112 | """ 113 | xyzs = None 114 | rgbs = None 115 | errors = None 116 | num_points = 0 117 | with open(path, "r") as fid: 118 | while True: 119 | line = fid.readline() 120 | if not line: 121 | break 122 | line = line.strip() 123 | if len(line) > 0 and line[0] != "#": 124 | num_points += 1 125 | 126 | xyzs = np.empty((num_points, 3)) 127 | rgbs = np.empty((num_points, 3)) 128 | errors = np.empty((num_points, 1)) 129 | count = 0 130 | with open(path, "r") as fid: 131 | while True: 132 | line = fid.readline() 133 | if not line: 134 | break 135 | line = line.strip() 136 | if len(line) > 0 and line[0] != "#": 137 | elems = line.split() 138 | xyz = np.array(tuple(map(float, elems[1:4]))) 139 | rgb = np.array(tuple(map(int, elems[4:7]))) 140 | error = np.array(float(elems[7])) 141 | xyzs[count] = xyz 142 | rgbs[count] = rgb 143 | errors[count] = error 144 | count += 1 145 | 146 | return xyzs, rgbs, errors 147 | 148 | 149 | def read_points3D_binary(path_to_model_file): 150 | """ 151 | see: src/base/reconstruction.cc 152 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 153 | void Reconstruction::WritePoints3DBinary(const std::string& path) 154 | """ 155 | 156 | with open(path_to_model_file, "rb") as fid: 157 | num_points = read_next_bytes(fid, 8, "Q")[0] 158 | 159 | xyzs = np.empty((num_points, 3)) 160 | rgbs = np.empty((num_points, 3)) 161 | errors = np.empty((num_points, 1)) 162 | 163 | for p_id in range(num_points): 164 | binary_point_line_properties = read_next_bytes( 165 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 166 | ) 167 | xyz = np.array(binary_point_line_properties[1:4]) 168 | rgb = np.array(binary_point_line_properties[4:7]) 169 | error = np.array(binary_point_line_properties[7]) 170 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 171 | 0 172 | ] 173 | track_elems = read_next_bytes( 174 | fid, 175 | num_bytes=8 * track_length, 176 | format_char_sequence="ii" * track_length, 177 | ) 178 | xyzs[p_id] = xyz 179 | rgbs[p_id] = rgb 180 | errors[p_id] = error 181 | return xyzs, rgbs, errors 182 | 183 | 184 | def read_intrinsics_text(path): 185 | """ 186 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 187 | """ 188 | cameras = {} 189 | with open(path, "r") as fid: 190 | while True: 191 | line = fid.readline() 192 | if not line: 193 | break 194 | line = line.strip() 195 | if len(line) > 0 and line[0] != "#": 196 | elems = line.split() 197 | camera_id = int(elems[0]) 198 | model = elems[1] 199 | print(model) 200 | assert ( 201 | model == "PINHOLE" 202 | ), "While the loader support other types, the rest of the code assumes PINHOLE" 203 | width = int(elems[2]) 204 | height = int(elems[3]) 205 | params = np.array(tuple(map(float, elems[4:]))) 206 | cameras[camera_id] = Camera( 207 | id=camera_id, model=model, width=width, height=height, params=params 208 | ) 209 | return cameras 210 | 211 | 212 | def read_extrinsics_binary(path_to_model_file): 213 | """ 214 | see: src/base/reconstruction.cc 215 | void Reconstruction::ReadImagesBinary(const std::string& path) 216 | void Reconstruction::WriteImagesBinary(const std::string& path) 217 | """ 218 | images = {} 219 | with open(path_to_model_file, "rb") as fid: 220 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 221 | for _ in range(num_reg_images): 222 | binary_image_properties = read_next_bytes( 223 | fid, num_bytes=64, format_char_sequence="idddddddi" 224 | ) 225 | image_id = binary_image_properties[0] 226 | qvec = np.array(binary_image_properties[1:5]) 227 | tvec = np.array(binary_image_properties[5:8]) 228 | camera_id = binary_image_properties[8] 229 | image_name = "" 230 | current_char = read_next_bytes(fid, 1, "c")[0] 231 | while current_char != b"\x00": # look for the ASCII 0 entry 232 | image_name += current_char.decode("utf-8") 233 | current_char = read_next_bytes(fid, 1, "c")[0] 234 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 235 | 0 236 | ] 237 | x_y_id_s = read_next_bytes( 238 | fid, 239 | num_bytes=24 * num_points2D, 240 | format_char_sequence="ddq" * num_points2D, 241 | ) 242 | xys = np.column_stack( 243 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] 244 | ) 245 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 246 | images[image_id] = Image( 247 | id=image_id, 248 | qvec=qvec, 249 | tvec=tvec, 250 | camera_id=camera_id, 251 | name=image_name, 252 | xys=xys, 253 | point3D_ids=point3D_ids, 254 | ) 255 | return images 256 | 257 | 258 | def read_intrinsics_binary(path_to_model_file): 259 | """ 260 | see: src/base/reconstruction.cc 261 | void Reconstruction::WriteCamerasBinary(const std::string& path) 262 | void Reconstruction::ReadCamerasBinary(const std::string& path) 263 | """ 264 | cameras = {} 265 | with open(path_to_model_file, "rb") as fid: 266 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 267 | for _ in range(num_cameras): 268 | camera_properties = read_next_bytes( 269 | fid, num_bytes=24, format_char_sequence="iiQQ" 270 | ) 271 | camera_id = camera_properties[0] 272 | model_id = camera_properties[1] 273 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 274 | width = camera_properties[2] 275 | height = camera_properties[3] 276 | num_params = CAMERA_MODEL_IDS[model_id].num_params 277 | params = read_next_bytes( 278 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params 279 | ) 280 | cameras[camera_id] = Camera( 281 | id=camera_id, 282 | model=model_name, 283 | width=width, 284 | height=height, 285 | params=np.array(params), 286 | ) 287 | assert len(cameras) == num_cameras 288 | return cameras 289 | 290 | 291 | def read_extrinsics_text(path): 292 | """ 293 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 294 | """ 295 | images = {} 296 | with open(path, "r") as fid: 297 | while True: 298 | line = fid.readline() 299 | if not line: 300 | break 301 | line = line.strip() 302 | if len(line) > 0 and line[0] != "#": 303 | elems = line.split() 304 | image_id = int(elems[0]) 305 | qvec = np.array(tuple(map(float, elems[1:5]))) 306 | tvec = np.array(tuple(map(float, elems[5:8]))) 307 | camera_id = int(elems[8]) 308 | image_name = elems[9] 309 | elems = fid.readline().split() 310 | xys = np.column_stack( 311 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] 312 | ) 313 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 314 | images[image_id] = Image( 315 | id=image_id, 316 | qvec=qvec, 317 | tvec=tvec, 318 | camera_id=camera_id, 319 | name=image_name, 320 | xys=xys, 321 | point3D_ids=point3D_ids, 322 | ) 323 | return images 324 | 325 | 326 | def read_colmap_bin_array(path): 327 | """ 328 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 329 | 330 | :param path: path to the colmap binary file. 331 | :return: nd array with the floating point values in the value 332 | """ 333 | with open(path, "rb") as fid: 334 | width, height, channels = np.genfromtxt( 335 | fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int 336 | ) 337 | fid.seek(0) 338 | num_delimiter = 0 339 | byte = fid.read(1) 340 | while True: 341 | if byte == b"&": 342 | num_delimiter += 1 343 | if num_delimiter >= 3: 344 | break 345 | byte = fid.read(1) 346 | array = np.fromfile(fid, np.float32) 347 | array = array.reshape((width, height, channels), order="F") 348 | return np.transpose(array, (1, 0, 2)).squeeze() 349 | -------------------------------------------------------------------------------- /scripts/associate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Software License Agreement (BSD License) 3 | # 4 | # Copyright (c) 2013, Juergen Sturm, TUM 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions 9 | # are met: 10 | # 11 | # * Redistributions of source code must retain the above copyright 12 | # notice, this list of conditions and the following disclaimer. 13 | # * Redistributions in binary form must reproduce the above 14 | # copyright notice, this list of conditions and the following 15 | # disclaimer in the documentation and/or other materials provided 16 | # with the distribution. 17 | # * Neither the name of TUM nor the names of its 18 | # contributors may be used to endorse or promote products derived 19 | # from this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 24 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 25 | # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 26 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 27 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 28 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 30 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 31 | # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | # POSSIBILITY OF SUCH DAMAGE. 33 | # 34 | # Requirements: 35 | # sudo apt-get install python-argparse 36 | 37 | """ 38 | The Kinect provides the color and depth images in an un-synchronized way. This means that the set of time stamps from the color images do not intersect with those of the depth images. Therefore, we need some way of associating color images to depth images. 39 | 40 | For this purpose, you can use the ''associate.py'' script. It reads the time stamps from the rgb.txt file and the depth.txt file, and joins them by finding the best matches. 41 | """ 42 | 43 | import argparse 44 | import sys 45 | import os 46 | import numpy 47 | 48 | 49 | def read_file_list(filename): 50 | """ 51 | Reads a trajectory from a text file. 52 | 53 | File format: 54 | The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched) 55 | and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp. 56 | 57 | Input: 58 | filename -- File name 59 | 60 | Output: 61 | dict -- dictionary of (stamp,data) tuples 62 | 63 | """ 64 | file = open(filename) 65 | data = file.read() 66 | lines = data.replace(","," ").replace("\t"," ").split("\n") 67 | list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"] 68 | list = [(float(l[0]),l[1:]) for l in list if len(l)>1] 69 | return dict(list) 70 | 71 | def associate(first_list, second_list,offset,max_difference): 72 | """ 73 | Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim 74 | to find the closest match for every input tuple. 75 | 76 | Input: 77 | first_list -- first dictionary of (stamp,data) tuples 78 | second_list -- second dictionary of (stamp,data) tuples 79 | offset -- time offset between both dictionaries (e.g., to model the delay between the sensors) 80 | max_difference -- search radius for candidate generation 81 | 82 | Output: 83 | matches -- list of matched tuples ((stamp1,data1),(stamp2,data2)) 84 | 85 | """ 86 | first_keys = list(first_list.keys()) 87 | second_keys = list(second_list.keys()) 88 | potential_matches = [(abs(a - (b + offset)), a, b) 89 | for a in first_keys 90 | for b in second_keys 91 | if abs(a - (b + offset)) < max_difference] 92 | potential_matches.sort() 93 | matches = [] 94 | for diff, a, b in potential_matches: 95 | if a in first_keys and b in second_keys: 96 | first_keys.remove(a) 97 | second_keys.remove(b) 98 | matches.append((a, b)) 99 | 100 | matches.sort() 101 | return matches 102 | 103 | if __name__ == '__main__': 104 | 105 | # parse command line 106 | parser = argparse.ArgumentParser(description=''' 107 | This script takes two data files with timestamps and associates them 108 | ''') 109 | parser.add_argument('first_file', help='first text file (format: timestamp data)') 110 | parser.add_argument('second_file', help='second text file (format: timestamp data)') 111 | parser.add_argument('--first_only', help='only output associated lines from first file', action='store_true') 112 | parser.add_argument('--offset', help='time offset added to the timestamps of the second file (default: 0.0)',default=0.0) 113 | parser.add_argument('--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)',default=0.02) 114 | args = parser.parse_args() 115 | 116 | first_list = read_file_list(args.first_file) 117 | second_list = read_file_list(args.second_file) 118 | 119 | matches = associate(first_list, second_list,float(args.offset),float(args.max_difference)) 120 | 121 | if args.first_only: 122 | for a,b in matches: 123 | print("%f %s"%(a," ".join(first_list[a]))) 124 | else: 125 | for a,b in matches: 126 | print("%f %s %f %s"%(a," ".join(first_list[a]),b-float(args.offset)," ".join(second_list[b]))) 127 | 128 | -------------------------------------------------------------------------------- /scripts/download_ours.sh: -------------------------------------------------------------------------------- 1 | wget ... -------------------------------------------------------------------------------- /scripts/download_replica.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data && cd data 2 | # you can also download the Replica.zip manually through 3 | # link: https://caiyun.139.com/m/i?1A5Ch5C3abNiL password: v3fY (the zip is split into smaller zips because of the size limitation of caiyun) 4 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip 5 | unzip Replica.zip 6 | 7 | wget https://cvg-data.inf.ethz.ch/nice-slam/cull_replica_mesh.zip 8 | unzip cull_replica_mesh.zip -------------------------------------------------------------------------------- /scripts/download_tum.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/TUM_RGBD 2 | cd data/TUM_RGBD 3 | wget https://vision.in.tum.de/rgbd/dataset/freiburg1/rgbd_dataset_freiburg1_desk.tgz 4 | tar -xvzf rgbd_dataset_freiburg1_desk.tgz 5 | wget https://vision.in.tum.de/rgbd/dataset/freiburg2/rgbd_dataset_freiburg2_xyz.tgz 6 | tar -xvzf rgbd_dataset_freiburg2_xyz.tgz 7 | wget https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_long_office_household.tgz 8 | tar -xvzf rgbd_dataset_freiburg3_long_office_household.tgz 9 | -------------------------------------------------------------------------------- /scripts/eval_ate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Software License Agreement (BSD License) 3 | # 4 | # Copyright (c) 2013, Juergen Sturm, TUM 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions 9 | # are met: 10 | # 11 | # * Redistributions of source code must retain the above copyright 12 | # notice, this list of conditions and the following disclaimer. 13 | # * Redistributions in binary form must reproduce the above 14 | # copyright notice, this list of conditions and the following 15 | # disclaimer in the documentation and/or other materials provided 16 | # with the distribution. 17 | # * Neither the name of TUM nor the names of its 18 | # contributors may be used to endorse or promote products derived 19 | # from this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 24 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 25 | # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 26 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 27 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 28 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 30 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 31 | # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | # POSSIBILITY OF SUCH DAMAGE. 33 | # 34 | # Requirements: 35 | # sudo apt-get install python-argparse 36 | 37 | """ 38 | This script computes the absolute trajectory error from the ground truth 39 | trajectory and the estimated trajectory. 40 | """ 41 | 42 | import sys 43 | import numpy 44 | import numpy as np 45 | import argparse 46 | import associate 47 | from scipy.spatial.transform import Rotation as R 48 | 49 | 50 | def align(model, data): 51 | """Align two trajectories using the method of Horn (closed-form). 52 | 53 | Input: 54 | model -- first trajectory (3xn) 55 | data -- second trajectory (3xn) 56 | 57 | Output: 58 | rot -- rotation matrix (3x3) 59 | trans -- translation vector (3x1) 60 | trans_error -- translational error per point (1xn) 61 | 62 | """ 63 | numpy.set_printoptions(precision=3, suppress=True) 64 | model_zerocentered = model - model.mean(1) 65 | data_zerocentered = data - data.mean(1) 66 | 67 | W = numpy.zeros((3, 3)) 68 | for column in range(model.shape[1]): 69 | W += numpy.outer(model_zerocentered[:, column], data_zerocentered[:, column]) 70 | U, d, Vh = numpy.linalg.linalg.svd(W.transpose()) 71 | S = numpy.matrix(numpy.identity(3)) 72 | if numpy.linalg.det(U) * numpy.linalg.det(Vh) < 0: 73 | S[2, 2] = -1 74 | rot = U * S * Vh 75 | trans = data.mean(1) - rot * model.mean(1) 76 | 77 | model_aligned = rot * model + trans 78 | alignment_error = model_aligned - data 79 | 80 | trans_error = numpy.sqrt( 81 | numpy.sum(numpy.multiply(alignment_error, alignment_error), 0) 82 | ).A[0] 83 | 84 | return rot, trans, trans_error 85 | 86 | 87 | def plot_traj(ax, stamps, traj, style, color, label): 88 | """ 89 | Plot a trajectory using matplotlib. 90 | 91 | Input: 92 | ax -- the plot 93 | stamps -- time stamps (1xn) 94 | traj -- trajectory (3xn) 95 | style -- line style 96 | color -- line color 97 | label -- plot legend 98 | 99 | """ 100 | stamps.sort() 101 | interval = numpy.median([s - t for s, t in zip(stamps[1:], stamps[:-1])]) 102 | x = [] 103 | y = [] 104 | last = stamps[0] 105 | for i in range(len(stamps)): 106 | if stamps[i] - last < 2 * interval: 107 | x.append(traj[i][0]) 108 | y.append(traj[i][1]) 109 | elif len(x) > 0: 110 | ax.plot(x, y, style, color=color, label=label) 111 | label = "" 112 | x = [] 113 | y = [] 114 | last = stamps[i] 115 | if len(x) > 0: 116 | ax.plot(x, y, style, color=color, label=label) 117 | 118 | 119 | if __name__ == "__main__": 120 | # parse command line 121 | parser = argparse.ArgumentParser( 122 | description=""" 123 | This script computes the absolute trajectory error from the ground truth trajectory and the estimated trajectory. 124 | """ 125 | ) 126 | parser.add_argument( 127 | "first_file", 128 | help="ground truth trajectory (format: timestamp tx ty tz qx qy qz qw)", 129 | ) 130 | parser.add_argument( 131 | "second_file", 132 | help="estimated trajectory (format: timestamp tx ty tz qx qy qz qw)", 133 | ) 134 | parser.add_argument("--num", type=int, default=-1) 135 | parser.add_argument( 136 | "--offset", 137 | help="time offset added to the timestamps of the second file (default: 0.0)", 138 | default=0.0, 139 | ) 140 | parser.add_argument( 141 | "--scale", 142 | help="scaling factor for the second trajectory (default: 1.0)", 143 | default=1.0, 144 | ) 145 | parser.add_argument( 146 | "--max_difference", 147 | help="maximally allowed time difference for matching entries (default: 0.02)", 148 | default=0.02, 149 | ) 150 | parser.add_argument( 151 | "--save", 152 | help="save aligned second trajectory to disk (format: stamp2 x2 y2 z2)", 153 | ) 154 | parser.add_argument( 155 | "--save_associations", 156 | help="save associated first and aligned second trajectory to disk (format: stamp1 x1 y1 z1 stamp2 x2 y2 z2)", 157 | ) 158 | parser.add_argument( 159 | "--plot", 160 | help="plot the first and the aligned second trajectory to an image (format: png)", 161 | ) 162 | parser.add_argument( 163 | "--verbose", 164 | help="print all evaluation data (otherwise, only the RMSE absolute translational error in meters after alignment will be printed)", 165 | action="store_true", 166 | ) 167 | args = parser.parse_args() 168 | 169 | eval_num = int(args.num) 170 | first_list = associate.read_file_list(args.first_file) 171 | second_list = associate.read_file_list(args.second_file) 172 | matches = associate.associate( 173 | first_list, second_list, float(args.offset), float(args.max_difference) 174 | )[:eval_num] 175 | if len(matches) < 2: 176 | sys.exit( 177 | "Couldn't find matching timestamp pairs between groundtruth and estimated trajectory! Did you choose the correct sequence?" 178 | ) 179 | 180 | first_xyz = numpy.matrix( 181 | [[float(value) for value in first_list[a][0:3]] for a, b in matches] 182 | ).transpose() 183 | second_xyz = numpy.matrix( 184 | [ 185 | [float(value) * float(args.scale) for value in second_list[b][0:3]] 186 | for a, b in matches 187 | ] 188 | ).transpose() 189 | # second_rot = [[R.from_rotvec([float(v) for v in value]) for value in second_list[b][3:]] for a,b in matches] 190 | # pose_gt = [] 191 | # print(second_rot) 192 | # for i in range(second_xyz.shape[0]): 193 | # parser 194 | 195 | rot, trans, trans_error = align(second_xyz, first_xyz) 196 | 197 | second_xyz_aligned = rot * second_xyz + trans 198 | 199 | for i in range(first_xyz.shape[1]): 200 | print( 201 | "kframe-{}, align_distance-error:{}".format( 202 | i, 100 * np.linalg.norm(first_xyz[:, i] - second_xyz_aligned[:, i]) 203 | ) 204 | ) 205 | 206 | first_stamps = list(first_list.keys()) 207 | first_stamps.sort() 208 | first_xyz_full = numpy.matrix( 209 | [[float(value) for value in first_list[b][0:3]] for b in first_stamps] 210 | ).transpose() 211 | 212 | second_stamps = list(second_list.keys()) 213 | second_stamps.sort() 214 | second_xyz_full = numpy.matrix( 215 | [ 216 | [float(value) * float(args.scale) for value in second_list[b][0:3]] 217 | for b in second_stamps 218 | ] 219 | ).transpose() 220 | second_xyz_full_aligned = rot * second_xyz_full + trans 221 | 222 | print( 223 | "%f" 224 | % (100.0 * numpy.sqrt(numpy.dot(trans_error, trans_error) / len(trans_error))) 225 | ) 226 | 227 | if args.save_associations: 228 | file = open(args.save_associations, "w") 229 | file.write( 230 | "\n".join( 231 | [ 232 | "%f %f %f %f %f %f %f %f" % (a, x1, y1, z1, b, x2, y2, z2) 233 | for (a, b), (x1, y1, z1), (x2, y2, z2) in zip( 234 | matches, 235 | first_xyz.transpose().A, 236 | second_xyz_aligned.transpose().A, 237 | ) 238 | ] 239 | ) 240 | ) 241 | file.close() 242 | -------------------------------------------------------------------------------- /scripts/parse_scannetpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from argparse import ArgumentParser 5 | from tqdm import tqdm 6 | 7 | if __name__ == "__main__": 8 | parser = ArgumentParser() 9 | parser.add_argument("--data_base", type=str) 10 | parser.add_argument("--output_path", type=str) 11 | args = parser.parse_args() 12 | data_base = args.data_base 13 | output_path = args.output_path 14 | 15 | data_path = os.path.join(data_base, "dslr") 16 | mesh_path = os.path.join(data_base, "scans") 17 | 18 | scene_name = os.path.basename(os.path.dirname(data_path)) 19 | save_path = os.path.join(output_path, scene_name) 20 | img_save_path = os.path.join(save_path, "color") 21 | img_eval_save_path = os.path.join(save_path, "color_eval") 22 | intrinsic_save_path = os.path.join(save_path, "intrinsic") 23 | depth_save_path = os.path.join(save_path, "depth") 24 | depth_eval_save_path = os.path.join(save_path, "depth_eval") 25 | intrinsic_save_path = os.path.join(save_path, "intrinsic") 26 | pose_save_path = os.path.join(save_path, "pose") 27 | pose_eval_save_path = os.path.join(save_path, "pose_eval") 28 | 29 | # os.system("rm -r {}".format(save_path)) 30 | os.makedirs(save_path, exist_ok=True) 31 | os.makedirs(img_save_path, exist_ok=True) 32 | os.makedirs(depth_save_path, exist_ok=True) 33 | os.makedirs(intrinsic_save_path, exist_ok=True) 34 | os.makedirs(pose_save_path, exist_ok=True) 35 | os.makedirs(img_eval_save_path, exist_ok=True) 36 | os.makedirs(depth_eval_save_path, exist_ok=True) 37 | os.makedirs(pose_eval_save_path, exist_ok=True) 38 | 39 | 40 | ply_files = [i for i in os.listdir(mesh_path) if ".ply" in i] 41 | for ply_file in ply_files: 42 | os.system("cp {} {}".format(os.path.join(mesh_path, ply_file), 43 | os.path.join(save_path, ply_file))) 44 | 45 | 46 | img_read_path = os.path.join(data_path, "undistorted_images") 47 | depth_read_path = os.path.join(data_path, "undistorted_depths") 48 | pose_read_path = os.path.join( 49 | data_path, "nerfstudio", "transforms_undistorted.json" 50 | ) 51 | 52 | with open(pose_read_path, "r") as pose_file: 53 | pose_intrinsic = json.load(pose_file) 54 | 55 | pose_intrinsic["frames"] = sorted( 56 | pose_intrinsic["frames"], key=lambda x: x["file_path"] 57 | ) 58 | 59 | pose_intrinsic["test_frames"] = sorted( 60 | pose_intrinsic["test_frames"], key=lambda x: x["file_path"] 61 | ) 62 | 63 | fx = pose_intrinsic["fl_x"] 64 | fy = pose_intrinsic["fl_y"] 65 | cx = pose_intrinsic["cx"] 66 | cy = pose_intrinsic["cy"] 67 | 68 | intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 69 | 70 | np.savetxt( 71 | os.path.join(intrinsic_save_path, "intrinsic_color.txt"), intrinsic, fmt="%f" 72 | ) 73 | np.savetxt( 74 | os.path.join(intrinsic_save_path, "intrinsic_depth.txt"), intrinsic, fmt="%f" 75 | ) 76 | 77 | frame_id = 0 78 | 79 | for frame in tqdm(pose_intrinsic["frames"]): 80 | is_bad = frame["is_bad"] 81 | if is_bad: 82 | continue 83 | color_file = frame["file_path"] 84 | depth_file = color_file.replace(".JPG", ".png") 85 | pose_file = color_file.replace(".JPG", ".txt") 86 | pose_c2w = np.array(frame["transform_matrix"]).reshape(4, 4) 87 | pose_c2w[:, 1:3] *= -1 88 | 89 | # copy image 90 | os.system( 91 | "cp {} {}".format( 92 | os.path.join(img_read_path, color_file), 93 | os.path.join(img_save_path, "%04d.jpg" % frame_id), 94 | ) 95 | ) 96 | 97 | os.system( 98 | "cp {} {}".format( 99 | os.path.join(depth_read_path, depth_file), 100 | os.path.join(depth_save_path, "%04d.png" % frame_id), 101 | ) 102 | ) 103 | 104 | np.savetxt( 105 | os.path.join(pose_save_path, "%04d.txt" % frame_id), 106 | pose_c2w.tolist(), 107 | fmt="%f", 108 | ) 109 | frame_id += 1 110 | 111 | frame_id = 0 112 | 113 | for frame in tqdm(pose_intrinsic["test_frames"]): 114 | is_bad = frame["is_bad"] 115 | if is_bad: 116 | continue 117 | color_file = frame["file_path"] 118 | depth_file = color_file.replace(".JPG", ".png") 119 | pose_file = color_file.replace(".JPG", ".txt") 120 | pose_c2w = np.array(frame["transform_matrix"]).reshape(4, 4) 121 | pose_c2w[:, 1:3] *= -1 122 | 123 | # copy image 124 | os.system( 125 | "cp {} {}".format( 126 | os.path.join(img_read_path, color_file), 127 | os.path.join(img_eval_save_path, "%04d.jpg" % frame_id), 128 | ) 129 | ) 130 | 131 | os.system( 132 | "cp {} {}".format( 133 | os.path.join(depth_read_path, depth_file), 134 | os.path.join(depth_eval_save_path, "%04d.png" % frame_id), 135 | ) 136 | ) 137 | 138 | np.savetxt( 139 | os.path.join(pose_eval_save_path, "%04d.txt" % frame_id), 140 | pose_c2w.tolist(), 141 | fmt="%f", 142 | ) 143 | frame_id += 1 144 | -------------------------------------------------------------------------------- /scripts/parse_scannetpp.sh: -------------------------------------------------------------------------------- 1 | # python parse.py --data_path \ 2 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/39f36da05b \ 3 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours_test 4 | 5 | # python parse.py --data_path \ 6 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/0cf2e9402d \ 7 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours 8 | 9 | # python parse.py --data_path \ 10 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/2a496183e1 \ 11 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours 12 | 13 | # python parse.py --data_path \ 14 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/8b5caf3398 \ 15 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours 16 | 17 | # python parse.py --data_path \ 18 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/b20a261fdf \ 19 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours 20 | 21 | # python parse.py --data_path \ 22 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/f34d532901 \ 23 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours 24 | 25 | # python parse.py --data_path \ 26 | # /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp/download/data/9b74afd2d2 \ 27 | # --output_path /home/pzx/download/ly/GSroom_new/GSroom/data/Scannetpp_ours 28 | 29 | 30 | 31 | python parse.py --data_path /mnt/data/pzx/data/SLAM/Scannetpp/download/data/9b74afd2d2 \ 32 | --output_path /home/ly/projects/SLAM/GSroom_history/GSroom_open_1/data/test_parse 33 | -------------------------------------------------------------------------------- /slam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | from utils.config_utils import read_config 5 | parser = ArgumentParser(description="Training script parameters") 6 | parser.add_argument("--config", type=str, default="configs/replica/office0.yaml") 7 | args = parser.parse_args() 8 | config_path = args.config 9 | args = read_config(config_path) 10 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(device) for device in args.device_list) 11 | import torch 12 | import json 13 | from utils.camera_utils import loadCam 14 | from arguments import DatasetParams, MapParams, OptimizationParams 15 | from scene import Dataset 16 | from SLAM.multiprocess.mapper import Mapping 17 | from SLAM.multiprocess.tracker import Tracker 18 | from SLAM.utils import * 19 | from SLAM.eval import eval_frame 20 | from utils.general_utils import safe_state 21 | from utils.monitor import Recorder 22 | 23 | torch.set_printoptions(4, sci_mode=False) 24 | 25 | 26 | def main(): 27 | # set visible devices 28 | time_recorder = Recorder(args.device_list[0]) 29 | optimization_params = OptimizationParams(parser) 30 | dataset_params = DatasetParams(parser, sentinel=True) 31 | map_params = MapParams(parser) 32 | 33 | safe_state(args.quiet) 34 | optimization_params = optimization_params.extract(args) 35 | dataset_params = dataset_params.extract(args) 36 | map_params = map_params.extract(args) 37 | 38 | # Initialize dataset 39 | dataset = Dataset( 40 | dataset_params, 41 | shuffle=False, 42 | resolution_scales=dataset_params.resolution_scales, 43 | ) 44 | 45 | record_mem = args.record_mem 46 | 47 | gaussian_map = Mapping(args, time_recorder) 48 | gaussian_map.create_workspace() 49 | gaussian_tracker = Tracker(args) 50 | # save config file 51 | prepare_cfg(args) 52 | # set time log 53 | tracker_time_sum = 0 54 | mapper_time_sum = 0 55 | 56 | # start SLAM 57 | for frame_id, frame_info in enumerate(dataset.scene_info.train_cameras): 58 | curr_frame = loadCam( 59 | dataset_params, frame_id, frame_info, dataset_params.resolution_scales[0] 60 | ) 61 | 62 | print("\n========== curr frame is: %d ==========\n" % frame_id) 63 | move_to_gpu(curr_frame) 64 | start_time = time.time() 65 | # tracker process 66 | frame_map = gaussian_tracker.map_preprocess(curr_frame, frame_id) 67 | gaussian_tracker.tracking(curr_frame, frame_map) 68 | tracker_time = time.time() 69 | tracker_consume_time = tracker_time - start_time 70 | time_recorder.update_mean("tracking", tracker_consume_time, 1) 71 | 72 | tracker_time_sum += tracker_consume_time 73 | print(f"[LOG] tracker cost time: {tracker_time - start_time}") 74 | 75 | mapper_start_time = time.time() 76 | 77 | new_poses = gaussian_tracker.get_new_poses() 78 | gaussian_map.update_poses(new_poses) 79 | # mapper process 80 | gaussian_map.mapping(curr_frame, frame_map, frame_id, optimization_params) 81 | 82 | gaussian_map.get_render_output(curr_frame) 83 | gaussian_tracker.update_last_status( 84 | curr_frame, 85 | gaussian_map.model_map["render_depth"], 86 | gaussian_map.frame_map["depth_map"], 87 | gaussian_map.model_map["render_normal"], 88 | gaussian_map.frame_map["normal_map_w"], 89 | ) 90 | mapper_time = time.time() 91 | mapper_consume_time = mapper_time - mapper_start_time 92 | time_recorder.update_mean("mapping", mapper_consume_time, 1) 93 | 94 | mapper_time_sum += mapper_consume_time 95 | print(f"[LOG] mapper cost time: {mapper_time - tracker_time}") 96 | if record_mem: 97 | time_recorder.watch_gpu() 98 | # report eval loss 99 | if ((gaussian_map.time + 1) % gaussian_map.save_step == 0) or ( 100 | gaussian_map.time == 0 101 | ): 102 | eval_frame( 103 | gaussian_map, 104 | curr_frame, 105 | os.path.join(gaussian_map.save_path, "eval_render"), 106 | min_depth=gaussian_map.min_depth, 107 | max_depth=gaussian_map.max_depth, 108 | save_picture=True, 109 | run_pcd=False 110 | ) 111 | gaussian_map.save_model(save_data=True) 112 | 113 | gaussian_map.time += 1 114 | move_to_cpu(curr_frame) 115 | torch.cuda.empty_cache() 116 | print("\n========== main loop finish ==========\n") 117 | print( 118 | "[LOG] stable num: {:d}, unstable num: {:d}".format( 119 | gaussian_map.get_stable_num, gaussian_map.get_unstable_num 120 | ) 121 | ) 122 | print("[LOG] processed frame: ", gaussian_map.optimize_frames_ids) 123 | print("[LOG] keyframes: ", gaussian_map.keyframe_ids) 124 | print("[LOG] mean tracker process time: ", tracker_time_sum / (frame_id + 1)) 125 | print("[LOG] mean mapper process time: ", mapper_time_sum / (frame_id + 1)) 126 | 127 | new_poses = gaussian_tracker.get_new_poses() 128 | gaussian_map.update_poses(new_poses) 129 | gaussian_map.global_optimization(optimization_params, is_end=True) 130 | eval_frame( 131 | gaussian_map, 132 | gaussian_map.keyframe_list[-1], 133 | os.path.join(gaussian_map.save_path, "eval_render"), 134 | min_depth=gaussian_map.min_depth, 135 | max_depth=gaussian_map.max_depth, 136 | save_picture=True, 137 | run_pcd=False 138 | ) 139 | 140 | gaussian_map.save_model(save_data=True) 141 | gaussian_tracker.save_traj(args.save_path) 142 | time_recorder.cal_fps() 143 | time_recorder.save(args.save_path) 144 | gaussian_map.time += 1 145 | 146 | if args.pcd_densify: 147 | densify_pcd = gaussian_map.stable_pointcloud.densify(1, 30, 5) 148 | o3d.io.write_point_cloud( 149 | os.path.join(args.save_path, "save_model", "pcd_densify.ply"), densify_pcd 150 | ) 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /slam_mp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | from utils.config_utils import read_config 5 | 6 | parser = ArgumentParser(description="Training script parameters") 7 | parser.add_argument("--config", type=str) 8 | args = parser.parse_args() 9 | config_path = args.config 10 | args = read_config(config_path) 11 | # set visible devices 12 | device_list = args.device_list 13 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(device) for device in device_list) 14 | 15 | import torch 16 | import torch.multiprocessing as mp 17 | 18 | from arguments import DatasetParams, MapParams, OptimizationParams 19 | from scene import Dataset 20 | from SLAM.multiprocess.system import * 21 | from SLAM.multiprocess.mapper import * 22 | from SLAM.utils import * 23 | from utils.general_utils import safe_state 24 | 25 | torch.set_printoptions(4, sci_mode=False) 26 | np.set_printoptions(4) 27 | mp.set_sharing_strategy("file_system") 28 | 29 | 30 | def main(): 31 | optimization_params = OptimizationParams(parser) 32 | dataset_params = DatasetParams(parser, sentinel=True) 33 | map_params = MapParams(parser) 34 | 35 | safe_state(args.quiet) 36 | optimization_params = optimization_params.extract(args) 37 | dataset_params = dataset_params.extract(args) 38 | map_params = map_params.extract(args) 39 | 40 | # Initialize dataset 41 | dataset = Dataset( 42 | dataset_params, 43 | shuffle=False, 44 | resolution_scales=dataset_params.resolution_scales, 45 | ) 46 | 47 | # need to use spawn 48 | try: 49 | mp.set_start_method("spawn", force=True) 50 | except RuntimeError: 51 | pass 52 | 53 | slam = SLAM(map_params, optimization_params, dataset, args) 54 | slam.run() 55 | 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from scene.cameras import Camera 16 | from utils.general_utils import PILtoTorch 17 | from utils.graphics_utils import fov2focal 18 | 19 | WARNED = False 20 | 21 | 22 | def loadCam(args, id, cam_info, resolution_scale): 23 | orig_w, orig_h = cam_info.image.size 24 | preload = args.preload 25 | if args.resolution in [1, 2, 4, 8]: 26 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 27 | orig_h / (resolution_scale * args.resolution) 28 | ) 29 | else: # should be a type that converts to float 30 | if args.resolution == -1: 31 | if orig_w > 1600: 32 | global WARNED 33 | if not WARNED: 34 | print( 35 | "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 36 | "If this is not desired, please explicitly specify '--resolution/-r' as 1" 37 | ) 38 | WARNED = True 39 | global_down = orig_w / 1600 40 | else: 41 | global_down = 1 42 | else: 43 | global_down = orig_w / args.resolution 44 | 45 | scale = float(global_down) * float(resolution_scale) 46 | resolution = (int(orig_w / scale), int(orig_h / scale)) 47 | 48 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 49 | resized_image_depth = PILtoTorch(cam_info.depth, resolution, Image.NEAREST) 50 | gt_image = resized_image_rgb[:3, ...] 51 | gt_depth = resized_image_depth 52 | loaded_mask = None 53 | if resized_image_rgb.shape[1] == 4: 54 | loaded_mask = resized_image_rgb[3:4, ...] 55 | 56 | return Camera( 57 | colmap_id=cam_info.uid, 58 | R=cam_info.R, 59 | T=cam_info.T, 60 | FoVx=cam_info.FovX, 61 | FoVy=cam_info.FovY, 62 | image=gt_image, 63 | depth=gt_depth, 64 | gt_alpha_mask=loaded_mask, 65 | image_name=cam_info.image_name, 66 | uid=id, 67 | data_device=args.data_device, 68 | pose_gt=cam_info.pose_gt, 69 | cx=cam_info.cx / resolution_scale, 70 | cy=cam_info.cy / resolution_scale, 71 | timestamp=cam_info.timestamp, 72 | preload=preload, 73 | depth_scale=cam_info.depth_scale, 74 | ) 75 | 76 | 77 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 78 | camera_list = [] 79 | 80 | for id, c in enumerate(cam_infos): 81 | camera_list.append(loadCam(args, id, c, resolution_scale)) 82 | 83 | return camera_list 84 | 85 | 86 | def camera_to_JSON(id, camera: Camera): 87 | Rt = np.zeros((4, 4)) 88 | Rt[:3, :3] = camera.R.transpose() 89 | Rt[:3, 3] = camera.T 90 | Rt[3, 3] = 1.0 91 | 92 | W2C = np.linalg.inv(Rt) 93 | pos = W2C[:3, 3] 94 | rot = W2C[:3, :3] 95 | serializable_array_2d = [x.tolist() for x in rot] 96 | camera_entry = { 97 | "id": id, 98 | "img_name": camera.image_name, 99 | "width": camera.width, 100 | "height": camera.height, 101 | "position": pos.tolist(), 102 | "rotation": serializable_array_2d, 103 | "fy": fov2focal(camera.FovY, camera.height), 104 | "fx": fov2focal(camera.FovX, camera.width), 105 | } 106 | return camera_entry 107 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | class GroupParams: 5 | pass 6 | 7 | 8 | def merge_yaml(a, b): 9 | if isinstance(a, dict) and isinstance(b, dict): 10 | for key in b: 11 | if key in a: 12 | a[key] = merge_yaml(a[key], b[key]) 13 | else: 14 | a[key] = b[key] 15 | return a 16 | else: 17 | return b 18 | 19 | 20 | def read_config(config_path): 21 | with open(config_path, "r") as f: 22 | base_config = yaml.safe_load(f) 23 | while base_config["parent"] != "None" and os.path.exists(base_config["parent"]): 24 | with open(base_config["parent"], "r") as f: 25 | parent_config = yaml.safe_load(f) 26 | parent_config_path = parent_config["parent"] 27 | parent_config.update(base_config) 28 | base_config = parent_config 29 | base_config["parent"] = parent_config_path 30 | group = GroupParams() 31 | for k, v in base_config.items(): 32 | setattr(group, k.lstrip("_"), v) 33 | return group 34 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import random 13 | import sys 14 | from datetime import datetime 15 | import open3d as o3d 16 | 17 | import numpy as np 18 | import torch 19 | from PIL import Image 20 | 21 | 22 | float_dev = torch.tensor([0], device="cuda", dtype=torch.float32) 23 | int_dev = torch.tensor([0], device="cuda", dtype=torch.int32) 24 | bool_dev = torch.tensor([0], device="cuda", dtype=torch.bool) 25 | 26 | 27 | def devF(tensor: torch.Tensor): 28 | return tensor.type_as(float_dev) 29 | 30 | 31 | def devI(tensor: torch.Tensor): 32 | return tensor.type_as(int_dev) 33 | 34 | 35 | def devB(tensor: torch.Tensor): 36 | return tensor.type_as(bool_dev) 37 | 38 | 39 | def inverse_sigmoid(x): 40 | return torch.log(x / (1 - x)) 41 | 42 | 43 | def PILtoTorch(pil_image, resolution, method=Image.BILINEAR): 44 | resized_image_PIL = pil_image.resize(resolution, method) 45 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 46 | if len(resized_image.shape) == 3: 47 | return resized_image.permute(2, 0, 1) 48 | else: 49 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 50 | 51 | 52 | def NPtoTorch(np_image, resolution): 53 | resized_image = torch.from_numpy(np_image) 54 | 55 | 56 | def get_expon_lr_func( 57 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 58 | ): 59 | """ 60 | Copied from Plenoxels 61 | 62 | Continuous learning rate decay function. Adapted from JaxNeRF 63 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 64 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 65 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 66 | function of lr_delay_mult, such that the initial learning rate is 67 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 68 | to the normal learning rate when steps>lr_delay_steps. 69 | :param conf: config subtree 'lr' or similar 70 | :param max_steps: int, the number of steps during optimization. 71 | :return HoF which takes step as input 72 | """ 73 | 74 | def helper(step): 75 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 76 | # Disable this parameter 77 | return 0.0 78 | if lr_delay_steps > 0: 79 | # A kind of reverse cosine decay. 80 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 81 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 82 | ) 83 | else: 84 | delay_rate = 1.0 85 | t = np.clip(step / max_steps, 0, 1) 86 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 87 | return delay_rate * log_lerp 88 | 89 | return helper 90 | 91 | 92 | def strip_lowerdiag(L): 93 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 94 | 95 | uncertainty[:, 0] = L[:, 0, 0] 96 | uncertainty[:, 1] = L[:, 0, 1] 97 | uncertainty[:, 2] = L[:, 0, 2] 98 | uncertainty[:, 3] = L[:, 1, 1] 99 | uncertainty[:, 4] = L[:, 1, 2] 100 | uncertainty[:, 5] = L[:, 2, 2] 101 | return uncertainty 102 | 103 | 104 | def strip_symmetric(sym): 105 | return strip_lowerdiag(sym) 106 | 107 | 108 | def build_rotation(r): 109 | norm = torch.sqrt( 110 | r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] 111 | ) 112 | 113 | q = r / norm[:, None] 114 | 115 | R = torch.zeros((q.size(0), 3, 3), device="cuda") 116 | 117 | r = q[:, 0] 118 | x = q[:, 1] 119 | y = q[:, 2] 120 | z = q[:, 3] 121 | 122 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 123 | R[:, 0, 1] = 2 * (x * y - r * z) 124 | R[:, 0, 2] = 2 * (x * z + r * y) 125 | R[:, 1, 0] = 2 * (x * y + r * z) 126 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 127 | R[:, 1, 2] = 2 * (y * z - r * x) 128 | R[:, 2, 0] = 2 * (x * z - r * y) 129 | R[:, 2, 1] = 2 * (y * z + r * x) 130 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 131 | return R 132 | 133 | 134 | def build_scaling_rotation(s, r): 135 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 136 | R = build_rotation(r) 137 | 138 | L[:, 0, 0] = s[:, 0] 139 | L[:, 1, 1] = s[:, 1] 140 | L[:, 2, 2] = s[:, 2] 141 | 142 | L = R @ L 143 | return L 144 | 145 | 146 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 147 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 148 | actual_covariance = L @ L.transpose(1, 2) 149 | symm = strip_symmetric(actual_covariance) 150 | return symm 151 | 152 | 153 | def safe_state(silent): 154 | old_f = sys.stdout 155 | 156 | class F: 157 | def __init__(self, silent): 158 | self.silent = silent 159 | 160 | def write(self, x): 161 | if not self.silent: 162 | if x.endswith("\n"): 163 | old_f.write( 164 | x.replace( 165 | "\n", 166 | " [{}]\n".format( 167 | str(datetime.now().strftime("%d/%m %H:%M:%S")) 168 | ), 169 | ) 170 | ) 171 | else: 172 | old_f.write(x) 173 | 174 | def flush(self): 175 | old_f.flush() 176 | 177 | sys.stdout = F(silent) 178 | 179 | random.seed(2024) 180 | np.random.seed(2024) 181 | torch.manual_seed(2024) 182 | torch.cuda.set_device(torch.device("cuda:0")) 183 | 184 | 185 | def quaternion_from_axis_angle(axis, angle): 186 | axis = axis / (torch.norm(axis, p=2, dim=-1, keepdim=True) + 1e-8) 187 | half_angle = angle / 2 188 | real_part = torch.cos(half_angle).type_as(axis) 189 | complex_part = axis * torch.sin(half_angle).type_as(axis) 190 | quaternion = torch.cat([real_part, complex_part], dim=1) 191 | return quaternion 192 | 193 | 194 | def is_valid_tensor(x: torch.Tensor): 195 | value_state = not (torch.isnan(x).any().item() or torch.isinf(x).any().item()) 196 | grad_state = True 197 | if x.grad is not None: 198 | grad_state = not ( 199 | torch.isnan(x.grad).any().item() or torch.isinf(x.grad).any().item() 200 | ) 201 | return value_state and grad_state 202 | 203 | 204 | def save_tensor_to_ply(save_path, xyz, voxel_size=-1): 205 | pcd = o3d.geometry.PointCloud() 206 | pcd.points = o3d.utility.Vector3dVector(xyz.cpu().numpy()) 207 | if voxel_size > 0: 208 | pcd = pcd.voxel_down_sample(voxel_size=voxel_size) 209 | o3d.io.write_point_cloud(save_path, pcd) 210 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | from typing import NamedTuple 14 | 15 | import numpy as np 16 | import torch 17 | 18 | 19 | class BasicPointCloud(NamedTuple): 20 | points: np.array 21 | colors: np.array 22 | normals: np.array 23 | 24 | 25 | def geom_transform_points(points, transf_matrix): 26 | P, _ = points.shape 27 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 28 | points_hom = torch.cat([points, ones], dim=1) 29 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 30 | 31 | denom = points_out[..., 3:] + 0.0000001 32 | return (points_out[..., :3] / denom).squeeze(dim=0) 33 | 34 | 35 | def getWorld2View(R, t): 36 | Rt = np.zeros((4, 4)) 37 | Rt[:3, :3] = R.transpose() 38 | Rt[:3, 3] = t 39 | Rt[3, 3] = 1.0 40 | return np.float32(Rt) 41 | 42 | 43 | def getK(fx, fy, cx, cy): 44 | K = np.eye(3) 45 | K[0, 0] = fx 46 | K[1, 1] = fy 47 | K[0, 2] = cx 48 | K[1, 2] = cy 49 | return K 50 | 51 | 52 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 53 | Rt = np.zeros((4, 4)) 54 | Rt[:3, :3] = R.transpose() 55 | Rt[:3, 3] = t 56 | Rt[3, 3] = 1.0 57 | 58 | C2W = np.linalg.inv(Rt) 59 | cam_center = C2W[:3, 3] 60 | cam_center = (cam_center + translate) * scale 61 | C2W[:3, 3] = cam_center 62 | Rt = np.linalg.inv(C2W) 63 | return np.float32(Rt) 64 | 65 | 66 | def getProjectionMatrix(znear, zfar, fovX, fovY): 67 | tanHalfFovY = math.tan((fovY / 2)) 68 | tanHalfFovX = math.tan((fovX / 2)) 69 | 70 | top = tanHalfFovY * znear 71 | bottom = -top 72 | right = tanHalfFovX * znear 73 | left = -right 74 | 75 | P = torch.zeros(4, 4) 76 | 77 | z_sign = 1.0 78 | 79 | P[0, 0] = 2.0 * znear / (right - left) 80 | P[1, 1] = 2.0 * znear / (top - bottom) 81 | P[0, 2] = (right + left) / (right - left) 82 | P[1, 2] = (top + bottom) / (top - bottom) 83 | P[3, 2] = z_sign 84 | P[2, 2] = z_sign * zfar / (zfar - znear) 85 | P[2, 3] = -(zfar * znear) / (zfar - znear) 86 | return P 87 | 88 | 89 | def fov2focal(fov, pixels): 90 | return pixels / (2 * math.tan(fov / 2)) 91 | 92 | 93 | def focal2fov(focal, pixels): 94 | return 2 * math.atan(pixels / (2 * focal)) 95 | 96 | 97 | def eulerAngles2rotationMat(theta, format='degree'): 98 | """ 99 | Calculates Rotation Matrix given euler angles. 100 | :param theta: 1-by-3 list [rx, ry, rz] angle in degree 101 | :return: 102 | RPY角,是ZYX欧拉角,依次 绕定轴XYZ转动[rx, ry, rz] 103 | """ 104 | if format is 'degree': 105 | theta = [i * math.pi / 180.0 for i in theta] 106 | 107 | R_x = np.array([[1, 0, 0], 108 | [0, math.cos(theta[0]), -math.sin(theta[0])], 109 | [0, math.sin(theta[0]), math.cos(theta[0])] 110 | ]) 111 | 112 | R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1])], 113 | [0, 1, 0], 114 | [-math.sin(theta[1]), 0, math.cos(theta[1])] 115 | ]) 116 | 117 | R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0], 118 | [math.sin(theta[2]), math.cos(theta[2]), 0], 119 | [0, 0, 1] 120 | ]) 121 | R = np.dot(R_z, np.dot(R_y, R_x)) 122 | return R 123 | 124 | # print(eulerAngles2rotationMat([120,0,0])) 125 | 126 | # Checks if a matrix is a valid rotation matrix. 127 | def isRotationMatrix(R) : 128 | Rt = np.transpose(R) 129 | shouldBeIdentity = np.dot(Rt, R) 130 | I = np.identity(3, dtype = R.dtype) 131 | n = np.linalg.norm(I - shouldBeIdentity) 132 | return n < 1e-6 133 | 134 | 135 | # Calculates rotation matrix to euler angles 136 | # The result is the same as MATLAB except the order 137 | # of the euler angles ( x and z are swapped ). 138 | def rotationMatrixToEulerAngles(R) : 139 | 140 | assert(isRotationMatrix(R)) 141 | 142 | sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) 143 | 144 | singular = sy < 1e-6 145 | 146 | if not singular : 147 | x = math.atan2(R[2,1] , R[2,2]) 148 | y = math.atan2(-R[2,0], sy) 149 | z = math.atan2(R[1,0], R[0,0]) 150 | else : 151 | x = math.atan2(-R[1,2], R[1,1]) 152 | y = math.atan2(-R[2,0], sy) 153 | z = 0 154 | 155 | return np.array([x, y, z]) 156 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from math import exp 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.autograd import Variable 18 | 19 | def mse(img1, img2): 20 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | 22 | 23 | def psnr(img1, img2): 24 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 25 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 26 | 27 | def l1_loss(network_output, gt, weight=None): 28 | if weight is None: 29 | return torch.abs((network_output - gt)).mean() 30 | else: 31 | return torch.mean(torch.abs((network_output - gt)).sum(dim=-1) * weight) 32 | 33 | 34 | def l2_loss(network_output, gt, weight=None): 35 | if weight is None: 36 | return ((network_output - gt) ** 2).mean() 37 | else: 38 | return torch.mean(((network_output - gt) ** 2).sum(dim=-1) * weight) 39 | 40 | 41 | def gaussian(window_size, sigma): 42 | gauss = torch.Tensor( 43 | [ 44 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 45 | for x in range(window_size) 46 | ] 47 | ) 48 | return gauss / gauss.sum() 49 | 50 | 51 | def create_window(window_size, channel): 52 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 53 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 54 | window = Variable( 55 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 56 | ) 57 | return window 58 | 59 | 60 | def ssim(img1, img2, window_size=11, size_average=True): 61 | channel = img1.size(-3) 62 | window = create_window(window_size, channel) 63 | 64 | if img1.is_cuda: 65 | window = window.cuda(img1.get_device()) 66 | window = window.type_as(img1) 67 | 68 | return _ssim(img1, img2, window, window_size, channel, size_average) 69 | 70 | 71 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 72 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 73 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 74 | 75 | mu1_sq = mu1.pow(2) 76 | mu2_sq = mu2.pow(2) 77 | mu1_mu2 = mu1 * mu2 78 | 79 | sigma1_sq = ( 80 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 81 | ) 82 | sigma2_sq = ( 83 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 84 | ) 85 | sigma12 = ( 86 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 87 | - mu1_mu2 88 | ) 89 | 90 | C1 = 0.01**2 91 | C2 = 0.03**2 92 | 93 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 94 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 95 | ) 96 | 97 | if size_average: 98 | return ssim_map.mean() 99 | else: 100 | return ssim_map.mean(1).mean(1).mean(1) 101 | -------------------------------------------------------------------------------- /utils/monitor.py: -------------------------------------------------------------------------------- 1 | import GPUtil 2 | import time 3 | import numpy as np 4 | import os 5 | import json 6 | from os.path import join, exists 7 | 8 | 9 | class Recorder(object): 10 | def __init__(self, gpu_id) -> None: 11 | self._gpu_id = gpu_id 12 | self._value = {} 13 | self._counter = {} 14 | 15 | def update_max(self, name, value): 16 | if name not in self._value: 17 | self._value[name] = value 18 | self._counter[name] = 1 19 | else: 20 | self._value[name] = max(self._value[name], value) 21 | 22 | def cal_fps(self): 23 | self._value["fps"] = 1 / self._value["mapping"] 24 | self._counter["fps"] = 1 25 | 26 | def update_mean(self, name, value, count): 27 | if count == 0: 28 | return 29 | value_mean = value / count 30 | if name not in self._value: 31 | self._value[name] = value_mean 32 | self._counter[name] = count 33 | else: 34 | self._value[name] = (self._value[name] * self._counter[name] + value) / ( 35 | self._counter[name] + count 36 | ) 37 | self._counter[name] = self._counter[name] + count 38 | 39 | def watch_gpu(self): 40 | # current gpu info 41 | gpu = GPUtil.getGPUs()[self._gpu_id] 42 | memory_used = gpu.memoryUsed 43 | self.update_max("gpu_memory", memory_used) 44 | return memory_used / 1024.0 45 | 46 | def save(self, dir): 47 | if not exists(dir): 48 | os.makedirs(dir) 49 | with open(os.path.join(dir, "performance.json"), "w") as f: 50 | json.dump(self._value, f) 51 | 52 | def display(self): 53 | print(self._gpu_info) 54 | 55 | 56 | if __name__ == "__main__": 57 | monitor = Recorder(0) 58 | for i in range(5): 59 | monitor.watch_gpu(i) 60 | time.sleep(0.5) 61 | monitor.update_mean("money", 100, 5) 62 | monitor.update_mean("money", 28, 3) 63 | monitor.save("./temp") 64 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396, 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435, 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = ( 78 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 79 | ) 80 | 81 | if deg > 1: 82 | xx, yy, zz = x * x, y * y, z * z 83 | xy, yz, xz = x * y, y * z, x * z 84 | result = ( 85 | result 86 | + C2[0] * xy * sh[..., 4] 87 | + C2[1] * yz * sh[..., 5] 88 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 89 | + C2[3] * xz * sh[..., 7] 90 | + C2[4] * (xx - yy) * sh[..., 8] 91 | ) 92 | 93 | if deg > 2: 94 | result = ( 95 | result 96 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 97 | + C3[1] * xy * z * sh[..., 10] 98 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 99 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 100 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 101 | + C3[5] * z * (xx - yy) * sh[..., 14] 102 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 103 | ) 104 | 105 | if deg > 3: 106 | result = ( 107 | result 108 | + C4[0] * xy * (xx - yy) * sh[..., 16] 109 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 110 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 111 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 112 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 113 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 114 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 115 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 116 | + C4[8] 117 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 118 | * sh[..., 24] 119 | ) 120 | return result 121 | 122 | 123 | def RGB2SH(rgb): 124 | return (rgb - 0.5) / C0 125 | 126 | 127 | def SH2RGB(sh): 128 | return sh * C0 + 0.5 129 | --------------------------------------------------------------------------------