├── .gitignore
├── .gitmodules
├── LICENSE.md
├── README.md
├── arguments
└── __init__.py
├── assets
└── teaser.png
├── async_seele_render.py
├── convert.py
├── finetune.py
├── full_eval.py
├── gaussian_renderer
├── __init__.py
└── network_gui.py
├── generate_cluster.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── render_video.py
├── requirements.txt
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
└── gaussian_model.py
├── scripts
├── generate_cluster.sh
├── run_all.sh
├── run_finetune.sh
├── run_render.sh
├── run_seele_render.sh
└── run_train.sh
├── seele_render.py
├── train.py
└── utils
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── make_depth_scale.py
├── pose_utils.py
├── read_write_model.py
├── sh_utils.py
└── system_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | output
4 | build
5 | seele-gaussian-rasterization/diff_rast.egg-info
6 | seele-gaussian-rasterization/dist
7 | tensorboard_3d
8 | screenshots
9 | temps*
10 | output*
11 | dataset*
12 | .out
13 | .ipynb
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 |
2 | [submodule "SIBR_viewers"]
3 | path = SIBR_viewers
4 | url = https://gitlab.inria.fr/sibr/sibr_core.git
5 | [submodule "submodules/fused-ssim"]
6 | path = submodules/fused-ssim
7 | url = https://github.com/rahul-goel/fused-ssim.git
8 | [submodule "submodules/simple-knn"]
9 | path = submodules/simple-knn
10 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
11 | [submodule "submodules/seele-gaussian-rasterization"]
12 | path = submodules/seele-gaussian-rasterization
13 | url = https://github.com/StoneSix16/seele-gaussian-rasterization
14 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 sjtu-mvclab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SeeLe: A Unified Acceleration Framework for Real-Time Gaussian Splatting
2 | | [🌍Webpage](https://seele-project.netlify.app/) | [📄Full Paper](https://arxiv.org/abs/2503.05168) | [🎥Video](https://github.com/user-attachments/assets/49cafdb6-5c8f-43cf-ab05-aa24a39ea1fc) |
3 |
4 | 
5 |
6 | ## 🔍What is it?
7 | This repository provides the official implementation of **SeeLe**, a general acceleration framework for the [3D Gaussian Splatting (3DGS)](https://github.com/graphdeco-inria/gaussian-splatting) pipeline, specifically designed for resource-constrained mobile devices. Our framework achieves a **2.6× speedup** and **32.5% model reduction** while maintaining superior rendering quality compared to existing methods. On an NVIDIA AGX Orin mobile SoC, SeeLe achieves over **90 FPS**⚡, meeting the real-time requirements for VR applications.
8 |
9 | There is a short demo video of our algorithm running on an Nvidia AGX Orin SoC:
10 |
11 | https://github.com/user-attachments/assets/49cafdb6-5c8f-43cf-ab05-aa24a39ea1fc
12 |
13 | ## 🛠️ How to run?
14 | ### Installation
15 | To clone the repository:
16 | ```shell
17 | git clone https://github.com/SJTU-MVCLab/SeeLe.git --recursive && cd SeeLe
18 | ```
19 | To install requirements:
20 | ```shell
21 | conda create -n seele python=3.9
22 | conda activate seele
23 | # Example for CUDA 12.4:
24 | pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
25 | pip3 install -r requirements.txt
26 | ```
27 | **Note:** [PyTorch](https://pytorch.org/) installation varies by system. Please ensure you install the appropriate version for your hardware.
28 |
29 | ### Dataset
30 | We use datasets from **MipNeRF360** and **Tank & Temple**, which can be downloaded from the authors' official [website](https://jonbarron.info/mipnerf360/). The dataset should be organized in the following structure:
31 | ```
32 | dataset
33 | └── seele
34 | └── [bicycle|bonsai|counter|train|truck|playroom|drjohnson|...]
35 | ├── images
36 | └── sparse
37 | ```
38 |
39 | ## 🚀 Training and Evaluation
40 | This section provides detailed instructions on how to **train**, **cluster**, **fine-tune**, and **render** the model using our provided scripts. We also provide **standalone evaluation scripts** for assessing the trained model.
41 |
42 | ### 🔄 One-Click Pipeline: Run Everything at Once
43 | For convenience, you can use the `run_all.sh` script to **automate the entire process** from training to rendering in a single command:
44 | ```shell
45 | bash scripts/run_all.sh
46 | ```
47 | **Note:** By default, all scripts will run on an exmaple scene "**Counter**" from **MipNeRF360**. If you want to train on other datasets, please modify the `datasets` variable in the script accordingly.
48 |
49 | ### 🏗️ Step-by-Step Training and Rendering
50 | #### 1. Train the 3DGS Model (30,000 Iterations)
51 | To train the **3D Gaussian Splatting (3DGS) model**, use:
52 | ```shell
53 | bash scripts/run_train.sh seele
54 | ```
55 |
56 | #### 2. Cluster the Trained Model
57 | Once training is complete, apply **k-means clustering** to the trained model with:
58 | ```shell
59 | bash scripts/generate_cluster.sh seele
60 | ```
61 |
62 | #### 3. Fine-Tune the Clustered Model
63 | After clustering, fine-tune the model for better optimization:
64 | ```shell
65 | bash scripts/run_finetune.sh seele
66 | ```
67 |
68 | #### 4. Render the Final Output with SeeLe
69 | To generate the rendered images using the fine-tuned model, run:
70 | ```shell
71 | bash scripts/run_seele_render.sh seele
72 | ```
73 |
74 | ### 🎨 Evaluation
75 | After training and fine-tuning, you can **evaluate the model** using the following standalone scripts:
76 |
77 | #### 1. Render with `seele_render.py`
78 | Renders a **SeeLe** model with optional fine-tuning:
79 | ```shell
80 | python3 seele_render.py -m [--load_finetune] [--debug]
81 | ```
82 | - **With `--load_finetune`**: Loads the **fine-tuned** model for improved rendering quality. Otherwise, loads the model **before fine-tuning**(output from `generate_cluster.py`).
83 | - **With `--debug`**: Prints the execution time per rendering.
84 |
85 | #### 2. Asynchronous Rendering with `async_seele_render.py`
86 | Uses **CUDA Stream API** for **efficient memory management**, asynchronously loading fine-tuned Gaussian point clouds:
87 | ```shell
88 | python3 async_seele_render.py -m [--debug]
89 | ```
90 |
91 | #### 3. Visualize in GUI with `render_video.py`
92 | Interactively preview rendered results in a GUI:
93 | ```shell
94 | python3 render_video.py -m --use_gui [--load_seele]
95 | ```
96 | - **With `--load_seele`**: Loads the **fine-tuned SeeLe** model. Otherwise, loads the **original** model.
97 |
98 | ## 🏋️♂️ Validate with a Pretrained Model
99 | To verify the correctness of **SeeLe**, we provide an example(dataset and checkpoint) for evaluation. You can download it [here](https://drive.google.com/file/d/1xfqSLFSLvx5IrpEZU62dw7xm1YZHiyYu/view?usp=sharing). This example includes the following key components:
100 |
101 | - **clusters** — The fine-tuned **SeeLe** model.
102 | - **point_cloud** — The original **3DGS** checkpoint.
103 |
104 | You can use this checkpoint to test the pipeline and ensure everything is working correctly.
105 |
106 | ## 🙏 Acknowledgments
107 |
108 | Our work is largely based on the implementation of **[3DGS](https://github.com/graphdeco-inria/gaussian-splatting)**, with significant modifications and optimizations to improve performance for mobile devices. Our key improvements include:
109 |
110 | - **`submodules/seele-gaussian-rasterzation`** — Optimized **[diff_gaussians_splatting](https://github.com/graphdeco-inria/diff-gaussian-rasterization/tree/9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0)** with **Opti** and **CR** techniques.
111 | - **`generate_cluster.py`** — Implements **k-means clustering** to partition the scene into multiple clusters.
112 | - **`finetune.py`** — Fine-tunes each cluster separately and saves the trained models.
113 | - **`seele_render.py`** — A modified version of `render.py`, designed to **load and render SeeLe models**.
114 | - **`async_seele_render.py`** — Utilizes **CUDA stream API** for **asynchronous memory optimization** across different clusters.
115 | - **`render_video.py`** — Uses **pyglet** to render images in a GUI. The `--load_finetune` option enables **SeeLe model rendering**.
116 |
117 | For more technical details, please refer to our [paper](https://arxiv.org/abs/2503.05168).
118 |
119 | ## 📬 Contact
120 | If you have any questions, feel free to reach out to:
121 |
122 | - **Xiaotong Huang** — [hxt0512@sjtu.edu.cn](mailto:hxt0512@sjtu.edu.cn)
123 | - **He Zhu** — [2394241800@qq.com](mailto:2394241800@qq.com)
124 |
125 | We appreciate your interest in **SeeLe**!
126 |
127 | ## 📖 Citation
128 | If you find this work helpful, please kindly consider citing our paper:
129 | ```
130 | @article{huang2025seele,
131 | title={SeeLe: A Unified Acceleration Framework for Real-Time Gaussian Splatting},
132 | author={Xiaotong Huang and He Zhu and Zihan Liu and Weikai Lin and Xiaohong Liu and Zhezhi He and Jingwen Leng and Minyi Guo and Yu Feng},
133 | journal={arXiv preprint arXiv:2503.05168},
134 | year={2025}
135 | }
136 | ```
137 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self._source_path = ""
51 | self._model_path = ""
52 | self._images = "images"
53 | self._depths = ""
54 | self._resolution = -1
55 | self._white_background = False
56 | self.train_test_exp = False
57 | self.data_device = "cuda"
58 | self.eval = False
59 | super().__init__(parser, "Loading Parameters", sentinel)
60 |
61 | def extract(self, args):
62 | g = super().extract(args)
63 | g.source_path = os.path.abspath(g.source_path)
64 | return g
65 |
66 | class PipelineParams(ParamGroup):
67 | def __init__(self, parser):
68 | self.convert_SHs_python = False
69 | self.compute_cov3D_python = False
70 | self.debug = False
71 | self.antialiasing = False
72 | super().__init__(parser, "Pipeline Parameters")
73 |
74 | class OptimizationParams(ParamGroup):
75 | def __init__(self, parser):
76 | self.iterations = 30_000
77 | self.position_lr_init = 0.00016
78 | self.position_lr_final = 0.0000016
79 | self.position_lr_delay_mult = 0.01
80 | self.position_lr_max_steps = 30_000
81 | self.feature_lr = 0.0025
82 | self.opacity_lr = 0.025
83 | self.scaling_lr = 0.005
84 | self.rotation_lr = 0.001
85 | self.exposure_lr_init = 0.01
86 | self.exposure_lr_final = 0.001
87 | self.exposure_lr_delay_steps = 0
88 | self.exposure_lr_delay_mult = 0.0
89 | self.percent_dense = 0.01
90 | self.lambda_dssim = 0.2
91 | self.densification_interval = 100
92 | self.opacity_reset_interval = 3000
93 | self.densify_from_iter = 500
94 | self.densify_until_iter = 15_000
95 | self.densify_grad_threshold = 0.0002
96 | self.depth_l1_weight_init = 1.0
97 | self.depth_l1_weight_final = 0.01
98 | self.random_background = False
99 | self.optimizer_type = "default"
100 | super().__init__(parser, "Optimization Parameters")
101 |
102 | def get_combined_args(parser : ArgumentParser):
103 | cmdlne_string = sys.argv[1:]
104 | cfgfile_string = "Namespace()"
105 | args_cmdline = parser.parse_args(cmdlne_string)
106 |
107 | try:
108 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
109 | print("Looking for config file in", cfgfilepath)
110 | with open(cfgfilepath) as cfg_file:
111 | print("Config file found: {}".format(cfgfilepath))
112 | cfgfile_string = cfg_file.read()
113 | except TypeError:
114 | print("Config file not found at")
115 | pass
116 | args_cfgfile = eval(cfgfile_string)
117 |
118 | merged_dict = vars(args_cfgfile).copy()
119 | for k,v in vars(args_cmdline).items():
120 | if v != None:
121 | merged_dict[k] = v
122 | return Namespace(**merged_dict)
123 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-MVCLab/SeeLe/867c009c7da8fd6c497df47985b41d60cdc4f4e0/assets/teaser.png
--------------------------------------------------------------------------------
/async_seele_render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import numpy as np
12 | import joblib
13 | import torch
14 | from scene import Scene
15 | import os
16 | from tqdm import tqdm
17 | from os import makedirs
18 | from gaussian_renderer import render
19 | import torchvision
20 | from utils.general_utils import safe_state
21 | from argparse import ArgumentParser
22 | from arguments import ModelParams, PipelineParams, get_combined_args
23 | # from gaussian_renderer import GaussianModel
24 | from gaussian_renderer import GaussianModel, GaussianStreamManager
25 | try:
26 | from diff_gaussian_rasterization import SparseGaussianAdam
27 | SPARSE_ADAM_AVAILABLE = True
28 | except:
29 | SPARSE_ADAM_AVAILABLE = False
30 |
31 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh, args):
32 | # Initialize paths and configuration
33 | render_path = os.path.join(model_path, name, f"ours_{iteration}", "renders")
34 | gts_path = os.path.join(model_path, name, f"ours_{iteration}", "gt")
35 | makedirs(render_path, exist_ok=True)
36 | makedirs(gts_path, exist_ok=True)
37 |
38 | # Load cluster data
39 | cluster_data = joblib.load(os.path.join(model_path, "clusters", "clusters.pkl"))
40 | K = len(cluster_data["cluster_viewpoint"])
41 |
42 | # Load all Gaussians to CPU
43 | cluster_gaussians = [
44 | torch.load(os.path.join(model_path, f"clusters/finetune/point_cloud_{cid}.pth"), map_location="cpu")
45 | for cid in range(K)
46 | ]
47 |
48 | labels = cluster_data[f"{name}_labels"]
49 |
50 | stream_manager = GaussianStreamManager(
51 | cluster_gaussians=cluster_gaussians,
52 | initial_cid=labels[0]
53 | )
54 |
55 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
56 | if idx + 1 < len(views):
57 | next_cid = labels[idx+1]
58 | stream_manager.preload_next(next_cid)
59 |
60 | gaussians.restore_gaussians(stream_manager.get_current())
61 |
62 | rendering = render(
63 | view, gaussians, pipeline, background,
64 | use_trained_exp=train_test_exp,
65 | separate_sh=separate_sh,
66 | rasterizer_type="CR"
67 | )["render"]
68 |
69 | torch.cuda.current_stream().wait_stream(stream_manager.load_stream)
70 | gt = view.original_image[0:3, :, :]
71 | if args.train_test_exp:
72 | rendering = rendering[..., rendering.shape[-1]//2:]
73 | gt = gt[..., gt.shape[-1]//2:]
74 |
75 | torchvision.utils.save_image(rendering, os.path.join(render_path, f"{idx:05d}.png"))
76 | torchvision.utils.save_image(gt, os.path.join(gts_path, f"{idx:05d}.png"))
77 |
78 | stream_manager.switch_gaussians()
79 |
80 | stream_manager.cleanup()
81 |
82 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, separate_sh: bool, args: ArgumentParser):
83 | with torch.no_grad():
84 | gaussians = GaussianModel(dataset.sh_degree)
85 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
86 |
87 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
88 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
89 |
90 | if not skip_train:
91 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args)
92 |
93 | if not skip_test:
94 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args)
95 |
96 | if __name__ == "__main__":
97 | # Set up command line argument parser
98 | parser = ArgumentParser(description="Testing script parameters")
99 | model = ModelParams(parser, sentinel=True)
100 | pipeline = PipelineParams(parser)
101 | parser.add_argument("--iteration", default=-1, type=int)
102 | parser.add_argument("--skip_train", action="store_true")
103 | parser.add_argument("--skip_test", action="store_true")
104 | parser.add_argument("--quiet", action="store_true")
105 | args = get_combined_args(parser)
106 | args.data_device = 'cpu'
107 | print("Rendering " + args.model_path)
108 | # Initialize system state (RNG)
109 | safe_state(args.quiet)
110 |
111 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE, args)
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import logging
14 | from argparse import ArgumentParser
15 | import shutil
16 |
17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
18 | parser = ArgumentParser("Colmap converter")
19 | parser.add_argument("--no_gpu", action='store_true')
20 | parser.add_argument("--skip_matching", action='store_true')
21 | parser.add_argument("--source_path", "-s", required=True, type=str)
22 | parser.add_argument("--camera", default="OPENCV", type=str)
23 | parser.add_argument("--colmap_executable", default="", type=str)
24 | parser.add_argument("--resize", action="store_true")
25 | parser.add_argument("--magick_executable", default="", type=str)
26 | args = parser.parse_args()
27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
29 | use_gpu = 1 if not args.no_gpu else 0
30 |
31 | if not args.skip_matching:
32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
33 |
34 | ## Feature extraction
35 | feat_extracton_cmd = colmap_command + " feature_extractor "\
36 | "--database_path " + args.source_path + "/distorted/database.db \
37 | --image_path " + args.source_path + "/input \
38 | --ImageReader.single_camera 1 \
39 | --ImageReader.camera_model " + args.camera + " \
40 | --SiftExtraction.use_gpu " + str(use_gpu)
41 | exit_code = os.system(feat_extracton_cmd)
42 | if exit_code != 0:
43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
44 | exit(exit_code)
45 |
46 | ## Feature matching
47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \
48 | --database_path " + args.source_path + "/distorted/database.db \
49 | --SiftMatching.use_gpu " + str(use_gpu)
50 | exit_code = os.system(feat_matching_cmd)
51 | if exit_code != 0:
52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
53 | exit(exit_code)
54 |
55 | ### Bundle adjustment
56 | # The default Mapper tolerance is unnecessarily large,
57 | # decreasing it speeds up bundle adjustment steps.
58 | mapper_cmd = (colmap_command + " mapper \
59 | --database_path " + args.source_path + "/distorted/database.db \
60 | --image_path " + args.source_path + "/input \
61 | --output_path " + args.source_path + "/distorted/sparse \
62 | --Mapper.ba_global_function_tolerance=0.000001")
63 | exit_code = os.system(mapper_cmd)
64 | if exit_code != 0:
65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
66 | exit(exit_code)
67 |
68 | ### Image undistortion
69 | ## We need to undistort our images into ideal pinhole intrinsics.
70 | img_undist_cmd = (colmap_command + " image_undistorter \
71 | --image_path " + args.source_path + "/input \
72 | --input_path " + args.source_path + "/distorted/sparse/0 \
73 | --output_path " + args.source_path + "\
74 | --output_type COLMAP")
75 | exit_code = os.system(img_undist_cmd)
76 | if exit_code != 0:
77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
78 | exit(exit_code)
79 |
80 | files = os.listdir(args.source_path + "/sparse")
81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
82 | # Copy each file from the source directory to the destination directory
83 | for file in files:
84 | if file == '0':
85 | continue
86 | source_file = os.path.join(args.source_path, "sparse", file)
87 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
88 | shutil.move(source_file, destination_file)
89 |
90 | if(args.resize):
91 | print("Copying and resizing...")
92 |
93 | # Resize images.
94 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
95 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
96 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
97 | # Get the list of files in the source directory
98 | files = os.listdir(args.source_path + "/images")
99 | # Copy each file from the source directory to the destination directory
100 | for file in files:
101 | source_file = os.path.join(args.source_path, "images", file)
102 |
103 | destination_file = os.path.join(args.source_path, "images_2", file)
104 | shutil.copy2(source_file, destination_file)
105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
106 | if exit_code != 0:
107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.")
108 | exit(exit_code)
109 |
110 | destination_file = os.path.join(args.source_path, "images_4", file)
111 | shutil.copy2(source_file, destination_file)
112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
113 | if exit_code != 0:
114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.")
115 | exit(exit_code)
116 |
117 | destination_file = os.path.join(args.source_path, "images_8", file)
118 | shutil.copy2(source_file, destination_file)
119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
120 | if exit_code != 0:
121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
122 | exit(exit_code)
123 |
124 | print("Done.")
125 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | import joblib
15 | from random import randint
16 | from utils.loss_utils import l1_loss, ssim
17 | from gaussian_renderer import render, network_gui
18 | import sys
19 | from scene import Scene, GaussianModel
20 | from utils.general_utils import safe_state, get_expon_lr_func
21 | import uuid
22 | from tqdm import tqdm
23 | from utils.image_utils import psnr
24 | from argparse import ArgumentParser, Namespace
25 | from arguments import ModelParams, PipelineParams, OptimizationParams
26 | import gc
27 |
28 | try:
29 | from torch.utils.tensorboard import SummaryWriter
30 | TENSORBOARD_FOUND = True
31 | except ImportError:
32 | TENSORBOARD_FOUND = False
33 |
34 | try:
35 | from fused_ssim import fused_ssim
36 | FUSED_SSIM_AVAILABLE = True
37 | except:
38 | FUSED_SSIM_AVAILABLE = False
39 |
40 | try:
41 | from diff_gaussian_rasterization import SparseGaussianAdam
42 | SPARSE_ADAM_AVAILABLE = True
43 | except:
44 | SPARSE_ADAM_AVAILABLE = False
45 |
46 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
47 | clusters_data_path = os.path.join(dataset.model_path, "clusters")
48 |
49 | cluster_data = joblib.load(os.path.join(clusters_data_path, "clusters.pkl"))
50 | K = len(cluster_data["cluster_viewpoint"])
51 |
52 | finetune_path = os.path.join(clusters_data_path, "finetune")
53 | os.makedirs(finetune_path, exist_ok=True)
54 | dataset.finetune_path = finetune_path
55 |
56 | gaussians = GaussianModel(dataset.sh_degree, opt.optimizer_type)
57 | scene = Scene(dataset, gaussians, shuffle=False)
58 | for cid in range(K):
59 | print(f"----------------- training cluster {cid} -----------------")
60 | viewpoint_indices = cluster_data["cluster_viewpoint"][cid].tolist()
61 | (gaussian_ids, lens) = cluster_data["cluster_gaussians"][cid]
62 | (model_params, first_iter) = torch.load(checkpoint, weights_only=False)
63 | dataset.cid = cid
64 | dataset.viewpoint_indices = viewpoint_indices
65 | gaussians.restore_models(model_params, (gaussian_ids, lens), opt)
66 | training_cluster(dataset, opt, pipe, gaussians, scene, first_iter, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from)
67 |
68 | del model_params
69 | torch.cuda.empty_cache()
70 | gc.collect()
71 |
72 | def training_cluster(dataset, opt, pipe, gaussians, scene, first_iter, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
73 | if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":
74 | sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
75 |
76 | # tb_writer = prepare_output_and_logger(dataset)
77 | # if checkpoint:
78 | # (model_params, first_iter) = torch.load(checkpoint)
79 | # gaussians.restore(model_params, opt)
80 |
81 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
82 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
83 |
84 | iter_start = torch.cuda.Event(enable_timing = True)
85 | iter_end = torch.cuda.Event(enable_timing = True)
86 |
87 | use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE
88 | depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
89 |
90 | trainCameras = scene.getTrainCameras().copy()
91 | trainCameras = [trainCameras[view_id] for view_id in dataset.viewpoint_indices]
92 |
93 | viewpoint_stack = trainCameras.copy()
94 | viewpoint_indices = list(range(len(viewpoint_stack)))
95 | ema_loss_for_log = 0.0
96 | ema_Ll1depth_for_log = 0.0
97 |
98 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
99 | first_iter += 1
100 | for iteration in range(first_iter, opt.iterations + 1):
101 | gaussians.update_learning_rate(iteration)
102 |
103 | # # Every 1000 its we increase the levels of SH up to a maximum degree
104 | # if iteration % 1000 == 0:
105 | # gaussians.oneupSHdegree()
106 |
107 | # Pick a random Camera
108 | if not viewpoint_stack:
109 | viewpoint_stack = trainCameras.copy()
110 | viewpoint_indices = list(range(len(viewpoint_stack)))
111 | rand_idx = randint(0, len(viewpoint_indices) - 1)
112 | viewpoint_cam = viewpoint_stack.pop(rand_idx)
113 | vind = viewpoint_indices.pop(rand_idx)
114 |
115 | # Render
116 | if (iteration - 1) == debug_from:
117 | pipe.debug = True
118 |
119 | bg = torch.rand((3), device="cuda") if opt.random_background else background
120 |
121 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE)
122 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
123 |
124 | # if viewpoint_cam.alpha_mask is not None:
125 | # alpha_mask = viewpoint_cam.alpha_mask.cuda()
126 | # image *= alpha_mask
127 |
128 | # Loss
129 | gt_image = viewpoint_cam.original_image.cuda()
130 | Ll1 = l1_loss(image, gt_image)
131 | if FUSED_SSIM_AVAILABLE:
132 | ssim_value = fused_ssim(image.unsqueeze(0), gt_image.unsqueeze(0))
133 | else:
134 | ssim_value = ssim(image, gt_image)
135 |
136 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)
137 |
138 | # Depth regularization
139 | Ll1depth_pure = 0.0
140 | if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable:
141 | invDepth = render_pkg["depth"]
142 | mono_invdepth = viewpoint_cam.invdepthmap.cuda()
143 | depth_mask = viewpoint_cam.depth_mask.cuda()
144 |
145 | Ll1depth_pure = torch.abs((invDepth - mono_invdepth) * depth_mask).mean()
146 | Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure
147 | loss += Ll1depth
148 | Ll1depth = Ll1depth.item()
149 | else:
150 | Ll1depth = 0
151 |
152 | loss.backward()
153 |
154 | iter_end.record()
155 |
156 | with torch.no_grad():
157 | # Progress bar
158 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
159 | ema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_log
160 |
161 | if iteration % 10 == 0:
162 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}"})
163 | progress_bar.update(10)
164 | if iteration == opt.iterations:
165 | progress_bar.close()
166 |
167 | # Log and save
168 | if (iteration in saving_iterations):
169 | print("\n[ITER {} Cid {}] Saving Gaussians".format(iteration, dataset.cid))
170 | # saving_gaussians(dataset, pipe, gaussians, trainCameras)
171 | torch.save(gaussians.capture_gaussians(), os.path.join(dataset.finetune_path, f"point_cloud_{dataset.cid}.pth"))
172 |
173 | # Densification
174 | if iteration < opt.densify_until_iter:
175 | # Keep track of max radii in image-space for pruning
176 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
177 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
178 |
179 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
180 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
181 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
182 |
183 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
184 | gaussians.reset_opacity()
185 |
186 | # Optimizer step
187 | if iteration < opt.iterations:
188 | gaussians.exposure_optimizer.step()
189 | gaussians.exposure_optimizer.zero_grad(set_to_none = True)
190 | if use_sparse_adam:
191 | visible = radii > 0
192 | gaussians.optimizer.step(visible, radii.shape[0])
193 | gaussians.optimizer.zero_grad(set_to_none = True)
194 | else:
195 | gaussians.optimizer.step()
196 | gaussians.optimizer.zero_grad(set_to_none = True)
197 |
198 | if (iteration in checkpoint_iterations):
199 | print("\n[ITER {} Cid {}] Saving Checkpoint".format(iteration, dataset.cid))
200 | torch.save((gaussians.capture(), iteration), os.path.join(dataset.finetune_path, "chkpnt" + str(iteration) + f"_{dataset.cid}.pth"))
201 |
202 | def prepare_output_and_logger(args):
203 | if not args.model_path:
204 | if os.getenv('OAR_JOB_ID'):
205 | unique_str=os.getenv('OAR_JOB_ID')
206 | else:
207 | unique_str = str(uuid.uuid4())
208 | args.model_path = os.path.join("./output/", unique_str[0:10])
209 |
210 | # Set up output folder
211 | print("Output folder: {}".format(args.model_path))
212 | os.makedirs(args.model_path, exist_ok = True)
213 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
214 | cfg_log_f.write(str(Namespace(**vars(args))))
215 |
216 | # Create Tensorboard writer
217 | tb_writer = None
218 | if TENSORBOARD_FOUND:
219 | tb_writer = SummaryWriter(args.model_path)
220 | else:
221 | print("Tensorboard not available: not logging progress")
222 | return tb_writer
223 |
224 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp):
225 | if tb_writer:
226 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
227 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
228 | tb_writer.add_scalar('iter_time', elapsed, iteration)
229 |
230 | # Report test and samples of training set
231 | if iteration in testing_iterations:
232 | torch.cuda.empty_cache()
233 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
234 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
235 |
236 | for config in validation_configs:
237 | if config['cameras'] and len(config['cameras']) > 0:
238 | l1_test = 0.0
239 | psnr_test = 0.0
240 | for idx, viewpoint in enumerate(config['cameras']):
241 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
242 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
243 | if train_test_exp:
244 | image = image[..., image.shape[-1] // 2:]
245 | gt_image = gt_image[..., gt_image.shape[-1] // 2:]
246 | if tb_writer and (idx < 5):
247 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
248 | if iteration == testing_iterations[0]:
249 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
250 | l1_test += l1_loss(image, gt_image).mean().double()
251 | psnr_test += psnr(image, gt_image).mean().double()
252 | psnr_test /= len(config['cameras'])
253 | l1_test /= len(config['cameras'])
254 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
255 | if tb_writer:
256 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
257 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
258 |
259 | if tb_writer:
260 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
261 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
262 | torch.cuda.empty_cache()
263 |
264 | if __name__ == "__main__":
265 | # Set up command line argument parser
266 | parser = ArgumentParser(description="Training script parameters")
267 | lp = ModelParams(parser)
268 | op = OptimizationParams(parser)
269 | pp = PipelineParams(parser)
270 | parser.add_argument('--ip', type=str, default="127.0.0.1")
271 | parser.add_argument('--port', type=int, default=6009)
272 | parser.add_argument('--debug_from', type=int, default=-1)
273 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
274 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[31_000])
275 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[31_000])
276 | parser.add_argument("--quiet", action="store_true")
277 | parser.add_argument('--disable_viewer', action='store_true', default=False)
278 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
279 | parser.add_argument("--start_checkpoint", type=str, default = None)
280 | args = parser.parse_args(sys.argv[1:])
281 | # op.densify_until_iter = args.iterations
282 | op.position_lr_max_steps = args.iterations
283 | args.save_iterations.append(args.iterations)
284 |
285 | print("Optimizing " + args.model_path)
286 |
287 | # Initialize system state (RNG)
288 | safe_state(args.quiet)
289 |
290 | # Start GUI server, configure and run training
291 | if not args.disable_viewer:
292 | network_gui.init(args.ip, args.port)
293 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
294 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
295 |
296 | # All done
297 | print("\nTraining complete.")
298 |
--------------------------------------------------------------------------------
/full_eval.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | from argparse import ArgumentParser
14 | import time
15 |
16 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
17 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
18 | tanks_and_temples_scenes = ["truck", "train"]
19 | deep_blending_scenes = ["drjohnson", "playroom"]
20 |
21 | parser = ArgumentParser(description="Full evaluation script parameters")
22 | parser.add_argument("--skip_training", action="store_true")
23 | parser.add_argument("--skip_rendering", action="store_true")
24 | parser.add_argument("--skip_metrics", action="store_true")
25 | parser.add_argument("--output_path", default="./eval")
26 | parser.add_argument("--use_depth", action="store_true")
27 | parser.add_argument("--use_expcomp", action="store_true")
28 | parser.add_argument("--fast", action="store_true")
29 | parser.add_argument("--aa", action="store_true")
30 |
31 |
32 |
33 |
34 | args, _ = parser.parse_known_args()
35 |
36 | all_scenes = []
37 | all_scenes.extend(mipnerf360_outdoor_scenes)
38 | all_scenes.extend(mipnerf360_indoor_scenes)
39 | all_scenes.extend(tanks_and_temples_scenes)
40 | all_scenes.extend(deep_blending_scenes)
41 |
42 | if not args.skip_training or not args.skip_rendering:
43 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
44 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
45 | parser.add_argument("--deepblending", "-db", required=True, type=str)
46 | args = parser.parse_args()
47 | if not args.skip_training:
48 | common_args = " --disable_viewer --quiet --eval --test_iterations -1 "
49 |
50 | if args.aa:
51 | common_args += " --antialiasing "
52 | if args.use_depth:
53 | common_args += " -d depths2/ "
54 |
55 | if args.use_expcomp:
56 | common_args += " --exposure_lr_init 0.001 --exposure_lr_final 0.0001 --exposure_lr_delay_steps 5000 --exposure_lr_delay_mult 0.001 --train_test_exp "
57 |
58 | if args.fast:
59 | common_args += " --optimizer_type sparse_adam "
60 |
61 | start_time = time.time()
62 | for scene in mipnerf360_outdoor_scenes:
63 | source = args.mipnerf360 + "/" + scene
64 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
65 | for scene in mipnerf360_indoor_scenes:
66 | source = args.mipnerf360 + "/" + scene
67 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
68 | m360_timing = (time.time() - start_time)/60.0
69 |
70 | start_time = time.time()
71 | for scene in tanks_and_temples_scenes:
72 | source = args.tanksandtemples + "/" + scene
73 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
74 | tandt_timing = (time.time() - start_time)/60.0
75 |
76 | start_time = time.time()
77 | for scene in deep_blending_scenes:
78 | source = args.deepblending + "/" + scene
79 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
80 | db_timing = (time.time() - start_time)/60.0
81 |
82 | with open(os.path.join(args.output_path,"timing.txt"), 'w') as file:
83 | file.write(f"m360: {m360_timing} minutes \n tandt: {tandt_timing} minutes \n db: {db_timing} minutes\n")
84 |
85 | if not args.skip_rendering:
86 | all_sources = []
87 | for scene in mipnerf360_outdoor_scenes:
88 | all_sources.append(args.mipnerf360 + "/" + scene)
89 | for scene in mipnerf360_indoor_scenes:
90 | all_sources.append(args.mipnerf360 + "/" + scene)
91 | for scene in tanks_and_temples_scenes:
92 | all_sources.append(args.tanksandtemples + "/" + scene)
93 | for scene in deep_blending_scenes:
94 | all_sources.append(args.deepblending + "/" + scene)
95 |
96 | common_args = " --quiet --eval --skip_train"
97 |
98 | if args.aa:
99 | common_args += " --antialiasing "
100 | if args.use_expcomp:
101 | common_args += " --train_test_exp "
102 |
103 | for scene, source in zip(all_scenes, all_sources):
104 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
105 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
106 |
107 | if not args.skip_metrics:
108 | scenes_string = ""
109 | for scene in all_scenes:
110 | scenes_string += "\"" + args.output_path + "/" + scene + "\" "
111 |
112 | os.system("python metrics.py -m " + scenes_string)
113 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | from scene.gaussian_model import GaussianModel, GaussianStreamManager
16 | from utils.sh_utils import eval_sh
17 |
18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp=False, rasterizer_type=""):
19 | """
20 | Render the scene.
21 |
22 | Background tensor (bg_color) must be on GPU!
23 | """
24 |
25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
27 | try:
28 | screenspace_points.retain_grad()
29 | except:
30 | pass
31 |
32 | # Set up rasterization configuration
33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
35 |
36 | raster_settings = GaussianRasterizationSettings(
37 | image_height=int(viewpoint_camera.image_height),
38 | image_width=int(viewpoint_camera.image_width),
39 | tanfovx=tanfovx,
40 | tanfovy=tanfovy,
41 | bg=bg_color,
42 | scale_modifier=scaling_modifier,
43 | viewmatrix=viewpoint_camera.world_view_transform,
44 | projmatrix=viewpoint_camera.full_proj_transform,
45 | sh_degree=pc.active_sh_degree,
46 | campos=viewpoint_camera.camera_center,
47 | prefiltered=False,
48 | debug=pipe.debug,
49 | rasterizer_type=rasterizer_type,
50 | )
51 |
52 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
53 |
54 | means3D = pc.get_xyz
55 | means2D = screenspace_points
56 | opacity = pc.get_opacity
57 |
58 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
59 | # scaling / rotation by the rasterizer.
60 | scales = None
61 | rotations = None
62 | cov3D_precomp = None
63 |
64 | if pipe.compute_cov3D_python:
65 | cov3D_precomp = pc.get_covariance(scaling_modifier)
66 | else:
67 | scales = pc.get_scaling
68 | rotations = pc.get_rotation
69 |
70 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
71 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
72 | shs = None
73 | colors_precomp = None
74 | if override_color is None:
75 | if pipe.convert_SHs_python:
76 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
77 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
78 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
79 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
80 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
81 | else:
82 | if separate_sh:
83 | dc, shs = pc.get_features_dc, pc.get_features_rest
84 | else:
85 | shs = pc.get_features
86 | else:
87 | colors_precomp = override_color
88 |
89 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
90 | if separate_sh:
91 | returns = rasterizer(
92 | means3D = means3D,
93 | means2D = means2D,
94 | dc = dc,
95 | shs = shs,
96 | colors_precomp = colors_precomp,
97 | opacities = opacity,
98 | scales = scales,
99 | rotations = rotations,
100 | cov3D_precomp = cov3D_precomp)
101 | else:
102 | returns = rasterizer(
103 | means3D = means3D,
104 | means2D = means2D,
105 | shs = shs,
106 | colors_precomp = colors_precomp,
107 | opacities = opacity,
108 | scales = scales,
109 | rotations = rotations,
110 | cov3D_precomp = cov3D_precomp)
111 | visible_gaussians = None
112 | if rasterizer_type == "Mark":
113 | rendered_image, visible_gaussians, radii = returns
114 | else:
115 | rendered_image, radii = returns
116 |
117 | # Apply exposure to rendered image (training only)
118 | if use_trained_exp:
119 | exposure = pc.get_exposure_from_name(viewpoint_camera.image_name)
120 | rendered_image = torch.matmul(rendered_image.permute(1, 2, 0), exposure[:3, :3]).permute(2, 0, 1) + exposure[:3, 3, None, None]
121 |
122 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
123 | # They will be excluded from value updates used in the splitting criteria.
124 | rendered_image = rendered_image.clamp(0, 1)
125 | out = {
126 | "render": rendered_image,
127 | "viewspace_points": screenspace_points,
128 | "visibility_filter" : (radii > 0).nonzero(),
129 | "radii": radii,
130 | "depth" : None,
131 | "visible_gaussians": visible_gaussians
132 | }
133 |
134 | return out
135 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 | def init(wish_host, wish_port):
27 | global host, port, listener
28 | host = wish_host
29 | port = wish_port
30 | listener.bind((host, port))
31 | listener.listen()
32 | listener.settimeout(0)
33 |
34 | def try_connect():
35 | global conn, addr, listener
36 | try:
37 | conn, addr = listener.accept()
38 | print(f"\nConnected by {addr}")
39 | conn.settimeout(None)
40 | except Exception as inst:
41 | pass
42 |
43 | def read():
44 | global conn
45 | messageLength = conn.recv(4)
46 | messageLength = int.from_bytes(messageLength, 'little')
47 | message = conn.recv(messageLength)
48 | return json.loads(message.decode("utf-8"))
49 |
50 | def send(message_bytes, verify):
51 | global conn
52 | if message_bytes != None:
53 | conn.sendall(message_bytes)
54 | conn.sendall(len(verify).to_bytes(4, 'little'))
55 | conn.sendall(bytes(verify, 'ascii'))
56 |
57 | def receive():
58 | message = read()
59 |
60 | width = message["resolution_x"]
61 | height = message["resolution_y"]
62 |
63 | if width != 0 and height != 0:
64 | try:
65 | do_training = bool(message["train"])
66 | fovy = message["fov_y"]
67 | fovx = message["fov_x"]
68 | znear = message["z_near"]
69 | zfar = message["z_far"]
70 | do_shs_python = bool(message["shs_python"])
71 | do_rot_scale_python = bool(message["rot_scale_python"])
72 | keep_alive = bool(message["keep_alive"])
73 | scaling_modifier = message["scaling_modifier"]
74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75 | world_view_transform[:,1] = -world_view_transform[:,1]
76 | world_view_transform[:,2] = -world_view_transform[:,2]
77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78 | full_proj_transform[:,1] = -full_proj_transform[:,1]
79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80 | except Exception as e:
81 | print("")
82 | traceback.print_exc()
83 | raise e
84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85 | else:
86 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/generate_cluster.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene
14 | import os
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | from utils.general_utils import safe_state
19 | from utils.graphics_utils import getWorld2View2
20 |
21 | import joblib
22 | import numpy as np
23 | from sklearn.cluster import KMeans
24 | from scipy.spatial.transform import Rotation as Rot
25 |
26 | from argparse import ArgumentParser
27 | from arguments import ModelParams, PipelineParams, get_combined_args
28 | from gaussian_renderer import GaussianModel
29 | try:
30 | from diff_gaussian_rasterization import SparseGaussianAdam
31 | SPARSE_ADAM_AVAILABLE = True
32 | except:
33 | SPARSE_ADAM_AVAILABLE = False
34 |
35 | def generate_features_from_Rt(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
36 | # R_w2c: R.T, t_w2c: t
37 | # R_c2w: R, t_c2w: -R.T @ t
38 | w2c = getWorld2View2(R, t, translate=translate, scale=scale)
39 | c2w = np.linalg.inv(w2c)
40 |
41 | rot = Rot.from_matrix(c2w[:3, :3]) # This function will orthonormalize R automatically.
42 | q = rot.as_quat(canonical=True)
43 | feature_vector = np.concatenate([c2w[:3, 3], q])
44 | return feature_vector
45 |
46 | def extract_features(views):
47 | features = []
48 | for view in views:
49 | features.append(generate_features_from_Rt(view.R, view.T))
50 | features = np.stack(features, axis=0)
51 | return features
52 |
53 | def merge_neighbor_mask(centers, cluster_masks, labels, neigh):
54 | K, P = cluster_masks.shape
55 |
56 | total_shared = total_exclusive = 0
57 | merge_gaussians, merge_viewpoint = [], []
58 | cluster_masks = cluster_masks.astype(np.uint32)
59 | average_gaussians = 0
60 | for cid in range(K):
61 | base = centers[cid:cid+1]
62 | dist2 = np.square(base - centers).sum(1)
63 | merge_clusters = np.argsort(dist2)[:neigh + 1]
64 |
65 | viewpoints = np.concatenate([(labels == cluster).nonzero()[0] for cluster in merge_clusters])
66 | merge_viewpoint.append(viewpoints)
67 |
68 | gaussians_counter = cluster_masks[merge_clusters].sum(axis=0)
69 | shared_mask = (gaussians_counter > ((neigh + 1) // 2))
70 | exclusive_mask = np.logical_xor(shared_mask, (gaussians_counter != 0))
71 |
72 | shared, exclusive = map(lambda x: x.nonzero()[0], [shared_mask, exclusive_mask])
73 | gaussian_ids = np.concatenate([shared, exclusive], axis=0)
74 |
75 | lens = (len(shared), len(exclusive))
76 | merge_gaussians.append((gaussian_ids, lens))
77 |
78 | total_shared += lens[0]
79 | total_exclusive += lens[1]
80 | average_gaussians += lens[0] + lens[1]
81 |
82 | total_shared //= K
83 | total_exclusive //= K
84 | average_gaussians //= K
85 | print(f"Total gaussians: {P}, average shared gaussians: {total_shared}, average exclusive gaussians: {total_exclusive}, average number of gaussians: {average_gaussians}")
86 | print(f"Expansion ratio: {(total_exclusive + total_shared) / P}")
87 | return merge_gaussians, merge_viewpoint
88 |
89 | def render_set(views, gaussians, pipeline, background, train_test_exp, separate_sh):
90 | gaussian_masks = []
91 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
92 | out = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh, rasterizer_type="Mark")
93 | visible_gaussians = out["visible_gaussians"].cpu().numpy()
94 | gaussian_masks.append(visible_gaussians != 0)
95 |
96 | return np.stack(gaussian_masks, axis=0)
97 |
98 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, args, separate_sh: bool):
99 | with torch.no_grad():
100 | gaussians = GaussianModel(dataset.sh_degree)
101 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
102 | train_features = extract_features(scene.getTrainCameras())
103 | test_features = extract_features(scene.getTestCameras())
104 | kmeans = KMeans(n_clusters=args.k, random_state=42, n_init='auto').fit(train_features)
105 | centers = kmeans.cluster_centers_
106 | train_labels = kmeans.labels_
107 | test_labels = kmeans.predict(test_features)
108 |
109 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
110 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
111 | view_gaussian_masks = render_set(scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh)
112 | cluster_gaussian_masks = np.stack([np.any(view_gaussian_masks[train_labels == j], axis=0) for j in range(args.k)], axis=0)
113 | merge_gaussians, merge_viewpoint = merge_neighbor_mask(centers, cluster_gaussian_masks, train_labels, neigh=args.n)
114 |
115 | save_path = os.path.join(dataset.model_path, "clusters")
116 | makedirs(save_path, exist_ok=True)
117 | data = {
118 | "cluster_gaussians": merge_gaussians,
119 | "cluster_viewpoint": merge_viewpoint,
120 | "train_labels": train_labels,
121 | "test_labels": test_labels,
122 | "centers": centers,
123 | }
124 | joblib.dump(data, os.path.join(save_path, "clusters.pkl"))
125 | joblib.dump(kmeans, os.path.join(save_path, "kmeans_model.pkl"))
126 |
127 | if __name__ == "__main__":
128 | # Set up command line argument parser
129 | parser = ArgumentParser(description="Testing script parameters")
130 | model = ModelParams(parser, sentinel=True)
131 | pipeline = PipelineParams(parser)
132 | parser.add_argument("--iteration", default=-1, type=int)
133 | parser.add_argument("--quiet", action="store_true")
134 | parser.add_argument("-k", type=int, default = 24)
135 | parser.add_argument("-n", type=int, default = 4)
136 | args = get_combined_args(parser)
137 | print("Generating clusters for" + args.model_path)
138 | print(f"k: {args.k}, n: {args.n}")
139 | # Initialize system state (RNG)
140 | safe_state(args.quiet)
141 |
142 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args, SPARSE_ADAM_AVAILABLE)
143 |
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | from lpipsPyTorch import lpips
19 | import json
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser
23 |
24 | def readImages(renders_dir, gt_dir):
25 | renders = []
26 | gts = []
27 | image_names = []
28 | for fname in os.listdir(renders_dir):
29 | render = Image.open(renders_dir / fname)
30 | gt = Image.open(gt_dir / fname)
31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
33 | image_names.append(fname)
34 | return renders, gts, image_names
35 |
36 | def evaluate(model_paths):
37 |
38 | full_dict = {}
39 | per_view_dict = {}
40 | full_dict_polytopeonly = {}
41 | per_view_dict_polytopeonly = {}
42 | print("")
43 |
44 | for scene_dir in model_paths:
45 | try:
46 | print("Scene:", scene_dir)
47 | full_dict[scene_dir] = {}
48 | per_view_dict[scene_dir] = {}
49 | full_dict_polytopeonly[scene_dir] = {}
50 | per_view_dict_polytopeonly[scene_dir] = {}
51 |
52 | test_dir = Path(scene_dir) / "test"
53 | # test_dir = Path(scene_dir) / "train"
54 |
55 | for method in os.listdir(test_dir):
56 | print("Method:", method)
57 |
58 | full_dict[scene_dir][method] = {}
59 | per_view_dict[scene_dir][method] = {}
60 | full_dict_polytopeonly[scene_dir][method] = {}
61 | per_view_dict_polytopeonly[scene_dir][method] = {}
62 |
63 | method_dir = test_dir / method
64 | gt_dir = method_dir/ "gt"
65 | renders_dir = method_dir / "renders"
66 | renders, gts, image_names = readImages(renders_dir, gt_dir)
67 |
68 | ssims = []
69 | psnrs = []
70 | lpipss = []
71 |
72 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
73 | ssims.append(ssim(renders[idx], gts[idx]))
74 | psnrs.append(psnr(renders[idx], gts[idx]))
75 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
76 |
77 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
78 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
79 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
80 | print("")
81 |
82 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
83 | "PSNR": torch.tensor(psnrs).mean().item(),
84 | "LPIPS": torch.tensor(lpipss).mean().item()})
85 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
86 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
87 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
88 |
89 | with open(scene_dir + "/results.json", 'w') as fp:
90 | json.dump(full_dict[scene_dir], fp, indent=True)
91 | with open(scene_dir + "/per_view.json", 'w') as fp:
92 | json.dump(per_view_dict[scene_dir], fp, indent=True)
93 | except:
94 | print("Unable to compute metrics for model", scene_dir)
95 |
96 | if __name__ == "__main__":
97 | device = torch.device("cuda:0")
98 | torch.cuda.set_device(device)
99 |
100 | # Set up command line argument parser
101 | parser = ArgumentParser(description="Training script parameters")
102 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
103 | args = parser.parse_args()
104 | evaluate(args.model_paths)
105 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene
14 | import os
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | import torchvision
19 | from utils.general_utils import safe_state
20 | from argparse import ArgumentParser
21 | from arguments import ModelParams, PipelineParams, get_combined_args
22 | from gaussian_renderer import GaussianModel
23 | try:
24 | from diff_gaussian_rasterization import SparseGaussianAdam
25 | SPARSE_ADAM_AVAILABLE = True
26 | except:
27 | SPARSE_ADAM_AVAILABLE = False
28 |
29 |
30 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh):
31 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
32 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
33 |
34 | makedirs(render_path, exist_ok=True)
35 | makedirs(gts_path, exist_ok=True)
36 |
37 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
38 | rendering = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh)["render"]
39 | gt = view.original_image[0:3, :, :]
40 |
41 | if args.train_test_exp:
42 | rendering = rendering[..., rendering.shape[-1] // 2:]
43 | gt = gt[..., gt.shape[-1] // 2:]
44 |
45 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
46 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
47 |
48 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, separate_sh: bool):
49 | with torch.no_grad():
50 | gaussians = GaussianModel(dataset.sh_degree)
51 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
52 |
53 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
54 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
55 |
56 | if not skip_train:
57 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh)
58 |
59 | if not skip_test:
60 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh)
61 |
62 | if __name__ == "__main__":
63 | # Set up command line argument parser
64 | parser = ArgumentParser(description="Testing script parameters")
65 | model = ModelParams(parser, sentinel=True)
66 | pipeline = PipelineParams(parser)
67 | parser.add_argument("--iteration", default=-1, type=int)
68 | parser.add_argument("--skip_train", action="store_true")
69 | parser.add_argument("--skip_test", action="store_true")
70 | parser.add_argument("--quiet", action="store_true")
71 | args = get_combined_args(parser)
72 | print("Rendering " + args.model_path)
73 |
74 | # Initialize system state (RNG)
75 | safe_state(args.quiet)
76 |
77 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE)
--------------------------------------------------------------------------------
/render_video.py:
--------------------------------------------------------------------------------
1 | import pyglet
2 | import numpy as np
3 | import joblib
4 | import torch
5 | import os
6 | import time
7 | from tqdm import tqdm
8 | from argparse import ArgumentParser
9 | from scene import Scene
10 | from gaussian_renderer import render, GaussianModel, GaussianStreamManager
11 | from utils.general_utils import safe_state
12 | from utils.pose_utils import generate_ellipse_path, getWorld2View2
13 | from arguments import ModelParams, PipelineParams, get_combined_args
14 | from generate_cluster import generate_features_from_Rt
15 | import torchvision
16 | SPARSE_ADAM_AVAILABLE = False
17 |
18 | class VideoPlayer:
19 | """Efficient video player using pyglet for 3DGS rendering display."""
20 |
21 | def __init__(self, width: int, height: int, total_frames: int):
22 | """Initialize the video player window and UI elements.
23 |
24 | Args:
25 | width: Width of the video frame
26 | height: Height of the video frame
27 | total_frames: Total number of frames to be displayed
28 | """
29 | self.window = pyglet.window.Window(
30 | width=width,
31 | height=height,
32 | caption='3DGS Rendering Viewer'
33 | )
34 | self.total_frames = total_frames
35 | self.current_frame = 0
36 | self.fps = 0.0
37 | self.last_time = time.time()
38 |
39 | # Initialize texture with blank frame
40 | self._init_texture(width, height)
41 |
42 | # Setup UI elements
43 | self._setup_ui(width, height)
44 |
45 | # Register event handlers
46 | self.window.event(self.on_draw)
47 |
48 | def _init_texture(self, width: int, height: int):
49 | """Initialize the OpenGL texture with blank data."""
50 | blank_data = np.zeros((height, width, 3), dtype=np.uint8).tobytes()
51 | self.texture = pyglet.image.ImageData(
52 | width, height, 'RGB', blank_data
53 | ).get_texture()
54 |
55 | def _setup_ui(self, width: int, height: int):
56 | """Initialize UI components (FPS counter and progress bar)."""
57 | self.batch = pyglet.graphics.Batch()
58 |
59 | # Frame counter label
60 | self.label = pyglet.text.Label(
61 | '',
62 | x=10, y=height-30,
63 | font_size=16,
64 | color=(255, 255, 255, 255),
65 | batch=self.batch
66 | )
67 |
68 | # Progress bar (positioned at bottom with 2% margin)
69 | self.progress_bar = pyglet.shapes.Rectangle(
70 | x=width*0.01, y=5,
71 | width=0, height=10,
72 | color=(0, 255, 0),
73 | batch=self.batch
74 | )
75 | self.progress_bar_max_width = width*0.98
76 |
77 | def update_frame(self, frame_data: np.ndarray):
78 | """Update the display with new frame data.
79 |
80 | Args:
81 | frame_data: Numpy array containing frame data (H,W,3)
82 | """
83 | # Convert tensor if necessary
84 | if isinstance(frame_data, torch.Tensor):
85 | frame_data = frame_data.detach().cpu().numpy()
86 |
87 | # Ensure correct shape and type
88 | if frame_data.shape[0] == 3: # CHW to HWC
89 | frame_data = frame_data.transpose(1, 2, 0)
90 | if frame_data.dtype != np.uint8:
91 | frame_data = (frame_data * 255).astype(np.uint8)
92 |
93 | # Flip vertically and update texture
94 | frame_data = np.ascontiguousarray(np.flipud(frame_data))
95 | self.texture = pyglet.image.ImageData(
96 | self.window.width, self.window.height,
97 | 'RGB', frame_data.tobytes()
98 | ).get_texture()
99 |
100 | # Update performance metrics
101 | self._update_perf_metrics()
102 |
103 | # Update UI
104 | self.label.text = f'Frame: {self.current_frame+1}/{self.total_frames} | FPS: {self.fps:.2f}'
105 | self.progress_bar.width = self.progress_bar_max_width * (self.current_frame+1)/self.total_frames
106 | self.current_frame += 1
107 |
108 | def _update_perf_metrics(self):
109 | """Calculate and update FPS metrics."""
110 | current_time = time.time()
111 | self.fps = 1.0 / max(0.001, current_time - self.last_time) # Avoid division by zero
112 | self.last_time = current_time
113 |
114 | def on_draw(self):
115 | """Window draw event handler."""
116 | self.window.clear()
117 | if self.texture:
118 | self.texture.blit(0, 0, width=self.window.width, height=self.window.height)
119 | self.batch.draw()
120 |
121 | def predict(X, centers):
122 | distances = np.sum((X[:, np.newaxis, :] - centers) ** 2,axis=2)
123 | labels = np.argmin(distances, axis=1)
124 | return labels
125 |
126 | def extract_features(Rt_list, trans=np.array([0.0, 0.0, 0.0]), scale=1.0):
127 | features = []
128 | for (R, t) in Rt_list:
129 | features.append(generate_features_from_Rt(R, t, trans, scale))
130 | return np.stack(features, axis=0)
131 |
132 | def render_set(model_path, views, gaussians, pipeline, background, train_test_exp, separate_sh, args):
133 | total_frame = args.frames
134 | load_seele = args.load_seele
135 | use_gui = args.use_gui
136 |
137 | # prepare the views
138 | poses = generate_ellipse_path(views, total_frame)
139 | Rt_list = [(pose[:3, :3].T, pose[:3, 3]) for pose in poses]
140 | w2c_list = [
141 | torch.tensor(getWorld2View2(Rt_list[frame][0], Rt_list[frame][1], views[0].trans, views[0].scale)).transpose(0, 1).cuda()
142 | for frame in range(total_frame)
143 | ]
144 |
145 | stream_manager, labels = None, None
146 | if load_seele:
147 | # Load cluster data
148 | cluster_data = joblib.load(os.path.join(model_path, "clusters", "clusters.pkl"))
149 | K = len(cluster_data["cluster_viewpoint"])
150 | cluster_centers = cluster_data["centers"]
151 |
152 | # Determine the test cluster labels
153 | test_features = extract_features(Rt_list, trans=views[0].trans, scale=views[0].scale)
154 | labels = predict(test_features, cluster_centers)
155 |
156 | # Load all Gaussians to CPU
157 | cluster_gaussians = [
158 | torch.load(os.path.join(model_path, f"clusters/finetune/point_cloud_{cid}.pth"), map_location="cpu")
159 | for cid in range(K)
160 | ]
161 |
162 | # Initialize stream manager
163 | stream_manager = GaussianStreamManager(
164 | cluster_gaussians=cluster_gaussians,
165 | initial_cid=labels[0]
166 | )
167 |
168 | # Warm up
169 | for _ in range(5):
170 | render(views[0], gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh)
171 |
172 | def render_view(frame):
173 | view = views[0]
174 | view.world_view_transform = w2c_list[frame]
175 | view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
176 | view.camera_center = view.world_view_transform.inverse()[3, :3]
177 |
178 | if load_seele:
179 | # Preload next frame's Gaussians
180 | if frame + 1 < total_frame:
181 | next_cid = labels[frame + 1]
182 | stream_manager.preload_next(next_cid)
183 |
184 | # Restore current Gaussians and render
185 | gaussians.restore_gaussians(stream_manager.get_current())
186 | rendering = render(
187 | view, gaussians, pipeline, background,
188 | use_trained_exp=train_test_exp,
189 | separate_sh=separate_sh,
190 | rasterizer_type="CR"
191 | )["render"]
192 |
193 | # Synchronize streams and switch buffers
194 | stream_manager.switch_gaussians()
195 | else:
196 | # Standard rendering path
197 | rendering = render(
198 | view, gaussians, pipeline, background,
199 | use_trained_exp=train_test_exp,
200 | separate_sh=separate_sh
201 | )["render"]
202 |
203 | return rendering
204 |
205 | if use_gui:
206 | # Initialize video player
207 | player = VideoPlayer(width=views[0].image_width, height=views[0].image_height, total_frames=total_frame)
208 |
209 | def update_frame(dt):
210 | """Callback function for frame updates."""
211 | nonlocal stream_manager, gaussians
212 |
213 | if player.current_frame >= args.frames - 1:
214 | pyglet.app.exit()
215 | return
216 |
217 | rendering = render_view(player.current_frame)
218 | # Update display
219 | player.update_frame(rendering)
220 |
221 | # Start rendering loop (target 500 FPS)
222 | pyglet.clock.schedule_interval(update_frame, 1/500.0)
223 | pyglet.app.run()
224 | else:
225 | output_dir = args.output_dir
226 | os.makedirs(output_dir, exist_ok=True)
227 | for frame_idx in range(total_frame):
228 | if load_seele:
229 | print(f"Rendering {frame_idx} image belong to cluster {labels[frame_idx]}")
230 | else:
231 | print(f"Rnedering {frame_idx} image")
232 | rendering = render_view(frame_idx)
233 | torchvision.utils.save_image(rendering, os.path.join(output_dir, '{0:05d}'.format(frame_idx) + ".png"))
234 |
235 | # clean up
236 | if stream_manager is not None:
237 | stream_manager.cleanup()
238 |
239 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, separate_sh: bool, args: ArgumentParser):
240 | with torch.no_grad():
241 | gaussians = GaussianModel(dataset.sh_degree)
242 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
243 |
244 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
245 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
246 |
247 | render_set(dataset.model_path, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args)
248 |
249 | # Example usage
250 | if __name__ == "__main__":
251 | # Set up command line argument parser
252 | parser = ArgumentParser(description="Testing script parameters")
253 | model = ModelParams(parser, sentinel=True)
254 | pipeline = PipelineParams(parser)
255 | parser.add_argument("--iteration", default=-1, type=int)
256 | parser.add_argument("--frames", default=200, type=int)
257 | parser.add_argument("--quiet", action="store_true")
258 | parser.add_argument("--load_seele", action="store_true")
259 | parser.add_argument("--use_gui", action="store_true")
260 | parser.add_argument('--output_dir', type=str, default="output/videos")
261 | args = get_combined_args(parser)
262 | print("Rendering " + args.model_path)
263 | # Initialize system state (RNG)
264 | safe_state(args.quiet)
265 |
266 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), SPARSE_ADAM_AVAILABLE, args)
267 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | submodules/seele-gaussian-rasterization
2 | submodules/simple-knn
3 | submodules/fused-ssim
4 |
5 | plyfile
6 | scikit-learn
7 | tqdm
8 | opencv-python
9 | joblib
10 | icecream
11 | pyglet
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 |
21 | class Scene:
22 |
23 | gaussians : GaussianModel
24 |
25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
26 | """b
27 | :param path: Path to colmap scene main folder.
28 | """
29 | self.model_path = args.model_path
30 | self.loaded_iter = None
31 | self.gaussians = gaussians
32 |
33 | if load_iteration:
34 | if load_iteration == -1:
35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
36 | else:
37 | self.loaded_iter = load_iteration
38 | print("Loading trained model at iteration {}".format(self.loaded_iter))
39 |
40 | self.train_cameras = {}
41 | self.test_cameras = {}
42 |
43 | if os.path.exists(os.path.join(args.source_path, "sparse")):
44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
45 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
46 | print("Found transforms_train.json file, assuming Blender data set!")
47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.depths, args.eval)
48 | else:
49 | assert False, "Could not recognize scene type!"
50 |
51 | if not self.loaded_iter:
52 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
53 | dest_file.write(src_file.read())
54 | json_cams = []
55 | camlist = []
56 | if scene_info.test_cameras:
57 | camlist.extend(scene_info.test_cameras)
58 | if scene_info.train_cameras:
59 | camlist.extend(scene_info.train_cameras)
60 | for id, cam in enumerate(camlist):
61 | json_cams.append(camera_to_JSON(id, cam))
62 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
63 | json.dump(json_cams, file)
64 |
65 | if shuffle:
66 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
67 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
68 |
69 | self.cameras_extent = scene_info.nerf_normalization["radius"]
70 |
71 | for resolution_scale in resolution_scales:
72 | print("Loading Training Cameras")
73 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, False)
74 | print("Loading Test Cameras")
75 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, True)
76 |
77 | if self.loaded_iter:
78 | self.gaussians.load_ply(os.path.join(self.model_path,
79 | "point_cloud",
80 | "iteration_" + str(self.loaded_iter),
81 | "point_cloud.ply"), args.train_test_exp)
82 | else:
83 | self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent)
84 |
85 | def save(self, iteration):
86 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
87 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
88 | exposure_dict = {
89 | image_name: self.gaussians.get_exposure_from_name(image_name).detach().cpu().numpy().tolist()
90 | for image_name in self.gaussians.exposure_mapping
91 | }
92 |
93 | with open(os.path.join(self.model_path, "exposure.json"), "w") as f:
94 | json.dump(exposure_dict, f, indent=2)
95 |
96 | def getTrainCameras(self, scale=1.0):
97 | return self.train_cameras[scale]
98 |
99 | def getTestCameras(self, scale=1.0):
100 | return self.test_cameras[scale]
101 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 | from utils.general_utils import PILtoTorch
17 | import cv2
18 |
19 | class Camera(nn.Module):
20 | def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, depth_params, image, invdepthmap,
21 | image_name, uid,
22 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda",
23 | train_test_exp = False, is_test_dataset = False, is_test_view = False
24 | ):
25 | super(Camera, self).__init__()
26 |
27 | self.uid = uid
28 | self.colmap_id = colmap_id
29 | self.R = R
30 | self.T = T
31 | self.FoVx = FoVx
32 | self.FoVy = FoVy
33 | self.image_name = image_name
34 |
35 | try:
36 | self.data_device = torch.device(data_device)
37 | except Exception as e:
38 | print(e)
39 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
40 | self.data_device = torch.device("cuda")
41 |
42 | resized_image_rgb = PILtoTorch(image, resolution)
43 | gt_image = resized_image_rgb[:3, ...]
44 | self.alpha_mask = None
45 | if resized_image_rgb.shape[0] == 4:
46 | self.alpha_mask = resized_image_rgb[3:4, ...].to(self.data_device)
47 | else:
48 | self.alpha_mask = torch.ones_like(resized_image_rgb[0:1, ...].to(self.data_device))
49 |
50 | if train_test_exp and is_test_view:
51 | if is_test_dataset:
52 | self.alpha_mask[..., :self.alpha_mask.shape[-1] // 2] = 0
53 | else:
54 | self.alpha_mask[..., self.alpha_mask.shape[-1] // 2:] = 0
55 |
56 | self.original_image = gt_image.clamp(0.0, 1.0).to(self.data_device)
57 | self.image_width = self.original_image.shape[2]
58 | self.image_height = self.original_image.shape[1]
59 |
60 | self.invdepthmap = None
61 | self.depth_reliable = False
62 | if invdepthmap is not None:
63 | self.depth_mask = torch.ones_like(self.alpha_mask)
64 | self.invdepthmap = cv2.resize(invdepthmap, resolution)
65 | self.invdepthmap[self.invdepthmap < 0] = 0
66 | self.depth_reliable = True
67 |
68 | if depth_params is not None:
69 | if depth_params["scale"] < 0.2 * depth_params["med_scale"] or depth_params["scale"] > 5 * depth_params["med_scale"]:
70 | self.depth_reliable = False
71 | self.depth_mask *= 0
72 |
73 | if depth_params["scale"] > 0:
74 | self.invdepthmap = self.invdepthmap * depth_params["scale"] + depth_params["offset"]
75 |
76 | if self.invdepthmap.ndim != 2:
77 | self.invdepthmap = self.invdepthmap[..., 0]
78 | self.invdepthmap = torch.from_numpy(self.invdepthmap[None]).to(self.data_device)
79 |
80 | self.zfar = 100.0
81 | self.znear = 0.01
82 |
83 | self.trans = trans
84 | self.scale = scale
85 |
86 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
87 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
88 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
89 | self.camera_center = self.world_view_transform.inverse()[3, :3]
90 |
91 | class MiniCam:
92 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
93 | self.image_width = width
94 | self.image_height = height
95 | self.FoVy = fovy
96 | self.FoVx = fovx
97 | self.znear = znear
98 | self.zfar = zfar
99 | self.world_view_transform = world_view_transform
100 | self.full_proj_transform = full_proj_transform
101 | view_inv = torch.inverse(self.world_view_transform)
102 | self.camera_center = view_inv[3][:3]
103 |
104 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54 |
55 | def rotmat2qvec(R):
56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57 | K = np.array([
58 | [Rxx - Ryy - Rzz, 0, 0, 0],
59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62 | eigvals, eigvecs = np.linalg.eigh(K)
63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64 | if qvec[0] < 0:
65 | qvec *= -1
66 | return qvec
67 |
68 | class Image(BaseImage):
69 | def qvec2rotmat(self):
70 | return qvec2rotmat(self.qvec)
71 |
72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73 | """Read and unpack the next bytes from a binary file.
74 | :param fid:
75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77 | :param endian_character: Any of {@, =, <, >, !}
78 | :return: Tuple of read and unpacked values.
79 | """
80 | data = fid.read(num_bytes)
81 | return struct.unpack(endian_character + format_char_sequence, data)
82 |
83 | def read_points3D_text(path):
84 | """
85 | see: src/base/reconstruction.cc
86 | void Reconstruction::ReadPoints3DText(const std::string& path)
87 | void Reconstruction::WritePoints3DText(const std::string& path)
88 | """
89 | xyzs = None
90 | rgbs = None
91 | errors = None
92 | num_points = 0
93 | with open(path, "r") as fid:
94 | while True:
95 | line = fid.readline()
96 | if not line:
97 | break
98 | line = line.strip()
99 | if len(line) > 0 and line[0] != "#":
100 | num_points += 1
101 |
102 |
103 | xyzs = np.empty((num_points, 3))
104 | rgbs = np.empty((num_points, 3))
105 | errors = np.empty((num_points, 1))
106 | count = 0
107 | with open(path, "r") as fid:
108 | while True:
109 | line = fid.readline()
110 | if not line:
111 | break
112 | line = line.strip()
113 | if len(line) > 0 and line[0] != "#":
114 | elems = line.split()
115 | xyz = np.array(tuple(map(float, elems[1:4])))
116 | rgb = np.array(tuple(map(int, elems[4:7])))
117 | error = np.array(float(elems[7]))
118 | xyzs[count] = xyz
119 | rgbs[count] = rgb
120 | errors[count] = error
121 | count += 1
122 |
123 | return xyzs, rgbs, errors
124 |
125 | def read_points3D_binary(path_to_model_file):
126 | """
127 | see: src/base/reconstruction.cc
128 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
129 | void Reconstruction::WritePoints3DBinary(const std::string& path)
130 | """
131 |
132 |
133 | with open(path_to_model_file, "rb") as fid:
134 | num_points = read_next_bytes(fid, 8, "Q")[0]
135 |
136 | xyzs = np.empty((num_points, 3))
137 | rgbs = np.empty((num_points, 3))
138 | errors = np.empty((num_points, 1))
139 |
140 | for p_id in range(num_points):
141 | binary_point_line_properties = read_next_bytes(
142 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
143 | xyz = np.array(binary_point_line_properties[1:4])
144 | rgb = np.array(binary_point_line_properties[4:7])
145 | error = np.array(binary_point_line_properties[7])
146 | track_length = read_next_bytes(
147 | fid, num_bytes=8, format_char_sequence="Q")[0]
148 | track_elems = read_next_bytes(
149 | fid, num_bytes=8*track_length,
150 | format_char_sequence="ii"*track_length)
151 | xyzs[p_id] = xyz
152 | rgbs[p_id] = rgb
153 | errors[p_id] = error
154 | return xyzs, rgbs, errors
155 |
156 | def read_intrinsics_text(path):
157 | """
158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159 | """
160 | cameras = {}
161 | with open(path, "r") as fid:
162 | while True:
163 | line = fid.readline()
164 | if not line:
165 | break
166 | line = line.strip()
167 | if len(line) > 0 and line[0] != "#":
168 | elems = line.split()
169 | camera_id = int(elems[0])
170 | model = elems[1]
171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172 | width = int(elems[2])
173 | height = int(elems[3])
174 | params = np.array(tuple(map(float, elems[4:])))
175 | cameras[camera_id] = Camera(id=camera_id, model=model,
176 | width=width, height=height,
177 | params=params)
178 | return cameras
179 |
180 | def read_extrinsics_binary(path_to_model_file):
181 | """
182 | see: src/base/reconstruction.cc
183 | void Reconstruction::ReadImagesBinary(const std::string& path)
184 | void Reconstruction::WriteImagesBinary(const std::string& path)
185 | """
186 | images = {}
187 | with open(path_to_model_file, "rb") as fid:
188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189 | for _ in range(num_reg_images):
190 | binary_image_properties = read_next_bytes(
191 | fid, num_bytes=64, format_char_sequence="idddddddi")
192 | image_id = binary_image_properties[0]
193 | qvec = np.array(binary_image_properties[1:5])
194 | tvec = np.array(binary_image_properties[5:8])
195 | camera_id = binary_image_properties[8]
196 | image_name = ""
197 | current_char = read_next_bytes(fid, 1, "c")[0]
198 | while current_char != b"\x00": # look for the ASCII 0 entry
199 | image_name += current_char.decode("utf-8")
200 | current_char = read_next_bytes(fid, 1, "c")[0]
201 | num_points2D = read_next_bytes(fid, num_bytes=8,
202 | format_char_sequence="Q")[0]
203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204 | format_char_sequence="ddq"*num_points2D)
205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206 | tuple(map(float, x_y_id_s[1::3]))])
207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208 | images[image_id] = Image(
209 | id=image_id, qvec=qvec, tvec=tvec,
210 | camera_id=camera_id, name=image_name,
211 | xys=xys, point3D_ids=point3D_ids)
212 | return images
213 |
214 |
215 | def read_intrinsics_binary(path_to_model_file):
216 | """
217 | see: src/base/reconstruction.cc
218 | void Reconstruction::WriteCamerasBinary(const std::string& path)
219 | void Reconstruction::ReadCamerasBinary(const std::string& path)
220 | """
221 | cameras = {}
222 | with open(path_to_model_file, "rb") as fid:
223 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
224 | for _ in range(num_cameras):
225 | camera_properties = read_next_bytes(
226 | fid, num_bytes=24, format_char_sequence="iiQQ")
227 | camera_id = camera_properties[0]
228 | model_id = camera_properties[1]
229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230 | width = camera_properties[2]
231 | height = camera_properties[3]
232 | num_params = CAMERA_MODEL_IDS[model_id].num_params
233 | params = read_next_bytes(fid, num_bytes=8*num_params,
234 | format_char_sequence="d"*num_params)
235 | cameras[camera_id] = Camera(id=camera_id,
236 | model=model_name,
237 | width=width,
238 | height=height,
239 | params=np.array(params))
240 | assert len(cameras) == num_cameras
241 | return cameras
242 |
243 |
244 | def read_extrinsics_text(path):
245 | """
246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247 | """
248 | images = {}
249 | with open(path, "r") as fid:
250 | while True:
251 | line = fid.readline()
252 | if not line:
253 | break
254 | line = line.strip()
255 | if len(line) > 0 and line[0] != "#":
256 | elems = line.split()
257 | image_id = int(elems[0])
258 | qvec = np.array(tuple(map(float, elems[1:5])))
259 | tvec = np.array(tuple(map(float, elems[5:8])))
260 | camera_id = int(elems[8])
261 | image_name = elems[9]
262 | elems = fid.readline().split()
263 | xys = np.column_stack([tuple(map(float, elems[0::3])),
264 | tuple(map(float, elems[1::3]))])
265 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
266 | images[image_id] = Image(
267 | id=image_id, qvec=qvec, tvec=tvec,
268 | camera_id=camera_id, name=image_name,
269 | xys=xys, point3D_ids=point3D_ids)
270 | return images
271 |
272 |
273 | def read_colmap_bin_array(path):
274 | """
275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276 |
277 | :param path: path to the colmap binary file.
278 | :return: nd array with the floating point values in the value
279 | """
280 | with open(path, "rb") as fid:
281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282 | usecols=(0, 1, 2), dtype=int)
283 | fid.seek(0)
284 | num_delimiter = 0
285 | byte = fid.read(1)
286 | while True:
287 | if byte == b"&":
288 | num_delimiter += 1
289 | if num_delimiter >= 3:
290 | break
291 | byte = fid.read(1)
292 | array = np.fromfile(fid, np.float32)
293 | array = array.reshape((width, height, channels), order="F")
294 | return np.transpose(array, (1, 0, 2)).squeeze()
295 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from typing import NamedTuple
16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19 | import numpy as np
20 | import json
21 | from pathlib import Path
22 | from plyfile import PlyData, PlyElement
23 | from utils.sh_utils import SH2RGB
24 | from scene.gaussian_model import BasicPointCloud
25 |
26 | class CameraInfo(NamedTuple):
27 | uid: int
28 | R: np.array
29 | T: np.array
30 | FovY: np.array
31 | FovX: np.array
32 | depth_params: dict
33 | image_path: str
34 | image_name: str
35 | depth_path: str
36 | width: int
37 | height: int
38 | is_test: bool
39 |
40 | class SceneInfo(NamedTuple):
41 | point_cloud: BasicPointCloud
42 | train_cameras: list
43 | test_cameras: list
44 | nerf_normalization: dict
45 | ply_path: str
46 | is_nerf_synthetic: bool
47 |
48 | def getNerfppNorm(cam_info):
49 | def get_center_and_diag(cam_centers):
50 | cam_centers = np.hstack(cam_centers)
51 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
52 | center = avg_cam_center
53 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
54 | diagonal = np.max(dist)
55 | return center.flatten(), diagonal
56 |
57 | cam_centers = []
58 |
59 | for cam in cam_info:
60 | W2C = getWorld2View2(cam.R, cam.T)
61 | C2W = np.linalg.inv(W2C)
62 | cam_centers.append(C2W[:3, 3:4])
63 |
64 | center, diagonal = get_center_and_diag(cam_centers)
65 | radius = diagonal * 1.1
66 |
67 | translate = -center
68 |
69 | return {"translate": translate, "radius": radius}
70 |
71 | def readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_folder, depths_folder, test_cam_names_list):
72 | cam_infos = []
73 | for idx, key in enumerate(cam_extrinsics):
74 | sys.stdout.write('\r')
75 | # the exact output you're looking for:
76 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
77 | sys.stdout.flush()
78 |
79 | extr = cam_extrinsics[key]
80 | intr = cam_intrinsics[extr.camera_id]
81 | height = intr.height
82 | width = intr.width
83 |
84 | uid = intr.id
85 | R = np.transpose(qvec2rotmat(extr.qvec))
86 | T = np.array(extr.tvec)
87 |
88 | if intr.model=="SIMPLE_PINHOLE":
89 | focal_length_x = intr.params[0]
90 | FovY = focal2fov(focal_length_x, height)
91 | FovX = focal2fov(focal_length_x, width)
92 | elif intr.model=="PINHOLE":
93 | focal_length_x = intr.params[0]
94 | focal_length_y = intr.params[1]
95 | FovY = focal2fov(focal_length_y, height)
96 | FovX = focal2fov(focal_length_x, width)
97 | else:
98 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
99 |
100 | n_remove = len(extr.name.split('.')[-1]) + 1
101 | depth_params = None
102 | if depths_params is not None:
103 | try:
104 | depth_params = depths_params[extr.name[:-n_remove]]
105 | except:
106 | print("\n", key, "not found in depths_params")
107 |
108 | image_path = os.path.join(images_folder, extr.name)
109 | image_name = extr.name
110 | depth_path = os.path.join(depths_folder, f"{extr.name[:-n_remove]}.png") if depths_folder != "" else ""
111 |
112 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, depth_params=depth_params,
113 | image_path=image_path, image_name=image_name, depth_path=depth_path,
114 | width=width, height=height, is_test=image_name in test_cam_names_list)
115 | cam_infos.append(cam_info)
116 |
117 | sys.stdout.write('\n')
118 | return cam_infos
119 |
120 | def fetchPly(path):
121 | plydata = PlyData.read(path)
122 | vertices = plydata['vertex']
123 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
124 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
125 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
126 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
127 |
128 | def storePly(path, xyz, rgb):
129 | # Define the dtype for the structured array
130 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
131 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
132 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
133 |
134 | normals = np.zeros_like(xyz)
135 |
136 | elements = np.empty(xyz.shape[0], dtype=dtype)
137 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
138 | elements[:] = list(map(tuple, attributes))
139 |
140 | # Create the PlyData object and write to file
141 | vertex_element = PlyElement.describe(elements, 'vertex')
142 | ply_data = PlyData([vertex_element])
143 | ply_data.write(path)
144 |
145 | def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8):
146 | try:
147 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
148 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
149 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
150 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
151 | except:
152 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
153 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
154 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
155 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
156 |
157 | depth_params_file = os.path.join(path, "sparse/0", "depth_params.json")
158 | ## if depth_params_file isnt there AND depths file is here -> throw error
159 | depths_params = None
160 | if depths != "":
161 | try:
162 | with open(depth_params_file, "r") as f:
163 | depths_params = json.load(f)
164 | all_scales = np.array([depths_params[key]["scale"] for key in depths_params])
165 | if (all_scales > 0).sum():
166 | med_scale = np.median(all_scales[all_scales > 0])
167 | else:
168 | med_scale = 0
169 | for key in depths_params:
170 | depths_params[key]["med_scale"] = med_scale
171 |
172 | except FileNotFoundError:
173 | print(f"Error: depth_params.json file not found at path '{depth_params_file}'.")
174 | sys.exit(1)
175 | except Exception as e:
176 | print(f"An unexpected error occurred when trying to open depth_params.json file: {e}")
177 | sys.exit(1)
178 |
179 | if eval:
180 | if "360" in path:
181 | llffhold = 8
182 | if llffhold:
183 | print("------------LLFF HOLD-------------")
184 | cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics]
185 | cam_names = sorted(cam_names)
186 | test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0]
187 | else:
188 | with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file:
189 | test_cam_names_list = [line.strip() for line in file]
190 | else:
191 | test_cam_names_list = []
192 |
193 | reading_dir = "images" if images == None else images
194 | cam_infos_unsorted = readColmapCameras(
195 | cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, depths_params=depths_params,
196 | images_folder=os.path.join(path, reading_dir),
197 | depths_folder=os.path.join(path, depths) if depths != "" else "", test_cam_names_list=test_cam_names_list)
198 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
199 |
200 | train_cam_infos = [c for c in cam_infos if train_test_exp or not c.is_test]
201 | test_cam_infos = [c for c in cam_infos if c.is_test]
202 |
203 | nerf_normalization = getNerfppNorm(train_cam_infos)
204 |
205 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
206 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
207 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
208 | if not os.path.exists(ply_path):
209 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
210 | try:
211 | xyz, rgb, _ = read_points3D_binary(bin_path)
212 | except:
213 | xyz, rgb, _ = read_points3D_text(txt_path)
214 | storePly(ply_path, xyz, rgb)
215 | try:
216 | pcd = fetchPly(ply_path)
217 | except:
218 | pcd = None
219 |
220 | scene_info = SceneInfo(point_cloud=pcd,
221 | train_cameras=train_cam_infos,
222 | test_cameras=test_cam_infos,
223 | nerf_normalization=nerf_normalization,
224 | ply_path=ply_path,
225 | is_nerf_synthetic=False)
226 | return scene_info
227 |
228 | def readCamerasFromTransforms(path, transformsfile, depths_folder, white_background, is_test, extension=".png"):
229 | cam_infos = []
230 |
231 | with open(os.path.join(path, transformsfile)) as json_file:
232 | contents = json.load(json_file)
233 | fovx = contents["camera_angle_x"]
234 |
235 | frames = contents["frames"]
236 | for idx, frame in enumerate(frames):
237 | cam_name = os.path.join(path, frame["file_path"] + extension)
238 |
239 | # NeRF 'transform_matrix' is a camera-to-world transform
240 | c2w = np.array(frame["transform_matrix"])
241 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
242 | c2w[:3, 1:3] *= -1
243 |
244 | # get the world-to-camera transform and set R, T
245 | w2c = np.linalg.inv(c2w)
246 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
247 | T = w2c[:3, 3]
248 |
249 | image_path = os.path.join(path, cam_name)
250 | image_name = Path(cam_name).stem
251 | image = Image.open(image_path)
252 |
253 | im_data = np.array(image.convert("RGBA"))
254 |
255 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
256 |
257 | norm_data = im_data / 255.0
258 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
259 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
260 |
261 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
262 | FovY = fovy
263 | FovX = fovx
264 |
265 | depth_path = os.path.join(depths_folder, f"{image_name}.png") if depths_folder != "" else ""
266 |
267 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,
268 | image_path=image_path, image_name=image_name,
269 | width=image.size[0], height=image.size[1], depth_path=depth_path, depth_params=None, is_test=is_test))
270 |
271 | return cam_infos
272 |
273 | def readNerfSyntheticInfo(path, white_background, depths, eval, extension=".png"):
274 |
275 | depths_folder=os.path.join(path, depths) if depths != "" else ""
276 | print("Reading Training Transforms")
277 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", depths_folder, white_background, False, extension)
278 | print("Reading Test Transforms")
279 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", depths_folder, white_background, True, extension)
280 |
281 | if not eval:
282 | train_cam_infos.extend(test_cam_infos)
283 | test_cam_infos = []
284 |
285 | nerf_normalization = getNerfppNorm(train_cam_infos)
286 |
287 | ply_path = os.path.join(path, "points3d.ply")
288 | if not os.path.exists(ply_path):
289 | # Since this data set has no colmap data, we start with random points
290 | num_pts = 100_000
291 | print(f"Generating random point cloud ({num_pts})...")
292 |
293 | # We create random points inside the bounds of the synthetic Blender scenes
294 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
295 | shs = np.random.random((num_pts, 3)) / 255.0
296 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
297 |
298 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
299 | try:
300 | pcd = fetchPly(ply_path)
301 | except:
302 | pcd = None
303 |
304 | scene_info = SceneInfo(point_cloud=pcd,
305 | train_cameras=train_cam_infos,
306 | test_cameras=test_cam_infos,
307 | nerf_normalization=nerf_normalization,
308 | ply_path=ply_path,
309 | is_nerf_synthetic=True)
310 | return scene_info
311 |
312 | sceneLoadTypeCallbacks = {
313 | "Colmap": readColmapSceneInfo,
314 | "Blender" : readNerfSyntheticInfo
315 | }
--------------------------------------------------------------------------------
/scripts/generate_cluster.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Function to get an available GPU with memory usage below the threshold
3 | get_available_gpu() {
4 | local mem_threshold=25000
5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \
6 | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # List of dataset names
12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson")
13 | datasets=("counter") # Replace with your actual dataset names
14 |
15 | # Path to models
16 | model_base_path="output/seele" # PATH TO YOUR MODELS
17 |
18 | # Iterate over each dataset
19 | for dataset_name in "${datasets[@]}"; do
20 | echo "Processing dataset: $dataset_name"
21 |
22 | # Find an available GPU
23 | while true; do
24 | available_gpu=$(get_available_gpu)
25 | if [ -z "$available_gpu" ]; then
26 | echo "No GPU available with memory usage below threshold. Waiting..."
27 | sleep 60
28 | continue
29 | fi
30 |
31 | echo "Using GPU: $available_gpu"
32 | # Run the Python script with the selected GPU
33 | CUDA_VISIBLE_DEVICES="$available_gpu" python generate_cluster.py -m "$model_base_path/$dataset_name"
34 | break
35 | done
36 | done
37 |
38 | # Completion signal
39 | echo "All datasets processed. Task complete."
--------------------------------------------------------------------------------
/scripts/run_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET
4 | output_base_path="output/seele" # PATH TO YOUR OUTPUT
5 |
6 | datasets=("counter") # Replace with your actual dataset names
7 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson")
8 |
9 | for dataset in "${datasets[@]}"; do
10 | model_path="$output_base_path/$dataset"
11 | dataset_path="$dataset_base_path/$dataset"
12 |
13 | echo "Train dataset: $dataset"
14 | python3 train.py -m $model_path -s $dataset_path --eval
15 |
16 | echo "Generate clusters for dataset: $dataset"
17 | if [[ "$dataset" == "playroom" || "$dataset" == "drjohnson" ]]; then
18 | python3 generate_cluster.py -m $model_path -n 8
19 | else
20 | python3 generate_cluster.py -m $model_path -n 4
21 | fi
22 |
23 | echo "Finetune dataset: $dataset"
24 | python3 finetune.py \
25 | -s $dataset_path \
26 | -m $model_path \
27 | --start_checkpoint "$model_path/chkpnt30000.pth" \
28 | --eval \
29 | --iterations 31_000
30 |
31 | echo "Render dataset: $dataset"
32 | python3 seele_render.py -m $model_path -s $dataset_path --eval --load_finetune --save_image --debug
33 |
34 | echo "Metrics for dataset: $dataset"
35 | python3 metrics.py -m $model_path
36 | done
37 |
--------------------------------------------------------------------------------
/scripts/run_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Function to get an available GPU with memory usage below the threshold
3 | get_available_gpu() {
4 | local mem_threshold=25000
5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \
6 | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # List of dataset names
12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson")
13 | datasets=("counter") # Replace with your actual dataset names
14 |
15 | # Path to models
16 | model_base_path="output/seele" # PATH TO YOUR MODELS
17 |
18 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET
19 | port=6035
20 |
21 | # Iterate over each dataset
22 | for dataset_name in "${datasets[@]}"; do
23 | echo "Processing dataset: $dataset_name"
24 |
25 | # Find an available GPU
26 | while true; do
27 | available_gpu=$(get_available_gpu)
28 | if [ -z "$available_gpu" ]; then
29 | echo "No GPU available with memory usage below threshold. Waiting..."
30 | sleep 60
31 | continue
32 | fi
33 |
34 | echo "Using GPU: $available_gpu"
35 | # Run the Python script with the selected GPU
36 | CUDA_VISIBLE_DEVICES="$available_gpu" python finetune.py \
37 | -s "$dataset_base_path/$dataset_name" \
38 | -m "$model_base_path/$dataset_name" \
39 | --start_checkpoint "$model_base_path/$dataset_name/chkpnt30000.pth" \
40 | --eval \
41 | --iterations 31_000
42 | break
43 | done
44 | done
45 |
46 | # Wait for all background processes to finish
47 | wait
48 |
49 | # Completion signal
50 | echo "All datasets processed. Task complete."
--------------------------------------------------------------------------------
/scripts/run_render.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Function to get an available GPU with memory usage below the threshold
3 | get_available_gpu() {
4 | local mem_threshold=25000
5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \
6 | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # List of dataset names
12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson")
13 | datasets=("counter") # Replace with your actual dataset names
14 |
15 | # Path to models
16 | model_base_path="output/seele" # PATH TO YOUR MODELS
17 |
18 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET
19 |
20 | # Iterate over each dataset
21 | for dataset_name in "${datasets[@]}"; do
22 | echo "Processing dataset: $dataset_name"
23 |
24 | # Find an available GPU
25 | while true; do
26 | available_gpu=$(get_available_gpu)
27 | if [ -z "$available_gpu" ]; then
28 | echo "No GPU available with memory usage below threshold. Waiting..."
29 | sleep 60
30 | continue
31 | fi
32 |
33 | echo "Using GPU: $available_gpu"
34 | # Run the render.py script with the selected GPU
35 | CUDA_VISIBLE_DEVICES="$available_gpu" python render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train
36 | # Run the metrics.py script and append the output to the same log file
37 | python3 metrics.py -m "$model_base_path/$dataset_name"
38 | break
39 | done
40 | done
41 |
42 | # Completion signal
43 | echo "All datasets processed. Task complete."
--------------------------------------------------------------------------------
/scripts/run_seele_render.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Function to get an available GPU with memory usage below the threshold
3 | get_available_gpu() {
4 | local mem_threshold=25000
5 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \
6 | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # List of dataset names
12 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson")
13 | datasets=("counter") # Replace with your actual dataset names
14 |
15 | # Path to models
16 | model_base_path="output/seele" # PATH TO YOUR MODELS
17 |
18 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET
19 |
20 | # Setting for load_finetune
21 | load_finetune=true # Set to true or false based on your requirement
22 |
23 | # Iterate over each dataset
24 | for dataset_name in "${datasets[@]}"; do
25 | echo "Processing dataset: $dataset_name"
26 |
27 | # Find an available GPU
28 | while true; do
29 | available_gpu=$(get_available_gpu)
30 | if [ -z "$available_gpu" ]; then
31 | echo "No GPU available with memory usage below threshold. Waiting..."
32 | sleep 60
33 | continue
34 | fi
35 |
36 | echo "Using GPU: $available_gpu"
37 | echo "load_finetune: $load_finetune"
38 | if [ "$load_finetune" = true ]; then
39 | CUDA_VISIBLE_DEVICES="$available_gpu" python3 seele_render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train --load_finetune --save_image
40 | else
41 | CUDA_VISIBLE_DEVICES="$available_gpu" python3 seele_render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train --save_image
42 | fi
43 |
44 | # Run the metrics.py script and append the output to the same log file
45 | python3 metrics.py -m "$model_base_path/$dataset_name"
46 | break
47 | done
48 | done
49 |
50 | # Completion signal
51 | echo "All datasets processed. Task complete."
--------------------------------------------------------------------------------
/scripts/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # List of dataset names
4 | # datasets=("bicycle" "bonsai" "counter" "flowers" "garden" "kitchen" "room" "stump" "treehill" "train" "truck" "playroom" "drjohnson")
5 | datasets=("counter") # Replace with your actual dataset names
6 |
7 | # Path to models
8 | model_base_path="output/seele" # PATH TO YOUR MODELS
9 |
10 | dataset_base_path="dataset/seele" # PATH TO YOUR DATASET
11 |
12 | # Iterate over each dataset
13 | for dataset_name in "${datasets[@]}"; do
14 | echo "Processing dataset: $dataset_name"
15 | python3 train.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --eval
16 | echo "Test:"
17 | python3 render.py -m "$model_base_path/$dataset_name" -s "$dataset_base_path/$dataset_name" --skip_train --eval
18 | python3 metrics.py -m "$model_base_path/$dataset_name"
19 | done
20 |
21 | # Completion signal
22 | echo "All datasets processed. Task complete."
--------------------------------------------------------------------------------
/seele_render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import numpy as np
12 | import joblib
13 | import torch
14 | from scene import Scene
15 | import os
16 | from tqdm import tqdm
17 | from os import makedirs
18 | from gaussian_renderer import render
19 | import torchvision
20 | from utils.general_utils import safe_state
21 | from argparse import ArgumentParser
22 | from arguments import ModelParams, PipelineParams, get_combined_args
23 | # from gaussian_renderer import GaussianModel
24 | from gaussian_renderer import GaussianModel
25 | try:
26 | from diff_gaussian_rasterization import SparseGaussianAdam
27 | SPARSE_ADAM_AVAILABLE = True
28 | except:
29 | SPARSE_ADAM_AVAILABLE = False
30 |
31 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp, separate_sh, args):
32 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
33 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
34 |
35 | cluster_data = joblib.load(os.path.join(model_path, "clusters", "clusters.pkl"))
36 | K = len(cluster_data["cluster_viewpoint"])
37 |
38 | if args.load_finetune:
39 | cluster_gaussians = [torch.load(os.path.join(model_path, f"clusters/finetune/point_cloud_{cid}.pth")) for cid in range(K)]
40 | cluster_gaussians = [tuple(map(lambda x: x.cuda(), data)) for data in cluster_gaussians]
41 | else:
42 | global_gaussians = gaussians.capture_gaussians()
43 | cluster_gaussian_ids = []
44 | for (gaussian_ids, lens) in cluster_data["cluster_gaussians"]:
45 | gaussian_ids = torch.tensor(gaussian_ids).cuda()
46 | cluster_gaussian_ids.append((gaussian_ids, lens))
47 | labels = cluster_data[f"{name}_labels"]
48 |
49 | makedirs(render_path, exist_ok=True)
50 | makedirs(gts_path, exist_ok=True)
51 |
52 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
53 | if args.load_finetune:
54 | gaussians.restore_gaussians(cluster_gaussians[labels[idx]])
55 | else:
56 | gaussians.restore_gaussians(global_gaussians, cluster_gaussian_ids[labels[idx]])
57 | rendering = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp, separate_sh=separate_sh, rasterizer_type="CR")["render"]
58 | gt = view.original_image[0:3, :, :]
59 |
60 | if args.train_test_exp:
61 | rendering = rendering[..., rendering.shape[-1] // 2:]
62 | gt = gt[..., gt.shape[-1] // 2:]
63 |
64 | if args.save_image:
65 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
66 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
67 |
68 | if not args.load_finetune:
69 | gaussians.restore_gaussians(global_gaussians)
70 |
71 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, separate_sh: bool, args: ArgumentParser):
72 | with torch.no_grad():
73 | gaussians = GaussianModel(dataset.sh_degree)
74 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
75 |
76 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
77 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
78 |
79 | if not skip_train:
80 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args)
81 |
82 | if not skip_test:
83 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp, separate_sh, args)
84 |
85 | if __name__ == "__main__":
86 | # Set up command line argument parser
87 | parser = ArgumentParser(description="Testing script parameters")
88 | model = ModelParams(parser, sentinel=True)
89 | pipeline = PipelineParams(parser)
90 | parser.add_argument("--iteration", default=-1, type=int)
91 | parser.add_argument("--skip_train", action="store_true")
92 | parser.add_argument("--skip_test", action="store_true")
93 | parser.add_argument("--quiet", action="store_true")
94 | parser.add_argument("--load_finetune", action="store_true")
95 | parser.add_argument("--save_image", action="store_true")
96 | args = get_combined_args(parser)
97 | args.depths = ""
98 | args.train_test_exp = False
99 | print("Rendering " + args.model_path)
100 | # Initialize system state (RNG)
101 | safe_state(args.quiet)
102 |
103 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, SPARSE_ADAM_AVAILABLE, args)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from random import randint
15 | from utils.loss_utils import l1_loss, ssim
16 | from gaussian_renderer import render, network_gui
17 | import sys
18 | from scene import Scene, GaussianModel
19 | from utils.general_utils import safe_state, get_expon_lr_func
20 | import uuid
21 | from tqdm import tqdm
22 | from utils.image_utils import psnr
23 | from argparse import ArgumentParser, Namespace
24 | from arguments import ModelParams, PipelineParams, OptimizationParams
25 | try:
26 | from torch.utils.tensorboard import SummaryWriter
27 | TENSORBOARD_FOUND = True
28 | except ImportError:
29 | TENSORBOARD_FOUND = False
30 |
31 | try:
32 | from fused_ssim import fused_ssim
33 | FUSED_SSIM_AVAILABLE = True
34 | except:
35 | FUSED_SSIM_AVAILABLE = False
36 |
37 | try:
38 | from diff_gaussian_rasterization import SparseGaussianAdam
39 | SPARSE_ADAM_AVAILABLE = True
40 | except:
41 | SPARSE_ADAM_AVAILABLE = False
42 |
43 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
44 |
45 | if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":
46 | sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
47 |
48 | first_iter = 0
49 | tb_writer = prepare_output_and_logger(dataset)
50 | gaussians = GaussianModel(dataset.sh_degree, opt.optimizer_type)
51 | scene = Scene(dataset, gaussians)
52 | gaussians.training_setup(opt)
53 | if checkpoint:
54 | (model_params, first_iter) = torch.load(checkpoint)
55 | gaussians.restore(model_params, opt)
56 |
57 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
58 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
59 |
60 | iter_start = torch.cuda.Event(enable_timing = True)
61 | iter_end = torch.cuda.Event(enable_timing = True)
62 |
63 | use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE
64 | depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
65 |
66 | viewpoint_stack = scene.getTrainCameras().copy()
67 | viewpoint_indices = list(range(len(viewpoint_stack)))
68 | ema_loss_for_log = 0.0
69 | ema_Ll1depth_for_log = 0.0
70 |
71 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
72 | first_iter += 1
73 | for iteration in range(first_iter, opt.iterations + 1):
74 | if network_gui.conn == None:
75 | network_gui.try_connect()
76 | while network_gui.conn != None:
77 | try:
78 | net_image_bytes = None
79 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
80 | if custom_cam != None:
81 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifier=scaling_modifer, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE)["render"]
82 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
83 | network_gui.send(net_image_bytes, dataset.source_path)
84 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
85 | break
86 | except Exception as e:
87 | network_gui.conn = None
88 |
89 | iter_start.record()
90 |
91 | gaussians.update_learning_rate(iteration)
92 |
93 | # Every 1000 its we increase the levels of SH up to a maximum degree
94 | if iteration % 1000 == 0:
95 | gaussians.oneupSHdegree()
96 |
97 | # Pick a random Camera
98 | if not viewpoint_stack:
99 | viewpoint_stack = scene.getTrainCameras().copy()
100 | viewpoint_indices = list(range(len(viewpoint_stack)))
101 | rand_idx = randint(0, len(viewpoint_indices) - 1)
102 | viewpoint_cam = viewpoint_stack.pop(rand_idx)
103 | vind = viewpoint_indices.pop(rand_idx)
104 |
105 | # Render
106 | if (iteration - 1) == debug_from:
107 | pipe.debug = True
108 |
109 | bg = torch.rand((3), device="cuda") if opt.random_background else background
110 |
111 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE)
112 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
113 |
114 | # if viewpoint_cam.alpha_mask is not None:
115 | # alpha_mask = viewpoint_cam.alpha_mask.cuda()
116 | # image *= alpha_mask
117 |
118 | # Loss
119 | gt_image = viewpoint_cam.original_image.cuda()
120 | Ll1 = l1_loss(image, gt_image)
121 | if FUSED_SSIM_AVAILABLE:
122 | ssim_value = fused_ssim(image.unsqueeze(0), gt_image.unsqueeze(0))
123 | else:
124 | ssim_value = ssim(image, gt_image)
125 |
126 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)
127 |
128 | # Depth regularization
129 | Ll1depth_pure = 0.0
130 | if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable:
131 | invDepth = render_pkg["depth"]
132 | mono_invdepth = viewpoint_cam.invdepthmap.cuda()
133 | depth_mask = viewpoint_cam.depth_mask.cuda()
134 |
135 | Ll1depth_pure = torch.abs((invDepth - mono_invdepth) * depth_mask).mean()
136 | Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure
137 | loss += Ll1depth
138 | Ll1depth = Ll1depth.item()
139 | else:
140 | Ll1depth = 0
141 |
142 | loss.backward()
143 |
144 | iter_end.record()
145 |
146 | with torch.no_grad():
147 | # Progress bar
148 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
149 | ema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_log
150 |
151 | if iteration % 10 == 0:
152 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}"})
153 | progress_bar.update(10)
154 | if iteration == opt.iterations:
155 | progress_bar.close()
156 |
157 | # Log and save
158 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp)
159 | if (iteration in saving_iterations):
160 | print("\n[ITER {}] Saving Gaussians".format(iteration))
161 | scene.save(iteration)
162 |
163 | # Densification
164 | if iteration < opt.densify_until_iter:
165 | # Keep track of max radii in image-space for pruning
166 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
167 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
168 |
169 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
170 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
171 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
172 |
173 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
174 | gaussians.reset_opacity()
175 |
176 | # Optimizer step
177 | if iteration < opt.iterations:
178 | gaussians.exposure_optimizer.step()
179 | gaussians.exposure_optimizer.zero_grad(set_to_none = True)
180 | if use_sparse_adam:
181 | visible = radii > 0
182 | gaussians.optimizer.step(visible, radii.shape[0])
183 | gaussians.optimizer.zero_grad(set_to_none = True)
184 | else:
185 | gaussians.optimizer.step()
186 | gaussians.optimizer.zero_grad(set_to_none = True)
187 |
188 | if (iteration in checkpoint_iterations):
189 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
190 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
191 |
192 | def prepare_output_and_logger(args):
193 | if not args.model_path:
194 | if os.getenv('OAR_JOB_ID'):
195 | unique_str=os.getenv('OAR_JOB_ID')
196 | else:
197 | unique_str = str(uuid.uuid4())
198 | args.model_path = os.path.join("./output/", unique_str[0:10])
199 |
200 | # Set up output folder
201 | print("Output folder: {}".format(args.model_path))
202 | os.makedirs(args.model_path, exist_ok = True)
203 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
204 | cfg_log_f.write(str(Namespace(**vars(args))))
205 |
206 | # Create Tensorboard writer
207 | tb_writer = None
208 | if TENSORBOARD_FOUND:
209 | tb_writer = SummaryWriter(args.model_path)
210 | else:
211 | print("Tensorboard not available: not logging progress")
212 | return tb_writer
213 |
214 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp):
215 | if tb_writer:
216 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
217 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
218 | tb_writer.add_scalar('iter_time', elapsed, iteration)
219 |
220 | # Report test and samples of training set
221 | if iteration in testing_iterations:
222 | torch.cuda.empty_cache()
223 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
224 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
225 |
226 | for config in validation_configs:
227 | if config['cameras'] and len(config['cameras']) > 0:
228 | l1_test = 0.0
229 | psnr_test = 0.0
230 | for idx, viewpoint in enumerate(config['cameras']):
231 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
232 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
233 | if train_test_exp:
234 | image = image[..., image.shape[-1] // 2:]
235 | gt_image = gt_image[..., gt_image.shape[-1] // 2:]
236 | if tb_writer and (idx < 5):
237 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
238 | if iteration == testing_iterations[0]:
239 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
240 | l1_test += l1_loss(image, gt_image).mean().double()
241 | psnr_test += psnr(image, gt_image).mean().double()
242 | psnr_test /= len(config['cameras'])
243 | l1_test /= len(config['cameras'])
244 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
245 | if tb_writer:
246 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
247 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
248 |
249 | if tb_writer:
250 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
251 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
252 | torch.cuda.empty_cache()
253 |
254 | if __name__ == "__main__":
255 | # Set up command line argument parser
256 | parser = ArgumentParser(description="Training script parameters")
257 | lp = ModelParams(parser)
258 | op = OptimizationParams(parser)
259 | pp = PipelineParams(parser)
260 | parser.add_argument('--ip', type=str, default="127.0.0.1")
261 | parser.add_argument('--port', type=int, default=6009)
262 | parser.add_argument('--debug_from', type=int, default=-1)
263 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
264 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[30_000])
265 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
266 | parser.add_argument("--quiet", action="store_true")
267 | parser.add_argument('--disable_viewer', action='store_true', default=False)
268 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[30_000])
269 | parser.add_argument("--start_checkpoint", type=str, default = None)
270 | args = parser.parse_args(sys.argv[1:])
271 | args.save_iterations.append(args.iterations)
272 |
273 | print("Optimizing " + args.model_path)
274 |
275 | # Initialize system state (RNG)
276 | safe_state(args.quiet)
277 |
278 | # Start GUI server, configure and run training
279 | if not args.disable_viewer:
280 | network_gui.init(args.ip, args.port)
281 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
282 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
283 |
284 | # All done
285 | print("\nTraining complete.")
286 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from scene.cameras import Camera
13 | import numpy as np
14 | from utils.graphics_utils import fov2focal
15 | from PIL import Image
16 | import cv2
17 |
18 | WARNED = False
19 |
20 | def loadCam(args, id, cam_info, resolution_scale, is_nerf_synthetic, is_test_dataset):
21 | image = Image.open(cam_info.image_path)
22 |
23 | if cam_info.depth_path != "":
24 | try:
25 | if is_nerf_synthetic:
26 | invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / 512
27 | else:
28 | invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / float(2**16)
29 |
30 | except FileNotFoundError:
31 | print(f"Error: The depth file at path '{cam_info.depth_path}' was not found.")
32 | raise
33 | except IOError:
34 | print(f"Error: Unable to open the image file '{cam_info.depth_path}'. It may be corrupted or an unsupported format.")
35 | raise
36 | except Exception as e:
37 | print(f"An unexpected error occurred when trying to read depth at {cam_info.depth_path}: {e}")
38 | raise
39 | else:
40 | invdepthmap = None
41 |
42 | orig_w, orig_h = image.size
43 | if args.resolution in [1, 2, 4, 8]:
44 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
45 | else: # should be a type that converts to float
46 | if args.resolution == -1:
47 | if orig_w > 1600:
48 | global WARNED
49 | if not WARNED:
50 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
51 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
52 | WARNED = True
53 | global_down = orig_w / 1600
54 | else:
55 | global_down = 1
56 | else:
57 | global_down = orig_w / args.resolution
58 |
59 |
60 | scale = float(global_down) * float(resolution_scale)
61 | resolution = (int(orig_w / scale), int(orig_h / scale))
62 | return Camera(resolution, colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
63 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, depth_params=cam_info.depth_params,
64 | image=image, invdepthmap=invdepthmap,
65 | image_name=cam_info.image_name, uid=id, data_device=args.data_device,
66 | train_test_exp=args.train_test_exp, is_test_dataset=is_test_dataset, is_test_view=cam_info.is_test)
67 |
68 | def cameraList_from_camInfos(cam_infos, resolution_scale, args, is_nerf_synthetic, is_test_dataset):
69 | camera_list = []
70 |
71 | for id, c in enumerate(cam_infos):
72 | camera_list.append(loadCam(args, id, c, resolution_scale, is_nerf_synthetic, is_test_dataset))
73 |
74 | return camera_list
75 |
76 | def camera_to_JSON(id, camera : Camera):
77 | Rt = np.zeros((4, 4))
78 | Rt[:3, :3] = camera.R.transpose()
79 | Rt[:3, 3] = camera.T
80 | Rt[3, 3] = 1.0
81 |
82 | W2C = np.linalg.inv(Rt)
83 | pos = W2C[:3, 3]
84 | rot = W2C[:3, :3]
85 | serializable_array_2d = [x.tolist() for x in rot]
86 | camera_entry = {
87 | 'id' : id,
88 | 'img_name' : camera.image_name,
89 | 'width' : camera.width,
90 | 'height' : camera.height,
91 | 'position': pos.tolist(),
92 | 'rotation': serializable_array_2d,
93 | 'fy' : fov2focal(camera.FovY, camera.height),
94 | 'fx' : fov2focal(camera.FovX, camera.width)
95 | }
96 | return camera_entry
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 |
18 | def inverse_sigmoid(x):
19 | return torch.log(x/(1-x))
20 |
21 | def PILtoTorch(pil_image, resolution):
22 | resized_image_PIL = pil_image.resize(resolution)
23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24 | if len(resized_image.shape) == 3:
25 | return resized_image.permute(2, 0, 1)
26 | else:
27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28 |
29 | def get_expon_lr_func(
30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31 | ):
32 | """
33 | Copied from Plenoxels
34 |
35 | Continuous learning rate decay function. Adapted from JaxNeRF
36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
37 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
39 | function of lr_delay_mult, such that the initial learning rate is
40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
41 | to the normal learning rate when steps>lr_delay_steps.
42 | :param conf: config subtree 'lr' or similar
43 | :param max_steps: int, the number of steps during optimization.
44 | :return HoF which takes step as input
45 | """
46 |
47 | def helper(step):
48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
49 | # Disable this parameter
50 | return 0.0
51 | if lr_delay_steps > 0:
52 | # A kind of reverse cosine decay.
53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
55 | )
56 | else:
57 | delay_rate = 1.0
58 | t = np.clip(step / max_steps, 0, 1)
59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
60 | return delay_rate * log_lerp
61 |
62 | return helper
63 |
64 | def strip_lowerdiag(L):
65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66 |
67 | uncertainty[:, 0] = L[:, 0, 0]
68 | uncertainty[:, 1] = L[:, 0, 1]
69 | uncertainty[:, 2] = L[:, 0, 2]
70 | uncertainty[:, 3] = L[:, 1, 1]
71 | uncertainty[:, 4] = L[:, 1, 2]
72 | uncertainty[:, 5] = L[:, 2, 2]
73 | return uncertainty
74 |
75 | def strip_symmetric(sym):
76 | return strip_lowerdiag(sym)
77 |
78 | def build_rotation(r):
79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
80 |
81 | q = r / norm[:, None]
82 |
83 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
84 |
85 | r = q[:, 0]
86 | x = q[:, 1]
87 | y = q[:, 2]
88 | z = q[:, 3]
89 |
90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91 | R[:, 0, 1] = 2 * (x*y - r*z)
92 | R[:, 0, 2] = 2 * (x*z + r*y)
93 | R[:, 1, 0] = 2 * (x*y + r*z)
94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95 | R[:, 1, 2] = 2 * (y*z - r*x)
96 | R[:, 2, 0] = 2 * (x*z - r*y)
97 | R[:, 2, 1] = 2 * (y*z + r*x)
98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99 | return R
100 |
101 | def build_scaling_rotation(s, r):
102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103 | R = build_rotation(r)
104 |
105 | L[:,0,0] = s[:,0]
106 | L[:,1,1] = s[:,1]
107 | L[:,2,2] = s[:,2]
108 |
109 | L = R @ L
110 | return L
111 |
112 | def safe_state(silent):
113 | old_f = sys.stdout
114 | class F:
115 | def __init__(self, silent):
116 | self.silent = silent
117 |
118 | def write(self, x):
119 | if not self.silent:
120 | if x.endswith("\n"):
121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
122 | else:
123 | old_f.write(x)
124 |
125 | def flush(self):
126 | old_f.flush()
127 |
128 | sys.stdout = F(silent)
129 |
130 | random.seed(0)
131 | np.random.seed(0)
132 | torch.manual_seed(0)
133 | torch.cuda.set_device(torch.device("cuda:0"))
134 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points : np.array
19 | colors : np.array
20 | normals : np.array
21 |
22 | def geom_transform_points(points, transf_matrix):
23 | P, _ = points.shape
24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25 | points_hom = torch.cat([points, ones], dim=1)
26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27 |
28 | denom = points_out[..., 3:] + 0.0000001
29 | return (points_out[..., :3] / denom).squeeze(dim=0)
30 |
31 | def getWorld2View(R, t):
32 | Rt = np.zeros((4, 4))
33 | Rt[:3, :3] = R.transpose()
34 | Rt[:3, 3] = t
35 | Rt[3, 3] = 1.0
36 | return np.float32(Rt)
37 |
38 | def orthonormalize_rotation_matrix(R, eps=1e-6):
39 | U, S, Vt = np.linalg.svd(R)
40 | R_ortho = U @ Vt
41 |
42 | if np.linalg.det(R_ortho) < 0:
43 | Vt[-1, :] *= -1
44 | R_ortho = U @ Vt
45 |
46 | return R_ortho
47 |
48 |
49 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
50 | Rt = np.zeros((4, 4))
51 | Rt[:3, :3] = R.transpose()
52 | Rt[:3, 3] = t
53 | Rt[3, 3] = 1.0
54 |
55 | C2W = np.linalg.inv(Rt)
56 | cam_center = C2W[:3, 3]
57 | cam_center = (cam_center + translate) * scale
58 | C2W[:3, 3] = cam_center
59 | Rt = np.linalg.inv(C2W)
60 | return np.float32(Rt)
61 |
62 | def getProjectionMatrix(znear, zfar, fovX, fovY):
63 | tanHalfFovY = math.tan((fovY / 2))
64 | tanHalfFovX = math.tan((fovX / 2))
65 |
66 | top = tanHalfFovY * znear
67 | bottom = -top
68 | right = tanHalfFovX * znear
69 | left = -right
70 |
71 | P = torch.zeros(4, 4)
72 |
73 | z_sign = 1.0
74 |
75 | P[0, 0] = 2.0 * znear / (right - left)
76 | P[1, 1] = 2.0 * znear / (top - bottom)
77 | P[0, 2] = (right + left) / (right - left)
78 | P[1, 2] = (top + bottom) / (top - bottom)
79 | P[3, 2] = z_sign
80 | P[2, 2] = z_sign * zfar / (zfar - znear)
81 | P[2, 3] = -(zfar * znear) / (zfar - znear)
82 | return P
83 |
84 | def fov2focal(fov, pixels):
85 | return pixels / (2 * math.tan(fov / 2))
86 |
87 | def focal2fov(focal, pixels):
88 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 | def mse(img1, img2):
15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16 |
17 | def psnr(img1, img2):
18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
20 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 | try:
17 | from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward
18 | except:
19 | pass
20 |
21 | C1 = 0.01 ** 2
22 | C2 = 0.03 ** 2
23 |
24 | class FusedSSIMMap(torch.autograd.Function):
25 | @staticmethod
26 | def forward(ctx, C1, C2, img1, img2):
27 | ssim_map = fusedssim(C1, C2, img1, img2)
28 | ctx.save_for_backward(img1.detach(), img2)
29 | ctx.C1 = C1
30 | ctx.C2 = C2
31 | return ssim_map
32 |
33 | @staticmethod
34 | def backward(ctx, opt_grad):
35 | img1, img2 = ctx.saved_tensors
36 | C1, C2 = ctx.C1, ctx.C2
37 | grad = fusedssim_backward(C1, C2, img1, img2, opt_grad)
38 | return None, None, grad, None
39 |
40 | def l1_loss(network_output, gt):
41 | return torch.abs((network_output - gt)).mean()
42 |
43 | def l2_loss(network_output, gt):
44 | return ((network_output - gt) ** 2).mean()
45 |
46 | def gaussian(window_size, sigma):
47 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
48 | return gauss / gauss.sum()
49 |
50 | def create_window(window_size, channel):
51 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
52 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
53 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
54 | return window
55 |
56 | def ssim(img1, img2, window_size=11, size_average=True):
57 | channel = img1.size(-3)
58 | window = create_window(window_size, channel)
59 |
60 | if img1.is_cuda:
61 | window = window.cuda(img1.get_device())
62 | window = window.type_as(img1)
63 |
64 | return _ssim(img1, img2, window, window_size, channel, size_average)
65 |
66 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
67 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
68 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
69 |
70 | mu1_sq = mu1.pow(2)
71 | mu2_sq = mu2.pow(2)
72 | mu1_mu2 = mu1 * mu2
73 |
74 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
75 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
76 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
77 |
78 | C1 = 0.01 ** 2
79 | C2 = 0.03 ** 2
80 |
81 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
82 |
83 | if size_average:
84 | return ssim_map.mean()
85 | else:
86 | return ssim_map.mean(1).mean(1).mean(1)
87 |
88 |
89 | def fast_ssim(img1, img2):
90 | ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2)
91 | return ssim_map.mean()
92 |
--------------------------------------------------------------------------------
/utils/make_depth_scale.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import cv2
4 | from joblib import delayed, Parallel
5 | import json
6 | from read_write_model import *
7 |
8 | def get_scales(key, cameras, images, points3d_ordered, args):
9 | image_meta = images[key]
10 | cam_intrinsic = cameras[image_meta.camera_id]
11 |
12 | pts_idx = images_metas[key].point3D_ids
13 |
14 | mask = pts_idx >= 0
15 | mask *= pts_idx < len(points3d_ordered)
16 |
17 | pts_idx = pts_idx[mask]
18 | valid_xys = image_meta.xys[mask]
19 |
20 | if len(pts_idx) > 0:
21 | pts = points3d_ordered[pts_idx]
22 | else:
23 | pts = np.array([0, 0, 0])
24 |
25 | R = qvec2rotmat(image_meta.qvec)
26 | pts = np.dot(pts, R.T) + image_meta.tvec
27 |
28 | invcolmapdepth = 1. / pts[..., 2]
29 | n_remove = len(image_meta.name.split('.')[-1]) + 1
30 | invmonodepthmap = cv2.imread(f"{args.depths_dir}/{image_meta.name[:-n_remove]}.png", cv2.IMREAD_UNCHANGED)
31 |
32 | if invmonodepthmap is None:
33 | return None
34 |
35 | if invmonodepthmap.ndim != 2:
36 | invmonodepthmap = invmonodepthmap[..., 0]
37 |
38 | invmonodepthmap = invmonodepthmap.astype(np.float32) / (2**16)
39 | s = invmonodepthmap.shape[0] / cam_intrinsic.height
40 |
41 | maps = (valid_xys * s).astype(np.float32)
42 | valid = (
43 | (maps[..., 0] >= 0) *
44 | (maps[..., 1] >= 0) *
45 | (maps[..., 0] < cam_intrinsic.width * s) *
46 | (maps[..., 1] < cam_intrinsic.height * s) * (invcolmapdepth > 0))
47 |
48 | if valid.sum() > 10 and (invcolmapdepth.max() - invcolmapdepth.min()) > 1e-3:
49 | maps = maps[valid, :]
50 | invcolmapdepth = invcolmapdepth[valid]
51 | invmonodepth = cv2.remap(invmonodepthmap, maps[..., 0], maps[..., 1], interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)[..., 0]
52 |
53 | ## Median / dev
54 | t_colmap = np.median(invcolmapdepth)
55 | s_colmap = np.mean(np.abs(invcolmapdepth - t_colmap))
56 |
57 | t_mono = np.median(invmonodepth)
58 | s_mono = np.mean(np.abs(invmonodepth - t_mono))
59 | scale = s_colmap / s_mono
60 | offset = t_colmap - t_mono * scale
61 | else:
62 | scale = 0
63 | offset = 0
64 | return {"image_name": image_meta.name[:-n_remove], "scale": scale, "offset": offset}
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--base_dir', default="../data/big_gaussians/standalone_chunks/campus")
69 | parser.add_argument('--depths_dir', default="../data/big_gaussians/standalone_chunks/campus/depths_any")
70 | parser.add_argument('--model_type', default="bin")
71 | args = parser.parse_args()
72 |
73 |
74 | cam_intrinsics, images_metas, points3d = read_model(os.path.join(args.base_dir, "sparse", "0"), ext=f".{args.model_type}")
75 |
76 | pts_indices = np.array([points3d[key].id for key in points3d])
77 | pts_xyzs = np.array([points3d[key].xyz for key in points3d])
78 | points3d_ordered = np.zeros([pts_indices.max()+1, 3])
79 | points3d_ordered[pts_indices] = pts_xyzs
80 |
81 | # depth_param_list = [get_scales(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas]
82 | depth_param_list = Parallel(n_jobs=-1, backend="threading")(
83 | delayed(get_scales)(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas
84 | )
85 |
86 | depth_params = {
87 | depth_param["image_name"]: {"scale": depth_param["scale"], "offset": depth_param["offset"]}
88 | for depth_param in depth_param_list if depth_param != None
89 | }
90 |
91 | with open(f"{args.base_dir}/sparse/0/depth_params.json", "w") as f:
92 | json.dump(depth_params, f, indent=2)
93 |
94 | print(0)
95 |
--------------------------------------------------------------------------------
/utils/pose_utils.py:
--------------------------------------------------------------------------------
1 | # Copy ideas from [LightGaussian](https://github.com/VITA-Group/LightGaussian)
2 |
3 | import numpy as np
4 | import torch
5 | from icecream import ic
6 | from utils.graphics_utils import getWorld2View2
7 |
8 |
9 | def normalize(x):
10 | return x / np.linalg.norm(x)
11 |
12 | def viewmatrix(z, up, pos):
13 | vec2 = normalize(z)
14 | vec1_avg = up
15 | vec0 = normalize(np.cross(vec1_avg, vec2))
16 | vec1 = normalize(np.cross(vec2, vec0))
17 | m = np.stack([vec0, vec1, vec2, pos], 1)
18 | return m
19 |
20 | def poses_avg(poses):
21 | hwf = poses[0, :3, -1:]
22 |
23 | center = poses[:, :3, 3].mean(0)
24 | vec2 = normalize(poses[:, :3, 2].sum(0))
25 | up = poses[:, :3, 1].sum(0)
26 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
27 |
28 | return c2w
29 |
30 | def get_focal(camera):
31 | focal = camera.FoVx
32 | return focal
33 |
34 | def poses_avg_fixed_center(poses):
35 | hwf = poses[0, :3, -1:]
36 | center = poses[:, :3, 3].mean(0)
37 | vec2 = [1, 0, 0]
38 | up = [0, 0, 1]
39 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
40 | return c2w
41 |
42 | def integrate_weights_np(w):
43 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
44 |
45 | The output's size on the last dimension is one greater than that of the input,
46 | because we're computing the integral corresponding to the endpoints of a step
47 | function, not the integral of the interior/bin values.
48 |
49 | Args:
50 | w: Tensor, which will be integrated along the last axis. This is assumed to
51 | sum to 1 along the last axis, and this function will (silently) break if
52 | that is not the case.
53 |
54 | Returns:
55 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
56 | """
57 | cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
58 | shape = cw.shape[:-1] + (1,)
59 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
60 | cw0 = np.concatenate([np.zeros(shape), cw,
61 | np.ones(shape)], axis=-1)
62 | return cw0
63 |
64 | def invert_cdf_np(u, t, w_logits):
65 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
66 | # Compute the PDF and CDF for each weight vector.
67 | w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True)
68 | cw = integrate_weights_np(w)
69 | # Interpolate into the inverse CDF.
70 | interp_fn = np.interp
71 | t_new = interp_fn(u, cw, t)
72 | return t_new
73 |
74 | def sample_np(rand,
75 | t,
76 | w_logits,
77 | num_samples,
78 | single_jitter=False,
79 | deterministic_center=False):
80 | """
81 | numpy version of sample()
82 | """
83 | eps = np.finfo(np.float32).eps
84 |
85 | # Draw uniform samples.
86 | if not rand:
87 | if deterministic_center:
88 | pad = 1 / (2 * num_samples)
89 | u = np.linspace(pad, 1. - pad - eps, num_samples)
90 | else:
91 | u = np.linspace(0, 1. - eps, num_samples)
92 | u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
93 | else:
94 | # `u` is in [0, 1) --- it can be zero, but it can never be 1.
95 | u_max = eps + (1 - eps) / num_samples
96 | max_jitter = (1 - u_max) / (num_samples - 1) - eps
97 | d = 1 if single_jitter else num_samples
98 | u = np.linspace(0, 1 - u_max, num_samples) + \
99 | np.random.rand(*t.shape[:-1], d) * max_jitter
100 |
101 | return invert_cdf_np(u, t, w_logits)
102 |
103 |
104 |
105 | def focus_point_fn(poses):
106 | """Calculate nearest point to all focal axes in poses."""
107 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
108 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
109 | mt_m = np.transpose(m, [0, 2, 1]) @ m
110 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
111 | return focus_pt
112 |
113 |
114 | def average_pose(poses: np.ndarray) -> np.ndarray:
115 | """New pose using average position, z-axis, and up vector of input poses."""
116 | position = poses[:, :3, 3].mean(0)
117 | z_axis = poses[:, :3, 2].mean(0)
118 | up = poses[:, :3, 1].mean(0)
119 | cam2world = viewmatrix(z_axis, up, position)
120 | return cam2world
121 |
122 | from typing import Tuple
123 | def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
124 | """Recenter poses around the origin."""
125 | cam2world = average_pose(poses)
126 | transform = np.linalg.inv(pad_poses(cam2world))
127 | poses = transform @ pad_poses(poses)
128 | return unpad_poses(poses), transform
129 |
130 |
131 | NEAR_STRETCH = .9 # Push forward near bound for forward facing render path.
132 | FAR_STRETCH = 5. # Push back far bound for forward facing render path.
133 | FOCUS_DISTANCE = .75 # Relative weighting of near, far bounds for render path.
134 | def generate_spiral_path(views, bounds,
135 | n_frames: int = 180,
136 | n_rots: int = 2,
137 | zrate: float = .5) -> np.ndarray:
138 | """Calculates a forward facing spiral path for rendering."""
139 | # Find a reasonable 'focus depth' for this dataset as a weighted average
140 | # of conservative near and far bounds in disparity space.
141 | poses = []
142 | for view in views:
143 | tmp_view = np.eye(4)
144 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
145 | tmp_view = np.linalg.inv(tmp_view)
146 | tmp_view[:, 1:3] *= -1
147 | poses.append(tmp_view)
148 | poses = np.stack(poses, 0)
149 |
150 | print(poses.shape)
151 | bounds = bounds.repeat(poses.shape[0], 0) #np.array([[ 16.21311152, 153.86329729]])
152 | scale = 1. / (bounds.min() * .75)
153 | poses[:, :3, 3] *= scale
154 | bounds *= scale
155 | # Recenter poses.
156 | # tmp, _ = recenter_poses(poses)
157 | # poses[:, :3, :3] = tmp[:, :3, :3] @ np.diag(np.array([1, -1, -1]))
158 |
159 | near_bound = bounds.min() * NEAR_STRETCH
160 | far_bound = bounds.max() * FAR_STRETCH
161 | # All cameras will point towards the world space point (0, 0, -focal).
162 | focal = 1 / (((1 - FOCUS_DISTANCE) / near_bound + FOCUS_DISTANCE / far_bound))
163 |
164 | # Get radii for spiral path using 90th percentile of camera positions.
165 | positions = poses[:, :3, 3]
166 | radii = np.percentile(np.abs(positions), 90, 0)
167 | radii = np.concatenate([radii, [1.]])
168 |
169 | # Generate poses for spiral path.
170 | render_poses = []
171 | cam2world = average_pose(poses)
172 | up = poses[:, :3, 1].mean(0)
173 | for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
174 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
175 | position = cam2world @ t
176 | lookat = cam2world @ [0, 0, -focal, 1.]
177 | z_axis = position - lookat
178 | render_pose = np.eye(4)
179 | render_pose[:3] = viewmatrix(z_axis, up, position)
180 | render_pose[:3, 1:3] *= -1
181 | render_poses.append(np.linalg.inv(render_pose))
182 | render_poses = np.stack(render_poses, axis=0)
183 | return render_poses
184 |
185 |
186 | def render_path_spiral(views, focal=50, zrate=0.5, rots=2, N=10):
187 | poses = []
188 | for view in views:
189 | tmp_view = np.eye(4)
190 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
191 | tmp_view = np.linalg.inv(tmp_view)
192 | tmp_view[:, 1:3] *= -1
193 | poses.append(tmp_view)
194 | poses = np.stack(poses, 0)
195 | # poses = np.stack([np.concatenate([view.R.T, view.T[:, None]], 1) for view in views], 0)
196 | c2w = poses_avg(poses)
197 | up = normalize(poses[:, :3, 1].sum(0))
198 |
199 | # Get radii for spiral path
200 | rads = np.percentile(np.abs(poses[:, :3, 3]), 90, 0)
201 | render_poses = []
202 | rads = np.array(list(rads) + [1.0])
203 |
204 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
205 | c = np.dot(
206 | c2w[:3, :4],
207 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) * rads,
208 | )
209 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
210 | render_pose = np.eye(4)
211 | render_pose[:3] = viewmatrix(z, up, c)
212 | render_pose[:3, 1:3] *= -1
213 | render_poses.append(np.linalg.inv(render_pose))
214 | return render_poses
215 |
216 | def pad_poses(p):
217 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
218 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
219 | return np.concatenate([p[..., :3, :4], bottom], axis=-2)
220 |
221 |
222 | def unpad_poses(p):
223 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
224 | return p[..., :3, :4]
225 |
226 | def transform_poses_pca(poses):
227 | """Transforms poses so principal components lie on XYZ axes.
228 |
229 | Args:
230 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
231 |
232 | Returns:
233 | A tuple (poses, transform), with the transformed poses and the applied
234 | camera_to_world transforms.
235 | """
236 | t = poses[:, :3, 3]
237 | t_mean = t.mean(axis=0)
238 | t = t - t_mean
239 |
240 | eigval, eigvec = np.linalg.eig(t.T @ t)
241 | # Sort eigenvectors in order of largest to smallest eigenvalue.
242 | inds = np.argsort(eigval)[::-1]
243 | eigvec = eigvec[:, inds]
244 | rot = eigvec.T
245 | if np.linalg.det(rot) < 0:
246 | rot = np.diag(np.array([1, 1, -1])) @ rot
247 |
248 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
249 | poses_recentered = unpad_poses(transform @ pad_poses(poses))
250 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
251 |
252 | # Flip coordinate system if z component of y-axis is negative
253 | if poses_recentered.mean(axis=0)[2, 1] < 0:
254 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
255 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
256 |
257 | # Just make sure it's it in the [-1, 1]^3 cube
258 | scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
259 | poses_recentered[:, :3, 3] *= scale_factor
260 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
261 | return poses_recentered, transform
262 |
263 | def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.):
264 | poses = []
265 | for view in views:
266 | tmp_view = np.eye(4)
267 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
268 | tmp_view = np.linalg.inv(tmp_view)
269 | tmp_view[:, 1:3] *= -1
270 | poses.append(tmp_view)
271 | poses = np.stack(poses, 0)
272 | poses, transform = transform_poses_pca(poses)
273 |
274 |
275 | # Calculate the focal point for the path (cameras point toward this).
276 | center = focus_point_fn(poses)
277 | offset = np.array([center[0] , center[1], center[2]*0 ])
278 | # Calculate scaling for ellipse axes based on input camera positions.
279 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
280 |
281 | # Use ellipse that is symmetric about the focal point in xy.
282 | low = -sc + offset
283 | high = sc + offset
284 | # Optional height variation need not be symmetric
285 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
286 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
287 |
288 |
289 | def get_positions(theta):
290 | # Interpolate between bounds with trig functions to get ellipse in x-y.
291 | # Optionally also interpolate in z to change camera height along path.
292 | return np.stack([
293 | (low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)),
294 | (low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)),
295 | z_variation * (z_low[2] + (z_high - z_low)[2] *
296 | (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
297 | ], -1)
298 |
299 | theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
300 | positions = get_positions(theta)
301 |
302 | if const_speed:
303 | # Resample theta angles so that the velocity is closer to constant.
304 | lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
305 | theta = sample_np(None, theta, np.log(lengths), n_frames + 1)
306 | positions = get_positions(theta)
307 |
308 | # Throw away duplicated last position.
309 | positions = positions[:-1]
310 |
311 | # Set path's up vector to axis closest to average of input pose up vectors.
312 | avg_up = poses[:, :3, 1].mean(0)
313 | avg_up = avg_up / np.linalg.norm(avg_up)
314 | ind_up = np.argmax(np.abs(avg_up))
315 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
316 |
317 | render_poses = []
318 | for p in positions:
319 | render_pose = np.eye(4)
320 | render_pose[:3] = viewmatrix(p - center, up, p)
321 | render_pose = np.linalg.inv(transform) @ render_pose
322 | render_pose[:3, 1:3] *= -1
323 | render_poses.append(np.linalg.inv(render_pose))
324 | return render_poses
325 |
326 |
327 | def generate_spherify_path(views):
328 | poses = []
329 | for view in views:
330 | tmp_view = np.eye(4)
331 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
332 | tmp_view = np.linalg.inv(tmp_view)
333 | tmp_view[:, 1:3] *= -1
334 | poses.append(tmp_view)
335 | poses = np.stack(poses, 0)
336 |
337 | p34_to_44 = lambda p: np.concatenate(
338 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
339 | )
340 |
341 | rays_d = poses[:, :3, 2:3]
342 | rays_o = poses[:, :3, 3:4]
343 |
344 | def min_line_dist(rays_o, rays_d):
345 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
346 | b_i = -A_i @ rays_o
347 | pt_mindist = np.squeeze(
348 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
349 | )
350 | return pt_mindist
351 |
352 | pt_mindist = min_line_dist(rays_o, rays_d)
353 |
354 | center = pt_mindist
355 | up = (poses[:, :3, 3] - center).mean(0)
356 |
357 | vec0 = normalize(up)
358 | vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0))
359 | vec2 = normalize(np.cross(vec0, vec1))
360 | pos = center
361 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
362 |
363 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
364 |
365 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
366 |
367 | sc = 1.0 / rad
368 | poses_reset[:, :3, 3] *= sc
369 | rad *= sc
370 |
371 | centroid = np.mean(poses_reset[:, :3, 3], 0)
372 | zh = centroid[2]
373 | radcircle = np.sqrt(rad**2 - zh**2)
374 | new_poses = []
375 |
376 | for th in np.linspace(0.0, 2.0 * np.pi, 120):
377 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
378 | up = np.array([0, 0, -1.0])
379 |
380 | vec2 = normalize(camorigin)
381 | vec0 = normalize(np.cross(vec2, up))
382 | vec1 = normalize(np.cross(vec2, vec0))
383 | pos = camorigin
384 | p = np.stack([vec0, vec1, vec2, pos], 1)
385 |
386 | render_pose = np.eye(4)
387 | render_pose[:3] = p
388 | #render_pose[:3, 1:3] *= -1
389 | new_poses.append(render_pose)
390 |
391 | new_poses = np.stack(new_poses, 0)
392 | return new_poses
393 |
394 | # def gaussian_poses(viewpoint_cam, mean =0, std_dev = 0.03):
395 | # translate_x = np.random.normal(mean, std_dev)
396 | # translate_y = np.random.normal(mean, std_dev)
397 | # translate_z = np.random.normal(mean, std_dev)
398 | # translate = np.array([translate_x, translate_y, translate_z])
399 | # viewpoint_cam.world_view_transform = torch.tensor(getWorld2View2(viewpoint_cam.R, viewpoint_cam.T, translate)).transpose(0, 1).cuda()
400 | # viewpoint_cam.full_proj_transform = (viewpoint_cam.world_view_transform.unsqueeze(0).bmm(viewpoint_cam.projection_matrix.unsqueeze(0))).squeeze(0)
401 | # viewpoint_cam.camera_center = viewpoint_cam.world_view_transform.inverse()[3, :3]
402 | # return viewpoint_cam
403 |
404 | def get_rotation_matrix(axis, angle):
405 | """
406 | Create a rotation matrix for a given axis (x, y, or z) and angle.
407 | """
408 | axis = axis.lower()
409 | cos_angle = np.cos(angle)
410 | sin_angle = np.sin(angle)
411 |
412 | if axis == 'x':
413 | return np.array([
414 | [1, 0, 0],
415 | [0, cos_angle, -sin_angle],
416 | [0, sin_angle, cos_angle]
417 | ])
418 | elif axis == 'y':
419 | return np.array([
420 | [cos_angle, 0, sin_angle],
421 | [0, 1, 0],
422 | [-sin_angle, 0, cos_angle]
423 | ])
424 | elif axis == 'z':
425 | return np.array([
426 | [cos_angle, -sin_angle, 0],
427 | [sin_angle, cos_angle, 0],
428 | [0, 0, 1]
429 | ])
430 | else:
431 | raise ValueError("Invalid axis. Choose from 'x', 'y', 'z'.")
432 |
433 |
434 |
435 | def gaussian_poses(viewpoint_cam, mean=0, std_dev_translation=0.03, std_dev_rotation=0.01):
436 | # Translation Perturbation
437 | translate_x = np.random.normal(mean, std_dev_translation)
438 | translate_y = np.random.normal(mean, std_dev_translation)
439 | translate_z = np.random.normal(mean, std_dev_translation)
440 | translate = np.array([translate_x, translate_y, translate_z])
441 |
442 | # Rotation Perturbation
443 | angle_x = np.random.normal(mean, std_dev_rotation)
444 | angle_y = np.random.normal(mean, std_dev_rotation)
445 | angle_z = np.random.normal(mean, std_dev_rotation)
446 |
447 | rot_x = get_rotation_matrix('x', angle_x)
448 | rot_y = get_rotation_matrix('y', angle_y)
449 | rot_z = get_rotation_matrix('z', angle_z)
450 |
451 | # Combined Rotation Matrix
452 | combined_rot = np.matmul(rot_z, np.matmul(rot_y, rot_x))
453 |
454 | # Apply Rotation to Camera
455 | rotated_R = np.matmul(viewpoint_cam.R, combined_rot)
456 |
457 | # Update Camera Transformation
458 | viewpoint_cam.world_view_transform = torch.tensor(getWorld2View2(rotated_R, viewpoint_cam.T, translate)).transpose(0, 1).cuda()
459 | viewpoint_cam.full_proj_transform = (viewpoint_cam.world_view_transform.unsqueeze(0).bmm(viewpoint_cam.projection_matrix.unsqueeze(0))).squeeze(0)
460 | viewpoint_cam.camera_center = viewpoint_cam.world_view_transform.inverse()[3, :3]
461 |
462 | return viewpoint_cam
463 |
464 |
465 |
466 | def circular_poses(viewpoint_cam, radius, angle=0.0):
467 | translate_x = radius * np.cos(angle)
468 | translate_y = radius * np.sin(angle)
469 | translate_z = 0
470 | translate = np.array([translate_x, translate_y, translate_z])
471 | viewpoint_cam.world_view_transform = torch.tensor(getWorld2View2(viewpoint_cam.R, viewpoint_cam.T, translate)).transpose(0, 1).cuda()
472 | viewpoint_cam.full_proj_transform = (viewpoint_cam.world_view_transform.unsqueeze(0).bmm(viewpoint_cam.projection_matrix.unsqueeze(0))).squeeze(0)
473 | viewpoint_cam.camera_center = viewpoint_cam.world_view_transform.inverse()[3, :3]
474 |
475 | return viewpoint_cam
476 |
477 | def generate_spherical_sample_path(views, azimuthal_rots=1, polar_rots=0.75, N=10):
478 | poses = []
479 | for view in views:
480 | tmp_view = np.eye(4)
481 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
482 | tmp_view = np.linalg.inv(tmp_view)
483 | tmp_view[:, 1:3] *= -1
484 | poses.append(tmp_view)
485 | focal = get_focal(view)
486 | poses = np.stack(poses, 0)
487 | # ic(min_focal, max_focal)
488 |
489 | c2w = poses_avg(poses)
490 | up = normalize(poses[:, :3, 1].sum(0))
491 | rads = np.percentile(np.abs(poses[:, :3, 3]), 90, 0)
492 | rads = np.array(list(rads) + [1.0])
493 | ic(rads)
494 | render_poses = []
495 | focal_range = np.linspace(0.5, 3, N **2+1)
496 | index = 0
497 | # Modify this loop to include phi
498 | for theta in np.linspace(0.0, 2.0 * np.pi * azimuthal_rots, N + 1)[:-1]:
499 | for phi in np.linspace(0.0, np.pi * polar_rots, N + 1)[:-1]:
500 | # Modify these lines to use spherical coordinates for c
501 | c = np.dot(
502 | c2w[:3, :4],
503 | rads * np.array([
504 | np.sin(phi) * np.cos(theta),
505 | np.sin(phi) * np.sin(theta),
506 | np.cos(phi),
507 | 1.0
508 | ])
509 | )
510 |
511 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal_range[index], 1.0])))
512 | render_pose = np.eye(4)
513 | render_pose[:3] = viewmatrix(z, up, c)
514 | render_pose[:3, 1:3] *= -1
515 | render_poses.append(np.linalg.inv(render_pose))
516 | index += 1
517 | return render_poses
518 |
519 |
520 | def generate_spiral_path(views, focal=1.5, zrate= 0, rots=1, N=600):
521 | poses = []
522 | focal = 0
523 | for view in views:
524 | tmp_view = np.eye(4)
525 | tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
526 | tmp_view = np.linalg.inv(tmp_view)
527 | tmp_view[:, 1:3] *= -1
528 | poses.append(tmp_view)
529 | focal += get_focal(views[0])
530 | poses = np.stack(poses, 0)
531 |
532 |
533 | c2w = poses_avg(poses)
534 | up = normalize(poses[:, :3, 1].sum(0))
535 |
536 | # Get radii for spiral path
537 | rads = np.percentile(np.abs(poses[:, :3, 3]), 90, 0)
538 | render_poses = []
539 |
540 | rads = np.array(list(rads) + [1.0])
541 | focal /= len(views)
542 |
543 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
544 | c = np.dot(
545 | c2w[:3, :4],
546 | np.array([np.cos(theta), -np.sin(theta),-np.sin(theta * zrate), 1.0]) * rads,
547 | )
548 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
549 |
550 | render_pose = np.eye(4)
551 | render_pose[:3] = viewmatrix(z, up, c)
552 | render_pose[:3, 1:3] *= -1
553 | render_poses.append(np.linalg.inv(render_pose))
554 | return render_poses
--------------------------------------------------------------------------------
/utils/read_write_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 |
30 |
31 | import os
32 | import collections
33 | import numpy as np
34 | import struct
35 | import argparse
36 |
37 |
38 | CameraModel = collections.namedtuple(
39 | "CameraModel", ["model_id", "model_name", "num_params"]
40 | )
41 | Camera = collections.namedtuple(
42 | "Camera", ["id", "model", "width", "height", "params"]
43 | )
44 | BaseImage = collections.namedtuple(
45 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
46 | )
47 | Point3D = collections.namedtuple(
48 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
49 | )
50 |
51 |
52 | class Image(BaseImage):
53 | def qvec2rotmat(self):
54 | return qvec2rotmat(self.qvec)
55 |
56 |
57 | CAMERA_MODELS = {
58 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
59 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
60 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
61 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
62 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
63 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
64 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
65 | CameraModel(model_id=7, model_name="FOV", num_params=5),
66 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
67 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
68 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
69 | }
70 | CAMERA_MODEL_IDS = dict(
71 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
72 | )
73 | CAMERA_MODEL_NAMES = dict(
74 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
75 | )
76 |
77 |
78 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
79 | """Read and unpack the next bytes from a binary file.
80 | :param fid:
81 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
82 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
83 | :param endian_character: Any of {@, =, <, >, !}
84 | :return: Tuple of read and unpacked values.
85 | """
86 | data = fid.read(num_bytes)
87 | return struct.unpack(endian_character + format_char_sequence, data)
88 |
89 |
90 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
91 | """pack and write to a binary file.
92 | :param fid:
93 | :param data: data to send, if multiple elements are sent at the same time,
94 | they should be encapsuled either in a list or a tuple
95 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
96 | should be the same length as the data list or tuple
97 | :param endian_character: Any of {@, =, <, >, !}
98 | """
99 | if isinstance(data, (list, tuple)):
100 | bytes = struct.pack(endian_character + format_char_sequence, *data)
101 | else:
102 | bytes = struct.pack(endian_character + format_char_sequence, data)
103 | fid.write(bytes)
104 |
105 |
106 | def read_cameras_text(path):
107 | """
108 | see: src/colmap/scene/reconstruction.cc
109 | void Reconstruction::WriteCamerasText(const std::string& path)
110 | void Reconstruction::ReadCamerasText(const std::string& path)
111 | """
112 | cameras = {}
113 | with open(path, "r") as fid:
114 | while True:
115 | line = fid.readline()
116 | if not line:
117 | break
118 | line = line.strip()
119 | if len(line) > 0 and line[0] != "#":
120 | elems = line.split()
121 | camera_id = int(elems[0])
122 | model = elems[1]
123 | width = int(elems[2])
124 | height = int(elems[3])
125 | params = np.array(tuple(map(float, elems[4:])))
126 | cameras[camera_id] = Camera(
127 | id=camera_id,
128 | model=model,
129 | width=width,
130 | height=height,
131 | params=params,
132 | )
133 | return cameras
134 |
135 |
136 | def read_cameras_binary(path_to_model_file):
137 | """
138 | see: src/colmap/scene/reconstruction.cc
139 | void Reconstruction::WriteCamerasBinary(const std::string& path)
140 | void Reconstruction::ReadCamerasBinary(const std::string& path)
141 | """
142 | cameras = {}
143 | with open(path_to_model_file, "rb") as fid:
144 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
145 | for _ in range(num_cameras):
146 | camera_properties = read_next_bytes(
147 | fid, num_bytes=24, format_char_sequence="iiQQ"
148 | )
149 | camera_id = camera_properties[0]
150 | model_id = camera_properties[1]
151 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
152 | width = camera_properties[2]
153 | height = camera_properties[3]
154 | num_params = CAMERA_MODEL_IDS[model_id].num_params
155 | params = read_next_bytes(
156 | fid,
157 | num_bytes=8 * num_params,
158 | format_char_sequence="d" * num_params,
159 | )
160 | cameras[camera_id] = Camera(
161 | id=camera_id,
162 | model=model_name,
163 | width=width,
164 | height=height,
165 | params=np.array(params),
166 | )
167 | assert len(cameras) == num_cameras
168 | return cameras
169 |
170 |
171 | def write_cameras_text(cameras, path):
172 | """
173 | see: src/colmap/scene/reconstruction.cc
174 | void Reconstruction::WriteCamerasText(const std::string& path)
175 | void Reconstruction::ReadCamerasText(const std::string& path)
176 | """
177 | HEADER = (
178 | "# Camera list with one line of data per camera:\n"
179 | + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
180 | + "# Number of cameras: {}\n".format(len(cameras))
181 | )
182 | with open(path, "w") as fid:
183 | fid.write(HEADER)
184 | for _, cam in cameras.items():
185 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
186 | line = " ".join([str(elem) for elem in to_write])
187 | fid.write(line + "\n")
188 |
189 |
190 | def write_cameras_binary(cameras, path_to_model_file):
191 | """
192 | see: src/colmap/scene/reconstruction.cc
193 | void Reconstruction::WriteCamerasBinary(const std::string& path)
194 | void Reconstruction::ReadCamerasBinary(const std::string& path)
195 | """
196 | with open(path_to_model_file, "wb") as fid:
197 | write_next_bytes(fid, len(cameras), "Q")
198 | for _, cam in cameras.items():
199 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id
200 | camera_properties = [cam.id, model_id, cam.width, cam.height]
201 | write_next_bytes(fid, camera_properties, "iiQQ")
202 | for p in cam.params:
203 | write_next_bytes(fid, float(p), "d")
204 | return cameras
205 |
206 |
207 | def read_images_text(path):
208 | """
209 | see: src/colmap/scene/reconstruction.cc
210 | void Reconstruction::ReadImagesText(const std::string& path)
211 | void Reconstruction::WriteImagesText(const std::string& path)
212 | """
213 | images = {}
214 | with open(path, "r") as fid:
215 | while True:
216 | line = fid.readline()
217 | if not line:
218 | break
219 | line = line.strip()
220 | if len(line) > 0 and line[0] != "#":
221 | elems = line.split()
222 | image_id = int(elems[0])
223 | qvec = np.array(tuple(map(float, elems[1:5])))
224 | tvec = np.array(tuple(map(float, elems[5:8])))
225 | camera_id = int(elems[8])
226 | image_name = elems[9]
227 | elems = fid.readline().split()
228 | xys = np.column_stack(
229 | [
230 | tuple(map(float, elems[0::3])),
231 | tuple(map(float, elems[1::3])),
232 | ]
233 | )
234 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
235 | images[image_id] = Image(
236 | id=image_id,
237 | qvec=qvec,
238 | tvec=tvec,
239 | camera_id=camera_id,
240 | name=image_name,
241 | xys=xys,
242 | point3D_ids=point3D_ids,
243 | )
244 | return images
245 |
246 |
247 | def read_images_binary(path_to_model_file):
248 | """
249 | see: src/colmap/scene/reconstruction.cc
250 | void Reconstruction::ReadImagesBinary(const std::string& path)
251 | void Reconstruction::WriteImagesBinary(const std::string& path)
252 | """
253 | images = {}
254 | with open(path_to_model_file, "rb") as fid:
255 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
256 | for _ in range(num_reg_images):
257 | binary_image_properties = read_next_bytes(
258 | fid, num_bytes=64, format_char_sequence="idddddddi"
259 | )
260 | image_id = binary_image_properties[0]
261 | qvec = np.array(binary_image_properties[1:5])
262 | tvec = np.array(binary_image_properties[5:8])
263 | camera_id = binary_image_properties[8]
264 | image_name = ""
265 | current_char = read_next_bytes(fid, 1, "c")[0]
266 | while current_char != b"\x00": # look for the ASCII 0 entry
267 | image_name += current_char.decode("utf-8")
268 | current_char = read_next_bytes(fid, 1, "c")[0]
269 | num_points2D = read_next_bytes(
270 | fid, num_bytes=8, format_char_sequence="Q"
271 | )[0]
272 | x_y_id_s = read_next_bytes(
273 | fid,
274 | num_bytes=24 * num_points2D,
275 | format_char_sequence="ddq" * num_points2D,
276 | )
277 | xys = np.column_stack(
278 | [
279 | tuple(map(float, x_y_id_s[0::3])),
280 | tuple(map(float, x_y_id_s[1::3])),
281 | ]
282 | )
283 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
284 | images[image_id] = Image(
285 | id=image_id,
286 | qvec=qvec,
287 | tvec=tvec,
288 | camera_id=camera_id,
289 | name=image_name,
290 | xys=xys,
291 | point3D_ids=point3D_ids,
292 | )
293 | return images
294 |
295 |
296 | def write_images_text(images, path):
297 | """
298 | see: src/colmap/scene/reconstruction.cc
299 | void Reconstruction::ReadImagesText(const std::string& path)
300 | void Reconstruction::WriteImagesText(const std::string& path)
301 | """
302 | if len(images) == 0:
303 | mean_observations = 0
304 | else:
305 | mean_observations = sum(
306 | (len(img.point3D_ids) for _, img in images.items())
307 | ) / len(images)
308 | HEADER = (
309 | "# Image list with two lines of data per image:\n"
310 | + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
311 | + "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
312 | + "# Number of images: {}, mean observations per image: {}\n".format(
313 | len(images), mean_observations
314 | )
315 | )
316 |
317 | with open(path, "w") as fid:
318 | fid.write(HEADER)
319 | for _, img in images.items():
320 | image_header = [
321 | img.id,
322 | *img.qvec,
323 | *img.tvec,
324 | img.camera_id,
325 | img.name,
326 | ]
327 | first_line = " ".join(map(str, image_header))
328 | fid.write(first_line + "\n")
329 |
330 | points_strings = []
331 | for xy, point3D_id in zip(img.xys, img.point3D_ids):
332 | points_strings.append(" ".join(map(str, [*xy, point3D_id])))
333 | fid.write(" ".join(points_strings) + "\n")
334 |
335 |
336 | def write_images_binary(images, path_to_model_file):
337 | """
338 | see: src/colmap/scene/reconstruction.cc
339 | void Reconstruction::ReadImagesBinary(const std::string& path)
340 | void Reconstruction::WriteImagesBinary(const std::string& path)
341 | """
342 | with open(path_to_model_file, "wb") as fid:
343 | write_next_bytes(fid, len(images), "Q")
344 | for _, img in images.items():
345 | write_next_bytes(fid, img.id, "i")
346 | write_next_bytes(fid, img.qvec.tolist(), "dddd")
347 | write_next_bytes(fid, img.tvec.tolist(), "ddd")
348 | write_next_bytes(fid, img.camera_id, "i")
349 | for char in img.name:
350 | write_next_bytes(fid, char.encode("utf-8"), "c")
351 | write_next_bytes(fid, b"\x00", "c")
352 | write_next_bytes(fid, len(img.point3D_ids), "Q")
353 | for xy, p3d_id in zip(img.xys, img.point3D_ids):
354 | write_next_bytes(fid, [*xy, p3d_id], "ddq")
355 |
356 |
357 | def read_points3D_text(path):
358 | """
359 | see: src/colmap/scene/reconstruction.cc
360 | void Reconstruction::ReadPoints3DText(const std::string& path)
361 | void Reconstruction::WritePoints3DText(const std::string& path)
362 | """
363 | points3D = {}
364 | with open(path, "r") as fid:
365 | while True:
366 | line = fid.readline()
367 | if not line:
368 | break
369 | line = line.strip()
370 | if len(line) > 0 and line[0] != "#":
371 | elems = line.split()
372 | point3D_id = int(elems[0])
373 | xyz = np.array(tuple(map(float, elems[1:4])))
374 | rgb = np.array(tuple(map(int, elems[4:7])))
375 | error = float(elems[7])
376 | image_ids = np.array(tuple(map(int, elems[8::2])))
377 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
378 | points3D[point3D_id] = Point3D(
379 | id=point3D_id,
380 | xyz=xyz,
381 | rgb=rgb,
382 | error=error,
383 | image_ids=image_ids,
384 | point2D_idxs=point2D_idxs,
385 | )
386 | return points3D
387 |
388 |
389 | def read_points3D_binary(path_to_model_file):
390 | """
391 | see: src/colmap/scene/reconstruction.cc
392 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
393 | void Reconstruction::WritePoints3DBinary(const std::string& path)
394 | """
395 | points3D = {}
396 | with open(path_to_model_file, "rb") as fid:
397 | num_points = read_next_bytes(fid, 8, "Q")[0]
398 | for _ in range(num_points):
399 | binary_point_line_properties = read_next_bytes(
400 | fid, num_bytes=43, format_char_sequence="QdddBBBd"
401 | )
402 | point3D_id = binary_point_line_properties[0]
403 | xyz = np.array(binary_point_line_properties[1:4])
404 | rgb = np.array(binary_point_line_properties[4:7])
405 | error = np.array(binary_point_line_properties[7])
406 | track_length = read_next_bytes(
407 | fid, num_bytes=8, format_char_sequence="Q"
408 | )[0]
409 | track_elems = read_next_bytes(
410 | fid,
411 | num_bytes=8 * track_length,
412 | format_char_sequence="ii" * track_length,
413 | )
414 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
415 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
416 | points3D[point3D_id] = Point3D(
417 | id=point3D_id,
418 | xyz=xyz,
419 | rgb=rgb,
420 | error=error,
421 | image_ids=image_ids,
422 | point2D_idxs=point2D_idxs,
423 | )
424 | return points3D
425 |
426 |
427 | def write_points3D_text(points3D, path):
428 | """
429 | see: src/colmap/scene/reconstruction.cc
430 | void Reconstruction::ReadPoints3DText(const std::string& path)
431 | void Reconstruction::WritePoints3DText(const std::string& path)
432 | """
433 | if len(points3D) == 0:
434 | mean_track_length = 0
435 | else:
436 | mean_track_length = sum(
437 | (len(pt.image_ids) for _, pt in points3D.items())
438 | ) / len(points3D)
439 | HEADER = (
440 | "# 3D point list with one line of data per point:\n"
441 | + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
442 | + "# Number of points: {}, mean track length: {}\n".format(
443 | len(points3D), mean_track_length
444 | )
445 | )
446 |
447 | with open(path, "w") as fid:
448 | fid.write(HEADER)
449 | for _, pt in points3D.items():
450 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
451 | fid.write(" ".join(map(str, point_header)) + " ")
452 | track_strings = []
453 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
454 | track_strings.append(" ".join(map(str, [image_id, point2D])))
455 | fid.write(" ".join(track_strings) + "\n")
456 |
457 |
458 | def write_points3D_binary(points3D, path_to_model_file):
459 | """
460 | see: src/colmap/scene/reconstruction.cc
461 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
462 | void Reconstruction::WritePoints3DBinary(const std::string& path)
463 | """
464 | with open(path_to_model_file, "wb") as fid:
465 | write_next_bytes(fid, len(points3D), "Q")
466 | for _, pt in points3D.items():
467 | write_next_bytes(fid, pt.id, "Q")
468 | write_next_bytes(fid, pt.xyz.tolist(), "ddd")
469 | write_next_bytes(fid, pt.rgb.tolist(), "BBB")
470 | write_next_bytes(fid, pt.error, "d")
471 | track_length = pt.image_ids.shape[0]
472 | write_next_bytes(fid, track_length, "Q")
473 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
474 | write_next_bytes(fid, [image_id, point2D_id], "ii")
475 |
476 |
477 | def detect_model_format(path, ext):
478 | if (
479 | os.path.isfile(os.path.join(path, "cameras" + ext))
480 | and os.path.isfile(os.path.join(path, "images" + ext))
481 | and os.path.isfile(os.path.join(path, "points3D" + ext))
482 | ):
483 | print("Detected model format: '" + ext + "'")
484 | return True
485 |
486 | return False
487 |
488 |
489 | def read_model(path, ext=""):
490 | # try to detect the extension automatically
491 | if ext == "":
492 | if detect_model_format(path, ".bin"):
493 | ext = ".bin"
494 | elif detect_model_format(path, ".txt"):
495 | ext = ".txt"
496 | else:
497 | print("Provide model format: '.bin' or '.txt'")
498 | return
499 |
500 | if ext == ".txt":
501 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
502 | images = read_images_text(os.path.join(path, "images" + ext))
503 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
504 | else:
505 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
506 | images = read_images_binary(os.path.join(path, "images" + ext))
507 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
508 | return cameras, images, points3D
509 |
510 |
511 | def write_model(cameras, images, points3D, path, ext=".bin"):
512 | if ext == ".txt":
513 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
514 | write_images_text(images, os.path.join(path, "images" + ext))
515 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
516 | else:
517 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
518 | write_images_binary(images, os.path.join(path, "images" + ext))
519 | write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
520 | return cameras, images, points3D
521 |
522 |
523 | def qvec2rotmat(qvec):
524 | return np.array(
525 | [
526 | [
527 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
528 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
529 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
530 | ],
531 | [
532 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
533 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
534 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
535 | ],
536 | [
537 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
538 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
539 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
540 | ],
541 | ]
542 | )
543 |
544 |
545 | def rotmat2qvec(R):
546 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
547 | K = (
548 | np.array(
549 | [
550 | [Rxx - Ryy - Rzz, 0, 0, 0],
551 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
552 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
553 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
554 | ]
555 | )
556 | / 3.0
557 | )
558 | eigvals, eigvecs = np.linalg.eigh(K)
559 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
560 | if qvec[0] < 0:
561 | qvec *= -1
562 | return qvec
563 |
564 |
565 | # def main():
566 | # parser = argparse.ArgumentParser(
567 | # description="Read and write COLMAP binary and text models"
568 | # )
569 | # parser.add_argument("--input_model", help="path to input model folder")
570 | # parser.add_argument(
571 | # "--input_format",
572 | # choices=[".bin", ".txt"],
573 | # help="input model format",
574 | # default="",
575 | # )
576 | # parser.add_argument("--output_model", help="path to output model folder")
577 | # parser.add_argument(
578 | # "--output_format",
579 | # choices=[".bin", ".txt"],
580 | # help="outut model format",
581 | # default=".txt",
582 | # )
583 | # args = parser.parse_args()
584 |
585 | # cameras, images, points3D = read_model(
586 | # path=args.input_model, ext=args.input_format
587 | # )
588 |
589 | # print("num_cameras:", len(cameras))
590 | # print("num_images:", len(images))
591 | # print("num_points3D:", len(points3D))
592 |
593 | # if args.output_model is not None:
594 | # write_model(
595 | # cameras,
596 | # images,
597 | # points3D,
598 | # path=args.output_model,
599 | # ext=args.output_format,
600 | # )
601 |
602 |
603 | # if __name__ == "__main__":
604 | # main()
605 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 | def mkdir_p(folder_path):
17 | # Creates a directory. equivalent to using mkdir -p on the command line
18 | try:
19 | makedirs(folder_path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == EEXIST and path.isdir(folder_path):
22 | pass
23 | else:
24 | raise
25 |
26 | def searchForMaxIteration(folder):
27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28 | return max(saved_iters)
29 |
--------------------------------------------------------------------------------