├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── assets └── teaser.png ├── async_seele_render.py ├── convert.py ├── finetune.py ├── full_eval.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── generate_cluster.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── render_video.py ├── requirements.txt ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── generate_cluster.sh ├── run_all.sh ├── run_finetune.sh ├── run_render.sh ├── run_seele_render.sh └── run_train.sh ├── seele_render.py ├── train.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── make_depth_scale.py ├── pose_utils.py ├── read_write_model.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | seele-gaussian-rasterization/diff_rast.egg-info 6 | seele-gaussian-rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | temps* 10 | output* 11 | dataset* 12 | .out 13 | .ipynb -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | [submodule "SIBR_viewers"] 3 | path = SIBR_viewers 4 | url = https://gitlab.inria.fr/sibr/sibr_core.git 5 | [submodule "submodules/fused-ssim"] 6 | path = submodules/fused-ssim 7 | url = https://github.com/rahul-goel/fused-ssim.git 8 | [submodule "submodules/simple-knn"] 9 | path = submodules/simple-knn 10 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 11 | [submodule "submodules/seele-gaussian-rasterization"] 12 | path = submodules/seele-gaussian-rasterization 13 | url = https://github.com/StoneSix16/seele-gaussian-rasterization 14 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 sjtu-mvclab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeeLe: A Unified Acceleration Framework for Real-Time Gaussian Splatting 2 | | [🌍Webpage](https://seele-project.netlify.app/) | [📄Full Paper](https://arxiv.org/abs/2503.05168) | [🎥Video](https://github.com/user-attachments/assets/49cafdb6-5c8f-43cf-ab05-aa24a39ea1fc) | 3 |
4 | ![Teaser image](assets/teaser.png) 5 | 6 | ## 🔍What is it? 7 | This repository provides the official implementation of **SeeLe**, a general acceleration framework for the [3D Gaussian Splatting (3DGS)](https://github.com/graphdeco-inria/gaussian-splatting) pipeline, specifically designed for resource-constrained mobile devices. Our framework achieves a **2.6× speedup** and **32.5% model reduction** while maintaining superior rendering quality compared to existing methods. On an NVIDIA AGX Orin mobile SoC, SeeLe achieves over **90 FPS**⚡, meeting the real-time requirements for VR applications. 8 | 9 | There is a short demo video of our algorithm running on an Nvidia AGX Orin SoC: 10 | 11 | https://github.com/user-attachments/assets/49cafdb6-5c8f-43cf-ab05-aa24a39ea1fc 12 | 13 | ## 🛠️ How to run? 14 | ### Installation 15 | To clone the repository: 16 | ```shell 17 | git clone https://github.com/SJTU-MVCLab/SeeLe.git --recursive && cd SeeLe 18 | ``` 19 | To install requirements: 20 | ```shell 21 | conda create -n seele python=3.9 22 | conda activate seele 23 | # Example for CUDA 12.4: 24 | pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 25 | pip3 install -r requirements.txt 26 | ``` 27 | **Note:** [PyTorch](https://pytorch.org/) installation varies by system. Please ensure you install the appropriate version for your hardware. 28 | 29 | ### Dataset 30 | We use datasets from **MipNeRF360** and **Tank & Temple**, which can be downloaded from the authors' official [website](https://jonbarron.info/mipnerf360/). The dataset should be organized in the following structure: 31 | ``` 32 | dataset 33 | └── seele 34 | └── [bicycle|bonsai|counter|train|truck|playroom|drjohnson|...] 35 | ├── images 36 | └── sparse 37 | ``` 38 | 39 | ## 🚀 Training and Evaluation 40 | This section provides detailed instructions on how to **train**, **cluster**, **fine-tune**, and **render** the model using our provided scripts. We also provide **standalone evaluation scripts** for assessing the trained model. 41 | 42 | ### 🔄 One-Click Pipeline: Run Everything at Once 43 | For convenience, you can use the `run_all.sh` script to **automate the entire process** from training to rendering in a single command: 44 | ```shell 45 | bash scripts/run_all.sh 46 | ``` 47 | **Note:** By default, all scripts will run on an exmaple scene "**Counter**" from **MipNeRF360**. If you want to train on other datasets, please modify the `datasets` variable in the script accordingly. 48 | 49 | ### 🏗️ Step-by-Step Training and Rendering 50 | #### 1. Train the 3DGS Model (30,000 Iterations) 51 | To train the **3D Gaussian Splatting (3DGS) model**, use: 52 | ```shell 53 | bash scripts/run_train.sh seele 54 | ``` 55 | 56 | #### 2. Cluster the Trained Model 57 | Once training is complete, apply **k-means clustering** to the trained model with: 58 | ```shell 59 | bash scripts/generate_cluster.sh seele 60 | ``` 61 | 62 | #### 3. Fine-Tune the Clustered Model 63 | After clustering, fine-tune the model for better optimization: 64 | ```shell 65 | bash scripts/run_finetune.sh seele 66 | ``` 67 | 68 | #### 4. Render the Final Output with SeeLe 69 | To generate the rendered images using the fine-tuned model, run: 70 | ```shell 71 | bash scripts/run_seele_render.sh seele 72 | ``` 73 | 74 | ### 🎨 Evaluation 75 | After training and fine-tuning, you can **evaluate the model** using the following standalone scripts: 76 | 77 | #### 1. Render with `seele_render.py` 78 | Renders a **SeeLe** model with optional fine-tuning: 79 | ```shell 80 | python3 seele_render.py -m [--load_finetune] [--debug] 81 | ``` 82 | - **With `--load_finetune`**: Loads the **fine-tuned** model for improved rendering quality. Otherwise, loads the model **before fine-tuning**(output from `generate_cluster.py`). 83 | - **With `--debug`**: Prints the execution time per rendering. 84 | 85 | #### 2. Asynchronous Rendering with `async_seele_render.py` 86 | Uses **CUDA Stream API** for **efficient memory management**, asynchronously loading fine-tuned Gaussian point clouds: 87 | ```shell 88 | python3 async_seele_render.py -m [--debug] 89 | ``` 90 | 91 | #### 3. Visualize in GUI with `render_video.py` 92 | Interactively preview rendered results in a GUI: 93 | ```shell 94 | python3 render_video.py -m --use_gui [--load_seele] 95 | ``` 96 | - **With `--load_seele`**: Loads the **fine-tuned SeeLe** model. Otherwise, loads the **original** model. 97 | 98 | ## 🏋️‍♂️ Validate with a Pretrained Model 99 | To verify the correctness of **SeeLe**, we provide an example(dataset and checkpoint) for evaluation. You can download it [here](https://drive.google.com/file/d/1xfqSLFSLvx5IrpEZU62dw7xm1YZHiyYu/view?usp=sharing). This example includes the following key components: 100 | 101 | - **clusters** — The fine-tuned **SeeLe** model. 102 | - **point_cloud** — The original **3DGS** checkpoint. 103 | 104 | You can use this checkpoint to test the pipeline and ensure everything is working correctly. 105 | 106 | ## 🙏 Acknowledgments 107 | 108 | Our work is largely based on the implementation of **[3DGS](https://github.com/graphdeco-inria/gaussian-splatting)**, with significant modifications and optimizations to improve performance for mobile devices. Our key improvements include: 109 | 110 | - **`submodules/seele-gaussian-rasterzation`** — Optimized **[diff_gaussians_splatting](https://github.com/graphdeco-inria/diff-gaussian-rasterization/tree/9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0)** with **Opti** and **CR** techniques. 111 | - **`generate_cluster.py`** — Implements **k-means clustering** to partition the scene into multiple clusters. 112 | - **`finetune.py`** — Fine-tunes each cluster separately and saves the trained models. 113 | - **`seele_render.py`** — A modified version of `render.py`, designed to **load and render SeeLe models**. 114 | - **`async_seele_render.py`** — Utilizes **CUDA stream API** for **asynchronous memory optimization** across different clusters. 115 | - **`render_video.py`** — Uses **pyglet** to render images in a GUI. The `--load_finetune` option enables **SeeLe model rendering**. 116 | 117 | For more technical details, please refer to our [paper](https://arxiv.org/abs/2503.05168). 118 | 119 | ## 📬 Contact 120 | If you have any questions, feel free to reach out to: 121 | 122 | - **Xiaotong Huang** — [hxt0512@sjtu.edu.cn](mailto:hxt0512@sjtu.edu.cn) 123 | - **He Zhu** — [2394241800@qq.com](mailto:2394241800@qq.com) 124 | 125 | We appreciate your interest in **SeeLe**! 126 | 127 | ## 📖 Citation 128 | If you find this work helpful, please kindly consider citing our paper: 129 | ``` 130 | @article{huang2025seele, 131 |   title={SeeLe: A Unified Acceleration Framework for Real-Time Gaussian Splatting}, 132 |   author={Xiaotong Huang and He Zhu and Zihan Liu and Weikai Lin and Xiaohong Liu and Zhezhi He and Jingwen Leng and Minyi Guo and Yu Feng}, 133 |   journal={arXiv preprint arXiv:2503.05168}, 134 |   year={2025} 135 | } 136 | ``` 137 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._depths = "" 54 | self._resolution = -1 55 | self._white_background = False 56 | self.train_test_exp = False 57 | self.data_device = "cuda" 58 | self.eval = False 59 | super().__init__(parser, "Loading Parameters", sentinel) 60 | 61 | def extract(self, args): 62 | g = super().extract(args) 63 | g.source_path = os.path.abspath(g.source_path) 64 | return g 65 | 66 | class PipelineParams(ParamGroup): 67 | def __init__(self, parser): 68 | self.convert_SHs_python = False 69 | self.compute_cov3D_python = False 70 | self.debug = False 71 | self.antialiasing = False 72 | super().__init__(parser, "Pipeline Parameters") 73 | 74 | class OptimizationParams(ParamGroup): 75 | def __init__(self, parser): 76 | self.iterations = 30_000 77 | self.position_lr_init = 0.00016 78 | self.position_lr_final = 0.0000016 79 | self.position_lr_delay_mult = 0.01 80 | self.position_lr_max_steps = 30_000 81 | self.feature_lr = 0.0025 82 | self.opacity_lr = 0.025 83 | self.scaling_lr = 0.005 84 | self.rotation_lr = 0.001 85 | self.exposure_lr_init = 0.01 86 | self.exposure_lr_final = 0.001 87 | self.exposure_lr_delay_steps = 0 88 | self.exposure_lr_delay_mult = 0.0 89 | self.percent_dense = 0.01 90 | self.lambda_dssim = 0.2 91 | self.densification_interval = 100 92 | self.opacity_reset_interval = 3000 93 | self.densify_from_iter = 500 94 | self.densify_until_iter = 15_000 95 | self.densify_grad_threshold = 0.0002 96 | self.depth_l1_weight_init = 1.0 97 | self.depth_l1_weight_final = 0.01 98 | self.random_background = False 99 | self.optimizer_type = "default" 100 | super().__init__(parser, "Optimization Parameters") 101 | 102 | def get_combined_args(parser : ArgumentParser): 103 | cmdlne_string = sys.argv[1:] 104 | cfgfile_string = "Namespace()" 105 | args_cmdline = parser.parse_args(cmdlne_string) 106 | 107 | try: 108 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 109 | print("Looking for config file in", cfgfilepath) 110 | with open(cfgfilepath) as cfg_file: 111 | print("Config file found: {}".format(cfgfilepath)) 112 | cfgfile_string = cfg_file.read() 113 | except TypeError: 114 | print("Config file not found at") 115 | pass 116 | args_cfgfile = eval(cfgfile_string) 117 | 118 | merged_dict = vars(args_cfgfile).copy() 119 | for k,v in vars(args_cmdline).items(): 120 | if v != None: 121 | merged_dict[k] = v 122 | return Namespace(**merged_dict) 123 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-MVCLab/SeeLe/867c009c7da8fd6c497df47985b41d60cdc4f4e0/assets/teaser.png -------------------------------------------------------------------------------- /async_seele_render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import numpy as np 12 | import joblib 13 | import torch 14 | from scene import Scene 15 | import os 16 | from tqdm import tqdm 17 | from os import makedirs 18 | from gaussian_renderer import render 19 | import torchvision 20 | from utils.general_utils import safe_state 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args 23 | # from gaussian_renderer import GaussianModel 24 | from gaussian_renderer import GaussianModel, GaussianStreamManager 25 | try: 26 | from diff_gaussian_rasterization import SparseGaussianAdam 27 | SPARSE_ADAM_AVAILABLE = True 28 | except: 29 | SPARSE_ADAM_AVAILABLE = False 30 | 31 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh, args): 32 | # Initialize paths and configuration 33 | render_path = os.path.join(model_path, name, f"ours_{iteration}", "renders") 34 | gts_path = os.path.join(model_path, name, f"ours_{iteration}", "gt") 35 | makedirs(render_path, exist_ok=True) 36 | makedirs(gts_path, exist_ok=True) 37 | 38 | # Load cluster data 39 | cluster_data = joblib.load(os.path.join(model_path, "clusters", "clusters.pkl")) 40 | K = len(cluster_data["cluster_viewpoint"]) 41 | 42 | # Load all Gaussians to CPU 43 | cluster_gaussians = [ 44 | torch.load(os.path.join(model_path, f"clusters/finetune/point_cloud_{cid}.pth"), map_location="cpu") 45 | for cid in range(K) 46 | ] 47 | 48 | labels = cluster_data[f"{name}_labels"] 49 | 50 | stream_manager = GaussianStreamManager( 51 | cluster_gaussians=cluster_gaussians, 52 | initial_cid=labels[0] 53 | ) 54 | 55 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 56 | if idx + 1 < len(views): 57 | next_cid = labels[idx+1] 58 | stream_manager.preload_next(next_cid) 59 | 60 | gaussians.restore_gaussians(stream_manager.get_current()) 61 | 62 | rendering = render( 63 | view, gaussians, pipeline, background, 64 | use_trained_exp=train_test_exp, 65 | separate_sh=separate_sh, 66 | rasterizer_type="CR" 67 | )["render"] 68 | 69 | torch.cuda.current_stream().wait_stream(stream_manager.load_stream) 70 | gt = view.original_image[0:3, :, :] 71 | if args.train_test_exp: 72 | rendering = rendering[..., rendering.shape[-1]//2:] 73 | gt = gt[..., gt.shape[-1]//2:] 74 | 75 | torchvision.utils.save_image(rendering, os.path.join(render_path, f"{idx:05d}.png")) 76 | torchvision.utils.save_image(gt, os.path.join(gts_path, f"{idx:05d}.png")) 77 | 78 | stream_manager.switch_gaussians() 79 | 80 | stream_manager.cleanup() 81 | 82 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, separate_sh: bool, args: ArgumentParser): 83 | with torch.no_grad(): 84 | gaussians = GaussianModel(dataset.sh_degree) 85 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 86 | 87 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 88 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 89 | 90 | if not skip_train: 91 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args) 92 | 93 | if not skip_test: 94 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args) 95 | 96 | if __name__ == "__main__": 97 | # Set up command line argument parser 98 | parser = ArgumentParser(description="Testing script parameters") 99 | model = ModelParams(parser, sentinel=True) 100 | pipeline = PipelineParams(parser) 101 | parser.add_argument("--iteration", default=-1, type=int) 102 | parser.add_argument("--skip_train", action="store_true") 103 | parser.add_argument("--skip_test", action="store_true") 104 | parser.add_argument("--quiet", action="store_true") 105 | args = get_combined_args(parser) 106 | args.data_device = 'cpu' 107 | print("Rendering " + args.model_path) 108 | # Initialize system state (RNG) 109 | safe_state(args.quiet) 110 | 111 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE, args) -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import joblib 15 | from random import randint 16 | from utils.loss_utils import l1_loss, ssim 17 | from gaussian_renderer import render, network_gui 18 | import sys 19 | from scene import Scene, GaussianModel 20 | from utils.general_utils import safe_state, get_expon_lr_func 21 | import uuid 22 | from tqdm import tqdm 23 | from utils.image_utils import psnr 24 | from argparse import ArgumentParser, Namespace 25 | from arguments import ModelParams, PipelineParams, OptimizationParams 26 | import gc 27 | 28 | try: 29 | from torch.utils.tensorboard import SummaryWriter 30 | TENSORBOARD_FOUND = True 31 | except ImportError: 32 | TENSORBOARD_FOUND = False 33 | 34 | try: 35 | from fused_ssim import fused_ssim 36 | FUSED_SSIM_AVAILABLE = True 37 | except: 38 | FUSED_SSIM_AVAILABLE = False 39 | 40 | try: 41 | from diff_gaussian_rasterization import SparseGaussianAdam 42 | SPARSE_ADAM_AVAILABLE = True 43 | except: 44 | SPARSE_ADAM_AVAILABLE = False 45 | 46 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 47 | clusters_data_path = os.path.join(dataset.model_path, "clusters") 48 | 49 | cluster_data = joblib.load(os.path.join(clusters_data_path, "clusters.pkl")) 50 | K = len(cluster_data["cluster_viewpoint"]) 51 | 52 | finetune_path = os.path.join(clusters_data_path, "finetune") 53 | os.makedirs(finetune_path, exist_ok=True) 54 | dataset.finetune_path = finetune_path 55 | 56 | gaussians = GaussianModel(dataset.sh_degree, opt.optimizer_type) 57 | scene = Scene(dataset, gaussians, shuffle=False) 58 | for cid in range(K): 59 | print(f"----------------- training cluster {cid} -----------------") 60 | viewpoint_indices = cluster_data["cluster_viewpoint"][cid].tolist() 61 | (gaussian_ids, lens) = cluster_data["cluster_gaussians"][cid] 62 | (model_params, first_iter) = torch.load(checkpoint, weights_only=False) 63 | dataset.cid = cid 64 | dataset.viewpoint_indices = viewpoint_indices 65 | gaussians.restore_models(model_params, (gaussian_ids, lens), opt) 66 | training_cluster(dataset, opt, pipe, gaussians, scene, first_iter, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from) 67 | 68 | del model_params 69 | torch.cuda.empty_cache() 70 | gc.collect() 71 | 72 | def training_cluster(dataset, opt, pipe, gaussians, scene, first_iter, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 73 | if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam": 74 | sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].") 75 | 76 | # tb_writer = prepare_output_and_logger(dataset) 77 | # if checkpoint: 78 | # (model_params, first_iter) = torch.load(checkpoint) 79 | # gaussians.restore(model_params, opt) 80 | 81 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 82 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 83 | 84 | iter_start = torch.cuda.Event(enable_timing = True) 85 | iter_end = torch.cuda.Event(enable_timing = True) 86 | 87 | use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE 88 | depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations) 89 | 90 | trainCameras = scene.getTrainCameras().copy() 91 | trainCameras = [trainCameras[view_id] for view_id in dataset.viewpoint_indices] 92 | 93 | viewpoint_stack = trainCameras.copy() 94 | viewpoint_indices = list(range(len(viewpoint_stack))) 95 | ema_loss_for_log = 0.0 96 | ema_Ll1depth_for_log = 0.0 97 | 98 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 99 | first_iter += 1 100 | for iteration in range(first_iter, opt.iterations + 1): 101 | gaussians.update_learning_rate(iteration) 102 | 103 | # # Every 1000 its we increase the levels of SH up to a maximum degree 104 | # if iteration % 1000 == 0: 105 | # gaussians.oneupSHdegree() 106 | 107 | # Pick a random Camera 108 | if not viewpoint_stack: 109 | viewpoint_stack = trainCameras.copy() 110 | viewpoint_indices = list(range(len(viewpoint_stack))) 111 | rand_idx = randint(0, len(viewpoint_indices) - 1) 112 | viewpoint_cam = viewpoint_stack.pop(rand_idx) 113 | vind = viewpoint_indices.pop(rand_idx) 114 | 115 | # Render 116 | if (iteration - 1) == debug_from: 117 | pipe.debug = True 118 | 119 | bg = torch.rand((3), device="cuda") if opt.random_background else background 120 | 121 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE) 122 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 123 | 124 | # if viewpoint_cam.alpha_mask is not None: 125 | # alpha_mask = viewpoint_cam.alpha_mask.cuda() 126 | # image *= alpha_mask 127 | 128 | # Loss 129 | gt_image = viewpoint_cam.original_image.cuda() 130 | Ll1 = l1_loss(image, gt_image) 131 | if FUSED_SSIM_AVAILABLE: 132 | ssim_value = fused_ssim(image.unsqueeze(0), gt_image.unsqueeze(0)) 133 | else: 134 | ssim_value = ssim(image, gt_image) 135 | 136 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) 137 | 138 | # Depth regularization 139 | Ll1depth_pure = 0.0 140 | if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable: 141 | invDepth = render_pkg["depth"] 142 | mono_invdepth = viewpoint_cam.invdepthmap.cuda() 143 | depth_mask = viewpoint_cam.depth_mask.cuda() 144 | 145 | Ll1depth_pure = torch.abs((invDepth - mono_invdepth) * depth_mask).mean() 146 | Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure 147 | loss += Ll1depth 148 | Ll1depth = Ll1depth.item() 149 | else: 150 | Ll1depth = 0 151 | 152 | loss.backward() 153 | 154 | iter_end.record() 155 | 156 | with torch.no_grad(): 157 | # Progress bar 158 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 159 | ema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_log 160 | 161 | if iteration % 10 == 0: 162 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}"}) 163 | progress_bar.update(10) 164 | if iteration == opt.iterations: 165 | progress_bar.close() 166 | 167 | # Log and save 168 | if (iteration in saving_iterations): 169 | print("\n[ITER {} Cid {}] Saving Gaussians".format(iteration, dataset.cid)) 170 | # saving_gaussians(dataset, pipe, gaussians, trainCameras) 171 | torch.save(gaussians.capture_gaussians(), os.path.join(dataset.finetune_path, f"point_cloud_{dataset.cid}.pth")) 172 | 173 | # Densification 174 | if iteration < opt.densify_until_iter: 175 | # Keep track of max radii in image-space for pruning 176 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 177 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 178 | 179 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 180 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 181 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii) 182 | 183 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 184 | gaussians.reset_opacity() 185 | 186 | # Optimizer step 187 | if iteration < opt.iterations: 188 | gaussians.exposure_optimizer.step() 189 | gaussians.exposure_optimizer.zero_grad(set_to_none = True) 190 | if use_sparse_adam: 191 | visible = radii > 0 192 | gaussians.optimizer.step(visible, radii.shape[0]) 193 | gaussians.optimizer.zero_grad(set_to_none = True) 194 | else: 195 | gaussians.optimizer.step() 196 | gaussians.optimizer.zero_grad(set_to_none = True) 197 | 198 | if (iteration in checkpoint_iterations): 199 | print("\n[ITER {} Cid {}] Saving Checkpoint".format(iteration, dataset.cid)) 200 | torch.save((gaussians.capture(), iteration), os.path.join(dataset.finetune_path, "chkpnt" + str(iteration) + f"_{dataset.cid}.pth")) 201 | 202 | def prepare_output_and_logger(args): 203 | if not args.model_path: 204 | if os.getenv('OAR_JOB_ID'): 205 | unique_str=os.getenv('OAR_JOB_ID') 206 | else: 207 | unique_str = str(uuid.uuid4()) 208 | args.model_path = os.path.join("./output/", unique_str[0:10]) 209 | 210 | # Set up output folder 211 | print("Output folder: {}".format(args.model_path)) 212 | os.makedirs(args.model_path, exist_ok = True) 213 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 214 | cfg_log_f.write(str(Namespace(**vars(args)))) 215 | 216 | # Create Tensorboard writer 217 | tb_writer = None 218 | if TENSORBOARD_FOUND: 219 | tb_writer = SummaryWriter(args.model_path) 220 | else: 221 | print("Tensorboard not available: not logging progress") 222 | return tb_writer 223 | 224 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp): 225 | if tb_writer: 226 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 227 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 228 | tb_writer.add_scalar('iter_time', elapsed, iteration) 229 | 230 | # Report test and samples of training set 231 | if iteration in testing_iterations: 232 | torch.cuda.empty_cache() 233 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 234 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 235 | 236 | for config in validation_configs: 237 | if config['cameras'] and len(config['cameras']) > 0: 238 | l1_test = 0.0 239 | psnr_test = 0.0 240 | for idx, viewpoint in enumerate(config['cameras']): 241 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 242 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 243 | if train_test_exp: 244 | image = image[..., image.shape[-1] // 2:] 245 | gt_image = gt_image[..., gt_image.shape[-1] // 2:] 246 | if tb_writer and (idx < 5): 247 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 248 | if iteration == testing_iterations[0]: 249 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 250 | l1_test += l1_loss(image, gt_image).mean().double() 251 | psnr_test += psnr(image, gt_image).mean().double() 252 | psnr_test /= len(config['cameras']) 253 | l1_test /= len(config['cameras']) 254 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 255 | if tb_writer: 256 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 257 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 258 | 259 | if tb_writer: 260 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 261 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 262 | torch.cuda.empty_cache() 263 | 264 | if __name__ == "__main__": 265 | # Set up command line argument parser 266 | parser = ArgumentParser(description="Training script parameters") 267 | lp = ModelParams(parser) 268 | op = OptimizationParams(parser) 269 | pp = PipelineParams(parser) 270 | parser.add_argument('--ip', type=str, default="127.0.0.1") 271 | parser.add_argument('--port', type=int, default=6009) 272 | parser.add_argument('--debug_from', type=int, default=-1) 273 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 274 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[31_000]) 275 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[31_000]) 276 | parser.add_argument("--quiet", action="store_true") 277 | parser.add_argument('--disable_viewer', action='store_true', default=False) 278 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 279 | parser.add_argument("--start_checkpoint", type=str, default = None) 280 | args = parser.parse_args(sys.argv[1:]) 281 | # op.densify_until_iter = args.iterations 282 | op.position_lr_max_steps = args.iterations 283 | args.save_iterations.append(args.iterations) 284 | 285 | print("Optimizing " + args.model_path) 286 | 287 | # Initialize system state (RNG) 288 | safe_state(args.quiet) 289 | 290 | # Start GUI server, configure and run training 291 | if not args.disable_viewer: 292 | network_gui.init(args.ip, args.port) 293 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 294 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 295 | 296 | # All done 297 | print("\nTraining complete.") 298 | -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | import time 15 | 16 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 17 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 18 | tanks_and_temples_scenes = ["truck", "train"] 19 | deep_blending_scenes = ["drjohnson", "playroom"] 20 | 21 | parser = ArgumentParser(description="Full evaluation script parameters") 22 | parser.add_argument("--skip_training", action="store_true") 23 | parser.add_argument("--skip_rendering", action="store_true") 24 | parser.add_argument("--skip_metrics", action="store_true") 25 | parser.add_argument("--output_path", default="./eval") 26 | parser.add_argument("--use_depth", action="store_true") 27 | parser.add_argument("--use_expcomp", action="store_true") 28 | parser.add_argument("--fast", action="store_true") 29 | parser.add_argument("--aa", action="store_true") 30 | 31 | 32 | 33 | 34 | args, _ = parser.parse_known_args() 35 | 36 | all_scenes = [] 37 | all_scenes.extend(mipnerf360_outdoor_scenes) 38 | all_scenes.extend(mipnerf360_indoor_scenes) 39 | all_scenes.extend(tanks_and_temples_scenes) 40 | all_scenes.extend(deep_blending_scenes) 41 | 42 | if not args.skip_training or not args.skip_rendering: 43 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 44 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 45 | parser.add_argument("--deepblending", "-db", required=True, type=str) 46 | args = parser.parse_args() 47 | if not args.skip_training: 48 | common_args = " --disable_viewer --quiet --eval --test_iterations -1 " 49 | 50 | if args.aa: 51 | common_args += " --antialiasing " 52 | if args.use_depth: 53 | common_args += " -d depths2/ " 54 | 55 | if args.use_expcomp: 56 | common_args += " --exposure_lr_init 0.001 --exposure_lr_final 0.0001 --exposure_lr_delay_steps 5000 --exposure_lr_delay_mult 0.001 --train_test_exp " 57 | 58 | if args.fast: 59 | common_args += " --optimizer_type sparse_adam " 60 | 61 | start_time = time.time() 62 | for scene in mipnerf360_outdoor_scenes: 63 | source = args.mipnerf360 + "/" + scene 64 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 65 | for scene in mipnerf360_indoor_scenes: 66 | source = args.mipnerf360 + "/" + scene 67 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 68 | m360_timing = (time.time() - start_time)/60.0 69 | 70 | start_time = time.time() 71 | for scene in tanks_and_temples_scenes: 72 | source = args.tanksandtemples + "/" + scene 73 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 74 | tandt_timing = (time.time() - start_time)/60.0 75 | 76 | start_time = time.time() 77 | for scene in deep_blending_scenes: 78 | source = args.deepblending + "/" + scene 79 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 80 | db_timing = (time.time() - start_time)/60.0 81 | 82 | with open(os.path.join(args.output_path,"timing.txt"), 'w') as file: 83 | file.write(f"m360: {m360_timing} minutes \n tandt: {tandt_timing} minutes \n db: {db_timing} minutes\n") 84 | 85 | if not args.skip_rendering: 86 | all_sources = [] 87 | for scene in mipnerf360_outdoor_scenes: 88 | all_sources.append(args.mipnerf360 + "/" + scene) 89 | for scene in mipnerf360_indoor_scenes: 90 | all_sources.append(args.mipnerf360 + "/" + scene) 91 | for scene in tanks_and_temples_scenes: 92 | all_sources.append(args.tanksandtemples + "/" + scene) 93 | for scene in deep_blending_scenes: 94 | all_sources.append(args.deepblending + "/" + scene) 95 | 96 | common_args = " --quiet --eval --skip_train" 97 | 98 | if args.aa: 99 | common_args += " --antialiasing " 100 | if args.use_expcomp: 101 | common_args += " --train_test_exp " 102 | 103 | for scene, source in zip(all_scenes, all_sources): 104 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 105 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 106 | 107 | if not args.skip_metrics: 108 | scenes_string = "" 109 | for scene in all_scenes: 110 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 111 | 112 | os.system("python metrics.py -m " + scenes_string) 113 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel, GaussianStreamManager 16 | from utils.sh_utils import eval_sh 17 | 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp=False, rasterizer_type=""): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug, 49 | rasterizer_type=rasterizer_type, 50 | ) 51 | 52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 53 | 54 | means3D = pc.get_xyz 55 | means2D = screenspace_points 56 | opacity = pc.get_opacity 57 | 58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 59 | # scaling / rotation by the rasterizer. 60 | scales = None 61 | rotations = None 62 | cov3D_precomp = None 63 | 64 | if pipe.compute_cov3D_python: 65 | cov3D_precomp = pc.get_covariance(scaling_modifier) 66 | else: 67 | scales = pc.get_scaling 68 | rotations = pc.get_rotation 69 | 70 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 71 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 72 | shs = None 73 | colors_precomp = None 74 | if override_color is None: 75 | if pipe.convert_SHs_python: 76 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 77 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 78 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 79 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 80 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 81 | else: 82 | if separate_sh: 83 | dc, shs = pc.get_features_dc, pc.get_features_rest 84 | else: 85 | shs = pc.get_features 86 | else: 87 | colors_precomp = override_color 88 | 89 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 90 | if separate_sh: 91 | returns = rasterizer( 92 | means3D = means3D, 93 | means2D = means2D, 94 | dc = dc, 95 | shs = shs, 96 | colors_precomp = colors_precomp, 97 | opacities = opacity, 98 | scales = scales, 99 | rotations = rotations, 100 | cov3D_precomp = cov3D_precomp) 101 | else: 102 | returns = rasterizer( 103 | means3D = means3D, 104 | means2D = means2D, 105 | shs = shs, 106 | colors_precomp = colors_precomp, 107 | opacities = opacity, 108 | scales = scales, 109 | rotations = rotations, 110 | cov3D_precomp = cov3D_precomp) 111 | visible_gaussians = None 112 | if rasterizer_type == "Mark": 113 | rendered_image, visible_gaussians, radii = returns 114 | else: 115 | rendered_image, radii = returns 116 | 117 | # Apply exposure to rendered image (training only) 118 | if use_trained_exp: 119 | exposure = pc.get_exposure_from_name(viewpoint_camera.image_name) 120 | rendered_image = torch.matmul(rendered_image.permute(1, 2, 0), exposure[:3, :3]).permute(2, 0, 1) + exposure[:3, 3, None, None] 121 | 122 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 123 | # They will be excluded from value updates used in the splitting criteria. 124 | rendered_image = rendered_image.clamp(0, 1) 125 | out = { 126 | "render": rendered_image, 127 | "viewspace_points": screenspace_points, 128 | "visibility_filter" : (radii > 0).nonzero(), 129 | "radii": radii, 130 | "depth" : None, 131 | "visible_gaussians": visible_gaussians 132 | } 133 | 134 | return out 135 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /generate_cluster.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | from utils.general_utils import safe_state 19 | from utils.graphics_utils import getWorld2View2 20 | 21 | import joblib 22 | import numpy as np 23 | from sklearn.cluster import KMeans 24 | from scipy.spatial.transform import Rotation as Rot 25 | 26 | from argparse import ArgumentParser 27 | from arguments import ModelParams, PipelineParams, get_combined_args 28 | from gaussian_renderer import GaussianModel 29 | try: 30 | from diff_gaussian_rasterization import SparseGaussianAdam 31 | SPARSE_ADAM_AVAILABLE = True 32 | except: 33 | SPARSE_ADAM_AVAILABLE = False 34 | 35 | def generate_features_from_Rt(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 36 | # R_w2c: R.T, t_w2c: t 37 | # R_c2w: R, t_c2w: -R.T @ t 38 | w2c = getWorld2View2(R, t, translate=translate, scale=scale) 39 | c2w = np.linalg.inv(w2c) 40 | 41 | rot = Rot.from_matrix(c2w[:3, :3]) # This function will orthonormalize R automatically. 42 | q = rot.as_quat(canonical=True) 43 | feature_vector = np.concatenate([c2w[:3, 3], q]) 44 | return feature_vector 45 | 46 | def extract_features(views): 47 | features = [] 48 | for view in views: 49 | features.append(generate_features_from_Rt(view.R, view.T)) 50 | features = np.stack(features, axis=0) 51 | return features 52 | 53 | def merge_neighbor_mask(centers, cluster_masks, labels, neigh): 54 | K, P = cluster_masks.shape 55 | 56 | total_shared = total_exclusive = 0 57 | merge_gaussians, merge_viewpoint = [], [] 58 | cluster_masks = cluster_masks.astype(np.uint32) 59 | average_gaussians = 0 60 | for cid in range(K): 61 | base = centers[cid:cid+1] 62 | dist2 = np.square(base - centers).sum(1) 63 | merge_clusters = np.argsort(dist2)[:neigh + 1] 64 | 65 | viewpoints = np.concatenate([(labels == cluster).nonzero()[0] for cluster in merge_clusters]) 66 | merge_viewpoint.append(viewpoints) 67 | 68 | gaussians_counter = cluster_masks[merge_clusters].sum(axis=0) 69 | shared_mask = (gaussians_counter > ((neigh + 1) // 2)) 70 | exclusive_mask = np.logical_xor(shared_mask, (gaussians_counter != 0)) 71 | 72 | shared, exclusive = map(lambda x: x.nonzero()[0], [shared_mask, exclusive_mask]) 73 | gaussian_ids = np.concatenate([shared, exclusive], axis=0) 74 | 75 | lens = (len(shared), len(exclusive)) 76 | merge_gaussians.append((gaussian_ids, lens)) 77 | 78 | total_shared += lens[0] 79 | total_exclusive += lens[1] 80 | average_gaussians += lens[0] + lens[1] 81 | 82 | total_shared //= K 83 | total_exclusive //= K 84 | average_gaussians //= K 85 | print(f"Total gaussians: {P}, average shared gaussians: {total_shared}, average exclusive gaussians: {total_exclusive}, average number of gaussians: {average_gaussians}") 86 | print(f"Expansion ratio: {(total_exclusive + total_shared) / P}") 87 | return merge_gaussians, merge_viewpoint 88 | 89 | def render_set(views, gaussians, pipeline, background, train_test_exp, separate_sh): 90 | gaussian_masks = [] 91 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 92 | out = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh, rasterizer_type="Mark") 93 | visible_gaussians = out["visible_gaussians"].cpu().numpy() 94 | gaussian_masks.append(visible_gaussians != 0) 95 | 96 | return np.stack(gaussian_masks, axis=0) 97 | 98 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, args, separate_sh: bool): 99 | with torch.no_grad(): 100 | gaussians = GaussianModel(dataset.sh_degree) 101 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 102 | train_features = extract_features(scene.getTrainCameras()) 103 | test_features = extract_features(scene.getTestCameras()) 104 | kmeans = KMeans(n_clusters=args.k, random_state=42, n_init='auto').fit(train_features) 105 | centers = kmeans.cluster_centers_ 106 | train_labels = kmeans.labels_ 107 | test_labels = kmeans.predict(test_features) 108 | 109 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 110 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 111 | view_gaussian_masks = render_set(scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh) 112 | cluster_gaussian_masks = np.stack([np.any(view_gaussian_masks[train_labels == j], axis=0) for j in range(args.k)], axis=0) 113 | merge_gaussians, merge_viewpoint = merge_neighbor_mask(centers, cluster_gaussian_masks, train_labels, neigh=args.n) 114 | 115 | save_path = os.path.join(dataset.model_path, "clusters") 116 | makedirs(save_path, exist_ok=True) 117 | data = { 118 | "cluster_gaussians": merge_gaussians, 119 | "cluster_viewpoint": merge_viewpoint, 120 | "train_labels": train_labels, 121 | "test_labels": test_labels, 122 | "centers": centers, 123 | } 124 | joblib.dump(data, os.path.join(save_path, "clusters.pkl")) 125 | joblib.dump(kmeans, os.path.join(save_path, "kmeans_model.pkl")) 126 | 127 | if __name__ == "__main__": 128 | # Set up command line argument parser 129 | parser = ArgumentParser(description="Testing script parameters") 130 | model = ModelParams(parser, sentinel=True) 131 | pipeline = PipelineParams(parser) 132 | parser.add_argument("--iteration", default=-1, type=int) 133 | parser.add_argument("--quiet", action="store_true") 134 | parser.add_argument("-k", type=int, default = 24) 135 | parser.add_argument("-n", type=int, default = 4) 136 | args = get_combined_args(parser) 137 | print("Generating clusters for" + args.model_path) 138 | print(f"k: {args.k}, n: {args.n}") 139 | # Initialize system state (RNG) 140 | safe_state(args.quiet) 141 | 142 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args, SPARSE_ADAM_AVAILABLE) 143 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | # test_dir = Path(scene_dir) / "train" 54 | 55 | for method in os.listdir(test_dir): 56 | print("Method:", method) 57 | 58 | full_dict[scene_dir][method] = {} 59 | per_view_dict[scene_dir][method] = {} 60 | full_dict_polytopeonly[scene_dir][method] = {} 61 | per_view_dict_polytopeonly[scene_dir][method] = {} 62 | 63 | method_dir = test_dir / method 64 | gt_dir = method_dir/ "gt" 65 | renders_dir = method_dir / "renders" 66 | renders, gts, image_names = readImages(renders_dir, gt_dir) 67 | 68 | ssims = [] 69 | psnrs = [] 70 | lpipss = [] 71 | 72 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 73 | ssims.append(ssim(renders[idx], gts[idx])) 74 | psnrs.append(psnr(renders[idx], gts[idx])) 75 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 76 | 77 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 78 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 79 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 80 | print("") 81 | 82 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 83 | "PSNR": torch.tensor(psnrs).mean().item(), 84 | "LPIPS": torch.tensor(lpipss).mean().item()}) 85 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 86 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 87 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 88 | 89 | with open(scene_dir + "/results.json", 'w') as fp: 90 | json.dump(full_dict[scene_dir], fp, indent=True) 91 | with open(scene_dir + "/per_view.json", 'w') as fp: 92 | json.dump(per_view_dict[scene_dir], fp, indent=True) 93 | except: 94 | print("Unable to compute metrics for model", scene_dir) 95 | 96 | if __name__ == "__main__": 97 | device = torch.device("cuda:0") 98 | torch.cuda.set_device(device) 99 | 100 | # Set up command line argument parser 101 | parser = ArgumentParser(description="Training script parameters") 102 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 103 | args = parser.parse_args() 104 | evaluate(args.model_paths) 105 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | try: 24 | from diff_gaussian_rasterization import SparseGaussianAdam 25 | SPARSE_ADAM_AVAILABLE = True 26 | except: 27 | SPARSE_ADAM_AVAILABLE = False 28 | 29 | 30 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh): 31 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 32 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 33 | 34 | makedirs(render_path, exist_ok=True) 35 | makedirs(gts_path, exist_ok=True) 36 | 37 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 38 | rendering = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh)["render"] 39 | gt = view.original_image[0:3, :, :] 40 | 41 | if args.train_test_exp: 42 | rendering = rendering[..., rendering.shape[-1] // 2:] 43 | gt = gt[..., gt.shape[-1] // 2:] 44 | 45 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 46 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 47 | 48 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, separate_sh: bool): 49 | with torch.no_grad(): 50 | gaussians = GaussianModel(dataset.sh_degree) 51 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 52 | 53 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 54 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 55 | 56 | if not skip_train: 57 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh) 58 | 59 | if not skip_test: 60 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh) 61 | 62 | if __name__ == "__main__": 63 | # Set up command line argument parser 64 | parser = ArgumentParser(description="Testing script parameters") 65 | model = ModelParams(parser, sentinel=True) 66 | pipeline = PipelineParams(parser) 67 | parser.add_argument("--iteration", default=-1, type=int) 68 | parser.add_argument("--skip_train", action="store_true") 69 | parser.add_argument("--skip_test", action="store_true") 70 | parser.add_argument("--quiet", action="store_true") 71 | args = get_combined_args(parser) 72 | print("Rendering " + args.model_path) 73 | 74 | # Initialize system state (RNG) 75 | safe_state(args.quiet) 76 | 77 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE) -------------------------------------------------------------------------------- /render_video.py: -------------------------------------------------------------------------------- 1 | import pyglet 2 | import numpy as np 3 | import joblib 4 | import torch 5 | import os 6 | import time 7 | from tqdm import tqdm 8 | from argparse import ArgumentParser 9 | from scene import Scene 10 | from gaussian_renderer import render, GaussianModel, GaussianStreamManager 11 | from utils.general_utils import safe_state 12 | from utils.pose_utils import generate_ellipse_path, getWorld2View2 13 | from arguments import ModelParams, PipelineParams, get_combined_args 14 | from generate_cluster import generate_features_from_Rt 15 | import torchvision 16 | SPARSE_ADAM_AVAILABLE = False 17 | 18 | class VideoPlayer: 19 | """Efficient video player using pyglet for 3DGS rendering display.""" 20 | 21 | def __init__(self, width: int, height: int, total_frames: int): 22 | """Initialize the video player window and UI elements. 23 | 24 | Args: 25 | width: Width of the video frame 26 | height: Height of the video frame 27 | total_frames: Total number of frames to be displayed 28 | """ 29 | self.window = pyglet.window.Window( 30 | width=width, 31 | height=height, 32 | caption='3DGS Rendering Viewer' 33 | ) 34 | self.total_frames = total_frames 35 | self.current_frame = 0 36 | self.fps = 0.0 37 | self.last_time = time.time() 38 | 39 | # Initialize texture with blank frame 40 | self._init_texture(width, height) 41 | 42 | # Setup UI elements 43 | self._setup_ui(width, height) 44 | 45 | # Register event handlers 46 | self.window.event(self.on_draw) 47 | 48 | def _init_texture(self, width: int, height: int): 49 | """Initialize the OpenGL texture with blank data.""" 50 | blank_data = np.zeros((height, width, 3), dtype=np.uint8).tobytes() 51 | self.texture = pyglet.image.ImageData( 52 | width, height, 'RGB', blank_data 53 | ).get_texture() 54 | 55 | def _setup_ui(self, width: int, height: int): 56 | """Initialize UI components (FPS counter and progress bar).""" 57 | self.batch = pyglet.graphics.Batch() 58 | 59 | # Frame counter label 60 | self.label = pyglet.text.Label( 61 | '', 62 | x=10, y=height-30, 63 | font_size=16, 64 | color=(255, 255, 255, 255), 65 | batch=self.batch 66 | ) 67 | 68 | # Progress bar (positioned at bottom with 2% margin) 69 | self.progress_bar = pyglet.shapes.Rectangle( 70 | x=width*0.01, y=5, 71 | width=0, height=10, 72 | color=(0, 255, 0), 73 | batch=self.batch 74 | ) 75 | self.progress_bar_max_width = width*0.98 76 | 77 | def update_frame(self, frame_data: np.ndarray): 78 | """Update the display with new frame data. 79 | 80 | Args: 81 | frame_data: Numpy array containing frame data (H,W,3) 82 | """ 83 | # Convert tensor if necessary 84 | if isinstance(frame_data, torch.Tensor): 85 | frame_data = frame_data.detach().cpu().numpy() 86 | 87 | # Ensure correct shape and type 88 | if frame_data.shape[0] == 3: # CHW to HWC 89 | frame_data = frame_data.transpose(1, 2, 0) 90 | if frame_data.dtype != np.uint8: 91 | frame_data = (frame_data * 255).astype(np.uint8) 92 | 93 | # Flip vertically and update texture 94 | frame_data = np.ascontiguousarray(np.flipud(frame_data)) 95 | self.texture = pyglet.image.ImageData( 96 | self.window.width, self.window.height, 97 | 'RGB', frame_data.tobytes() 98 | ).get_texture() 99 | 100 | # Update performance metrics 101 | self._update_perf_metrics() 102 | 103 | # Update UI 104 | self.label.text = f'Frame: {self.current_frame+1}/{self.total_frames} | FPS: {self.fps:.2f}' 105 | self.progress_bar.width = self.progress_bar_max_width * (self.current_frame+1)/self.total_frames 106 | self.current_frame += 1 107 | 108 | def _update_perf_metrics(self): 109 | """Calculate and update FPS metrics.""" 110 | current_time = time.time() 111 | self.fps = 1.0 / max(0.001, current_time - self.last_time) # Avoid division by zero 112 | self.last_time = current_time 113 | 114 | def on_draw(self): 115 | """Window draw event handler.""" 116 | self.window.clear() 117 | if self.texture: 118 | self.texture.blit(0, 0, width=self.window.width, height=self.window.height) 119 | self.batch.draw() 120 | 121 | def predict(X, centers): 122 | distances = np.sum((X[:, np.newaxis, :] - centers) ** 2,axis=2) 123 | labels = np.argmin(distances, axis=1) 124 | return labels 125 | 126 | def extract_features(Rt_list, trans=np.array([0.0, 0.0, 0.0]), scale=1.0): 127 | features = [] 128 | for (R, t) in Rt_list: 129 | features.append(generate_features_from_Rt(R, t, trans, scale)) 130 | return np.stack(features, axis=0) 131 | 132 | def render_set(model_path, views, gaussians, pipeline, background, train_test_exp, separate_sh, args): 133 | total_frame = args.frames 134 | load_seele = args.load_seele 135 | use_gui = args.use_gui 136 | 137 | # prepare the views 138 | poses = generate_ellipse_path(views, total_frame) 139 | Rt_list = [(pose[:3, :3].T, pose[:3, 3]) for pose in poses] 140 | w2c_list = [ 141 | torch.tensor(getWorld2View2(Rt_list[frame][0], Rt_list[frame][1], views[0].trans, views[0].scale)).transpose(0, 1).cuda() 142 | for frame in range(total_frame) 143 | ] 144 | 145 | stream_manager, labels = None, None 146 | if load_seele: 147 | # Load cluster data 148 | cluster_data = joblib.load(os.path.join(model_path, "clusters", "clusters.pkl")) 149 | K = len(cluster_data["cluster_viewpoint"]) 150 | cluster_centers = cluster_data["centers"] 151 | 152 | # Determine the test cluster labels 153 | test_features = extract_features(Rt_list, trans=views[0].trans, scale=views[0].scale) 154 | labels = predict(test_features, cluster_centers) 155 | 156 | # Load all Gaussians to CPU 157 | cluster_gaussians = [ 158 | torch.load(os.path.join(model_path, f"clusters/finetune/point_cloud_{cid}.pth"), map_location="cpu") 159 | for cid in range(K) 160 | ] 161 | 162 | # Initialize stream manager 163 | stream_manager = GaussianStreamManager( 164 | cluster_gaussians=cluster_gaussians, 165 | initial_cid=labels[0] 166 | ) 167 | 168 | # Warm up 169 | for _ in range(5): 170 | render(views[0], gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh) 171 | 172 | def render_view(frame): 173 | view = views[0] 174 | view.world_view_transform = w2c_list[frame] 175 | view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0) 176 | view.camera_center = view.world_view_transform.inverse()[3, :3] 177 | 178 | if load_seele: 179 | # Preload next frame's Gaussians 180 | if frame + 1 < total_frame: 181 | next_cid = labels[frame + 1] 182 | stream_manager.preload_next(next_cid) 183 | 184 | # Restore current Gaussians and render 185 | gaussians.restore_gaussians(stream_manager.get_current()) 186 | rendering = render( 187 | view, gaussians, pipeline, background, 188 | use_trained_exp=train_test_exp, 189 | separate_sh=separate_sh, 190 | rasterizer_type="CR" 191 | )["render"] 192 | 193 | # Synchronize streams and switch buffers 194 | stream_manager.switch_gaussians() 195 | else: 196 | # Standard rendering path 197 | rendering = render( 198 | view, gaussians, pipeline, background, 199 | use_trained_exp=train_test_exp, 200 | separate_sh=separate_sh 201 | )["render"] 202 | 203 | return rendering 204 | 205 | if use_gui: 206 | # Initialize video player 207 | player = VideoPlayer(width=views[0].image_width, height=views[0].image_height, total_frames=total_frame) 208 | 209 | def update_frame(dt): 210 | """Callback function for frame updates.""" 211 | nonlocal stream_manager, gaussians 212 | 213 | if player.current_frame >= args.frames - 1: 214 | pyglet.app.exit() 215 | return 216 | 217 | rendering = render_view(player.current_frame) 218 | # Update display 219 | player.update_frame(rendering) 220 | 221 | # Start rendering loop (target 500 FPS) 222 | pyglet.clock.schedule_interval(update_frame, 1/500.0) 223 | pyglet.app.run() 224 | else: 225 | output_dir = args.output_dir 226 | os.makedirs(output_dir, exist_ok=True) 227 | for frame_idx in range(total_frame): 228 | if load_seele: 229 | print(f"Rendering {frame_idx} image belong to cluster {labels[frame_idx]}") 230 | else: 231 | print(f"Rnedering {frame_idx} image") 232 | rendering = render_view(frame_idx) 233 | torchvision.utils.save_image(rendering, os.path.join(output_dir, '{0:05d}'.format(frame_idx) + ".png")) 234 | 235 | # clean up 236 | if stream_manager is not None: 237 | stream_manager.cleanup() 238 | 239 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, separate_sh: bool, args: ArgumentParser): 240 | with torch.no_grad(): 241 | gaussians = GaussianModel(dataset.sh_degree) 242 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 243 | 244 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 245 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 246 | 247 | render_set(dataset.model_path, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args) 248 | 249 | # Example usage 250 | if __name__ == "__main__": 251 | # Set up command line argument parser 252 | parser = ArgumentParser(description="Testing script parameters") 253 | model = ModelParams(parser, sentinel=True) 254 | pipeline = PipelineParams(parser) 255 | parser.add_argument("--iteration", default=-1, type=int) 256 | parser.add_argument("--frames", default=200, type=int) 257 | parser.add_argument("--quiet", action="store_true") 258 | parser.add_argument("--load_seele", action="store_true") 259 | parser.add_argument("--use_gui", action="store_true") 260 | parser.add_argument('--output_dir', type=str, default="output/videos") 261 | args = get_combined_args(parser) 262 | print("Rendering " + args.model_path) 263 | # Initialize system state (RNG) 264 | safe_state(args.quiet) 265 | 266 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), SPARSE_ADAM_AVAILABLE, args) 267 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | submodules/seele-gaussian-rasterization 2 | submodules/simple-knn 3 | submodules/fused-ssim 4 | 5 | plyfile 6 | scikit-learn 7 | tqdm 8 | opencv-python 9 | joblib 10 | icecream 11 | pyglet -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | if load_iteration: 34 | if load_iteration == -1: 35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 36 | else: 37 | self.loaded_iter = load_iteration 38 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 39 | 40 | self.train_cameras = {} 41 | self.test_cameras = {} 42 | 43 | if os.path.exists(os.path.join(args.source_path, "sparse")): 44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp) 45 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 46 | print("Found transforms_train.json file, assuming Blender data set!") 47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.depths, args.eval) 48 | else: 49 | assert False, "Could not recognize scene type!" 50 | 51 | if not self.loaded_iter: 52 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 53 | dest_file.write(src_file.read()) 54 | json_cams = [] 55 | camlist = [] 56 | if scene_info.test_cameras: 57 | camlist.extend(scene_info.test_cameras) 58 | if scene_info.train_cameras: 59 | camlist.extend(scene_info.train_cameras) 60 | for id, cam in enumerate(camlist): 61 | json_cams.append(camera_to_JSON(id, cam)) 62 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 63 | json.dump(json_cams, file) 64 | 65 | if shuffle: 66 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 67 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 68 | 69 | self.cameras_extent = scene_info.nerf_normalization["radius"] 70 | 71 | for resolution_scale in resolution_scales: 72 | print("Loading Training Cameras") 73 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, False) 74 | print("Loading Test Cameras") 75 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, True) 76 | 77 | if self.loaded_iter: 78 | self.gaussians.load_ply(os.path.join(self.model_path, 79 | "point_cloud", 80 | "iteration_" + str(self.loaded_iter), 81 | "point_cloud.ply"), args.train_test_exp) 82 | else: 83 | self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent) 84 | 85 | def save(self, iteration): 86 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 87 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 88 | exposure_dict = { 89 | image_name: self.gaussians.get_exposure_from_name(image_name).detach().cpu().numpy().tolist() 90 | for image_name in self.gaussians.exposure_mapping 91 | } 92 | 93 | with open(os.path.join(self.model_path, "exposure.json"), "w") as f: 94 | json.dump(exposure_dict, f, indent=2) 95 | 96 | def getTrainCameras(self, scale=1.0): 97 | return self.train_cameras[scale] 98 | 99 | def getTestCameras(self, scale=1.0): 100 | return self.test_cameras[scale] 101 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | from utils.general_utils import PILtoTorch 17 | import cv2 18 | 19 | class Camera(nn.Module): 20 | def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, depth_params, image, invdepthmap, 21 | image_name, uid, 22 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", 23 | train_test_exp = False, is_test_dataset = False, is_test_view = False 24 | ): 25 | super(Camera, self).__init__() 26 | 27 | self.uid = uid 28 | self.colmap_id = colmap_id 29 | self.R = R 30 | self.T = T 31 | self.FoVx = FoVx 32 | self.FoVy = FoVy 33 | self.image_name = image_name 34 | 35 | try: 36 | self.data_device = torch.device(data_device) 37 | except Exception as e: 38 | print(e) 39 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 40 | self.data_device = torch.device("cuda") 41 | 42 | resized_image_rgb = PILtoTorch(image, resolution) 43 | gt_image = resized_image_rgb[:3, ...] 44 | self.alpha_mask = None 45 | if resized_image_rgb.shape[0] == 4: 46 | self.alpha_mask = resized_image_rgb[3:4, ...].to(self.data_device) 47 | else: 48 | self.alpha_mask = torch.ones_like(resized_image_rgb[0:1, ...].to(self.data_device)) 49 | 50 | if train_test_exp and is_test_view: 51 | if is_test_dataset: 52 | self.alpha_mask[..., :self.alpha_mask.shape[-1] // 2] = 0 53 | else: 54 | self.alpha_mask[..., self.alpha_mask.shape[-1] // 2:] = 0 55 | 56 | self.original_image = gt_image.clamp(0.0, 1.0).to(self.data_device) 57 | self.image_width = self.original_image.shape[2] 58 | self.image_height = self.original_image.shape[1] 59 | 60 | self.invdepthmap = None 61 | self.depth_reliable = False 62 | if invdepthmap is not None: 63 | self.depth_mask = torch.ones_like(self.alpha_mask) 64 | self.invdepthmap = cv2.resize(invdepthmap, resolution) 65 | self.invdepthmap[self.invdepthmap < 0] = 0 66 | self.depth_reliable = True 67 | 68 | if depth_params is not None: 69 | if depth_params["scale"] < 0.2 * depth_params["med_scale"] or depth_params["scale"] > 5 * depth_params["med_scale"]: 70 | self.depth_reliable = False 71 | self.depth_mask *= 0 72 | 73 | if depth_params["scale"] > 0: 74 | self.invdepthmap = self.invdepthmap * depth_params["scale"] + depth_params["offset"] 75 | 76 | if self.invdepthmap.ndim != 2: 77 | self.invdepthmap = self.invdepthmap[..., 0] 78 | self.invdepthmap = torch.from_numpy(self.invdepthmap[None]).to(self.data_device) 79 | 80 | self.zfar = 100.0 81 | self.znear = 0.01 82 | 83 | self.trans = trans 84 | self.scale = scale 85 | 86 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 87 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 88 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 89 | self.camera_center = self.world_view_transform.inverse()[3, :3] 90 | 91 | class MiniCam: 92 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 93 | self.image_width = width 94 | self.image_height = height 95 | self.FoVy = fovy 96 | self.FoVx = fovx 97 | self.znear = znear 98 | self.zfar = zfar 99 | self.world_view_transform = world_view_transform 100 | self.full_proj_transform = full_proj_transform 101 | view_inv = torch.inverse(self.world_view_transform) 102 | self.camera_center = view_inv[3][:3] 103 | 104 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | 26 | class CameraInfo(NamedTuple): 27 | uid: int 28 | R: np.array 29 | T: np.array 30 | FovY: np.array 31 | FovX: np.array 32 | depth_params: dict 33 | image_path: str 34 | image_name: str 35 | depth_path: str 36 | width: int 37 | height: int 38 | is_test: bool 39 | 40 | class SceneInfo(NamedTuple): 41 | point_cloud: BasicPointCloud 42 | train_cameras: list 43 | test_cameras: list 44 | nerf_normalization: dict 45 | ply_path: str 46 | is_nerf_synthetic: bool 47 | 48 | def getNerfppNorm(cam_info): 49 | def get_center_and_diag(cam_centers): 50 | cam_centers = np.hstack(cam_centers) 51 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 52 | center = avg_cam_center 53 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 54 | diagonal = np.max(dist) 55 | return center.flatten(), diagonal 56 | 57 | cam_centers = [] 58 | 59 | for cam in cam_info: 60 | W2C = getWorld2View2(cam.R, cam.T) 61 | C2W = np.linalg.inv(W2C) 62 | cam_centers.append(C2W[:3, 3:4]) 63 | 64 | center, diagonal = get_center_and_diag(cam_centers) 65 | radius = diagonal * 1.1 66 | 67 | translate = -center 68 | 69 | return {"translate": translate, "radius": radius} 70 | 71 | def readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_folder, depths_folder, test_cam_names_list): 72 | cam_infos = [] 73 | for idx, key in enumerate(cam_extrinsics): 74 | sys.stdout.write('\r') 75 | # the exact output you're looking for: 76 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 77 | sys.stdout.flush() 78 | 79 | extr = cam_extrinsics[key] 80 | intr = cam_intrinsics[extr.camera_id] 81 | height = intr.height 82 | width = intr.width 83 | 84 | uid = intr.id 85 | R = np.transpose(qvec2rotmat(extr.qvec)) 86 | T = np.array(extr.tvec) 87 | 88 | if intr.model=="SIMPLE_PINHOLE": 89 | focal_length_x = intr.params[0] 90 | FovY = focal2fov(focal_length_x, height) 91 | FovX = focal2fov(focal_length_x, width) 92 | elif intr.model=="PINHOLE": 93 | focal_length_x = intr.params[0] 94 | focal_length_y = intr.params[1] 95 | FovY = focal2fov(focal_length_y, height) 96 | FovX = focal2fov(focal_length_x, width) 97 | else: 98 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 99 | 100 | n_remove = len(extr.name.split('.')[-1]) + 1 101 | depth_params = None 102 | if depths_params is not None: 103 | try: 104 | depth_params = depths_params[extr.name[:-n_remove]] 105 | except: 106 | print("\n", key, "not found in depths_params") 107 | 108 | image_path = os.path.join(images_folder, extr.name) 109 | image_name = extr.name 110 | depth_path = os.path.join(depths_folder, f"{extr.name[:-n_remove]}.png") if depths_folder != "" else "" 111 | 112 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, depth_params=depth_params, 113 | image_path=image_path, image_name=image_name, depth_path=depth_path, 114 | width=width, height=height, is_test=image_name in test_cam_names_list) 115 | cam_infos.append(cam_info) 116 | 117 | sys.stdout.write('\n') 118 | return cam_infos 119 | 120 | def fetchPly(path): 121 | plydata = PlyData.read(path) 122 | vertices = plydata['vertex'] 123 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 124 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 125 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 126 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 127 | 128 | def storePly(path, xyz, rgb): 129 | # Define the dtype for the structured array 130 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 131 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 132 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 133 | 134 | normals = np.zeros_like(xyz) 135 | 136 | elements = np.empty(xyz.shape[0], dtype=dtype) 137 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 138 | elements[:] = list(map(tuple, attributes)) 139 | 140 | # Create the PlyData object and write to file 141 | vertex_element = PlyElement.describe(elements, 'vertex') 142 | ply_data = PlyData([vertex_element]) 143 | ply_data.write(path) 144 | 145 | def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8): 146 | try: 147 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 148 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 149 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 150 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 151 | except: 152 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 153 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 154 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 155 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 156 | 157 | depth_params_file = os.path.join(path, "sparse/0", "depth_params.json") 158 | ## if depth_params_file isnt there AND depths file is here -> throw error 159 | depths_params = None 160 | if depths != "": 161 | try: 162 | with open(depth_params_file, "r") as f: 163 | depths_params = json.load(f) 164 | all_scales = np.array([depths_params[key]["scale"] for key in depths_params]) 165 | if (all_scales > 0).sum(): 166 | med_scale = np.median(all_scales[all_scales > 0]) 167 | else: 168 | med_scale = 0 169 | for key in depths_params: 170 | depths_params[key]["med_scale"] = med_scale 171 | 172 | except FileNotFoundError: 173 | print(f"Error: depth_params.json file not found at path '{depth_params_file}'.") 174 | sys.exit(1) 175 | except Exception as e: 176 | print(f"An unexpected error occurred when trying to open depth_params.json file: {e}") 177 | sys.exit(1) 178 | 179 | if eval: 180 | if "360" in path: 181 | llffhold = 8 182 | if llffhold: 183 | print("------------LLFF HOLD-------------") 184 | cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics] 185 | cam_names = sorted(cam_names) 186 | test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0] 187 | else: 188 | with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file: 189 | test_cam_names_list = [line.strip() for line in file] 190 | else: 191 | test_cam_names_list = [] 192 | 193 | reading_dir = "images" if images == None else images 194 | cam_infos_unsorted = readColmapCameras( 195 | cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, depths_params=depths_params, 196 | images_folder=os.path.join(path, reading_dir), 197 | depths_folder=os.path.join(path, depths) if depths != "" else "", test_cam_names_list=test_cam_names_list) 198 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 199 | 200 | train_cam_infos = [c for c in cam_infos if train_test_exp or not c.is_test] 201 | test_cam_infos = [c for c in cam_infos if c.is_test] 202 | 203 | nerf_normalization = getNerfppNorm(train_cam_infos) 204 | 205 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 206 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 207 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 208 | if not os.path.exists(ply_path): 209 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 210 | try: 211 | xyz, rgb, _ = read_points3D_binary(bin_path) 212 | except: 213 | xyz, rgb, _ = read_points3D_text(txt_path) 214 | storePly(ply_path, xyz, rgb) 215 | try: 216 | pcd = fetchPly(ply_path) 217 | except: 218 | pcd = None 219 | 220 | scene_info = SceneInfo(point_cloud=pcd, 221 | train_cameras=train_cam_infos, 222 | test_cameras=test_cam_infos, 223 | nerf_normalization=nerf_normalization, 224 | ply_path=ply_path, 225 | is_nerf_synthetic=False) 226 | return scene_info 227 | 228 | def readCamerasFromTransforms(path, transformsfile, depths_folder, white_background, is_test, extension=".png"): 229 | cam_infos = [] 230 | 231 | with open(os.path.join(path, transformsfile)) as json_file: 232 | contents = json.load(json_file) 233 | fovx = contents["camera_angle_x"] 234 | 235 | frames = contents["frames"] 236 | for idx, frame in enumerate(frames): 237 | cam_name = os.path.join(path, frame["file_path"] + extension) 238 | 239 | # NeRF 'transform_matrix' is a camera-to-world transform 240 | c2w = np.array(frame["transform_matrix"]) 241 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 242 | c2w[:3, 1:3] *= -1 243 | 244 | # get the world-to-camera transform and set R, T 245 | w2c = np.linalg.inv(c2w) 246 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 247 | T = w2c[:3, 3] 248 | 249 | image_path = os.path.join(path, cam_name) 250 | image_name = Path(cam_name).stem 251 | image = Image.open(image_path) 252 | 253 | im_data = np.array(image.convert("RGBA")) 254 | 255 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 256 | 257 | norm_data = im_data / 255.0 258 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 259 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 260 | 261 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 262 | FovY = fovy 263 | FovX = fovx 264 | 265 | depth_path = os.path.join(depths_folder, f"{image_name}.png") if depths_folder != "" else "" 266 | 267 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, 268 | image_path=image_path, image_name=image_name, 269 | width=image.size[0], height=image.size[1], depth_path=depth_path, depth_params=None, is_test=is_test)) 270 | 271 | return cam_infos 272 | 273 | def readNerfSyntheticInfo(path, white_background, depths, eval, extension=".png"): 274 | 275 | depths_folder=os.path.join(path, depths) if depths != "" else "" 276 | print("Reading Training Transforms") 277 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", depths_folder, white_background, False, extension) 278 | print("Reading Test Transforms") 279 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", depths_folder, white_background, True, extension) 280 | 281 | if not eval: 282 | train_cam_infos.extend(test_cam_infos) 283 | test_cam_infos = [] 284 | 285 | nerf_normalization = getNerfppNorm(train_cam_infos) 286 | 287 | ply_path = os.path.join(path, "points3d.ply") 288 | if not os.path.exists(ply_path): 289 | # Since this data set has no colmap data, we start with random points 290 | num_pts = 100_000 291 | print(f"Generating random point cloud ({num_pts})...") 292 | 293 | # We create random points inside the bounds of the synthetic Blender scenes 294 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 295 | shs = np.random.random((num_pts, 3)) / 255.0 296 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 297 | 298 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 299 | try: 300 | pcd = fetchPly(ply_path) 301 | except: 302 | pcd = None 303 | 304 | scene_info = SceneInfo(point_cloud=pcd, 305 | train_cameras=train_cam_infos, 306 | test_cameras=test_cam_infos, 307 | nerf_normalization=nerf_normalization, 308 | ply_path=ply_path, 309 | is_nerf_synthetic=True) 310 | return scene_info 311 | 312 | sceneLoadTypeCallbacks = { 313 | "Colmap": readColmapSceneInfo, 314 | "Blender" : readNerfSyntheticInfo 315 | } -------------------------------------------------------------------------------- /scripts/generate_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Function to get an available GPU with memory usage below the threshold 3 | get_available_gpu() { 4 | local mem_threshold=25000 5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \ 6 | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # List of dataset names 12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson") 13 | datasets=("counter") # Replace with your actual dataset names 14 | 15 | # Path to models 16 | model_base_path="output/seele" # PATH TO YOUR MODELS 17 | 18 | # Iterate over each dataset 19 | for dataset_name in "${datasets[@]}"; do 20 | echo "Processing dataset: $dataset_name" 21 | 22 | # Find an available GPU 23 | while true; do 24 | available_gpu=$(get_available_gpu) 25 | if [ -z "$available_gpu" ]; then 26 | echo "No GPU available with memory usage below threshold. Waiting..." 27 | sleep 60 28 | continue 29 | fi 30 | 31 | echo "Using GPU: $available_gpu" 32 | # Run the Python script with the selected GPU 33 | CUDA_VISIBLE_DEVICES="$available_gpu" python generate_cluster.py -m "$model_base_path/$dataset_name" 34 | break 35 | done 36 | done 37 | 38 | # Completion signal 39 | echo "All datasets processed. Task complete." -------------------------------------------------------------------------------- /scripts/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET 4 | output_base_path="output/seele" # PATH TO YOUR OUTPUT 5 | 6 | datasets=("counter") # Replace with your actual dataset names 7 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson") 8 | 9 | for dataset in "${datasets[@]}"; do 10 | model_path="$output_base_path/$dataset" 11 | dataset_path="$dataset_base_path/$dataset" 12 | 13 | echo "Train dataset: $dataset" 14 | python3 train.py -m $model_path -s $dataset_path --eval 15 | 16 | echo "Generate clusters for dataset: $dataset" 17 | if [[ "$dataset" == "playroom" || "$dataset" == "drjohnson" ]]; then 18 | python3 generate_cluster.py -m $model_path -n 8 19 | else 20 | python3 generate_cluster.py -m $model_path -n 4 21 | fi 22 | 23 | echo "Finetune dataset: $dataset" 24 | python3 finetune.py \ 25 | -s $dataset_path \ 26 | -m $model_path \ 27 | --start_checkpoint "$model_path/chkpnt30000.pth" \ 28 | --eval \ 29 | --iterations 31_000 30 | 31 | echo "Render dataset: $dataset" 32 | python3 seele_render.py -m $model_path -s $dataset_path --eval --load_finetune --save_image --debug 33 | 34 | echo "Metrics for dataset: $dataset" 35 | python3 metrics.py -m $model_path 36 | done 37 | -------------------------------------------------------------------------------- /scripts/run_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Function to get an available GPU with memory usage below the threshold 3 | get_available_gpu() { 4 | local mem_threshold=25000 5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \ 6 | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # List of dataset names 12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson") 13 | datasets=("counter") # Replace with your actual dataset names 14 | 15 | # Path to models 16 | model_base_path="output/seele" # PATH TO YOUR MODELS 17 | 18 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET 19 | port=6035 20 | 21 | # Iterate over each dataset 22 | for dataset_name in "${datasets[@]}"; do 23 | echo "Processing dataset: $dataset_name" 24 | 25 | # Find an available GPU 26 | while true; do 27 | available_gpu=$(get_available_gpu) 28 | if [ -z "$available_gpu" ]; then 29 | echo "No GPU available with memory usage below threshold. Waiting..." 30 | sleep 60 31 | continue 32 | fi 33 | 34 | echo "Using GPU: $available_gpu" 35 | # Run the Python script with the selected GPU 36 | CUDA_VISIBLE_DEVICES="$available_gpu" python finetune.py \ 37 | -s "$dataset_base_path/$dataset_name" \ 38 | -m "$model_base_path/$dataset_name" \ 39 | --start_checkpoint "$model_base_path/$dataset_name/chkpnt30000.pth" \ 40 | --eval \ 41 | --iterations 31_000 42 | break 43 | done 44 | done 45 | 46 | # Wait for all background processes to finish 47 | wait 48 | 49 | # Completion signal 50 | echo "All datasets processed. Task complete." -------------------------------------------------------------------------------- /scripts/run_render.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Function to get an available GPU with memory usage below the threshold 3 | get_available_gpu() { 4 | local mem_threshold=25000 5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \ 6 | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # List of dataset names 12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson") 13 | datasets=("counter") # Replace with your actual dataset names 14 | 15 | # Path to models 16 | model_base_path="output/seele" # PATH TO YOUR MODELS 17 | 18 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET 19 | 20 | # Iterate over each dataset 21 | for dataset_name in "${datasets[@]}"; do 22 | echo "Processing dataset: $dataset_name" 23 | 24 | # Find an available GPU 25 | while true; do 26 | available_gpu=$(get_available_gpu) 27 | if [ -z "$available_gpu" ]; then 28 | echo "No GPU available with memory usage below threshold. Waiting..." 29 | sleep 60 30 | continue 31 | fi 32 | 33 | echo "Using GPU: $available_gpu" 34 | # Run the render.py script with the selected GPU 35 | CUDA_VISIBLE_DEVICES="$available_gpu" python render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train 36 | # Run the metrics.py script and append the output to the same log file 37 | python3 metrics.py -m "$model_base_path/$dataset_name" 38 | break 39 | done 40 | done 41 | 42 | # Completion signal 43 | echo "All datasets processed. Task complete." -------------------------------------------------------------------------------- /scripts/run_seele_render.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Function to get an available GPU with memory usage below the threshold 3 | get_available_gpu() { 4 | local mem_threshold=25000 5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \ 6 | awk -v threshold="$mem_threshold" -F', ' ' 7 | $2 < threshold { print $1; exit } 8 | ' 9 | } 10 | 11 | # List of dataset names 12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson") 13 | datasets=("counter") # Replace with your actual dataset names 14 | 15 | # Path to models 16 | model_base_path="output/seele" # PATH TO YOUR MODELS 17 | 18 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET 19 | 20 | # Setting for load_finetune 21 | load_finetune=true # Set to true or false based on your requirement 22 | 23 | # Iterate over each dataset 24 | for dataset_name in "${datasets[@]}"; do 25 | echo "Processing dataset: $dataset_name" 26 | 27 | # Find an available GPU 28 | while true; do 29 | available_gpu=$(get_available_gpu) 30 | if [ -z "$available_gpu" ]; then 31 | echo "No GPU available with memory usage below threshold. Waiting..." 32 | sleep 60 33 | continue 34 | fi 35 | 36 | echo "Using GPU: $available_gpu" 37 | echo "load_finetune: $load_finetune" 38 | if [ "$load_finetune" = true ]; then 39 | CUDA_VISIBLE_DEVICES="$available_gpu" python3 seele_render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train --load_finetune --save_image 40 | else 41 | CUDA_VISIBLE_DEVICES="$available_gpu" python3 seele_render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train --save_image 42 | fi 43 | 44 | # Run the metrics.py script and append the output to the same log file 45 | python3 metrics.py -m "$model_base_path/$dataset_name" 46 | break 47 | done 48 | done 49 | 50 | # Completion signal 51 | echo "All datasets processed. Task complete." -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # List of dataset names 4 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson") 5 | datasets=("counter") # Replace with your actual dataset names 6 | 7 | # Path to models 8 | model_base_path="output/seele" # PATH TO YOUR MODELS 9 | 10 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET 11 | 12 | # Iterate over each dataset 13 | for dataset_name in "${datasets[@]}"; do 14 | echo "Processing dataset: $dataset_name" 15 | python3 train.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --eval 16 | echo "Test:" 17 | python3 render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train --eval 18 | python3 metrics.py -m "$model_base_path/$dataset_name" 19 | done 20 | 21 | # Completion signal 22 | echo "All datasets processed. Task complete." -------------------------------------------------------------------------------- /seele_render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import numpy as np 12 | import joblib 13 | import torch 14 | from scene import Scene 15 | import os 16 | from tqdm import tqdm 17 | from os import makedirs 18 | from gaussian_renderer import render 19 | import torchvision 20 | from utils.general_utils import safe_state 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args 23 | # from gaussian_renderer import GaussianModel 24 | from gaussian_renderer import GaussianModel 25 | try: 26 | from diff_gaussian_rasterization import SparseGaussianAdam 27 | SPARSE_ADAM_AVAILABLE = True 28 | except: 29 | SPARSE_ADAM_AVAILABLE = False 30 | 31 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh, args): 32 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 33 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 34 | 35 | cluster_data = joblib.load(os.path.join(model_path, "clusters", "clusters.pkl")) 36 | K = len(cluster_data["cluster_viewpoint"]) 37 | 38 | if args.load_finetune: 39 | cluster_gaussians = [torch.load(os.path.join(model_path, f"clusters/finetune/point_cloud_{cid}.pth")) for cid in range(K)] 40 | cluster_gaussians = [tuple(map(lambda x: x.cuda(), data)) for data in cluster_gaussians] 41 | else: 42 | global_gaussians = gaussians.capture_gaussians() 43 | cluster_gaussian_ids = [] 44 | for (gaussian_ids, lens) in cluster_data["cluster_gaussians"]: 45 | gaussian_ids = torch.tensor(gaussian_ids).cuda() 46 | cluster_gaussian_ids.append((gaussian_ids, lens)) 47 | labels = cluster_data[f"{name}_labels"] 48 | 49 | makedirs(render_path, exist_ok=True) 50 | makedirs(gts_path, exist_ok=True) 51 | 52 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 53 | if args.load_finetune: 54 | gaussians.restore_gaussians(cluster_gaussians[labels[idx]]) 55 | else: 56 | gaussians.restore_gaussians(global_gaussians, cluster_gaussian_ids[labels[idx]]) 57 | rendering = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh, rasterizer_type="CR")["render"] 58 | gt = view.original_image[0:3, :, :] 59 | 60 | if args.train_test_exp: 61 | rendering = rendering[..., rendering.shape[-1] // 2:] 62 | gt = gt[..., gt.shape[-1] // 2:] 63 | 64 | if args.save_image: 65 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 66 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 67 | 68 | if not args.load_finetune: 69 | gaussians.restore_gaussians(global_gaussians) 70 | 71 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, separate_sh: bool, args: ArgumentParser): 72 | with torch.no_grad(): 73 | gaussians = GaussianModel(dataset.sh_degree) 74 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 75 | 76 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 77 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 78 | 79 | if not skip_train: 80 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args) 81 | 82 | if not skip_test: 83 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args) 84 | 85 | if __name__ == "__main__": 86 | # Set up command line argument parser 87 | parser = ArgumentParser(description="Testing script parameters") 88 | model = ModelParams(parser, sentinel=True) 89 | pipeline = PipelineParams(parser) 90 | parser.add_argument("--iteration", default=-1, type=int) 91 | parser.add_argument("--skip_train", action="store_true") 92 | parser.add_argument("--skip_test", action="store_true") 93 | parser.add_argument("--quiet", action="store_true") 94 | parser.add_argument("--load_finetune", action="store_true") 95 | parser.add_argument("--save_image", action="store_true") 96 | args = get_combined_args(parser) 97 | args.depths = "" 98 | args.train_test_exp = False 99 | print("Rendering " + args.model_path) 100 | # Initialize system state (RNG) 101 | safe_state(args.quiet) 102 | 103 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE, args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state, get_expon_lr_func 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | try: 26 | from torch.utils.tensorboard import SummaryWriter 27 | TENSORBOARD_FOUND = True 28 | except ImportError: 29 | TENSORBOARD_FOUND = False 30 | 31 | try: 32 | from fused_ssim import fused_ssim 33 | FUSED_SSIM_AVAILABLE = True 34 | except: 35 | FUSED_SSIM_AVAILABLE = False 36 | 37 | try: 38 | from diff_gaussian_rasterization import SparseGaussianAdam 39 | SPARSE_ADAM_AVAILABLE = True 40 | except: 41 | SPARSE_ADAM_AVAILABLE = False 42 | 43 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 44 | 45 | if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam": 46 | sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].") 47 | 48 | first_iter = 0 49 | tb_writer = prepare_output_and_logger(dataset) 50 | gaussians = GaussianModel(dataset.sh_degree, opt.optimizer_type) 51 | scene = Scene(dataset, gaussians) 52 | gaussians.training_setup(opt) 53 | if checkpoint: 54 | (model_params, first_iter) = torch.load(checkpoint) 55 | gaussians.restore(model_params, opt) 56 | 57 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 58 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 59 | 60 | iter_start = torch.cuda.Event(enable_timing = True) 61 | iter_end = torch.cuda.Event(enable_timing = True) 62 | 63 | use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE 64 | depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations) 65 | 66 | viewpoint_stack = scene.getTrainCameras().copy() 67 | viewpoint_indices = list(range(len(viewpoint_stack))) 68 | ema_loss_for_log = 0.0 69 | ema_Ll1depth_for_log = 0.0 70 | 71 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 72 | first_iter += 1 73 | for iteration in range(first_iter, opt.iterations + 1): 74 | if network_gui.conn == None: 75 | network_gui.try_connect() 76 | while network_gui.conn != None: 77 | try: 78 | net_image_bytes = None 79 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 80 | if custom_cam != None: 81 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifier=scaling_modifer, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE)["render"] 82 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 83 | network_gui.send(net_image_bytes, dataset.source_path) 84 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 85 | break 86 | except Exception as e: 87 | network_gui.conn = None 88 | 89 | iter_start.record() 90 | 91 | gaussians.update_learning_rate(iteration) 92 | 93 | # Every 1000 its we increase the levels of SH up to a maximum degree 94 | if iteration % 1000 == 0: 95 | gaussians.oneupSHdegree() 96 | 97 | # Pick a random Camera 98 | if not viewpoint_stack: 99 | viewpoint_stack = scene.getTrainCameras().copy() 100 | viewpoint_indices = list(range(len(viewpoint_stack))) 101 | rand_idx = randint(0, len(viewpoint_indices) - 1) 102 | viewpoint_cam = viewpoint_stack.pop(rand_idx) 103 | vind = viewpoint_indices.pop(rand_idx) 104 | 105 | # Render 106 | if (iteration - 1) == debug_from: 107 | pipe.debug = True 108 | 109 | bg = torch.rand((3), device="cuda") if opt.random_background else background 110 | 111 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE) 112 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 113 | 114 | # if viewpoint_cam.alpha_mask is not None: 115 | # alpha_mask = viewpoint_cam.alpha_mask.cuda() 116 | # image *= alpha_mask 117 | 118 | # Loss 119 | gt_image = viewpoint_cam.original_image.cuda() 120 | Ll1 = l1_loss(image, gt_image) 121 | if FUSED_SSIM_AVAILABLE: 122 | ssim_value = fused_ssim(image.unsqueeze(0), gt_image.unsqueeze(0)) 123 | else: 124 | ssim_value = ssim(image, gt_image) 125 | 126 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) 127 | 128 | # Depth regularization 129 | Ll1depth_pure = 0.0 130 | if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable: 131 | invDepth = render_pkg["depth"] 132 | mono_invdepth = viewpoint_cam.invdepthmap.cuda() 133 | depth_mask = viewpoint_cam.depth_mask.cuda() 134 | 135 | Ll1depth_pure = torch.abs((invDepth - mono_invdepth) * depth_mask).mean() 136 | Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure 137 | loss += Ll1depth 138 | Ll1depth = Ll1depth.item() 139 | else: 140 | Ll1depth = 0 141 | 142 | loss.backward() 143 | 144 | iter_end.record() 145 | 146 | with torch.no_grad(): 147 | # Progress bar 148 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 149 | ema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_log 150 | 151 | if iteration % 10 == 0: 152 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}"}) 153 | progress_bar.update(10) 154 | if iteration == opt.iterations: 155 | progress_bar.close() 156 | 157 | # Log and save 158 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp) 159 | if (iteration in saving_iterations): 160 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 161 | scene.save(iteration) 162 | 163 | # Densification 164 | if iteration < opt.densify_until_iter: 165 | # Keep track of max radii in image-space for pruning 166 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 167 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 168 | 169 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 170 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 171 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii) 172 | 173 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 174 | gaussians.reset_opacity() 175 | 176 | # Optimizer step 177 | if iteration < opt.iterations: 178 | gaussians.exposure_optimizer.step() 179 | gaussians.exposure_optimizer.zero_grad(set_to_none = True) 180 | if use_sparse_adam: 181 | visible = radii > 0 182 | gaussians.optimizer.step(visible, radii.shape[0]) 183 | gaussians.optimizer.zero_grad(set_to_none = True) 184 | else: 185 | gaussians.optimizer.step() 186 | gaussians.optimizer.zero_grad(set_to_none = True) 187 | 188 | if (iteration in checkpoint_iterations): 189 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 190 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 191 | 192 | def prepare_output_and_logger(args): 193 | if not args.model_path: 194 | if os.getenv('OAR_JOB_ID'): 195 | unique_str=os.getenv('OAR_JOB_ID') 196 | else: 197 | unique_str = str(uuid.uuid4()) 198 | args.model_path = os.path.join("./output/", unique_str[0:10]) 199 | 200 | # Set up output folder 201 | print("Output folder: {}".format(args.model_path)) 202 | os.makedirs(args.model_path, exist_ok = True) 203 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 204 | cfg_log_f.write(str(Namespace(**vars(args)))) 205 | 206 | # Create Tensorboard writer 207 | tb_writer = None 208 | if TENSORBOARD_FOUND: 209 | tb_writer = SummaryWriter(args.model_path) 210 | else: 211 | print("Tensorboard not available: not logging progress") 212 | return tb_writer 213 | 214 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp): 215 | if tb_writer: 216 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 217 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 218 | tb_writer.add_scalar('iter_time', elapsed, iteration) 219 | 220 | # Report test and samples of training set 221 | if iteration in testing_iterations: 222 | torch.cuda.empty_cache() 223 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 224 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 225 | 226 | for config in validation_configs: 227 | if config['cameras'] and len(config['cameras']) > 0: 228 | l1_test = 0.0 229 | psnr_test = 0.0 230 | for idx, viewpoint in enumerate(config['cameras']): 231 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 232 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 233 | if train_test_exp: 234 | image = image[..., image.shape[-1] // 2:] 235 | gt_image = gt_image[..., gt_image.shape[-1] // 2:] 236 | if tb_writer and (idx < 5): 237 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 238 | if iteration == testing_iterations[0]: 239 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 240 | l1_test += l1_loss(image, gt_image).mean().double() 241 | psnr_test += psnr(image, gt_image).mean().double() 242 | psnr_test /= len(config['cameras']) 243 | l1_test /= len(config['cameras']) 244 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 245 | if tb_writer: 246 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 247 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 248 | 249 | if tb_writer: 250 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 251 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 252 | torch.cuda.empty_cache() 253 | 254 | if __name__ == "__main__": 255 | # Set up command line argument parser 256 | parser = ArgumentParser(description="Training script parameters") 257 | lp = ModelParams(parser) 258 | op = OptimizationParams(parser) 259 | pp = PipelineParams(parser) 260 | parser.add_argument('--ip', type=str, default="127.0.0.1") 261 | parser.add_argument('--port', type=int, default=6009) 262 | parser.add_argument('--debug_from', type=int, default=-1) 263 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 264 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[30_000]) 265 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) 266 | parser.add_argument("--quiet", action="store_true") 267 | parser.add_argument('--disable_viewer', action='store_true', default=False) 268 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[30_000]) 269 | parser.add_argument("--start_checkpoint", type=str, default = None) 270 | args = parser.parse_args(sys.argv[1:]) 271 | args.save_iterations.append(args.iterations) 272 | 273 | print("Optimizing " + args.model_path) 274 | 275 | # Initialize system state (RNG) 276 | safe_state(args.quiet) 277 | 278 | # Start GUI server, configure and run training 279 | if not args.disable_viewer: 280 | network_gui.init(args.ip, args.port) 281 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 282 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 283 | 284 | # All done 285 | print("\nTraining complete.") 286 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.graphics_utils import fov2focal 15 | from PIL import Image 16 | import cv2 17 | 18 | WARNED = False 19 | 20 | def loadCam(args, id, cam_info, resolution_scale, is_nerf_synthetic, is_test_dataset): 21 | image = Image.open(cam_info.image_path) 22 | 23 | if cam_info.depth_path != "": 24 | try: 25 | if is_nerf_synthetic: 26 | invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / 512 27 | else: 28 | invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / float(2**16) 29 | 30 | except FileNotFoundError: 31 | print(f"Error: The depth file at path '{cam_info.depth_path}' was not found.") 32 | raise 33 | except IOError: 34 | print(f"Error: Unable to open the image file '{cam_info.depth_path}'. It may be corrupted or an unsupported format.") 35 | raise 36 | except Exception as e: 37 | print(f"An unexpected error occurred when trying to read depth at {cam_info.depth_path}: {e}") 38 | raise 39 | else: 40 | invdepthmap = None 41 | 42 | orig_w, orig_h = image.size 43 | if args.resolution in [1, 2, 4, 8]: 44 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 45 | else: # should be a type that converts to float 46 | if args.resolution == -1: 47 | if orig_w > 1600: 48 | global WARNED 49 | if not WARNED: 50 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 51 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 52 | WARNED = True 53 | global_down = orig_w / 1600 54 | else: 55 | global_down = 1 56 | else: 57 | global_down = orig_w / args.resolution 58 | 59 | 60 | scale = float(global_down) * float(resolution_scale) 61 | resolution = (int(orig_w / scale), int(orig_h / scale)) 62 | return Camera(resolution, colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 63 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, depth_params=cam_info.depth_params, 64 | image=image, invdepthmap=invdepthmap, 65 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, 66 | train_test_exp=args.train_test_exp, is_test_dataset=is_test_dataset, is_test_view=cam_info.is_test) 67 | 68 | def cameraList_from_camInfos(cam_infos, resolution_scale, args, is_nerf_synthetic, is_test_dataset): 69 | camera_list = [] 70 | 71 | for id, c in enumerate(cam_infos): 72 | camera_list.append(loadCam(args, id, c, resolution_scale, is_nerf_synthetic, is_test_dataset)) 73 | 74 | return camera_list 75 | 76 | def camera_to_JSON(id, camera : Camera): 77 | Rt = np.zeros((4, 4)) 78 | Rt[:3, :3] = camera.R.transpose() 79 | Rt[:3, 3] = camera.T 80 | Rt[3, 3] = 1.0 81 | 82 | W2C = np.linalg.inv(Rt) 83 | pos = W2C[:3, 3] 84 | rot = W2C[:3, :3] 85 | serializable_array_2d = [x.tolist() for x in rot] 86 | camera_entry = { 87 | 'id' : id, 88 | 'img_name' : camera.image_name, 89 | 'width' : camera.width, 90 | 'height' : camera.height, 91 | 'position': pos.tolist(), 92 | 'rotation': serializable_array_2d, 93 | 'fy' : fov2focal(camera.FovY, camera.height), 94 | 'fx' : fov2focal(camera.FovX, camera.width) 95 | } 96 | return camera_entry -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def orthonormalize_rotation_matrix(R, eps=1e-6): 39 | U, S, Vt = np.linalg.svd(R) 40 | R_ortho = U @ Vt 41 | 42 | if np.linalg.det(R_ortho) < 0: 43 | Vt[-1, :] *= -1 44 | R_ortho = U @ Vt 45 | 46 | return R_ortho 47 | 48 | 49 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 50 | Rt = np.zeros((4, 4)) 51 | Rt[:3, :3] = R.transpose() 52 | Rt[:3, 3] = t 53 | Rt[3, 3] = 1.0 54 | 55 | C2W = np.linalg.inv(Rt) 56 | cam_center = C2W[:3, 3] 57 | cam_center = (cam_center + translate) * scale 58 | C2W[:3, 3] = cam_center 59 | Rt = np.linalg.inv(C2W) 60 | return np.float32(Rt) 61 | 62 | def getProjectionMatrix(znear, zfar, fovX, fovY): 63 | tanHalfFovY = math.tan((fovY / 2)) 64 | tanHalfFovX = math.tan((fovX / 2)) 65 | 66 | top = tanHalfFovY * znear 67 | bottom = -top 68 | right = tanHalfFovX * znear 69 | left = -right 70 | 71 | P = torch.zeros(4, 4) 72 | 73 | z_sign = 1.0 74 | 75 | P[0, 0] = 2.0 * znear / (right - left) 76 | P[1, 1] = 2.0 * znear / (top - bottom) 77 | P[0, 2] = (right + left) / (right - left) 78 | P[1, 2] = (top + bottom) / (top - bottom) 79 | P[3, 2] = z_sign 80 | P[2, 2] = z_sign * zfar / (zfar - znear) 81 | P[2, 3] = -(zfar * znear) / (zfar - znear) 82 | return P 83 | 84 | def fov2focal(fov, pixels): 85 | return pixels / (2 * math.tan(fov / 2)) 86 | 87 | def focal2fov(focal, pixels): 88 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | try: 17 | from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward 18 | except: 19 | pass 20 | 21 | C1 = 0.01 ** 2 22 | C2 = 0.03 ** 2 23 | 24 | class FusedSSIMMap(torch.autograd.Function): 25 | @staticmethod 26 | def forward(ctx, C1, C2, img1, img2): 27 | ssim_map = fusedssim(C1, C2, img1, img2) 28 | ctx.save_for_backward(img1.detach(), img2) 29 | ctx.C1 = C1 30 | ctx.C2 = C2 31 | return ssim_map 32 | 33 | @staticmethod 34 | def backward(ctx, opt_grad): 35 | img1, img2 = ctx.saved_tensors 36 | C1, C2 = ctx.C1, ctx.C2 37 | grad = fusedssim_backward(C1, C2, img1, img2, opt_grad) 38 | return None, None, grad, None 39 | 40 | def l1_loss(network_output, gt): 41 | return torch.abs((network_output - gt)).mean() 42 | 43 | def l2_loss(network_output, gt): 44 | return ((network_output - gt) ** 2).mean() 45 | 46 | def gaussian(window_size, sigma): 47 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 48 | return gauss / gauss.sum() 49 | 50 | def create_window(window_size, channel): 51 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 52 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 53 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 54 | return window 55 | 56 | def ssim(img1, img2, window_size=11, size_average=True): 57 | channel = img1.size(-3) 58 | window = create_window(window_size, channel) 59 | 60 | if img1.is_cuda: 61 | window = window.cuda(img1.get_device()) 62 | window = window.type_as(img1) 63 | 64 | return _ssim(img1, img2, window, window_size, channel, size_average) 65 | 66 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 67 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 68 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 69 | 70 | mu1_sq = mu1.pow(2) 71 | mu2_sq = mu2.pow(2) 72 | mu1_mu2 = mu1 * mu2 73 | 74 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 75 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 76 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 77 | 78 | C1 = 0.01 ** 2 79 | C2 = 0.03 ** 2 80 | 81 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 82 | 83 | if size_average: 84 | return ssim_map.mean() 85 | else: 86 | return ssim_map.mean(1).mean(1).mean(1) 87 | 88 | 89 | def fast_ssim(img1, img2): 90 | ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2) 91 | return ssim_map.mean() 92 | -------------------------------------------------------------------------------- /utils/make_depth_scale.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import cv2 4 | from joblib import delayed, Parallel 5 | import json 6 | from read_write_model import * 7 | 8 | def get_scales(key, cameras, images, points3d_ordered, args): 9 | image_meta = images[key] 10 | cam_intrinsic = cameras[image_meta.camera_id] 11 | 12 | pts_idx = images_metas[key].point3D_ids 13 | 14 | mask = pts_idx >= 0 15 | mask *= pts_idx < len(points3d_ordered) 16 | 17 | pts_idx = pts_idx[mask] 18 | valid_xys = image_meta.xys[mask] 19 | 20 | if len(pts_idx) > 0: 21 | pts = points3d_ordered[pts_idx] 22 | else: 23 | pts = np.array([0, 0, 0]) 24 | 25 | R = qvec2rotmat(image_meta.qvec) 26 | pts = np.dot(pts, R.T) + image_meta.tvec 27 | 28 | invcolmapdepth = 1. / pts[..., 2] 29 | n_remove = len(image_meta.name.split('.')[-1]) + 1 30 | invmonodepthmap = cv2.imread(f"{args.depths_dir}/{image_meta.name[:-n_remove]}.png", cv2.IMREAD_UNCHANGED) 31 | 32 | if invmonodepthmap is None: 33 | return None 34 | 35 | if invmonodepthmap.ndim != 2: 36 | invmonodepthmap = invmonodepthmap[..., 0] 37 | 38 | invmonodepthmap = invmonodepthmap.astype(np.float32) / (2**16) 39 | s = invmonodepthmap.shape[0] / cam_intrinsic.height 40 | 41 | maps = (valid_xys * s).astype(np.float32) 42 | valid = ( 43 | (maps[..., 0] >= 0) * 44 | (maps[..., 1] >= 0) * 45 | (maps[..., 0] < cam_intrinsic.width * s) * 46 | (maps[..., 1] < cam_intrinsic.height * s) * (invcolmapdepth > 0)) 47 | 48 | if valid.sum() > 10 and (invcolmapdepth.max() - invcolmapdepth.min()) > 1e-3: 49 | maps = maps[valid, :] 50 | invcolmapdepth = invcolmapdepth[valid] 51 | invmonodepth = cv2.remap(invmonodepthmap, maps[..., 0], maps[..., 1], interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)[..., 0] 52 | 53 | ## Median / dev 54 | t_colmap = np.median(invcolmapdepth) 55 | s_colmap = np.mean(np.abs(invcolmapdepth - t_colmap)) 56 | 57 | t_mono = np.median(invmonodepth) 58 | s_mono = np.mean(np.abs(invmonodepth - t_mono)) 59 | scale = s_colmap / s_mono 60 | offset = t_colmap - t_mono * scale 61 | else: 62 | scale = 0 63 | offset = 0 64 | return {"image_name": image_meta.name[:-n_remove], "scale": scale, "offset": offset} 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--base_dir', default="../data/big_gaussians/standalone_chunks/campus") 69 | parser.add_argument('--depths_dir', default="../data/big_gaussians/standalone_chunks/campus/depths_any") 70 | parser.add_argument('--model_type', default="bin") 71 | args = parser.parse_args() 72 | 73 | 74 | cam_intrinsics, images_metas, points3d = read_model(os.path.join(args.base_dir, "sparse", "0"), ext=f".{args.model_type}") 75 | 76 | pts_indices = np.array([points3d[key].id for key in points3d]) 77 | pts_xyzs = np.array([points3d[key].xyz for key in points3d]) 78 | points3d_ordered = np.zeros([pts_indices.max()+1, 3]) 79 | points3d_ordered[pts_indices] = pts_xyzs 80 | 81 | # depth_param_list = [get_scales(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas] 82 | depth_param_list = Parallel(n_jobs=-1, backend="threading")( 83 | delayed(get_scales)(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas 84 | ) 85 | 86 | depth_params = { 87 | depth_param["image_name"]: {"scale": depth_param["scale"], "offset": depth_param["offset"]} 88 | for depth_param in depth_param_list if depth_param != None 89 | } 90 | 91 | with open(f"{args.base_dir}/sparse/0/depth_params.json", "w") as f: 92 | json.dump(depth_params, f, indent=2) 93 | 94 | print(0) 95 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | # Copy ideas from [LightGaussian](https://github.com/VITA-Group/LightGaussian) 2 | 3 | import numpy as np 4 | import torch 5 | from icecream import ic 6 | from utils.graphics_utils import getWorld2View2 7 | 8 | 9 | def normalize(x): 10 | return x / np.linalg.norm(x) 11 | 12 | def viewmatrix(z, up, pos): 13 | vec2 = normalize(z) 14 | vec1_avg = up 15 | vec0 = normalize(np.cross(vec1_avg, vec2)) 16 | vec1 = normalize(np.cross(vec2, vec0)) 17 | m = np.stack([vec0, vec1, vec2, pos], 1) 18 | return m 19 | 20 | def poses_avg(poses): 21 | hwf = poses[0, :3, -1:] 22 | 23 | center = poses[:, :3, 3].mean(0) 24 | vec2 = normalize(poses[:, :3, 2].sum(0)) 25 | up = poses[:, :3, 1].sum(0) 26 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 27 | 28 | return c2w 29 | 30 | def get_focal(camera): 31 | focal = camera.FoVx 32 | return focal 33 | 34 | def poses_avg_fixed_center(poses): 35 | hwf = poses[0, :3, -1:] 36 | center = poses[:, :3, 3].mean(0) 37 | vec2 = [1, 0, 0] 38 | up = [0, 0, 1] 39 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 40 | return c2w 41 | 42 | def integrate_weights_np(w): 43 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1. 44 | 45 | The output's size on the last dimension is one greater than that of the input, 46 | because we're computing the integral corresponding to the endpoints of a step 47 | function, not the integral of the interior/bin values. 48 | 49 | Args: 50 | w: Tensor, which will be integrated along the last axis. This is assumed to 51 | sum to 1 along the last axis, and this function will (silently) break if 52 | that is not the case. 53 | 54 | Returns: 55 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 56 | """ 57 | cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1)) 58 | shape = cw.shape[:-1] + (1,) 59 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1. 60 | cw0 = np.concatenate([np.zeros(shape), cw, 61 | np.ones(shape)], axis=-1) 62 | return cw0 63 | 64 | def invert_cdf_np(u, t, w_logits): 65 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" 66 | # Compute the PDF and CDF for each weight vector. 67 | w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True) 68 | cw = integrate_weights_np(w) 69 | # Interpolate into the inverse CDF. 70 | interp_fn = np.interp 71 | t_new = interp_fn(u, cw, t) 72 | return t_new 73 | 74 | def sample_np(rand, 75 | t, 76 | w_logits, 77 | num_samples, 78 | single_jitter=False, 79 | deterministic_center=False): 80 | """ 81 | numpy version of sample() 82 | """ 83 | eps = np.finfo(np.float32).eps 84 | 85 | # Draw uniform samples. 86 | if not rand: 87 | if deterministic_center: 88 | pad = 1 / (2 * num_samples) 89 | u = np.linspace(pad, 1. - pad - eps, num_samples) 90 | else: 91 | u = np.linspace(0, 1. - eps, num_samples) 92 | u = np.broadcast_to(u, t.shape[:-1] + (num_samples,)) 93 | else: 94 | # `u` is in [0, 1) --- it can be zero, but it can never be 1. 95 | u_max = eps + (1 - eps) / num_samples 96 | max_jitter = (1 - u_max) / (num_samples - 1) - eps 97 | d = 1 if single_jitter else num_samples 98 | u = np.linspace(0, 1 - u_max, num_samples) + \ 99 | np.random.rand(*t.shape[:-1], d) * max_jitter 100 | 101 | return invert_cdf_np(u, t, w_logits) 102 | 103 | 104 | 105 | def focus_point_fn(poses): 106 | """Calculate nearest point to all focal axes in poses.""" 107 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 108 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 109 | mt_m = np.transpose(m, [0, 2, 1]) @ m 110 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 111 | return focus_pt 112 | 113 | 114 | def average_pose(poses: np.ndarray) -> np.ndarray: 115 | """New pose using average position, z-axis, and up vector of input poses.""" 116 | position = poses[:, :3, 3].mean(0) 117 | z_axis = poses[:, :3, 2].mean(0) 118 | up = poses[:, :3, 1].mean(0) 119 | cam2world = viewmatrix(z_axis, up, position) 120 | return cam2world 121 | 122 | from typing import Tuple 123 | def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 124 | """Recenter poses around the origin.""" 125 | cam2world = average_pose(poses) 126 | transform = np.linalg.inv(pad_poses(cam2world)) 127 | poses = transform @ pad_poses(poses) 128 | return unpad_poses(poses), transform 129 | 130 | 131 | NEAR_STRETCH = .9 # Push forward near bound for forward facing render path. 132 | FAR_STRETCH = 5. # Push back far bound for forward facing render path. 133 | FOCUS_DISTANCE = .75 # Relative weighting of near, far bounds for render path. 134 | def generate_spiral_path(views, bounds, 135 | n_frames: int = 180, 136 | n_rots: int = 2, 137 | zrate: float = .5) -> np.ndarray: 138 | """Calculates a forward facing spiral path for rendering.""" 139 | # Find a reasonable 'focus depth' for this dataset as a weighted average 140 | # of conservative near and far bounds in disparity space. 141 | poses = [] 142 | for view in views: 143 | tmp_view = np.eye(4) 144 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 145 | tmp_view = np.linalg.inv(tmp_view) 146 | tmp_view[:, 1:3] *= -1 147 | poses.append(tmp_view) 148 | poses = np.stack(poses, 0) 149 | 150 | print(poses.shape) 151 | bounds = bounds.repeat(poses.shape[0], 0) #np.array([[ 16.21311152, 153.86329729]]) 152 | scale = 1. / (bounds.min() * .75) 153 | poses[:, :3, 3] *= scale 154 | bounds *= scale 155 | # Recenter poses. 156 | # tmp, _ = recenter_poses(poses) 157 | # poses[:, :3, :3] = tmp[:, :3, :3] @ np.diag(np.array([1, -1, -1])) 158 | 159 | near_bound = bounds.min() * NEAR_STRETCH 160 | far_bound = bounds.max() * FAR_STRETCH 161 | # All cameras will point towards the world space point (0, 0, -focal). 162 | focal = 1 / (((1 - FOCUS_DISTANCE) / near_bound + FOCUS_DISTANCE / far_bound)) 163 | 164 | # Get radii for spiral path using 90th percentile of camera positions. 165 | positions = poses[:, :3, 3] 166 | radii = np.percentile(np.abs(positions), 90, 0) 167 | radii = np.concatenate([radii, [1.]]) 168 | 169 | # Generate poses for spiral path. 170 | render_poses = [] 171 | cam2world = average_pose(poses) 172 | up = poses[:, :3, 1].mean(0) 173 | for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False): 174 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.] 175 | position = cam2world @ t 176 | lookat = cam2world @ [0, 0, -focal, 1.] 177 | z_axis = position - lookat 178 | render_pose = np.eye(4) 179 | render_pose[:3] = viewmatrix(z_axis, up, position) 180 | render_pose[:3, 1:3] *= -1 181 | render_poses.append(np.linalg.inv(render_pose)) 182 | render_poses = np.stack(render_poses, axis=0) 183 | return render_poses 184 | 185 | 186 | def render_path_spiral(views, focal=50, zrate=0.5, rots=2, N=10): 187 | poses = [] 188 | for view in views: 189 | tmp_view = np.eye(4) 190 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 191 | tmp_view = np.linalg.inv(tmp_view) 192 | tmp_view[:, 1:3] *= -1 193 | poses.append(tmp_view) 194 | poses = np.stack(poses, 0) 195 | # poses = np.stack([np.concatenate([view.R.T, view.T[:, None]], 1) for view in views], 0) 196 | c2w = poses_avg(poses) 197 | up = normalize(poses[:, :3, 1].sum(0)) 198 | 199 | # Get radii for spiral path 200 | rads = np.percentile(np.abs(poses[:, :3, 3]), 90, 0) 201 | render_poses = [] 202 | rads = np.array(list(rads) + [1.0]) 203 | 204 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 205 | c = np.dot( 206 | c2w[:3, :4], 207 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) * rads, 208 | ) 209 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 210 | render_pose = np.eye(4) 211 | render_pose[:3] = viewmatrix(z, up, c) 212 | render_pose[:3, 1:3] *= -1 213 | render_poses.append(np.linalg.inv(render_pose)) 214 | return render_poses 215 | 216 | def pad_poses(p): 217 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 218 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 219 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 220 | 221 | 222 | def unpad_poses(p): 223 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 224 | return p[..., :3, :4] 225 | 226 | def transform_poses_pca(poses): 227 | """Transforms poses so principal components lie on XYZ axes. 228 | 229 | Args: 230 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 231 | 232 | Returns: 233 | A tuple (poses, transform), with the transformed poses and the applied 234 | camera_to_world transforms. 235 | """ 236 | t = poses[:, :3, 3] 237 | t_mean = t.mean(axis=0) 238 | t = t - t_mean 239 | 240 | eigval, eigvec = np.linalg.eig(t.T @ t) 241 | # Sort eigenvectors in order of largest to smallest eigenvalue. 242 | inds = np.argsort(eigval)[::-1] 243 | eigvec = eigvec[:, inds] 244 | rot = eigvec.T 245 | if np.linalg.det(rot) < 0: 246 | rot = np.diag(np.array([1, 1, -1])) @ rot 247 | 248 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 249 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 250 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 251 | 252 | # Flip coordinate system if z component of y-axis is negative 253 | if poses_recentered.mean(axis=0)[2, 1] < 0: 254 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 255 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 256 | 257 | # Just make sure it's it in the [-1, 1]^3 cube 258 | scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) 259 | poses_recentered[:, :3, 3] *= scale_factor 260 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 261 | return poses_recentered, transform 262 | 263 | def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.): 264 | poses = [] 265 | for view in views: 266 | tmp_view = np.eye(4) 267 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 268 | tmp_view = np.linalg.inv(tmp_view) 269 | tmp_view[:, 1:3] *= -1 270 | poses.append(tmp_view) 271 | poses = np.stack(poses, 0) 272 | poses, transform = transform_poses_pca(poses) 273 | 274 | 275 | # Calculate the focal point for the path (cameras point toward this). 276 | center = focus_point_fn(poses) 277 | offset = np.array([center[0] , center[1], center[2]*0 ]) 278 | # Calculate scaling for ellipse axes based on input camera positions. 279 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 280 | 281 | # Use ellipse that is symmetric about the focal point in xy. 282 | low = -sc + offset 283 | high = sc + offset 284 | # Optional height variation need not be symmetric 285 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 286 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 287 | 288 | 289 | def get_positions(theta): 290 | # Interpolate between bounds with trig functions to get ellipse in x-y. 291 | # Optionally also interpolate in z to change camera height along path. 292 | return np.stack([ 293 | (low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)), 294 | (low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)), 295 | z_variation * (z_low[2] + (z_high - z_low)[2] * 296 | (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), 297 | ], -1) 298 | 299 | theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) 300 | positions = get_positions(theta) 301 | 302 | if const_speed: 303 | # Resample theta angles so that the velocity is closer to constant. 304 | lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 305 | theta = sample_np(None, theta, np.log(lengths), n_frames + 1) 306 | positions = get_positions(theta) 307 | 308 | # Throw away duplicated last position. 309 | positions = positions[:-1] 310 | 311 | # Set path's up vector to axis closest to average of input pose up vectors. 312 | avg_up = poses[:, :3, 1].mean(0) 313 | avg_up = avg_up / np.linalg.norm(avg_up) 314 | ind_up = np.argmax(np.abs(avg_up)) 315 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 316 | 317 | render_poses = [] 318 | for p in positions: 319 | render_pose = np.eye(4) 320 | render_pose[:3] = viewmatrix(p - center, up, p) 321 | render_pose = np.linalg.inv(transform) @ render_pose 322 | render_pose[:3, 1:3] *= -1 323 | render_poses.append(np.linalg.inv(render_pose)) 324 | return render_poses 325 | 326 | 327 | def generate_spherify_path(views): 328 | poses = [] 329 | for view in views: 330 | tmp_view = np.eye(4) 331 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 332 | tmp_view = np.linalg.inv(tmp_view) 333 | tmp_view[:, 1:3] *= -1 334 | poses.append(tmp_view) 335 | poses = np.stack(poses, 0) 336 | 337 | p34_to_44 = lambda p: np.concatenate( 338 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 339 | ) 340 | 341 | rays_d = poses[:, :3, 2:3] 342 | rays_o = poses[:, :3, 3:4] 343 | 344 | def min_line_dist(rays_o, rays_d): 345 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 346 | b_i = -A_i @ rays_o 347 | pt_mindist = np.squeeze( 348 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0) 349 | ) 350 | return pt_mindist 351 | 352 | pt_mindist = min_line_dist(rays_o, rays_d) 353 | 354 | center = pt_mindist 355 | up = (poses[:, :3, 3] - center).mean(0) 356 | 357 | vec0 = normalize(up) 358 | vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0)) 359 | vec2 = normalize(np.cross(vec0, vec1)) 360 | pos = center 361 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 362 | 363 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 364 | 365 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 366 | 367 | sc = 1.0 / rad 368 | poses_reset[:, :3, 3] *= sc 369 | rad *= sc 370 | 371 | centroid = np.mean(poses_reset[:, :3, 3], 0) 372 | zh = centroid[2] 373 | radcircle = np.sqrt(rad**2 - zh**2) 374 | new_poses = [] 375 | 376 | for th in np.linspace(0.0, 2.0 * np.pi, 120): 377 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 378 | up = np.array([0, 0, -1.0]) 379 | 380 | vec2 = normalize(camorigin) 381 | vec0 = normalize(np.cross(vec2, up)) 382 | vec1 = normalize(np.cross(vec2, vec0)) 383 | pos = camorigin 384 | p = np.stack([vec0, vec1, vec2, pos], 1) 385 | 386 | render_pose = np.eye(4) 387 | render_pose[:3] = p 388 | #render_pose[:3, 1:3] *= -1 389 | new_poses.append(render_pose) 390 | 391 | new_poses = np.stack(new_poses, 0) 392 | return new_poses 393 | 394 | # def gaussian_poses(viewpoint_cam, mean =0, std_dev = 0.03): 395 | # translate_x = np.random.normal(mean, std_dev) 396 | # translate_y = np.random.normal(mean, std_dev) 397 | # translate_z = np.random.normal(mean, std_dev) 398 | # translate = np.array([translate_x, translate_y, translate_z]) 399 | # viewpoint_cam.world_view_transform = torch.tensor(getWorld2View2(viewpoint_cam.R, viewpoint_cam.T, translate)).transpose(0, 1).cuda() 400 | # viewpoint_cam.full_proj_transform = (viewpoint_cam.world_view_transform.unsqueeze(0).bmm(viewpoint_cam.projection_matrix.unsqueeze(0))).squeeze(0) 401 | # viewpoint_cam.camera_center = viewpoint_cam.world_view_transform.inverse()[3, :3] 402 | # return viewpoint_cam 403 | 404 | def get_rotation_matrix(axis, angle): 405 | """ 406 | Create a rotation matrix for a given axis (x, y, or z) and angle. 407 | """ 408 | axis = axis.lower() 409 | cos_angle = np.cos(angle) 410 | sin_angle = np.sin(angle) 411 | 412 | if axis == 'x': 413 | return np.array([ 414 | [1, 0, 0], 415 | [0, cos_angle, -sin_angle], 416 | [0, sin_angle, cos_angle] 417 | ]) 418 | elif axis == 'y': 419 | return np.array([ 420 | [cos_angle, 0, sin_angle], 421 | [0, 1, 0], 422 | [-sin_angle, 0, cos_angle] 423 | ]) 424 | elif axis == 'z': 425 | return np.array([ 426 | [cos_angle, -sin_angle, 0], 427 | [sin_angle, cos_angle, 0], 428 | [0, 0, 1] 429 | ]) 430 | else: 431 | raise ValueError("Invalid axis. Choose from 'x', 'y', 'z'.") 432 | 433 | 434 | 435 | def gaussian_poses(viewpoint_cam, mean=0, std_dev_translation=0.03, std_dev_rotation=0.01): 436 | # Translation Perturbation 437 | translate_x = np.random.normal(mean, std_dev_translation) 438 | translate_y = np.random.normal(mean, std_dev_translation) 439 | translate_z = np.random.normal(mean, std_dev_translation) 440 | translate = np.array([translate_x, translate_y, translate_z]) 441 | 442 | # Rotation Perturbation 443 | angle_x = np.random.normal(mean, std_dev_rotation) 444 | angle_y = np.random.normal(mean, std_dev_rotation) 445 | angle_z = np.random.normal(mean, std_dev_rotation) 446 | 447 | rot_x = get_rotation_matrix('x', angle_x) 448 | rot_y = get_rotation_matrix('y', angle_y) 449 | rot_z = get_rotation_matrix('z', angle_z) 450 | 451 | # Combined Rotation Matrix 452 | combined_rot = np.matmul(rot_z, np.matmul(rot_y, rot_x)) 453 | 454 | # Apply Rotation to Camera 455 | rotated_R = np.matmul(viewpoint_cam.R, combined_rot) 456 | 457 | # Update Camera Transformation 458 | viewpoint_cam.world_view_transform = torch.tensor(getWorld2View2(rotated_R, viewpoint_cam.T, translate)).transpose(0, 1).cuda() 459 | viewpoint_cam.full_proj_transform = (viewpoint_cam.world_view_transform.unsqueeze(0).bmm(viewpoint_cam.projection_matrix.unsqueeze(0))).squeeze(0) 460 | viewpoint_cam.camera_center = viewpoint_cam.world_view_transform.inverse()[3, :3] 461 | 462 | return viewpoint_cam 463 | 464 | 465 | 466 | def circular_poses(viewpoint_cam, radius, angle=0.0): 467 | translate_x = radius * np.cos(angle) 468 | translate_y = radius * np.sin(angle) 469 | translate_z = 0 470 | translate = np.array([translate_x, translate_y, translate_z]) 471 | viewpoint_cam.world_view_transform = torch.tensor(getWorld2View2(viewpoint_cam.R, viewpoint_cam.T, translate)).transpose(0, 1).cuda() 472 | viewpoint_cam.full_proj_transform = (viewpoint_cam.world_view_transform.unsqueeze(0).bmm(viewpoint_cam.projection_matrix.unsqueeze(0))).squeeze(0) 473 | viewpoint_cam.camera_center = viewpoint_cam.world_view_transform.inverse()[3, :3] 474 | 475 | return viewpoint_cam 476 | 477 | def generate_spherical_sample_path(views, azimuthal_rots=1, polar_rots=0.75, N=10): 478 | poses = [] 479 | for view in views: 480 | tmp_view = np.eye(4) 481 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 482 | tmp_view = np.linalg.inv(tmp_view) 483 | tmp_view[:, 1:3] *= -1 484 | poses.append(tmp_view) 485 | focal = get_focal(view) 486 | poses = np.stack(poses, 0) 487 | # ic(min_focal, max_focal) 488 | 489 | c2w = poses_avg(poses) 490 | up = normalize(poses[:, :3, 1].sum(0)) 491 | rads = np.percentile(np.abs(poses[:, :3, 3]), 90, 0) 492 | rads = np.array(list(rads) + [1.0]) 493 | ic(rads) 494 | render_poses = [] 495 | focal_range = np.linspace(0.5, 3, N **2+1) 496 | index = 0 497 | # Modify this loop to include phi 498 | for theta in np.linspace(0.0, 2.0 * np.pi * azimuthal_rots, N + 1)[:-1]: 499 | for phi in np.linspace(0.0, np.pi * polar_rots, N + 1)[:-1]: 500 | # Modify these lines to use spherical coordinates for c 501 | c = np.dot( 502 | c2w[:3, :4], 503 | rads * np.array([ 504 | np.sin(phi) * np.cos(theta), 505 | np.sin(phi) * np.sin(theta), 506 | np.cos(phi), 507 | 1.0 508 | ]) 509 | ) 510 | 511 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal_range[index], 1.0]))) 512 | render_pose = np.eye(4) 513 | render_pose[:3] = viewmatrix(z, up, c) 514 | render_pose[:3, 1:3] *= -1 515 | render_poses.append(np.linalg.inv(render_pose)) 516 | index += 1 517 | return render_poses 518 | 519 | 520 | def generate_spiral_path(views, focal=1.5, zrate= 0, rots=1, N=600): 521 | poses = [] 522 | focal = 0 523 | for view in views: 524 | tmp_view = np.eye(4) 525 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) 526 | tmp_view = np.linalg.inv(tmp_view) 527 | tmp_view[:, 1:3] *= -1 528 | poses.append(tmp_view) 529 | focal += get_focal(views[0]) 530 | poses = np.stack(poses, 0) 531 | 532 | 533 | c2w = poses_avg(poses) 534 | up = normalize(poses[:, :3, 1].sum(0)) 535 | 536 | # Get radii for spiral path 537 | rads = np.percentile(np.abs(poses[:, :3, 3]), 90, 0) 538 | render_poses = [] 539 | 540 | rads = np.array(list(rads) + [1.0]) 541 | focal /= len(views) 542 | 543 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 544 | c = np.dot( 545 | c2w[:3, :4], 546 | np.array([np.cos(theta), -np.sin(theta),-np.sin(theta * zrate), 1.0]) * rads, 547 | ) 548 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 549 | 550 | render_pose = np.eye(4) 551 | render_pose[:3] = viewmatrix(z, up, c) 552 | render_pose[:3, 1:3] *= -1 553 | render_poses.append(np.linalg.inv(render_pose)) 554 | return render_poses -------------------------------------------------------------------------------- /utils/read_write_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | import os 32 | import collections 33 | import numpy as np 34 | import struct 35 | import argparse 36 | 37 | 38 | CameraModel = collections.namedtuple( 39 | "CameraModel", ["model_id", "model_name", "num_params"] 40 | ) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"] 43 | ) 44 | BaseImage = collections.namedtuple( 45 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] 46 | ) 47 | Point3D = collections.namedtuple( 48 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] 49 | ) 50 | 51 | 52 | class Image(BaseImage): 53 | def qvec2rotmat(self): 54 | return qvec2rotmat(self.qvec) 55 | 56 | 57 | CAMERA_MODELS = { 58 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 59 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 60 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 61 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 62 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 63 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 64 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 65 | CameraModel(model_id=7, model_name="FOV", num_params=5), 66 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 67 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 68 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 69 | } 70 | CAMERA_MODEL_IDS = dict( 71 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] 72 | ) 73 | CAMERA_MODEL_NAMES = dict( 74 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] 75 | ) 76 | 77 | 78 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 79 | """Read and unpack the next bytes from a binary file. 80 | :param fid: 81 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 82 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 83 | :param endian_character: Any of {@, =, <, >, !} 84 | :return: Tuple of read and unpacked values. 85 | """ 86 | data = fid.read(num_bytes) 87 | return struct.unpack(endian_character + format_char_sequence, data) 88 | 89 | 90 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 91 | """pack and write to a binary file. 92 | :param fid: 93 | :param data: data to send, if multiple elements are sent at the same time, 94 | they should be encapsuled either in a list or a tuple 95 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 96 | should be the same length as the data list or tuple 97 | :param endian_character: Any of {@, =, <, >, !} 98 | """ 99 | if isinstance(data, (list, tuple)): 100 | bytes = struct.pack(endian_character + format_char_sequence, *data) 101 | else: 102 | bytes = struct.pack(endian_character + format_char_sequence, data) 103 | fid.write(bytes) 104 | 105 | 106 | def read_cameras_text(path): 107 | """ 108 | see: src/colmap/scene/reconstruction.cc 109 | void Reconstruction::WriteCamerasText(const std::string& path) 110 | void Reconstruction::ReadCamerasText(const std::string& path) 111 | """ 112 | cameras = {} 113 | with open(path, "r") as fid: 114 | while True: 115 | line = fid.readline() 116 | if not line: 117 | break 118 | line = line.strip() 119 | if len(line) > 0 and line[0] != "#": 120 | elems = line.split() 121 | camera_id = int(elems[0]) 122 | model = elems[1] 123 | width = int(elems[2]) 124 | height = int(elems[3]) 125 | params = np.array(tuple(map(float, elems[4:]))) 126 | cameras[camera_id] = Camera( 127 | id=camera_id, 128 | model=model, 129 | width=width, 130 | height=height, 131 | params=params, 132 | ) 133 | return cameras 134 | 135 | 136 | def read_cameras_binary(path_to_model_file): 137 | """ 138 | see: src/colmap/scene/reconstruction.cc 139 | void Reconstruction::WriteCamerasBinary(const std::string& path) 140 | void Reconstruction::ReadCamerasBinary(const std::string& path) 141 | """ 142 | cameras = {} 143 | with open(path_to_model_file, "rb") as fid: 144 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 145 | for _ in range(num_cameras): 146 | camera_properties = read_next_bytes( 147 | fid, num_bytes=24, format_char_sequence="iiQQ" 148 | ) 149 | camera_id = camera_properties[0] 150 | model_id = camera_properties[1] 151 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 152 | width = camera_properties[2] 153 | height = camera_properties[3] 154 | num_params = CAMERA_MODEL_IDS[model_id].num_params 155 | params = read_next_bytes( 156 | fid, 157 | num_bytes=8 * num_params, 158 | format_char_sequence="d" * num_params, 159 | ) 160 | cameras[camera_id] = Camera( 161 | id=camera_id, 162 | model=model_name, 163 | width=width, 164 | height=height, 165 | params=np.array(params), 166 | ) 167 | assert len(cameras) == num_cameras 168 | return cameras 169 | 170 | 171 | def write_cameras_text(cameras, path): 172 | """ 173 | see: src/colmap/scene/reconstruction.cc 174 | void Reconstruction::WriteCamerasText(const std::string& path) 175 | void Reconstruction::ReadCamerasText(const std::string& path) 176 | """ 177 | HEADER = ( 178 | "# Camera list with one line of data per camera:\n" 179 | + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" 180 | + "# Number of cameras: {}\n".format(len(cameras)) 181 | ) 182 | with open(path, "w") as fid: 183 | fid.write(HEADER) 184 | for _, cam in cameras.items(): 185 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 186 | line = " ".join([str(elem) for elem in to_write]) 187 | fid.write(line + "\n") 188 | 189 | 190 | def write_cameras_binary(cameras, path_to_model_file): 191 | """ 192 | see: src/colmap/scene/reconstruction.cc 193 | void Reconstruction::WriteCamerasBinary(const std::string& path) 194 | void Reconstruction::ReadCamerasBinary(const std::string& path) 195 | """ 196 | with open(path_to_model_file, "wb") as fid: 197 | write_next_bytes(fid, len(cameras), "Q") 198 | for _, cam in cameras.items(): 199 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 200 | camera_properties = [cam.id, model_id, cam.width, cam.height] 201 | write_next_bytes(fid, camera_properties, "iiQQ") 202 | for p in cam.params: 203 | write_next_bytes(fid, float(p), "d") 204 | return cameras 205 | 206 | 207 | def read_images_text(path): 208 | """ 209 | see: src/colmap/scene/reconstruction.cc 210 | void Reconstruction::ReadImagesText(const std::string& path) 211 | void Reconstruction::WriteImagesText(const std::string& path) 212 | """ 213 | images = {} 214 | with open(path, "r") as fid: 215 | while True: 216 | line = fid.readline() 217 | if not line: 218 | break 219 | line = line.strip() 220 | if len(line) > 0 and line[0] != "#": 221 | elems = line.split() 222 | image_id = int(elems[0]) 223 | qvec = np.array(tuple(map(float, elems[1:5]))) 224 | tvec = np.array(tuple(map(float, elems[5:8]))) 225 | camera_id = int(elems[8]) 226 | image_name = elems[9] 227 | elems = fid.readline().split() 228 | xys = np.column_stack( 229 | [ 230 | tuple(map(float, elems[0::3])), 231 | tuple(map(float, elems[1::3])), 232 | ] 233 | ) 234 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 235 | images[image_id] = Image( 236 | id=image_id, 237 | qvec=qvec, 238 | tvec=tvec, 239 | camera_id=camera_id, 240 | name=image_name, 241 | xys=xys, 242 | point3D_ids=point3D_ids, 243 | ) 244 | return images 245 | 246 | 247 | def read_images_binary(path_to_model_file): 248 | """ 249 | see: src/colmap/scene/reconstruction.cc 250 | void Reconstruction::ReadImagesBinary(const std::string& path) 251 | void Reconstruction::WriteImagesBinary(const std::string& path) 252 | """ 253 | images = {} 254 | with open(path_to_model_file, "rb") as fid: 255 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 256 | for _ in range(num_reg_images): 257 | binary_image_properties = read_next_bytes( 258 | fid, num_bytes=64, format_char_sequence="idddddddi" 259 | ) 260 | image_id = binary_image_properties[0] 261 | qvec = np.array(binary_image_properties[1:5]) 262 | tvec = np.array(binary_image_properties[5:8]) 263 | camera_id = binary_image_properties[8] 264 | image_name = "" 265 | current_char = read_next_bytes(fid, 1, "c")[0] 266 | while current_char != b"\x00": # look for the ASCII 0 entry 267 | image_name += current_char.decode("utf-8") 268 | current_char = read_next_bytes(fid, 1, "c")[0] 269 | num_points2D = read_next_bytes( 270 | fid, num_bytes=8, format_char_sequence="Q" 271 | )[0] 272 | x_y_id_s = read_next_bytes( 273 | fid, 274 | num_bytes=24 * num_points2D, 275 | format_char_sequence="ddq" * num_points2D, 276 | ) 277 | xys = np.column_stack( 278 | [ 279 | tuple(map(float, x_y_id_s[0::3])), 280 | tuple(map(float, x_y_id_s[1::3])), 281 | ] 282 | ) 283 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 284 | images[image_id] = Image( 285 | id=image_id, 286 | qvec=qvec, 287 | tvec=tvec, 288 | camera_id=camera_id, 289 | name=image_name, 290 | xys=xys, 291 | point3D_ids=point3D_ids, 292 | ) 293 | return images 294 | 295 | 296 | def write_images_text(images, path): 297 | """ 298 | see: src/colmap/scene/reconstruction.cc 299 | void Reconstruction::ReadImagesText(const std::string& path) 300 | void Reconstruction::WriteImagesText(const std::string& path) 301 | """ 302 | if len(images) == 0: 303 | mean_observations = 0 304 | else: 305 | mean_observations = sum( 306 | (len(img.point3D_ids) for _, img in images.items()) 307 | ) / len(images) 308 | HEADER = ( 309 | "# Image list with two lines of data per image:\n" 310 | + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" 311 | + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" 312 | + "# Number of images: {}, mean observations per image: {}\n".format( 313 | len(images), mean_observations 314 | ) 315 | ) 316 | 317 | with open(path, "w") as fid: 318 | fid.write(HEADER) 319 | for _, img in images.items(): 320 | image_header = [ 321 | img.id, 322 | *img.qvec, 323 | *img.tvec, 324 | img.camera_id, 325 | img.name, 326 | ] 327 | first_line = " ".join(map(str, image_header)) 328 | fid.write(first_line + "\n") 329 | 330 | points_strings = [] 331 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 332 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 333 | fid.write(" ".join(points_strings) + "\n") 334 | 335 | 336 | def write_images_binary(images, path_to_model_file): 337 | """ 338 | see: src/colmap/scene/reconstruction.cc 339 | void Reconstruction::ReadImagesBinary(const std::string& path) 340 | void Reconstruction::WriteImagesBinary(const std::string& path) 341 | """ 342 | with open(path_to_model_file, "wb") as fid: 343 | write_next_bytes(fid, len(images), "Q") 344 | for _, img in images.items(): 345 | write_next_bytes(fid, img.id, "i") 346 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 347 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 348 | write_next_bytes(fid, img.camera_id, "i") 349 | for char in img.name: 350 | write_next_bytes(fid, char.encode("utf-8"), "c") 351 | write_next_bytes(fid, b"\x00", "c") 352 | write_next_bytes(fid, len(img.point3D_ids), "Q") 353 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 354 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 355 | 356 | 357 | def read_points3D_text(path): 358 | """ 359 | see: src/colmap/scene/reconstruction.cc 360 | void Reconstruction::ReadPoints3DText(const std::string& path) 361 | void Reconstruction::WritePoints3DText(const std::string& path) 362 | """ 363 | points3D = {} 364 | with open(path, "r") as fid: 365 | while True: 366 | line = fid.readline() 367 | if not line: 368 | break 369 | line = line.strip() 370 | if len(line) > 0 and line[0] != "#": 371 | elems = line.split() 372 | point3D_id = int(elems[0]) 373 | xyz = np.array(tuple(map(float, elems[1:4]))) 374 | rgb = np.array(tuple(map(int, elems[4:7]))) 375 | error = float(elems[7]) 376 | image_ids = np.array(tuple(map(int, elems[8::2]))) 377 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 378 | points3D[point3D_id] = Point3D( 379 | id=point3D_id, 380 | xyz=xyz, 381 | rgb=rgb, 382 | error=error, 383 | image_ids=image_ids, 384 | point2D_idxs=point2D_idxs, 385 | ) 386 | return points3D 387 | 388 | 389 | def read_points3D_binary(path_to_model_file): 390 | """ 391 | see: src/colmap/scene/reconstruction.cc 392 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 393 | void Reconstruction::WritePoints3DBinary(const std::string& path) 394 | """ 395 | points3D = {} 396 | with open(path_to_model_file, "rb") as fid: 397 | num_points = read_next_bytes(fid, 8, "Q")[0] 398 | for _ in range(num_points): 399 | binary_point_line_properties = read_next_bytes( 400 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 401 | ) 402 | point3D_id = binary_point_line_properties[0] 403 | xyz = np.array(binary_point_line_properties[1:4]) 404 | rgb = np.array(binary_point_line_properties[4:7]) 405 | error = np.array(binary_point_line_properties[7]) 406 | track_length = read_next_bytes( 407 | fid, num_bytes=8, format_char_sequence="Q" 408 | )[0] 409 | track_elems = read_next_bytes( 410 | fid, 411 | num_bytes=8 * track_length, 412 | format_char_sequence="ii" * track_length, 413 | ) 414 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 415 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 416 | points3D[point3D_id] = Point3D( 417 | id=point3D_id, 418 | xyz=xyz, 419 | rgb=rgb, 420 | error=error, 421 | image_ids=image_ids, 422 | point2D_idxs=point2D_idxs, 423 | ) 424 | return points3D 425 | 426 | 427 | def write_points3D_text(points3D, path): 428 | """ 429 | see: src/colmap/scene/reconstruction.cc 430 | void Reconstruction::ReadPoints3DText(const std::string& path) 431 | void Reconstruction::WritePoints3DText(const std::string& path) 432 | """ 433 | if len(points3D) == 0: 434 | mean_track_length = 0 435 | else: 436 | mean_track_length = sum( 437 | (len(pt.image_ids) for _, pt in points3D.items()) 438 | ) / len(points3D) 439 | HEADER = ( 440 | "# 3D point list with one line of data per point:\n" 441 | + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" 442 | + "# Number of points: {}, mean track length: {}\n".format( 443 | len(points3D), mean_track_length 444 | ) 445 | ) 446 | 447 | with open(path, "w") as fid: 448 | fid.write(HEADER) 449 | for _, pt in points3D.items(): 450 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] 451 | fid.write(" ".join(map(str, point_header)) + " ") 452 | track_strings = [] 453 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): 454 | track_strings.append(" ".join(map(str, [image_id, point2D]))) 455 | fid.write(" ".join(track_strings) + "\n") 456 | 457 | 458 | def write_points3D_binary(points3D, path_to_model_file): 459 | """ 460 | see: src/colmap/scene/reconstruction.cc 461 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 462 | void Reconstruction::WritePoints3DBinary(const std::string& path) 463 | """ 464 | with open(path_to_model_file, "wb") as fid: 465 | write_next_bytes(fid, len(points3D), "Q") 466 | for _, pt in points3D.items(): 467 | write_next_bytes(fid, pt.id, "Q") 468 | write_next_bytes(fid, pt.xyz.tolist(), "ddd") 469 | write_next_bytes(fid, pt.rgb.tolist(), "BBB") 470 | write_next_bytes(fid, pt.error, "d") 471 | track_length = pt.image_ids.shape[0] 472 | write_next_bytes(fid, track_length, "Q") 473 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): 474 | write_next_bytes(fid, [image_id, point2D_id], "ii") 475 | 476 | 477 | def detect_model_format(path, ext): 478 | if ( 479 | os.path.isfile(os.path.join(path, "cameras" + ext)) 480 | and os.path.isfile(os.path.join(path, "images" + ext)) 481 | and os.path.isfile(os.path.join(path, "points3D" + ext)) 482 | ): 483 | print("Detected model format: '" + ext + "'") 484 | return True 485 | 486 | return False 487 | 488 | 489 | def read_model(path, ext=""): 490 | # try to detect the extension automatically 491 | if ext == "": 492 | if detect_model_format(path, ".bin"): 493 | ext = ".bin" 494 | elif detect_model_format(path, ".txt"): 495 | ext = ".txt" 496 | else: 497 | print("Provide model format: '.bin' or '.txt'") 498 | return 499 | 500 | if ext == ".txt": 501 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 502 | images = read_images_text(os.path.join(path, "images" + ext)) 503 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 504 | else: 505 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 506 | images = read_images_binary(os.path.join(path, "images" + ext)) 507 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) 508 | return cameras, images, points3D 509 | 510 | 511 | def write_model(cameras, images, points3D, path, ext=".bin"): 512 | if ext == ".txt": 513 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 514 | write_images_text(images, os.path.join(path, "images" + ext)) 515 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext) 516 | else: 517 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 518 | write_images_binary(images, os.path.join(path, "images" + ext)) 519 | write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) 520 | return cameras, images, points3D 521 | 522 | 523 | def qvec2rotmat(qvec): 524 | return np.array( 525 | [ 526 | [ 527 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 528 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 529 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 530 | ], 531 | [ 532 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 533 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 534 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 535 | ], 536 | [ 537 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 538 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 539 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 540 | ], 541 | ] 542 | ) 543 | 544 | 545 | def rotmat2qvec(R): 546 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 547 | K = ( 548 | np.array( 549 | [ 550 | [Rxx - Ryy - Rzz, 0, 0, 0], 551 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 552 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 553 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], 554 | ] 555 | ) 556 | / 3.0 557 | ) 558 | eigvals, eigvecs = np.linalg.eigh(K) 559 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 560 | if qvec[0] < 0: 561 | qvec *= -1 562 | return qvec 563 | 564 | 565 | # def main(): 566 | # parser = argparse.ArgumentParser( 567 | # description="Read and write COLMAP binary and text models" 568 | # ) 569 | # parser.add_argument("--input_model", help="path to input model folder") 570 | # parser.add_argument( 571 | # "--input_format", 572 | # choices=[".bin", ".txt"], 573 | # help="input model format", 574 | # default="", 575 | # ) 576 | # parser.add_argument("--output_model", help="path to output model folder") 577 | # parser.add_argument( 578 | # "--output_format", 579 | # choices=[".bin", ".txt"], 580 | # help="outut model format", 581 | # default=".txt", 582 | # ) 583 | # args = parser.parse_args() 584 | 585 | # cameras, images, points3D = read_model( 586 | # path=args.input_model, ext=args.input_format 587 | # ) 588 | 589 | # print("num_cameras:", len(cameras)) 590 | # print("num_images:", len(images)) 591 | # print("num_points3D:", len(points3D)) 592 | 593 | # if args.output_model is not None: 594 | # write_model( 595 | # cameras, 596 | # images, 597 | # points3D, 598 | # path=args.output_model, 599 | # ext=args.output_format, 600 | # ) 601 | 602 | 603 | # if __name__ == "__main__": 604 | # main() 605 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------