├── .gitmodules ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── convert.py ├── docker ├── Dockerfile ├── build_gaussian_pro_docker.sh ├── entrypoint.sh ├── environment.yml └── run_gaussian_pro_docker.sh ├── environment.yml ├── figs ├── comparison.gif ├── effel_tower.mp4 ├── jianzhu_final_demo.mp4 ├── jiaotang_final_demo.mp4 ├── motivation.png ├── output.gif ├── output1.gif ├── output2.gif └── pipeline.png ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── results ├── DeepBlending │ ├── drjohnson.csv │ └── playroom.csv ├── Eth3D │ ├── delivery_area.csv │ ├── electro.csv │ ├── kicker.csv │ ├── meadow.csv │ ├── office.csv │ ├── playground.csv │ ├── relief.csv │ ├── relief2.csv │ └── terrace.csv ├── MipNeRF360 │ ├── bicycle.csv │ ├── bonsai.csv │ ├── counter.csv │ ├── flowers.csv │ ├── garden.csv │ ├── kitchen.csv │ ├── room.csv │ ├── stump.csv │ └── treehill.csv └── TanksAndTemples │ ├── train.csv │ └── truck.csv ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── demo.sh └── waymo.sh ├── submodules └── Propagation │ ├── PatchMatch.cpp │ ├── PatchMatch.h │ ├── Propagation.cu │ ├── main.h │ ├── pro.cpp │ └── setup.py ├── train.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/diff-gaussian-rasterization"] 2 | path = submodules/diff-gaussian-rasterization 3 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization 4 | [submodule "submodules/simple-knn"] 5 | path = submodules/simple-knn 6 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "workbench.colorCustomizations": { 3 | "activityBar.background": "#22312D", 4 | "titleBar.activeBackground": "#30443F", 5 | "titleBar.activeForeground": "#F8FAF9" 6 | } 7 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 kcheng1021. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

GaussianPro: 3D Gaussian Splatting with Progressive Propagation

4 |

ICML 2024

5 | 6 | ### [Project Page](https://kcheng1021.github.io/gaussianpro.github.io/) | [Paper](https://arxiv.org/abs/2402.14650) 7 | 8 |
9 | 10 |
11 | image 12 |
13 | 14 | ## 📖 Abstract 15 | 16 | The advent of 3D Gaussian Splatting (3DGS) has recently brought about a revolution in the field of neural rendering, facilitating high-quality renderings at real-time speed. However, 3DGS heavily depends on the initialized point cloud produced by Structure-from-Motion (SfM) techniques. 17 | When tackling with large-scale scenes that unavoidably contain texture-less surfaces, the SfM techniques always fail to produce enough points in these surfaces and cannot provide good initialization for 3DGS. As a result, 3DGS suffers from difficult optimization and low-quality renderings. 18 | In this paper, inspired by classical multi-view stereo (MVS) techniques, we propose **GaussianPro**, a novel method that applies a progressive propagation strategy to guide the densification of the 3D Gaussians. 19 | Compared to the simple split and clone strategies used in 3DGS, our method leverages the priors of the existing reconstructed geometries of the scene and patch matching techniques to produce new Gaussians with accurate positions and orientations. 20 | Experiments on both large-scale and small-scale scenes validate the effectiveness of our method, where our method significantly surpasses 3DGS on the Waymo dataset, exhibiting an improvement of 1.15dB in terms of PSNR. 21 | 22 | ## 🗓️ News 23 | 24 | [2024.10.10] Many thanks to [Caio Viturino](https://github.com/caioviturinofs), the project provides the docker environment! 25 | 26 | [2024.9.28] Many thanks to [Chongjie Ye](https://github.com/hugoycj), the project avoid the dependency on Opencv C++ libraray, making it more convenient to install! 27 | 28 | Some amazing enhancements will also come out this year. 29 | 30 | ## 🗓️ TODO 31 | - [✔] Code pre-release -- Beta version. 32 | - [✔] Demo Scenes. 33 | - [✔] Pybinding & CUDA acceleration. 34 | - [ ] Support for unordered set of images. 35 | 36 | Some amazing enhancements are under development. We are warmly welcome anyone to collaborate in improving this repository. Please send me an email if you are interested! 37 | 38 | ## 🚀 Pipeline 39 | 40 |
41 | image 42 |
43 | 44 | 45 | ## 🚀 Setup 46 | #### Tested Environment 47 | Ubuntu 20.04.1 LTS, GeForce 3090, CUDA 11.3 (tested) / 11.7 (tested), C++17 48 | 49 | #### Clone the repo. 50 | ``` 51 | git clone https://github.com/kcheng1021/GaussianPro.git --recursive 52 | ``` 53 | 54 | #### Environment setup 55 | ``` 56 | conda env create --file environment.yml 57 | 58 | # install the propagation package 59 | # The gpu compute architecture is specified as sm_86 in setup.py. Please replace it with a version that is suitable for your GPU. 60 | # Replace the opencv and CUDA include/lib path with your own (Ignore in the latest version) 61 | # the C++ opencv is better installed in conda environment by conda install -c conda-forge opencv (Ignore in the latest version) 62 | pip install ./submodules/Propagation 63 | 64 | ``` 65 | 66 | #### Docker install (Alternative) 67 | 68 | To build the GaussianPro using docker, execute the following commands: 69 | ```bash 70 | sh docker/build_gaussian_pro_docker.sh 71 | ``` 72 | 73 | To execute the container, run: 74 | ```bash 75 | # Please remember to substitute the dataset path to your desired path 76 | sh docker/run_gaussian_pro_docker.sh 77 | ``` 78 | 79 | #### Download the Waymo scenes: Segment-102751,100613,132384,144248,148697,405841,164701,150623,113792 80 | ``` 81 | wget https://drive.google.com/file/d/1DXQRBcUIrnIC33WNq8pVLKZ_W1VwON3k/view?usp=sharing 82 | https://drive.google.com/file/d/1DEDt8sNshAlmcwbp_KleeNYf_Jq0fy4u/view?usp=sharing 83 | https://drive.google.com/file/d/1J7_IA2w4-u51lCmtmMA5CDxXR4Dbkeoq/view?usp=sharing 84 | https://rec.ustc.edu.cn/share/d34a0370-2bb2-11f0-b128-73c0ccb2577f password:ux3p 85 | ``` 86 | 87 | #### Besides the public datasets, we also test GaussianPro from random selected Youtube videos and find consistent improvement. The processed data is provided below. 88 | 89 | ``` 90 | #youtube01: Park. 91 | wget https://drive.google.com/file/d/1iHYTnI76Zx9VTKbMu1zUE7gVKP4UpBan/view?usp=sharing 92 | 93 | #youtube02: Church 94 | wget https://drive.google.com/file/d/1i2ReAJYkeLHBBbs_8Zn560Tke2F8yR1X/view?usp=sharing 95 | 96 | #youtube03: The forbidden city. 97 | wget https://drive.google.com/file/d/1PZ_917Oq0Y45_5dJ504RxRmpUnewYmyn/view?usp=sharing 98 | 99 | #youtube04: Eiffel tower. 100 | wget https://drive.google.com/file/d/1JoYyfAu3RNnj12C2gPvfljHLUKlUsSr1/view?usp=sharing 101 | ``` 102 | 103 | ![image](https://github.com/kcheng1021/GaussianPro/blob/main/figs/output.gif) 104 | ![image](https://github.com/kcheng1021/GaussianPro/blob/main/figs/output2.gif) 105 | 106 | #### Run the codes 107 | ``` 108 | # Run the 3DGS, we modify the defaulting parameters in 3DGS to better learn large scenes. The description of parameters in GaussianPro will come out later. 109 | 110 | # To run the Waymo scenes (3DGS and GaussianPro) 111 | bash scripts/waymo.sh 112 | 113 | # Run the Youtube scenes above 114 | bash scripts/demo.sh 115 | ``` 116 | 117 | To ensure the reproducibility, we present a reference of the results in the provided demo scenes based on the current code. 118 | | | Waymo-1002751 | Youtube-01 | Youtube-02 | Youtube-03 | Youtube-04 | 119 | | :--- | :---: | :---: | :---: | :---: | :---: | 120 | | 3DGS | 35.22,0.950,0.234 | 34.40,0.964,0.092 | 34.67,0.954,0.072 | 37.81,0.971,0.081 | 33.05,0.950,0.079 | 121 | | GaussianPro | **35.97,0.959,0.207** | **35.29,0.969,0.076** | **35.08,0.959,0.064** | **38.27,0.974,0.072** | **33.66,0.956,0.072** | 122 | 123 | ### Try your scenes 124 | 125 | **If you want to try your scenes, make sure your images are sorted in the time order, i.e. video data. The current version does not support unordered image sets, but it 126 | will be updated in the next version. Then you can try the commands in demo.sh to run your own scenes.** 127 | 128 | **Please ensure that your neighboring images have sufficient overlap.** 129 | 130 | ## 🎫 License 131 | 132 | For non-commercial use, this code is released under the [LICENSE](LICENSE). 133 | For commercial use, please contact Xuejin Chen. 134 | 135 | ## 🎫 Acknowledgment 136 | This project largely references [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [ACMH/ACMM](https://github.com/GhiXu/ACMH). Thanks for their amazing works! 137 | 138 | ## 🖊️ Citation 139 | 140 | 141 | If you find this project useful in your research, please consider cite: 142 | 143 | 144 | ```BibTeX 145 | @article{cheng2024gaussianpro, 146 | title={GaussianPro: 3D Gaussian Splatting with Progressive Propagation}, 147 | author={Cheng, Kai and Long, Xiaoxiao and Yang, Kaizhi and Yao, Yao and Yin, Wei and Ma, Yuexin and Wang, Wenping and Chen, Xuejin}, 148 | journal={arXiv preprint arXiv:2402.14650}, 149 | year={2024} 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = False 55 | self.data_device = "cuda" 56 | self.sky_seg = False 57 | self.load_normal = False 58 | self.load_depth = False 59 | self.eval = False 60 | super().__init__(parser, "Loading Parameters", sentinel) 61 | 62 | def extract(self, args): 63 | g = super().extract(args) 64 | g.source_path = os.path.abspath(g.source_path) 65 | return g 66 | 67 | class PipelineParams(ParamGroup): 68 | def __init__(self, parser): 69 | self.convert_SHs_python = False 70 | self.compute_cov3D_python = False 71 | self.debug = False 72 | super().__init__(parser, "Pipeline Parameters") 73 | 74 | class OptimizationParams(ParamGroup): 75 | def __init__(self, parser): 76 | self.iterations = 30_000 77 | self.position_lr_init = 0.00016 78 | self.position_lr_final = 0.0000016 79 | self.position_lr_delay_mult = 0.01 80 | self.position_lr_max_steps = 30_000 81 | self.feature_lr = 0.0025 82 | self.opacity_lr = 0.05 83 | self.scaling_lr = 0.005 84 | self.rotation_lr = 0.001 85 | self.percent_dense = 0.01 86 | self.normal_loss = False 87 | self.sparse_loss = False 88 | self.flatten_loss = False 89 | self.depth_loss = False 90 | self.depth2normal_loss = False 91 | self.lambda_l1_normal = 0.01 92 | self.lambda_cos_normal = 0.01 93 | self.lambda_flatten = 100.0 94 | self.lambda_dssim = 0.2 95 | self.lambda_sparse = 0.001 96 | self.lambda_depth = 0.1 97 | self.lambda_depth2normal = 0.05 98 | self.densification_interval = 100 99 | self.opacity_reset_interval = 3000 100 | self.densify_from_iter = 500 101 | self.densify_until_iter = 15_000 102 | self.densify_grad_threshold = 0.0002 103 | self.random_background = False 104 | 105 | #propagation parameters 106 | self.dataset = 'waymo' 107 | self.propagation_interval = 20 108 | self.depth_error_min_threshold = 1.0 109 | self.depth_error_max_threshold = 1.0 110 | self.propagated_iteration_begin = 1000 111 | self.propagated_iteration_after = 12000 112 | self.patch_size = 20 113 | self.pair_path = '' 114 | super().__init__(parser, "Optimization Parameters") 115 | 116 | def get_combined_args(parser : ArgumentParser): 117 | cmdlne_string = sys.argv[1:] 118 | cfgfile_string = "Namespace()" 119 | args_cmdline = parser.parse_args(cmdlne_string) 120 | 121 | try: 122 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 123 | print("Looking for config file in", cfgfilepath) 124 | with open(cfgfilepath) as cfg_file: 125 | print("Config file found: {}".format(cfgfilepath)) 126 | cfgfile_string = cfg_file.read() 127 | except TypeError: 128 | print("Config file not found at") 129 | pass 130 | args_cfgfile = eval(cfgfile_string) 131 | 132 | merged_dict = vars(args_cfgfile).copy() 133 | for k,v in vars(args_cmdline).items(): 134 | if v != None: 135 | merged_dict[k] = v 136 | return Namespace(**merged_dict) 137 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --ImageReader.mask_path " + args.source_path + "/mask" + " \ 41 | --SiftExtraction.use_gpu " + str(use_gpu) 42 | exit_code = os.system(feat_extracton_cmd) 43 | if exit_code != 0: 44 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 45 | exit(exit_code) 46 | 47 | ## Feature matching 48 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 49 | --database_path " + args.source_path + "/distorted/database.db \ 50 | --SiftMatching.use_gpu " + str(use_gpu) 51 | exit_code = os.system(feat_matching_cmd) 52 | if exit_code != 0: 53 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 54 | exit(exit_code) 55 | 56 | ### Bundle adjustment 57 | # The default Mapper tolerance is unnecessarily large, 58 | # decreasing it speeds up bundle adjustment steps. 59 | mapper_cmd = (colmap_command + " mapper \ 60 | --database_path " + args.source_path + "/distorted/database.db \ 61 | --image_path " + args.source_path + "/input \ 62 | --output_path " + args.source_path + "/distorted/sparse \ 63 | --Mapper.ba_global_function_tolerance=0.000001") 64 | exit_code = os.system(mapper_cmd) 65 | if exit_code != 0: 66 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 67 | exit(exit_code) 68 | 69 | ### Image undistortion 70 | ## We need to undistort our images into ideal pinhole intrinsics. 71 | img_undist_cmd = (colmap_command + " image_undistorter \ 72 | --image_path " + args.source_path + "/input \ 73 | --input_path " + args.source_path + "/distorted/sparse/0 \ 74 | --output_path " + args.source_path + "\ 75 | --output_type COLMAP") 76 | exit_code = os.system(img_undist_cmd) 77 | if exit_code != 0: 78 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 79 | exit(exit_code) 80 | 81 | files = os.listdir(args.source_path + "/sparse") 82 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 83 | # Copy each file from the source directory to the destination directory 84 | for file in files: 85 | if file == '0': 86 | continue 87 | source_file = os.path.join(args.source_path, "sparse", file) 88 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 89 | shutil.move(source_file, destination_file) 90 | 91 | if(args.resize): 92 | print("Copying and resizing...") 93 | 94 | # Resize images. 95 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 97 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 98 | # Get the list of files in the source directory 99 | files = os.listdir(args.source_path + "/images") 100 | # Copy each file from the source directory to the destination directory 101 | for file in files: 102 | source_file = os.path.join(args.source_path, "images", file) 103 | 104 | destination_file = os.path.join(args.source_path, "images_2", file) 105 | shutil.copy2(source_file, destination_file) 106 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 107 | if exit_code != 0: 108 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 109 | exit(exit_code) 110 | 111 | destination_file = os.path.join(args.source_path, "images_4", file) 112 | shutil.copy2(source_file, destination_file) 113 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 114 | if exit_code != 0: 115 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 116 | exit(exit_code) 117 | 118 | destination_file = os.path.join(args.source_path, "images_8", file) 119 | shutil.copy2(source_file, destination_file) 120 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 121 | if exit_code != 0: 122 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 123 | exit(exit_code) 124 | 125 | print("Done.") 126 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official CUDA runtime as the base image 2 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 3 | 4 | # Set environment variables 5 | ENV DEBIAN_FRONTEND=noninteractive \ 6 | CONDA_DIR=/opt/conda \ 7 | CUDA_HOME=/usr/local/cuda \ 8 | TORCH_CUDA_ARCH_LIST="7.5" 9 | 10 | # Add Conda to PATH 11 | ENV PATH=$CONDA_DIR/bin:$PATH 12 | 13 | # Install system dependencies 14 | RUN apt-get update && apt-get install -y \ 15 | git \ 16 | wget \ 17 | build-essential \ 18 | libgl1-mesa-glx \ 19 | libglib2.0-0 \ 20 | python3-dev \ 21 | python3-pip \ 22 | && rm -rf /var/lib/apt/lists/* 23 | 24 | # Install Miniconda 25 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ 26 | bash /tmp/miniconda.sh -b -p $CONDA_DIR && \ 27 | rm /tmp/miniconda.sh && \ 28 | $CONDA_DIR/bin/conda clean -afy 29 | 30 | # Initialize Conda 31 | RUN $CONDA_DIR/bin/conda init bash 32 | 33 | # Clone the GaussianPro repository 34 | RUN git clone https://github.com/kcheng1021/GaussianPro.git --recursive 35 | 36 | # Set the working directory 37 | WORKDIR /GaussianPro 38 | 39 | # Copy environment.yml into the Docker image 40 | COPY environment.yml . 41 | 42 | # Create the Conda environment and install additional packages 43 | # RUN /opt/conda/bin/conda env create -f environment.yml 44 | 45 | # Activate the environment 46 | RUN echo "source /opt/conda/etc/profile.d/conda.sh" >> /root/.bashrc && \ 47 | echo "conda activate gaussianpro" >> /root/.bashrc 48 | 49 | # Create the Conda environment, install packages, and clean up in one RUN command 50 | RUN /bin/bash -c "source $CONDA_DIR/etc/profile.d/conda.sh && \ 51 | conda env create -f environment.yml && \ 52 | conda activate gaussianpro && \ 53 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.6 -c pytorch -c conda-forge && \ 54 | pip install --upgrade pip && \ 55 | pip install ./submodules/Propagation && \ 56 | pip install ./submodules/diff-gaussian-rasterization && \ 57 | pip install ./submodules/simple-knn && \ 58 | conda clean -afy" 59 | 60 | # Copy the entrypoint script into the Docker image 61 | COPY entrypoint.sh /entrypoint.sh 62 | 63 | # Make the entrypoint script executable 64 | RUN chmod +x /entrypoint.sh 65 | 66 | # Set the entrypoint to the script that activates the Conda environment 67 | ENTRYPOINT ["/entrypoint.sh"] 68 | 69 | # Set the default command to bash to keep the container running 70 | CMD ["/bin/bash"] 71 | -------------------------------------------------------------------------------- /docker/build_gaussian_pro_docker.sh: -------------------------------------------------------------------------------- 1 | docker build --no-cache -t gaussian-pro . -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Initializing Conda..." 5 | source /opt/conda/etc/profile.d/conda.sh 6 | 7 | echo "Activating Conda environment 'gaussianpro'..." 8 | conda activate gaussianpro 9 | 10 | echo "Conda environment activated: $(conda info --envs | grep '*' )" 11 | 12 | if [ "$#" -gt 0 ]; then 13 | echo "Executing command: $@" 14 | exec "$@" 15 | else 16 | echo "Starting interactive bash shell..." 17 | exec bash 18 | fi 19 | -------------------------------------------------------------------------------- /docker/environment.yml: -------------------------------------------------------------------------------- 1 | name: gaussianpro 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.7 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - tqdm 12 | - ninja 13 | - opencv-python=4.10.0.84 14 | - matplotlib=3.5.3 15 | - open3d=0.17.0 16 | - imageio=2.31.2 17 | -------------------------------------------------------------------------------- /docker/run_gaussian_pro_docker.sh: -------------------------------------------------------------------------------- 1 | docker run --rm --gpus all -it \ 2 | --entrypoint /bin/bash \ 3 | -v your_dataset_path:/GaussianPro/datasets \ 4 | gaussian-pro 5 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gaussianpro 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - pip: 16 | - submodules/diff-gaussian-rasterization 17 | - submodules/simple-knn 18 | -------------------------------------------------------------------------------- /figs/comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/comparison.gif -------------------------------------------------------------------------------- /figs/effel_tower.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/effel_tower.mp4 -------------------------------------------------------------------------------- /figs/jianzhu_final_demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/jianzhu_final_demo.mp4 -------------------------------------------------------------------------------- /figs/jiaotang_final_demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/jiaotang_final_demo.mp4 -------------------------------------------------------------------------------- /figs/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/motivation.png -------------------------------------------------------------------------------- /figs/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/output.gif -------------------------------------------------------------------------------- /figs/output1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/output1.gif -------------------------------------------------------------------------------- /figs/output2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/output2.gif -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/pipeline.png -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | from utils.general_utils import build_rotation 18 | import torch.nn.functional as F 19 | 20 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, 21 | return_depth = False, return_normal = False, return_opacity = False): 22 | """ 23 | Render the scene. 24 | 25 | Background tensor (bg_color) must be on GPU! 26 | """ 27 | 28 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 29 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 30 | try: 31 | screenspace_points.retain_grad() 32 | except: 33 | pass 34 | 35 | # Set up rasterization configuration 36 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 37 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 38 | 39 | raster_settings = GaussianRasterizationSettings( 40 | image_height=int(viewpoint_camera.image_height), 41 | image_width=int(viewpoint_camera.image_width), 42 | tanfovx=tanfovx, 43 | tanfovy=tanfovy, 44 | bg=bg_color, 45 | scale_modifier=scaling_modifier, 46 | viewmatrix=viewpoint_camera.world_view_transform, 47 | projmatrix=viewpoint_camera.full_proj_transform, 48 | sh_degree=pc.active_sh_degree, 49 | campos=viewpoint_camera.camera_center, 50 | prefiltered=False, 51 | debug=pipe.debug 52 | ) 53 | 54 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 55 | 56 | means3D = pc.get_xyz 57 | means2D = screenspace_points 58 | opacity = pc.get_opacity 59 | 60 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 61 | # scaling / rotation by the rasterizer. 62 | scales = None 63 | rotations = None 64 | cov3D_precomp = None 65 | if pipe.compute_cov3D_python: 66 | cov3D_precomp = pc.get_covariance(scaling_modifier) 67 | else: 68 | scales = pc.get_scaling 69 | rotations = pc.get_rotation 70 | 71 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 72 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 73 | shs = None 74 | colors_precomp = None 75 | if override_color is None: 76 | if pipe.convert_SHs_python: 77 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 78 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 79 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 80 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 81 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 82 | else: 83 | shs = pc.get_features 84 | else: 85 | colors_precomp = override_color 86 | 87 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 88 | rendered_image, radii = rasterizer( 89 | means3D = means3D, 90 | means2D = means2D, 91 | shs = shs, 92 | colors_precomp = colors_precomp, 93 | opacities = opacity, 94 | scales = scales, 95 | rotations = rotations, 96 | cov3D_precomp = cov3D_precomp) 97 | 98 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 99 | # They will be excluded from value updates used in the splitting criteria. 100 | return_dict = {"render": rendered_image, 101 | "viewspace_points": screenspace_points, 102 | "visibility_filter" : radii > 0, 103 | "radii": radii} 104 | 105 | if return_depth: 106 | projvect1 = viewpoint_camera.world_view_transform[:,2][:3].detach() 107 | projvect2 = viewpoint_camera.world_view_transform[:,2][-1].detach() 108 | means3D_depth = (means3D * projvect1.unsqueeze(0)).sum(dim=-1,keepdim=True) + projvect2 109 | means3D_depth = means3D_depth.repeat(1,3) 110 | render_depth, _ = rasterizer( 111 | means3D = means3D, 112 | means2D = means2D, 113 | shs = None, 114 | colors_precomp = means3D_depth, 115 | opacities = opacity, 116 | scales = scales, 117 | rotations = rotations, 118 | cov3D_precomp = cov3D_precomp) 119 | render_depth = render_depth.mean(dim=0) 120 | return_dict.update({'render_depth': render_depth}) 121 | 122 | if return_normal: 123 | rotations_mat = build_rotation(rotations) 124 | scales = pc.get_scaling 125 | min_scales = torch.argmin(scales, dim=1) 126 | indices = torch.arange(min_scales.shape[0]) 127 | normal = rotations_mat[indices, :, min_scales] 128 | 129 | # convert normal direction to the camera; calculate the normal in the camera coordinate 130 | view_dir = means3D - viewpoint_camera.camera_center 131 | normal = normal * ((((view_dir * normal).sum(dim=-1) < 0) * 1 - 0.5) * 2)[...,None] 132 | 133 | R_w2c = torch.tensor(viewpoint_camera.R.T).cuda().to(torch.float32) 134 | normal = (R_w2c @ normal.transpose(0, 1)).transpose(0, 1) 135 | 136 | render_normal, _ = rasterizer( 137 | means3D = means3D, 138 | means2D = means2D, 139 | shs = None, 140 | colors_precomp = normal, 141 | opacities = opacity, 142 | scales = scales, 143 | rotations = rotations, 144 | cov3D_precomp = cov3D_precomp) 145 | render_normal = F.normalize(render_normal, dim = 0) 146 | return_dict.update({'render_normal': render_normal}) 147 | 148 | if return_opacity: 149 | density = torch.ones_like(means3D) 150 | 151 | render_opacity, _ = rasterizer( 152 | means3D = means3D, 153 | means2D = means2D, 154 | shs = None, 155 | colors_precomp = density, 156 | opacities = opacity, 157 | scales = scales, 158 | rotations = rotations, 159 | cov3D_precomp = cov3D_precomp) 160 | return_dict.update({'render_opacity': render_opacity.mean(dim=0)}) 161 | 162 | return return_dict 163 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | # try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | 71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 72 | ssims.append(ssim(renders[idx], gts[idx])) 73 | psnrs.append(psnr(renders[idx], gts[idx])) 74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 75 | 76 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 77 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 78 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 79 | print("") 80 | 81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 82 | "PSNR": torch.tensor(psnrs).mean().item(), 83 | "LPIPS": torch.tensor(lpipss).mean().item()}) 84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 87 | 88 | with open(scene_dir + "/results.json", 'w') as fp: 89 | json.dump(full_dict[scene_dir], fp, indent=True) 90 | with open(scene_dir + "/per_view.json", 'w') as fp: 91 | json.dump(per_view_dict[scene_dir], fp, indent=True) 92 | # except: 93 | # print("Unable to compute metrics for model", scene_dir) 94 | 95 | if __name__ == "__main__": 96 | device = torch.device("cuda:0") 97 | torch.cuda.set_device(device) 98 | 99 | # Set up command line argument parser 100 | parser = ArgumentParser(description="Training script parameters") 101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 102 | args = parser.parse_args() 103 | evaluate(args.model_paths) 104 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state, vis_depth 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | import imageio 24 | import numpy as np 25 | 26 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 27 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 28 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 29 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "render_depth") 30 | normal_path = os.path.join(model_path, name, "ours_{}".format(iteration), "render_normal") 31 | 32 | makedirs(render_path, exist_ok=True) 33 | makedirs(gts_path, exist_ok=True) 34 | makedirs(depth_path , exist_ok=True) 35 | makedirs(normal_path, exist_ok=True) 36 | 37 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 38 | renders = render(view, gaussians, pipeline, background, return_depth=True, return_normal=True) 39 | rendering = renders["render"] 40 | gt = view.original_image[0:3, :, :] 41 | 42 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 43 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 44 | 45 | render_depth = renders["render_depth"] 46 | if view.sky_mask is not None: 47 | render_depth[~(view.sky_mask.to(render_depth.device).to(torch.bool))] = 300 48 | render_depth = vis_depth(render_depth.detach().cpu().numpy())[0] 49 | imageio.imwrite(os.path.join(depth_path , '{0:05d}'.format(idx) + ".png"), render_depth) 50 | 51 | render_normal = (renders["render_normal"] + 1.0) / 2.0 52 | if view.sky_mask is not None: 53 | render_normal[~(view.sky_mask.to(rendering.device).to(torch.bool).unsqueeze(0).repeat(3, 1, 1))] = -10 54 | # render_normal = renders["render_normal"] 55 | np.save(os.path.join(normal_path, '{0:05d}'.format(idx) + ".png"), renders["render_normal"].detach().cpu().numpy()) 56 | torchvision.utils.save_image(render_normal, os.path.join(normal_path, '{0:05d}'.format(idx) + ".png")) 57 | # normal_gt = torch.nn.functional.normalize(view.normal, p=2, dim=0) 58 | # render_normal_gt = (normal_gt + 1.0) / 2.0 59 | # torchvision.utils.save_image(render_normal_gt, os.path.join(normal_path, '{0:05d}'.format(idx) + "_normalgt.png")) 60 | # exit() 61 | 62 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): 63 | with torch.no_grad(): 64 | gaussians = GaussianModel(dataset.sh_degree) 65 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 66 | 67 | # gaussians._scaling[:, 0] = 0.001 68 | # gaussians._scaling[:, 1] = 0.0005 69 | # gaussians._scaling[:, 2] = -10000.0 70 | # gaussians._rotation[:, 0] = 1 71 | # gaussians._rotation[:, 1:] = 0 72 | scales = gaussians.get_scaling 73 | 74 | # min_scale, _ = torch.min(scales, dim=1) 75 | # max_scale, _ = torch.max(scales, dim=1) 76 | # median_scale, _ = torch.median(scales, dim=1) 77 | # print(min_scale) 78 | # print(max_scale) 79 | 80 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 81 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 82 | 83 | if not skip_train: 84 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 85 | 86 | if not skip_test: 87 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 88 | 89 | if __name__ == "__main__": 90 | # Set up command line argument parser 91 | parser = ArgumentParser(description="Testing script parameters") 92 | model = ModelParams(parser, sentinel=True) 93 | pipeline = PipelineParams(parser) 94 | parser.add_argument("--iteration", default=-1, type=int) 95 | parser.add_argument("--skip_train", action="store_true") 96 | parser.add_argument("--skip_test", action="store_true") 97 | parser.add_argument("--quiet", action="store_true") 98 | args = get_combined_args(parser) 99 | print("Rendering " + args.model_path) 100 | 101 | # Initialize system state (RNG) 102 | safe_state(args.quiet) 103 | 104 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) -------------------------------------------------------------------------------- /results/DeepBlending/drjohnson.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,29.25110626,0.907489359,0.223799944,784796221.4,3164513 3 | ,29.27627754,0.906062722,0.227002591,749050265.6,3020344 4 | -------------------------------------------------------------------------------- /results/DeepBlending/playroom.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,30.33006096,0.918174982,0.219789669,496238592,2000956 3 | ,30.40084267,0.918608427,0.222872108,466092032,1879412 4 | -------------------------------------------------------------------------------- /results/Eth3D/delivery_area.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,17.53661728,0.779447019,0.347055882,884998144,3568545 3 | ,19.40202713,0.817175806,0.3138583,844963512.3,3407113 4 | -------------------------------------------------------------------------------- /results/Eth3D/electro.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,15.60790539,0.703940392,0.398457259,835337584.6,3368290 3 | ,16.02513695,0.704342365,0.393627644,849703075.8,3426235 4 | -------------------------------------------------------------------------------- /results/Eth3D/kicker.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,22.60378647,0.765782118,0.37088266,341647032.3,1377615 3 | ,22.70046806,0.761263132,0.372863382,344551587.8,1389313 4 | -------------------------------------------------------------------------------- /results/Eth3D/meadow.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,14.37556458,0.396413594,0.495654911,800986234.9,3229780 3 | ,14.46836758,0.379970253,0.487953991,763195555.8,3077388 4 | -------------------------------------------------------------------------------- /results/Eth3D/office.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,17.37435532,0.805293322,0.322832406,173151354.9,698192 3 | ,17.59784698,0.809933305,0.325274765,175080734.7,705952 4 | -------------------------------------------------------------------------------- /results/Eth3D/playground.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,15.42733574,0.490227938,0.435589731,1236292076,4985058 3 | ,15.45620441,0.499574512,0.431878805,1214292951,4896345 4 | -------------------------------------------------------------------------------- /results/Eth3D/relief.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,26.46071815,0.829667032,0.286081493,355635036.2,1434025 3 | ,26.71376228,0.833274424,0.281712294,361140060.2,1456216 4 | -------------------------------------------------------------------------------- /results/Eth3D/relief2.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,27.4475174,0.876014948,0.25743258,309319434.2,1247246 3 | ,27.19592476,0.872983813,0.258035332,313052364.8,1262298 4 | -------------------------------------------------------------------------------- /results/Eth3D/terrace.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,20.75988007,0.780296981,0.276560545,409951273,1653027 3 | ,20.40294075,0.776602566,0.272293329,407738777.6,1644092 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/bicycle.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,25.06400108,0.747040689,0.240248889,1556464271,6276054 3 | ,25.12158394,0.747914374,0.244340554,1408059310,5677644 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/bonsai.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,32.31475449,0.946509063,0.18023923,319847137.3,1289692 3 | ,32.47714996,0.95010221,0.16658856,329525493.8,1328746 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/counter.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,29.10391808,0.915232599,0.181373119,299924193.3,1209366 3 | ,29.15148544,0.917220533,0.174561828,292699504.6,1180219 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/flowers.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,21.44571877,0.5914886,0.351740956,943477227.5,3804355 3 | ,21.35391426,0.589819193,0.352616757,808902983.7,3261679 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/garden.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,27.23822975,0.855030119,0.121553876,1459837993,5886426 3 | ,27.12190247,0.854436398,0.119427539,1228364841,4953084 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/kitchen.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,31.22276497,0.931057751,0.117309771,429842759.7,1733247 3 | ,31.52673721,0.932370305,0.115075439,431960883.2,1741790 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/room.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,31.28110695,0.922523379,0.200879946,377676103.7,1522889 3 | ,31.69297218,0.924869835,0.197009131,329515008,1328679 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/stump.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,26.64076614,0.772574306,0.234375477,1183664046,4772826 3 | ,26.60910416,0.770324588,0.239168838,1112067277,4484153 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/treehill.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,22.56472206,0.634812593,0.339756131,1025916273,4136754 3 | ,22.11814117,0.629052699,0.346725404,814082949.1,3282588 4 | -------------------------------------------------------------------------------- /results/TanksAndTemples/train.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,21.91499519,0.823102057,0.230635419,266768220,1075666 3 | ,21.49312401,0.815916955,0.239080429,259522560,1046465 4 | -------------------------------------------------------------------------------- /results/TanksAndTemples/truck.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,26.26991844,0.901777327,0.138400272,448087982.1,1806784 3 | ,25.92964554,0.897741735,0.132989123,439835689,1773540 4 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | if load_iteration: 34 | if load_iteration == -1: 35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 36 | else: 37 | self.loaded_iter = load_iteration 38 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 39 | 40 | self.train_cameras = {} 41 | self.test_cameras = {} 42 | 43 | if os.path.exists(os.path.join(args.source_path, "sparse")): 44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, 45 | sky_seg=args.sky_seg, load_normal=args.load_normal, load_depth=args.load_depth) 46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 47 | print("Found transforms_train.json file, assuming Blender data set!") 48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 49 | else: 50 | assert False, "Could not recognize scene type!" 51 | 52 | if not self.loaded_iter: 53 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 54 | dest_file.write(src_file.read()) 55 | json_cams = [] 56 | camlist = [] 57 | if scene_info.test_cameras: 58 | camlist.extend(scene_info.test_cameras) 59 | if scene_info.train_cameras: 60 | camlist.extend(scene_info.train_cameras) 61 | for id, cam in enumerate(camlist): 62 | json_cams.append(camera_to_JSON(id, cam)) 63 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 64 | json.dump(json_cams, file) 65 | 66 | # if shuffle: 67 | # random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 68 | # random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 69 | 70 | self.cameras_extent = scene_info.nerf_normalization["radius"] 71 | 72 | for resolution_scale in resolution_scales: 73 | print("Loading Training Cameras") 74 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 75 | print("Loading Test Cameras") 76 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 77 | 78 | if self.loaded_iter: 79 | self.gaussians.load_ply(os.path.join(self.model_path, 80 | "point_cloud", 81 | "iteration_" + str(self.loaded_iter), 82 | "point_cloud.ply")) 83 | else: 84 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 85 | 86 | def save(self, iteration): 87 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 88 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 89 | 90 | def getTrainCameras(self, scale=1.0): 91 | return self.train_cameras[scale] 92 | 93 | def getTestCameras(self, scale=1.0): 94 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", K=None, 21 | sky_mask=None, normal=None, depth=None 22 | ): 23 | super(Camera, self).__init__() 24 | 25 | self.uid = uid 26 | self.colmap_id = colmap_id 27 | self.R = R 28 | self.T = T 29 | self.FoVx = FoVx 30 | self.FoVy = FoVy 31 | self.image_name = image_name 32 | self.sky_mask = sky_mask 33 | self.normal = normal 34 | self.depth = depth 35 | 36 | try: 37 | self.data_device = torch.device(data_device) 38 | except Exception as e: 39 | print(e) 40 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 41 | self.data_device = torch.device("cuda") 42 | 43 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 44 | self.image_width = self.original_image.shape[2] 45 | self.image_height = self.original_image.shape[1] 46 | 47 | if gt_alpha_mask is not None: 48 | self.original_image *= gt_alpha_mask.to(self.data_device) 49 | else: 50 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 51 | 52 | self.K = torch.tensor([[K[0], 0, K[2]], 53 | [0, K[1], K[3]], 54 | [0, 0, 1]]).to(self.data_device).to(torch.float32) 55 | 56 | self.zfar = 100.0 57 | self.znear = 0.01 58 | 59 | self.trans = trans 60 | self.scale = scale 61 | 62 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 63 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 64 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 65 | self.camera_center = self.world_view_transform.inverse()[3, :3] 66 | 67 | class MiniCam: 68 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 69 | self.image_width = width 70 | self.image_height = height 71 | self.FoVy = fovy 72 | self.FoVx = fovx 73 | self.znear = znear 74 | self.zfar = zfar 75 | self.world_view_transform = world_view_transform 76 | self.full_proj_transform = full_proj_transform 77 | view_inv = torch.inverse(self.world_view_transform) 78 | self.camera_center = view_inv[3][:3] 79 | 80 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | import open3d as o3d 26 | 27 | class CameraInfo(NamedTuple): 28 | uid: int 29 | R: np.array 30 | T: np.array 31 | FovY: np.array 32 | FovX: np.array 33 | image: np.array 34 | image_path: str 35 | image_name: str 36 | width: int 37 | height: int 38 | K: np.array 39 | sky_mask: np.array 40 | normal: np.array 41 | depth: np.array 42 | 43 | class SceneInfo(NamedTuple): 44 | point_cloud: BasicPointCloud 45 | train_cameras: list 46 | test_cameras: list 47 | nerf_normalization: dict 48 | ply_path: str 49 | 50 | def getNerfppNorm(cam_info): 51 | def get_center_and_diag(cam_centers): 52 | cam_centers = np.hstack(cam_centers) 53 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 54 | center = avg_cam_center 55 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 56 | diagonal = np.max(dist) 57 | return center.flatten(), diagonal 58 | 59 | cam_centers = [] 60 | 61 | for cam in cam_info: 62 | W2C = getWorld2View2(cam.R, cam.T) 63 | C2W = np.linalg.inv(W2C) 64 | cam_centers.append(C2W[:3, 3:4]) 65 | 66 | center, diagonal = get_center_and_diag(cam_centers) 67 | radius = diagonal * 1.1 68 | 69 | translate = -center 70 | 71 | return {"translate": translate, "radius": radius} 72 | 73 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, sky_seg=False, load_normal=False, load_depth=False): 74 | cam_infos = [] 75 | for idx, key in enumerate(cam_extrinsics): 76 | sys.stdout.write('\r') 77 | # the exact output you're looking for: 78 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 79 | sys.stdout.flush() 80 | 81 | extr = cam_extrinsics[key] 82 | intr = cam_intrinsics[extr.camera_id] 83 | 84 | height = intr.height 85 | width = intr.width 86 | 87 | uid = intr.id 88 | R = np.transpose(qvec2rotmat(extr.qvec)) 89 | T = np.array(extr.tvec) 90 | 91 | if intr.model=="SIMPLE_PINHOLE": 92 | focal_length_x = intr.params[0] 93 | FovY = focal2fov(focal_length_x, height) 94 | FovX = focal2fov(focal_length_x, width) 95 | elif intr.model=="PINHOLE": 96 | focal_length_x = intr.params[0] 97 | focal_length_y = intr.params[1] 98 | FovY = focal2fov(focal_length_y, height) 99 | FovX = focal2fov(focal_length_x, width) 100 | else: 101 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 102 | 103 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 104 | image_name = os.path.basename(image_path).split(".")[0] 105 | 106 | image = Image.open(image_path) 107 | 108 | # #sky mask 109 | if sky_seg: 110 | sky_path = image_path.replace("images", "mask")[:-4]+".npy" 111 | sky_mask = np.load(sky_path).astype(np.uint8) 112 | else: 113 | sky_mask = None 114 | 115 | if load_normal: 116 | normal_path = image_path.replace("images", "normals")[:-4]+".npy" 117 | normal = np.load(normal_path).astype(np.float32) 118 | normal = (normal - 0.5) * 2.0 119 | else: 120 | normal = None 121 | 122 | if load_depth: 123 | # depth_path = image_path.replace("images", "monodepth")[:-4]+".npy" 124 | depth_path = image_path.replace("images", "metricdepth")[:-4]+".npy" 125 | depth = np.load(depth_path).astype(np.float32) 126 | else: 127 | depth = None 128 | 129 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 130 | image_path=image_path, image_name=image_name, width=width, height=height, 131 | K=intr.params, sky_mask=sky_mask, normal=normal, depth=depth) 132 | cam_infos.append(cam_info) 133 | sys.stdout.write('\n') 134 | return cam_infos 135 | 136 | def fetchPly(path): 137 | plydata = PlyData.read(path) 138 | vertices = plydata['vertex'] 139 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 140 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 141 | # normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 142 | normals = np.zeros_like(positions) 143 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 144 | 145 | def storePly(path, xyz, rgb): 146 | # Define the dtype for the structured array 147 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 148 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 149 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 150 | 151 | normals = np.zeros_like(xyz) 152 | 153 | elements = np.empty(xyz.shape[0], dtype=dtype) 154 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 155 | elements[:] = list(map(tuple, attributes)) 156 | 157 | # Create the PlyData object and write to file 158 | vertex_element = PlyElement.describe(elements, 'vertex') 159 | ply_data = PlyData([vertex_element]) 160 | ply_data.write(path) 161 | 162 | def readColmapSceneInfo(path, images, eval, llffhold=8, sky_seg=False, load_normal=False, load_depth=False): 163 | try: 164 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 165 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 166 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 167 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 168 | except: 169 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 170 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 171 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 172 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 173 | 174 | reading_dir = "images" if images == None else images 175 | 176 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), 177 | sky_seg=sky_seg, load_normal=load_normal, load_depth=load_depth) 178 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 179 | 180 | if eval: 181 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 182 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 183 | if 'waymo' in path: 184 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != (llffhold-1)] 185 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == (llffhold-1)] 186 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % (llffhold * 3) >= 3] 187 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % (llffhold * 3) < 3] 188 | else: 189 | train_cam_infos = cam_infos 190 | test_cam_infos = [] 191 | 192 | nerf_normalization = getNerfppNorm(train_cam_infos) 193 | 194 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 195 | 196 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 197 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 198 | if not os.path.exists(ply_path): 199 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 200 | try: 201 | xyz, rgb, _ = read_points3D_binary(bin_path) 202 | except: 203 | xyz, rgb, _ = read_points3D_text(txt_path) 204 | storePly(ply_path, xyz, rgb) 205 | try: 206 | pcd = fetchPly(ply_path) 207 | except: 208 | pcd = None 209 | 210 | scene_info = SceneInfo(point_cloud=pcd, 211 | train_cameras=train_cam_infos, 212 | test_cameras=test_cam_infos, 213 | nerf_normalization=nerf_normalization, 214 | ply_path=ply_path) 215 | return scene_info 216 | 217 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png", is_train=True): 218 | cam_infos = [] 219 | 220 | with open(os.path.join(path, transformsfile)) as json_file: 221 | contents = json.load(json_file) 222 | fovx = contents["camera_angle_x"] 223 | 224 | frames = contents["frames"] 225 | for idx, frame in enumerate(frames): 226 | cam_name = os.path.join(path, frame["file_path"] + extension) 227 | 228 | # NeRF 'transform_matrix' is a camera-to-world transform 229 | c2w = np.array(frame["transform_matrix"]) 230 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 231 | c2w[:3, 1:3] *= -1 232 | 233 | # get the world-to-camera transform and set R, T 234 | w2c = np.linalg.inv(c2w) 235 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 236 | T = w2c[:3, 3] 237 | 238 | image_path = os.path.join(path, cam_name) 239 | image_name = Path(cam_name).stem 240 | image = Image.open(image_path) 241 | 242 | im_data = np.array(image.convert("RGBA")) 243 | 244 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 245 | 246 | norm_data = im_data / 255.0 247 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 248 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 249 | 250 | sky_mask = np.ones_like(image)[:, :, 0].astype(np.uint8) 251 | 252 | if is_train: 253 | normal_path = image_path.replace("train", "normals")[:-4]+".npy" 254 | normal = np.load(normal_path).astype(np.float32) 255 | normal = (normal - 0.5) * 2.0 256 | # normal[2, :, :] *= -1 257 | else: 258 | normal = np.zeros_like(image).transpose(2, 0, 1) 259 | 260 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 261 | FovY = fovy 262 | FovX = fovx 263 | 264 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 265 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], 266 | K=np.array([1, 2, 3, 4]), sky_mask=sky_mask, normal=normal)) 267 | 268 | return cam_infos 269 | 270 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 271 | print("Reading Training Transforms") 272 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 273 | print("Reading Test Transforms") 274 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension, is_train=False) 275 | 276 | if not eval: 277 | train_cam_infos.extend(test_cam_infos) 278 | test_cam_infos = [] 279 | 280 | nerf_normalization = getNerfppNorm(train_cam_infos) 281 | 282 | ply_path = os.path.join(path, "points3d.ply") 283 | if not os.path.exists(ply_path): 284 | # Since this data set has no colmap data, we start with random points 285 | num_pts = 100_000 286 | print(f"Generating random point cloud ({num_pts})...") 287 | 288 | # We create random points inside the bounds of the synthetic Blender scenes 289 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 290 | shs = np.random.random((num_pts, 3)) / 255.0 291 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 292 | 293 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 294 | try: 295 | pcd = fetchPly(ply_path) 296 | except: 297 | pcd = None 298 | 299 | scene_info = SceneInfo(point_cloud=pcd, 300 | train_cameras=train_cam_infos, 301 | test_cameras=test_cam_infos, 302 | nerf_normalization=nerf_normalization, 303 | ply_path=ply_path) 304 | return scene_info 305 | 306 | sceneLoadTypeCallbacks = { 307 | "Colmap": readColmapSceneInfo, 308 | "Blender" : readNerfSyntheticInfo 309 | } -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud, getWorld2View2 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation 23 | 24 | class GaussianModel: 25 | 26 | def setup_functions(self): 27 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 28 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 29 | actual_covariance = L @ L.transpose(1, 2) 30 | symm = strip_symmetric(actual_covariance) 31 | return symm 32 | 33 | self.scaling_activation = torch.exp 34 | self.scaling_inverse_activation = torch.log 35 | 36 | self.covariance_activation = build_covariance_from_scaling_rotation 37 | 38 | self.opacity_activation = torch.sigmoid 39 | self.inverse_opacity_activation = inverse_sigmoid 40 | 41 | self.rotation_activation = torch.nn.functional.normalize 42 | 43 | 44 | def __init__(self, sh_degree : int): 45 | self.active_sh_degree = 0 46 | self.max_sh_degree = sh_degree 47 | self._xyz = torch.empty(0) 48 | self._features_dc = torch.empty(0) 49 | self._features_rest = torch.empty(0) 50 | self._scaling = torch.empty(0) 51 | self._rotation = torch.empty(0) 52 | self._opacity = torch.empty(0) 53 | self.max_radii2D = torch.empty(0) 54 | self.xyz_gradient_accum = torch.empty(0) 55 | self.denom = torch.empty(0) 56 | self.optimizer = None 57 | self.percent_dense = 0 58 | self.spatial_lr_scale = 0 59 | self.setup_functions() 60 | 61 | def capture(self): 62 | return ( 63 | self.active_sh_degree, 64 | self._xyz, 65 | self._features_dc, 66 | self._features_rest, 67 | self._scaling, 68 | self._rotation, 69 | self._opacity, 70 | self.max_radii2D, 71 | self.xyz_gradient_accum, 72 | self.denom, 73 | self.optimizer.state_dict(), 74 | self.spatial_lr_scale, 75 | ) 76 | 77 | def restore(self, model_args, training_args): 78 | (self.active_sh_degree, 79 | self._xyz, 80 | self._features_dc, 81 | self._features_rest, 82 | self._scaling, 83 | self._rotation, 84 | self._opacity, 85 | self.max_radii2D, 86 | xyz_gradient_accum, 87 | denom, 88 | opt_dict, 89 | self.spatial_lr_scale) = model_args 90 | self.training_setup(training_args) 91 | self.xyz_gradient_accum = xyz_gradient_accum 92 | self.denom = denom 93 | self.optimizer.load_state_dict(opt_dict) 94 | 95 | @property 96 | def get_scaling(self): 97 | return self.scaling_activation(self._scaling) 98 | 99 | @property 100 | def get_rotation(self): 101 | return self.rotation_activation(self._rotation) 102 | 103 | @property 104 | def get_xyz(self): 105 | return self._xyz 106 | 107 | @property 108 | def get_features(self): 109 | features_dc = self._features_dc 110 | features_rest = self._features_rest 111 | return torch.cat((features_dc, features_rest), dim=1) 112 | 113 | @property 114 | def get_opacity(self): 115 | return self.opacity_activation(self._opacity) 116 | 117 | def get_covariance(self, scaling_modifier = 1): 118 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 119 | 120 | def oneupSHdegree(self): 121 | if self.active_sh_degree < self.max_sh_degree: 122 | self.active_sh_degree += 1 123 | 124 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): 125 | self.spatial_lr_scale = spatial_lr_scale 126 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 127 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 128 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 129 | features[:, :3, 0 ] = fused_color 130 | features[:, 3:, 1:] = 0.0 131 | 132 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 133 | 134 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 135 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 136 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 137 | rots[:, 0] = 1 138 | 139 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 140 | 141 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 142 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 143 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 144 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 145 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 146 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 147 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 148 | 149 | def training_setup(self, training_args): 150 | self.percent_dense = training_args.percent_dense 151 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 152 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 153 | 154 | l = [ 155 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 156 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 157 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 158 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 159 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 160 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 161 | ] 162 | 163 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 164 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, 165 | lr_final=training_args.position_lr_final*self.spatial_lr_scale, 166 | lr_delay_mult=training_args.position_lr_delay_mult, 167 | max_steps=training_args.position_lr_max_steps) 168 | 169 | def update_learning_rate(self, iteration): 170 | ''' Learning rate scheduling per step ''' 171 | for param_group in self.optimizer.param_groups: 172 | if param_group["name"] == "xyz": 173 | lr = self.xyz_scheduler_args(iteration) 174 | param_group['lr'] = lr 175 | return lr 176 | 177 | def construct_list_of_attributes(self): 178 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 179 | # All channels except the 3 DC 180 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 181 | l.append('f_dc_{}'.format(i)) 182 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 183 | l.append('f_rest_{}'.format(i)) 184 | l.append('opacity') 185 | for i in range(self._scaling.shape[1]): 186 | l.append('scale_{}'.format(i)) 187 | for i in range(self._rotation.shape[1]): 188 | l.append('rot_{}'.format(i)) 189 | return l 190 | 191 | def save_ply(self, path): 192 | mkdir_p(os.path.dirname(path)) 193 | 194 | xyz = self._xyz.detach().cpu().numpy() 195 | normals = np.zeros_like(xyz) 196 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 197 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 198 | opacities = self._opacity.detach().cpu().numpy() 199 | scale = self._scaling.detach().cpu().numpy() 200 | rotation = self._rotation.detach().cpu().numpy() 201 | 202 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 203 | 204 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 205 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 206 | elements[:] = list(map(tuple, attributes)) 207 | el = PlyElement.describe(elements, 'vertex') 208 | PlyData([el]).write(path) 209 | 210 | def reset_opacity(self): 211 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 212 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 213 | self._opacity = optimizable_tensors["opacity"] 214 | 215 | def load_ply(self, path): 216 | plydata = PlyData.read(path) 217 | 218 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 219 | np.asarray(plydata.elements[0]["y"]), 220 | np.asarray(plydata.elements[0]["z"])), axis=1) 221 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 222 | 223 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 224 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 225 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 226 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 227 | 228 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 229 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 230 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 231 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 232 | for idx, attr_name in enumerate(extra_f_names): 233 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 234 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 235 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 236 | 237 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 238 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 239 | scales = np.zeros((xyz.shape[0], len(scale_names))) 240 | for idx, attr_name in enumerate(scale_names): 241 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 242 | 243 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 244 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 245 | rots = np.zeros((xyz.shape[0], len(rot_names))) 246 | for idx, attr_name in enumerate(rot_names): 247 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 248 | 249 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 250 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 251 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 252 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 253 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 254 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 255 | 256 | self.active_sh_degree = self.max_sh_degree 257 | 258 | def replace_tensor_to_optimizer(self, tensor, name): 259 | optimizable_tensors = {} 260 | for group in self.optimizer.param_groups: 261 | if group["name"] == name: 262 | stored_state = self.optimizer.state.get(group['params'][0], None) 263 | stored_state["exp_avg"] = torch.zeros_like(tensor) 264 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 265 | 266 | del self.optimizer.state[group['params'][0]] 267 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 268 | self.optimizer.state[group['params'][0]] = stored_state 269 | 270 | optimizable_tensors[group["name"]] = group["params"][0] 271 | return optimizable_tensors 272 | 273 | def _prune_optimizer(self, mask): 274 | optimizable_tensors = {} 275 | for group in self.optimizer.param_groups: 276 | stored_state = self.optimizer.state.get(group['params'][0], None) 277 | if stored_state is not None: 278 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 279 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 280 | 281 | del self.optimizer.state[group['params'][0]] 282 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 283 | self.optimizer.state[group['params'][0]] = stored_state 284 | 285 | optimizable_tensors[group["name"]] = group["params"][0] 286 | else: 287 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 288 | optimizable_tensors[group["name"]] = group["params"][0] 289 | return optimizable_tensors 290 | 291 | def prune_points(self, mask): 292 | valid_points_mask = ~mask 293 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 294 | 295 | self._xyz = optimizable_tensors["xyz"] 296 | self._features_dc = optimizable_tensors["f_dc"] 297 | self._features_rest = optimizable_tensors["f_rest"] 298 | self._opacity = optimizable_tensors["opacity"] 299 | self._scaling = optimizable_tensors["scaling"] 300 | self._rotation = optimizable_tensors["rotation"] 301 | 302 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 303 | 304 | self.denom = self.denom[valid_points_mask] 305 | self.max_radii2D = self.max_radii2D[valid_points_mask] 306 | 307 | def cat_tensors_to_optimizer(self, tensors_dict): 308 | optimizable_tensors = {} 309 | for group in self.optimizer.param_groups: 310 | assert len(group["params"]) == 1 311 | extension_tensor = tensors_dict[group["name"]] 312 | stored_state = self.optimizer.state.get(group['params'][0], None) 313 | if stored_state is not None: 314 | 315 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 316 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 317 | 318 | del self.optimizer.state[group['params'][0]] 319 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 320 | self.optimizer.state[group['params'][0]] = stored_state 321 | 322 | optimizable_tensors[group["name"]] = group["params"][0] 323 | else: 324 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 325 | optimizable_tensors[group["name"]] = group["params"][0] 326 | 327 | return optimizable_tensors 328 | 329 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): 330 | d = {"xyz": new_xyz, 331 | "f_dc": new_features_dc, 332 | "f_rest": new_features_rest, 333 | "opacity": new_opacities, 334 | "scaling" : new_scaling, 335 | "rotation" : new_rotation} 336 | 337 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 338 | self._xyz = optimizable_tensors["xyz"] 339 | self._features_dc = optimizable_tensors["f_dc"] 340 | self._features_rest = optimizable_tensors["f_rest"] 341 | self._opacity = optimizable_tensors["opacity"] 342 | self._scaling = optimizable_tensors["scaling"] 343 | self._rotation = optimizable_tensors["rotation"] 344 | 345 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 346 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 347 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 348 | 349 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 350 | n_init_points = self.get_xyz.shape[0] 351 | # Extract points that satisfy the gradient condition 352 | padded_grad = torch.zeros((n_init_points), device="cuda") 353 | padded_grad[:grads.shape[0]] = grads.squeeze() 354 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 355 | selected_pts_mask = torch.logical_and(selected_pts_mask, 356 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 357 | 358 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 359 | means =torch.zeros((stds.size(0), 3),device="cuda") 360 | samples = torch.normal(mean=means, std=stds) 361 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 362 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 363 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 364 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 365 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 366 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 367 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 368 | 369 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) 370 | 371 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 372 | self.prune_points(prune_filter) 373 | 374 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 375 | # Extract points that satisfy the gradient condition 376 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 377 | selected_pts_mask = torch.logical_and(selected_pts_mask, 378 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 379 | 380 | new_xyz = self._xyz[selected_pts_mask] 381 | new_features_dc = self._features_dc[selected_pts_mask] 382 | new_features_rest = self._features_rest[selected_pts_mask] 383 | new_opacities = self._opacity[selected_pts_mask] 384 | new_scaling = self._scaling[selected_pts_mask] 385 | new_rotation = self._rotation[selected_pts_mask] 386 | 387 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) 388 | 389 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 390 | grads = self.xyz_gradient_accum / self.denom 391 | grads[grads.isnan()] = 0.0 392 | 393 | self.densify_and_clone(grads, max_grad, extent) 394 | self.densify_and_split(grads, max_grad, extent) 395 | 396 | prune_mask = (self.get_opacity < min_opacity).squeeze() 397 | if max_screen_size: 398 | big_points_vs = self.max_radii2D > max_screen_size 399 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 400 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 401 | self.prune_points(prune_mask) 402 | 403 | torch.cuda.empty_cache() 404 | 405 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 406 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) 407 | self.denom[update_filter] += 1 408 | 409 | def densify_from_depth_propagation(self, viewpoint_cam, propagated_depth, filter_mask, gt_image): 410 | # inverse project pixels into 3D scenes 411 | K = viewpoint_cam.K 412 | cam2world = viewpoint_cam.world_view_transform.transpose(0, 1).inverse() 413 | 414 | # Get the shape of the depth image 415 | height, width = propagated_depth.shape 416 | # Create a grid of 2D pixel coordinates 417 | y, x = torch.meshgrid(torch.arange(0, height), torch.arange(0, width)) 418 | # Stack the 2D and depth coordinates to create 3D homogeneous coordinates 419 | coordinates = torch.stack([x.to(propagated_depth.device), y.to(propagated_depth.device), torch.ones_like(propagated_depth)], dim=-1) 420 | # Reshape the coordinates to (height * width, 3) 421 | coordinates = coordinates.view(-1, 3).to(K.device).to(torch.float32) 422 | # Reproject the 2D coordinates to 3D coordinates 423 | coordinates_3D = (K.inverse() @ coordinates.T).T 424 | 425 | # Multiply by depth 426 | coordinates_3D *= propagated_depth.view(-1, 1) 427 | 428 | # convert to the world coordinate 429 | world_coordinates_3D = (cam2world[:3, :3] @ coordinates_3D.T).T + cam2world[:3, 3] 430 | 431 | # import open3d as o3d 432 | # point_cloud = o3d.geometry.PointCloud() 433 | # point_cloud.points = o3d.utility.Vector3dVector(world_coordinates_3D.detach().cpu().numpy()) 434 | # o3d.io.write_point_cloud("partpc.ply", point_cloud) 435 | # exit() 436 | 437 | #mask the points below the confidence threshold 438 | #downsample the pixels; 1/4 439 | world_coordinates_3D = world_coordinates_3D.view(height, width, 3) 440 | world_coordinates_3D_downsampled = world_coordinates_3D[::8, ::8] 441 | filter_mask_downsampled = filter_mask[::8, ::8] 442 | gt_image_downsampled = gt_image.permute(1, 2, 0)[::8, ::8] 443 | 444 | world_coordinates_3D_downsampled = world_coordinates_3D_downsampled[filter_mask_downsampled] 445 | color_downsampled = gt_image_downsampled[filter_mask_downsampled] 446 | 447 | # initialize gaussians 448 | fused_point_cloud = world_coordinates_3D_downsampled 449 | fused_color = RGB2SH(color_downsampled) 450 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).to(fused_color.device) 451 | features[:, :3, 0 ] = fused_color 452 | features[:, 3:, 1:] = 0.0 453 | 454 | original_point_cloud = self.get_xyz 455 | # initialize the scale from the mode, if using the distance to calculate, there are outliers, if using the whole gaussians, it is memory consuming 456 | # quantile_scale = torch.quantile(self.get_scaling, 0.5, dim=0) 457 | # scales = self.scaling_inverse_activation(quantile_scale.unsqueeze(0).repeat(fused_point_cloud.shape[0], 1)) 458 | fused_shape = fused_point_cloud.shape[0] 459 | all_point_cloud = torch.concat([fused_point_cloud, original_point_cloud], dim=0) 460 | all_dist2 = torch.clamp_min(distCUDA2(all_point_cloud), 0.0000001) 461 | dist2 = all_dist2[:fused_shape] 462 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 463 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 464 | rots[:, 0] = 1 465 | 466 | opacities = inverse_sigmoid(1.0 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 467 | 468 | new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 469 | new_features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 470 | new_features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 471 | new_scaling = nn.Parameter(scales.requires_grad_(True)) 472 | new_rotation = nn.Parameter(rots.requires_grad_(True)) 473 | new_opacity = nn.Parameter(opacities.requires_grad_(True)) 474 | 475 | #update gaussians 476 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) -------------------------------------------------------------------------------- /scripts/demo.sh: -------------------------------------------------------------------------------- 1 | python train.py -s $path/to/data$ -m $save_path$ \ 2 | --eval --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 3 | 4 | python render.py -m $save_path$ 5 | python metrics.py -m $save_path$ 6 | 7 | python train.py -s $path/to/data$ -m $save_path$ \ 8 | --eval --flatten_loss --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 \ 9 | --normal_loss --depth_loss --propagation_interval 50 --depth_error_min_threshold 0.8 --depth_error_max_threshold 1.0 \ 10 | --propagated_iteration_begin 1000 --propagated_iteration_after 6000 --patch_size 20 --lambda_l1_normal 0.001 --lambda_cos_normal 0.001 11 | 12 | python render.py -m $save_path$ 13 | python metrics.py -m $save_path$ 14 | 15 | # normal_loss -- whether using planar-constrained loss 16 | # depth_loss -- whether using propagation 17 | # propagation_interval -- the frequency for activating propagation 18 | # depth_error_min_threshold -- the final threshold of relative depth error between rendered depth and propagated depth for initializing new gaussians 19 | # depth_error_max_threshold -- the initial threshold of relative depth error between rendered depth and propagated depth for initializing new gaussians 20 | # patch size for patchmatching, make it bigger if your scenes are consisted of many large textureless planes, smaller otherwise 21 | # lambda_xx_normal normal loss weight 22 | -------------------------------------------------------------------------------- /scripts/waymo.sh: -------------------------------------------------------------------------------- 1 | python train.py -s $path/to/data$ -m $save_path$ \ 2 | --eval --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 --dataset waymo 3 | 4 | python render.py -m $save_path$ 5 | python metrics.py -m $save_path$ 6 | 7 | python train.py -s $path/to/data$ -m $save_path$ \ 8 | --eval --flatten_loss --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 --dataset waymo \ 9 | --sky_seg --normal_loss --depth_loss --propagation_interval 30 --depth_error_min_threshold 0.8 --depth_error_max_threshold 1.0 \ 10 | --propagated_iteration_begin 1000 --propagated_iteration_after 12000 --patch_size 20 --lambda_l1_normal 0.001 --lambda_cos_normal 0.001 11 | 12 | python render.py -m $save_path$ 13 | python metrics.py -m $save_path$ 14 | -------------------------------------------------------------------------------- /submodules/Propagation/PatchMatch.cpp: -------------------------------------------------------------------------------- 1 | #include "PatchMatch.h" 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | void StringAppendV(std::string* dst, const char* format, va_list ap) { 8 | // First try with a small fixed size buffer. 9 | static const int kFixedBufferSize = 1024; 10 | char fixed_buffer[kFixedBufferSize]; 11 | 12 | // It is possible for methods that use a va_list to invalidate 13 | // the data in it upon use. The fix is to make a copy 14 | // of the structure before using it and use that copy instead. 15 | va_list backup_ap; 16 | va_copy(backup_ap, ap); 17 | int result = vsnprintf(fixed_buffer, kFixedBufferSize, format, backup_ap); 18 | va_end(backup_ap); 19 | 20 | if (result < kFixedBufferSize) { 21 | if (result >= 0) { 22 | // Normal case - everything fits. 23 | dst->append(fixed_buffer, result); 24 | return; 25 | } 26 | 27 | #ifdef _MSC_VER 28 | // Error or MSVC running out of space. MSVC 8.0 and higher 29 | // can be asked about space needed with the special idiom below: 30 | va_copy(backup_ap, ap); 31 | result = vsnprintf(nullptr, 0, format, backup_ap); 32 | va_end(backup_ap); 33 | #endif 34 | 35 | if (result < 0) { 36 | // Just an error. 37 | return; 38 | } 39 | } 40 | 41 | // Increase the buffer size to the size requested by vsnprintf, 42 | // plus one for the closing \0. 43 | const int variable_buffer_size = result + 1; 44 | std::unique_ptr variable_buffer(new char[variable_buffer_size]); 45 | 46 | // Restore the va_list before we use it again. 47 | va_copy(backup_ap, ap); 48 | result = 49 | vsnprintf(variable_buffer.get(), variable_buffer_size, format, backup_ap); 50 | va_end(backup_ap); 51 | 52 | if (result >= 0 && result < variable_buffer_size) { 53 | dst->append(variable_buffer.get(), result); 54 | } 55 | } 56 | 57 | std::string StringPrintf(const char* format, ...) { 58 | va_list ap; 59 | va_start(ap, format); 60 | std::string result; 61 | StringAppendV(&result, format, ap); 62 | va_end(ap); 63 | return result; 64 | } 65 | 66 | void CudaSafeCall(const cudaError_t error, const std::string& file, 67 | const int line) { 68 | if (error != cudaSuccess) { 69 | std::cerr << StringPrintf("%s in %s at line %i", cudaGetErrorString(error), 70 | file.c_str(), line) 71 | << std::endl; 72 | exit(EXIT_FAILURE); 73 | } 74 | } 75 | 76 | void CudaCheckError(const char* file, const int line) { 77 | cudaError error = cudaGetLastError(); 78 | if (error != cudaSuccess) { 79 | std::cerr << StringPrintf("cudaCheckError() failed at %s:%i : %s", file, 80 | line, cudaGetErrorString(error)) 81 | << std::endl; 82 | exit(EXIT_FAILURE); 83 | } 84 | 85 | // More careful checking. However, this will affect performance. 86 | // Comment away if needed. 87 | error = cudaDeviceSynchronize(); 88 | if (cudaSuccess != error) { 89 | std::cerr << StringPrintf("cudaCheckError() with sync failed at %s:%i : %s", 90 | file, line, cudaGetErrorString(error)) 91 | << std::endl; 92 | std::cerr 93 | << "This error is likely caused by the graphics card timeout " 94 | "detection mechanism of your operating system. Please refer to " 95 | "the FAQ in the documentation on how to solve this problem." 96 | << std::endl; 97 | exit(EXIT_FAILURE); 98 | } 99 | } 100 | 101 | PatchMatch::PatchMatch() {} 102 | 103 | PatchMatch::~PatchMatch() 104 | { 105 | delete[] plane_hypotheses_host; 106 | delete[] costs_host; 107 | 108 | for (int i = 0; i < num_images; ++i) { 109 | cudaDestroyTextureObject(texture_objects_host.images[i]); 110 | cudaFreeArray(cuArray[i]); 111 | } 112 | cudaFree(texture_objects_cuda); 113 | cudaFree(cameras_cuda); 114 | cudaFree(plane_hypotheses_cuda); 115 | cudaFree(costs_cuda); 116 | cudaFree(rand_states_cuda); 117 | cudaFree(selected_views_cuda); 118 | cudaFree(depths_cuda); 119 | 120 | if (params.geom_consistency) { 121 | for (int i = 0; i < num_images; ++i) { 122 | cudaDestroyTextureObject(texture_depths_host.images[i]); 123 | cudaFreeArray(cuDepthArray[i]); 124 | } 125 | cudaFree(texture_depths_cuda); 126 | } 127 | } 128 | 129 | Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval) 130 | { 131 | Camera camera; 132 | 133 | for (int i = 0; i < 3; ++i) { 134 | camera.R[3 * i + 0] = pose[i][0].item(); 135 | camera.R[3 * i + 1] = pose[i][1].item(); 136 | camera.R[3 * i + 2] = pose[i][2].item(); 137 | camera.t[i] = pose[i][3].item(); 138 | } 139 | 140 | for (int i = 0; i < 3; ++i) { 141 | camera.K[3 * i + 0] = intrinsic[i][0].item(); 142 | camera.K[3 * i + 1] = intrinsic[i][1].item(); 143 | camera.K[3 * i + 2] = intrinsic[i][2].item(); 144 | } 145 | 146 | camera.depth_min = depth_interval[0].item(); 147 | camera.depth_max = depth_interval[3].item(); 148 | 149 | return camera; 150 | } 151 | 152 | void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera) 153 | { 154 | const int cols = depth.size(1); 155 | const int rows = depth.size(0); 156 | 157 | if (cols == src.size(1) && rows == src.size(0)) { 158 | dst = src.clone(); 159 | return; 160 | } 161 | 162 | const float scale_x = cols / static_cast(src.size(1)); 163 | const float scale_y = rows / static_cast(src.size(0)); 164 | dst = torch::nn::functional::interpolate(src.unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector({rows, cols})).mode(torch::kBilinear)).squeeze(0); 165 | 166 | camera.K[0] *= scale_x; 167 | camera.K[2] *= scale_x; 168 | camera.K[4] *= scale_y; 169 | camera.K[5] *= scale_y; 170 | camera.width = cols; 171 | camera.height = rows; 172 | } 173 | 174 | float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera) 175 | { 176 | float3 pointX; 177 | float3 tmpX; 178 | // Reprojection 179 | pointX.x = depth * (x - camera.K[2]) / camera.K[0]; 180 | pointX.y = depth * (y - camera.K[5]) / camera.K[4]; 181 | pointX.z = depth; 182 | 183 | // Rotation 184 | tmpX.x = camera.R[0] * pointX.x + camera.R[3] * pointX.y + camera.R[6] * pointX.z; 185 | tmpX.y = camera.R[1] * pointX.x + camera.R[4] * pointX.y + camera.R[7] * pointX.z; 186 | tmpX.z = camera.R[2] * pointX.x + camera.R[5] * pointX.y + camera.R[8] * pointX.z; 187 | 188 | // Transformation 189 | float3 C; 190 | C.x = -(camera.R[0] * camera.t[0] + camera.R[3] * camera.t[1] + camera.R[6] * camera.t[2]); 191 | C.y = -(camera.R[1] * camera.t[0] + camera.R[4] * camera.t[1] + camera.R[7] * camera.t[2]); 192 | C.z = -(camera.R[2] * camera.t[0] + camera.R[5] * camera.t[1] + camera.R[8] * camera.t[2]); 193 | pointX.x = tmpX.x + C.x; 194 | pointX.y = tmpX.y + C.y; 195 | pointX.z = tmpX.z + C.z; 196 | 197 | return pointX; 198 | } 199 | 200 | void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth) 201 | { 202 | float3 tmp; 203 | tmp.x = camera.R[0] * PointX.x + camera.R[1] * PointX.y + camera.R[2] * PointX.z + camera.t[0]; 204 | tmp.y = camera.R[3] * PointX.x + camera.R[4] * PointX.y + camera.R[5] * PointX.z + camera.t[1]; 205 | tmp.z = camera.R[6] * PointX.x + camera.R[7] * PointX.y + camera.R[8] * PointX.z + camera.t[2]; 206 | 207 | depth = camera.K[6] * tmp.x + camera.K[7] * tmp.y + camera.K[8] * tmp.z; 208 | point.x = (camera.K[0] * tmp.x + camera.K[1] * tmp.y + camera.K[2] * tmp.z) / depth; 209 | point.y = (camera.K[3] * tmp.x + camera.K[4] * tmp.y + camera.K[5] * tmp.z) / depth; 210 | } 211 | 212 | float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2) 213 | { 214 | float dot_product = v1[0].item() * v2[0].item() + v1[1].item() * v2[1].item() + v1[2].item() * v2[2].item(); 215 | float angle = acosf(dot_product); 216 | //if angle is not a number the dot product was 1 and thus the two vectors should be identical --> return 0 217 | if ( angle != angle ) 218 | return 0.0f; 219 | 220 | return angle; 221 | } 222 | 223 | void StoreColorPlyFileBinaryPointCloud (const std::string &plyFilePath, const std::vector &pc) 224 | { 225 | std::cout << "store 3D points to ply file" << std::endl; 226 | 227 | FILE *outputPly; 228 | outputPly=fopen(plyFilePath.c_str(), "wb"); 229 | 230 | /*write header*/ 231 | fprintf(outputPly, "ply\n"); 232 | fprintf(outputPly, "format binary_little_endian 1.0\n"); 233 | fprintf(outputPly, "element vertex %d\n",pc.size()); 234 | fprintf(outputPly, "property float x\n"); 235 | fprintf(outputPly, "property float y\n"); 236 | fprintf(outputPly, "property float z\n"); 237 | fprintf(outputPly, "property float nx\n"); 238 | fprintf(outputPly, "property float ny\n"); 239 | fprintf(outputPly, "property float nz\n"); 240 | fprintf(outputPly, "property uchar red\n"); 241 | fprintf(outputPly, "property uchar green\n"); 242 | fprintf(outputPly, "property uchar blue\n"); 243 | fprintf(outputPly, "end_header\n"); 244 | 245 | //write data 246 | #pragma omp parallel for 247 | for(size_t i = 0; i < pc.size(); i++) { 248 | const PointList &p = pc[i]; 249 | float3 X = p.coord; 250 | const float3 normal = p.normal; 251 | const float3 color = p.color; 252 | const char b_color = (int)color.x; 253 | const char g_color = (int)color.y; 254 | const char r_color = (int)color.z; 255 | 256 | if(!(X.x < FLT_MAX && X.x > -FLT_MAX) || !(X.y < FLT_MAX && X.y > -FLT_MAX) || !(X.z < FLT_MAX && X.z >= -FLT_MAX)){ 257 | X.x = 0.0f; 258 | X.y = 0.0f; 259 | X.z = 0.0f; 260 | } 261 | #pragma omp critical 262 | { 263 | fwrite(&X.x, sizeof(X.x), 1, outputPly); 264 | fwrite(&X.y, sizeof(X.y), 1, outputPly); 265 | fwrite(&X.z, sizeof(X.z), 1, outputPly); 266 | fwrite(&normal.x, sizeof(normal.x), 1, outputPly); 267 | fwrite(&normal.y, sizeof(normal.y), 1, outputPly); 268 | fwrite(&normal.z, sizeof(normal.z), 1, outputPly); 269 | fwrite(&r_color, sizeof(char), 1, outputPly); 270 | fwrite(&g_color, sizeof(char), 1, outputPly); 271 | fwrite(&b_color, sizeof(char), 1, outputPly); 272 | } 273 | 274 | } 275 | fclose(outputPly); 276 | } 277 | 278 | static float GetDisparity(const Camera &camera, const int2 &p, const float &depth) 279 | { 280 | float point3D[3]; 281 | point3D[0] = depth * (p.x - camera.K[2]) / camera.K[0]; 282 | point3D[1] = depth * (p.y - camera.K[5]) / camera.K[4]; 283 | point3D[2] = depth; 284 | 285 | return std::sqrt(point3D[0] * point3D[0] + point3D[1] * point3D[1] + point3D[2] * point3D[2]); 286 | } 287 | 288 | void PatchMatch::SetGeomConsistencyParams() 289 | { 290 | params.geom_consistency = true; 291 | params.max_iterations = 2; 292 | } 293 | 294 | void PatchMatch::InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, 295 | torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals) 296 | { 297 | images.clear(); 298 | cameras.clear(); 299 | 300 | torch::Tensor image_color = images_cuda[0]; 301 | torch::Tensor image_float = torch::mean(image_color, /*dim=*/2, /*keepdim=*/true).squeeze(); 302 | image_float = image_float.to(torch::kFloat32); 303 | images.push_back(image_float); 304 | 305 | Camera camera = ReadCamera(intrinsics_cuda[0], poses_cuda[0], depth_intervals[0]); 306 | camera.height = image_float.size(0); 307 | camera.width = image_float.size(1); 308 | cameras.push_back(camera); 309 | 310 | torch::Tensor ref_depth = depth_cuda; 311 | depths.push_back(ref_depth); 312 | 313 | int num_src_images = images_cuda.size(0); 314 | for (int i = 1; i < num_src_images; ++i) { 315 | torch::Tensor src_image_color = images_cuda[i]; 316 | torch::Tensor src_image_float = torch::mean(src_image_color, /*dim=*/2, /*keepdim=*/true).squeeze(); 317 | src_image_float = src_image_float.to(torch::kFloat32); 318 | images.push_back(src_image_float); 319 | 320 | Camera camera = ReadCamera(intrinsics_cuda[i], poses_cuda[i], depth_intervals[i]); 321 | camera.height = src_image_float.size(0); 322 | camera.width = src_image_float.size(1); 323 | cameras.push_back(camera); 324 | } 325 | 326 | // Scale cameras and images 327 | for (size_t i = 0; i < images.size(); ++i) { 328 | if (images[i].size(1) <= params.max_image_size && images[i].size(0) <= params.max_image_size) { 329 | continue; 330 | } 331 | 332 | const float factor_x = static_cast(params.max_image_size) / images[i].size(1); 333 | const float factor_y = static_cast(params.max_image_size) / images[i].size(0); 334 | const float factor = std::min(factor_x, factor_y); 335 | 336 | const int new_cols = std::round(images[i].size(1) * factor); 337 | const int new_rows = std::round(images[i].size(0) * factor); 338 | 339 | const float scale_x = new_cols / static_cast(images[i].size(1)); 340 | const float scale_y = new_rows / static_cast(images[i].size(0)); 341 | 342 | torch::Tensor scaled_image_float = torch::nn::functional::interpolate(images[i].unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector({new_rows, new_cols})).mode(torch::kBilinear)).squeeze(0); 343 | images[i] = scaled_image_float.clone(); 344 | 345 | cameras[i].K[0] *= scale_x; 346 | cameras[i].K[2] *= scale_x; 347 | cameras[i].K[4] *= scale_y; 348 | cameras[i].K[5] *= scale_y; 349 | cameras[i].height = scaled_image_float.size(0); 350 | cameras[i].width = scaled_image_float.size(1); 351 | } 352 | 353 | params.depth_min = cameras[0].depth_min * 0.6f; 354 | params.depth_max = cameras[0].depth_max * 1.2f; 355 | params.num_images = (int)images.size(); 356 | params.disparity_min = cameras[0].K[0] * params.baseline / params.depth_max; 357 | params.disparity_max = cameras[0].K[0] * params.baseline / params.depth_min; 358 | 359 | } 360 | 361 | void PatchMatch::CudaSpaceInitialization() 362 | { 363 | num_images = (int)images.size(); 364 | 365 | for (int i = 0; i < num_images; ++i) { 366 | int rows = images[i].size(0); 367 | int cols = images[i].size(1); 368 | 369 | cudaChannelFormatDesc channelDesc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); 370 | cudaMallocArray(&cuArray[i], &channelDesc, cols, rows); 371 | 372 | cudaMemcpy2DToArray(cuArray[i], 0, 0, images[i].data_ptr(), images[i].stride(0) * sizeof(float), cols * sizeof(float), rows, cudaMemcpyHostToDevice); 373 | 374 | struct cudaResourceDesc resDesc; 375 | memset(&resDesc, 0, sizeof(cudaResourceDesc)); 376 | resDesc.resType = cudaResourceTypeArray; 377 | resDesc.res.array.array = cuArray[i]; 378 | 379 | struct cudaTextureDesc texDesc; 380 | memset(&texDesc, 0, sizeof(cudaTextureDesc)); 381 | texDesc.addressMode[0] = cudaAddressModeWrap; 382 | texDesc.addressMode[1] = cudaAddressModeWrap; 383 | texDesc.filterMode = cudaFilterModeLinear; 384 | texDesc.readMode = cudaReadModeElementType; 385 | texDesc.normalizedCoords = 0; 386 | 387 | cudaCreateTextureObject(&(texture_objects_host.images[i]), &resDesc, &texDesc, NULL); 388 | } 389 | 390 | cudaMalloc((void**)&texture_objects_cuda, sizeof(cudaTextureObjects)); 391 | cudaMemcpy(texture_objects_cuda, &texture_objects_host, sizeof(cudaTextureObjects), cudaMemcpyHostToDevice); 392 | 393 | cudaMalloc((void**)&cameras_cuda, sizeof(Camera) * (num_images)); 394 | cudaMemcpy(cameras_cuda, &cameras[0], sizeof(Camera) * (num_images), cudaMemcpyHostToDevice); 395 | 396 | int total_pixels = cameras[0].height * cameras[0].width; 397 | plane_hypotheses_host = new float4[total_pixels]; 398 | cudaMalloc((void**)&plane_hypotheses_cuda, sizeof(float4) * total_pixels); 399 | 400 | costs_host = new float[cameras[0].height * cameras[0].width]; 401 | cudaMalloc((void**)&costs_cuda, sizeof(float) * (cameras[0].height * cameras[0].width)); 402 | 403 | cudaMalloc((void**)&rand_states_cuda, sizeof(curandState) * (cameras[0].height * cameras[0].width)); 404 | cudaMalloc((void**)&selected_views_cuda, sizeof(unsigned int) * (cameras[0].height * cameras[0].width)); 405 | 406 | cudaMalloc((void**)&depths_cuda, sizeof(float) * (cameras[0].height * cameras[0].width)); 407 | cudaMemcpy(depths_cuda, depths[0].data_ptr(), sizeof(float) * cameras[0].height * cameras[0].width, cudaMemcpyHostToDevice); 408 | } 409 | 410 | int PatchMatch::GetReferenceImageWidth() 411 | { 412 | return cameras[0].width; 413 | } 414 | 415 | int PatchMatch::GetReferenceImageHeight() 416 | { 417 | return cameras[0].height; 418 | } 419 | 420 | torch::Tensor PatchMatch::GetReferenceImage() 421 | { 422 | return images[0]; 423 | } 424 | 425 | float4 PatchMatch::GetPlaneHypothesis(const int index) 426 | { 427 | return plane_hypotheses_host[index]; 428 | } 429 | 430 | float4* PatchMatch::GetPlaneHypotheses() 431 | { 432 | return plane_hypotheses_host; 433 | } 434 | 435 | float PatchMatch::GetCost(const int index) 436 | { 437 | return costs_host[index]; 438 | } 439 | 440 | void PatchMatch::SetPatchSize(int patch_size) 441 | { 442 | params.patch_size = patch_size; 443 | } 444 | 445 | int PatchMatch::GetPatchSize() 446 | { 447 | return params.patch_size; 448 | } 449 | 450 | 451 | -------------------------------------------------------------------------------- /submodules/Propagation/PatchMatch.h: -------------------------------------------------------------------------------- 1 | #ifndef _PatchMatch_H_ 2 | #define _PatchMatch_H_ 3 | 4 | #include "main.h" 5 | #include 6 | 7 | Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval); 8 | void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera); 9 | float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera); 10 | void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth); 11 | float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2); 12 | void StoreColorPlyFileBinaryPointCloud(const std::string &plyFilePath, const std::vector &pc); 13 | 14 | #define CUDA_SAFE_CALL(error) CudaSafeCall(error, __FILE__, __LINE__) 15 | #define CUDA_CHECK_ERROR() CudaCheckError(__FILE__, __LINE__) 16 | 17 | void CudaSafeCall(const cudaError_t error, const std::string& file, const int line); 18 | void CudaCheckError(const char* file, const int line); 19 | 20 | struct cudaTextureObjects { 21 | cudaTextureObject_t images[MAX_IMAGES]; 22 | }; 23 | 24 | struct PatchMatchParams { 25 | int max_iterations = 6; 26 | int patch_size = 11; 27 | int num_images = 5; 28 | int max_image_size=3200; 29 | int radius_increment = 2; 30 | float sigma_spatial = 5.0f; 31 | float sigma_color = 3.0f; 32 | int top_k = 4; 33 | float baseline = 0.54f; 34 | float depth_min = 0.0f; 35 | float depth_max = 1.0f; 36 | float disparity_min = 0.0f; 37 | float disparity_max = 1.0f; 38 | bool geom_consistency = false; 39 | }; 40 | 41 | class PatchMatch { 42 | public: 43 | PatchMatch(); 44 | ~PatchMatch(); 45 | 46 | void InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals); 47 | void Colmap2MVS(const std::string &dense_folder, std::vector &problems); 48 | void CudaSpaceInitialization(); 49 | void RunPatchMatch(); 50 | void SetGeomConsistencyParams(); 51 | void SetPatchSize(int patch_size); 52 | int GetPatchSize(); 53 | int GetReferenceImageWidth(); 54 | int GetReferenceImageHeight(); 55 | torch::Tensor GetReferenceImage(); 56 | float4 GetPlaneHypothesis(const int index); 57 | float GetCost(const int index); 58 | float4* GetPlaneHypotheses(); 59 | 60 | private: 61 | int num_images; 62 | std::vector images; 63 | std::vector depths; 64 | std::vector cameras; 65 | cudaTextureObjects texture_objects_host; 66 | cudaTextureObjects texture_depths_host; 67 | float4 *plane_hypotheses_host; 68 | float *costs_host; 69 | PatchMatchParams params; 70 | 71 | Camera *cameras_cuda; 72 | cudaArray *cuArray[MAX_IMAGES]; 73 | cudaArray *cuDepthArray[MAX_IMAGES]; 74 | cudaTextureObjects *texture_objects_cuda; 75 | cudaTextureObjects *texture_depths_cuda; 76 | float4 *plane_hypotheses_cuda; 77 | float *costs_cuda; 78 | curandState *rand_states_cuda; 79 | unsigned int *selected_views_cuda; 80 | float *depths_cuda; 81 | }; 82 | 83 | #endif // _PatchMatch_H_ 84 | -------------------------------------------------------------------------------- /submodules/Propagation/main.h: -------------------------------------------------------------------------------- 1 | #ifndef _MAIN_H_ 2 | #define _MAIN_H_ 3 | 4 | // Includes CUDA 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include "iomanip" 21 | 22 | #include // mkdir 23 | #include // mkdir 24 | 25 | #define MAX_IMAGES 256 26 | 27 | struct Camera { 28 | float K[9]; 29 | float R[9]; 30 | float t[3]; 31 | int height; 32 | int width; 33 | float depth_min; 34 | float depth_max; 35 | }; 36 | 37 | struct Problem { 38 | int ref_image_id; 39 | std::vector src_image_ids; 40 | }; 41 | 42 | struct PointList { 43 | float3 coord; 44 | float3 normal; 45 | float3 color; 46 | }; 47 | 48 | #endif // _MAIN_H_ 49 | -------------------------------------------------------------------------------- /submodules/Propagation/pro.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor propagate_cuda( 5 | torch::Tensor images, 6 | torch::Tensor intrinsics, 7 | torch::Tensor poses, 8 | torch::Tensor depth, 9 | torch::Tensor normal, 10 | torch::Tensor depth_intervals, 11 | int patch_size); 12 | 13 | torch::Tensor propagate( 14 | torch::Tensor images, 15 | torch::Tensor intrinsics, 16 | torch::Tensor poses, 17 | torch::Tensor depth, 18 | torch::Tensor normal, 19 | torch::Tensor depth_intervals, 20 | int patch_size) { 21 | 22 | return propagate_cuda(images, intrinsics, poses, depth, normal, depth_intervals, patch_size); 23 | } 24 | 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 27 | // bundle adjustment kernels 28 | m.def("propagate", &propagate, "plane propagation"); 29 | } -------------------------------------------------------------------------------- /submodules/Propagation/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | import os.path as osp 5 | ROOT = osp.dirname(osp.abspath(__file__)) 6 | 7 | setup( 8 | name='gaussianpro', 9 | ext_modules=[ 10 | CUDAExtension('gaussianpro', 11 | sources=[ 12 | 'PatchMatch.cpp', 13 | 'Propagation.cu', 14 | 'pro.cpp' 15 | ], 16 | extra_compile_args={ 17 | 'cxx': ['-O3'], 18 | 'nvcc': ['-O3', 19 | '-gencode=arch=compute_86,code=sm_86', 20 | ] 21 | }), 22 | ], 23 | cmdclass={ 'build_ext' : BuildExtension } 24 | ) 25 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim, compute_scale_and_shift, ScaleAndShiftInvariantLoss 16 | from utils.general_utils import vis_depth, read_propagted_depth 17 | from gaussian_renderer import render, network_gui 18 | from utils.graphics_utils import surface_normal_from_depth, img_warping, depth_propagation, check_geometric_consistency, generate_edge_mask 19 | import sys 20 | from scene import Scene, GaussianModel 21 | from utils.general_utils import safe_state, load_pairs_relation 22 | import uuid 23 | from tqdm import tqdm 24 | from utils.image_utils import psnr 25 | from argparse import ArgumentParser, Namespace 26 | from arguments import ModelParams, PipelineParams, OptimizationParams 27 | import imageio 28 | import numpy as np 29 | import torchvision 30 | import cv2 31 | try: 32 | from torch.utils.tensorboard import SummaryWriter 33 | TENSORBOARD_FOUND = True 34 | except ImportError: 35 | TENSORBOARD_FOUND = False 36 | 37 | 38 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 39 | first_iter = 0 40 | tb_writer = prepare_output_and_logger(dataset) 41 | gaussians = GaussianModel(dataset.sh_degree) 42 | scene = Scene(dataset, gaussians) 43 | 44 | #read the overlapping txt 45 | if opt.dataset == '360' and opt.depth_loss: 46 | pairs = load_pairs_relation(opt.pair_path) 47 | 48 | gaussians.training_setup(opt) 49 | if checkpoint: 50 | (model_params, first_iter) = torch.load(checkpoint) 51 | gaussians.restore(model_params, opt) 52 | 53 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 54 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 55 | 56 | iter_start = torch.cuda.Event(enable_timing = True) 57 | iter_end = torch.cuda.Event(enable_timing = True) 58 | 59 | viewpoint_stack = scene.getTrainCameras().copy() 60 | ema_loss_for_log = 0.0 61 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 62 | first_iter += 1 63 | # depth_loss_fn = ScaleAndShiftInvariantLoss(alpha=0.1, scales=1) 64 | propagated_iteration_begin = opt.propagated_iteration_begin 65 | propagated_iteration_after = opt.propagated_iteration_after 66 | after_propagated = False 67 | propagation_dict = {} 68 | for i in range(0, len(viewpoint_stack), 1): 69 | propagation_dict[viewpoint_stack[i].image_name] = False 70 | 71 | for iteration in range(first_iter, opt.iterations + 1): 72 | if network_gui.conn == None: 73 | network_gui.try_connect() 74 | while network_gui.conn != None: 75 | try: 76 | net_image_bytes = None 77 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 78 | if custom_cam != None: 79 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 80 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 81 | network_gui.send(net_image_bytes, dataset.source_path) 82 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 83 | break 84 | except Exception as e: 85 | network_gui.conn = None 86 | 87 | iter_start.record() 88 | 89 | gaussians.update_learning_rate(iteration) 90 | 91 | # Every 1000 its we increase the levels of SH up to a maximum degree 92 | if iteration % 1000 == 0: 93 | gaussians.oneupSHdegree() 94 | 95 | # Pick a random Camera 96 | # if not viewpoint_stack: 97 | # viewpoint_stack = scene.getTrainCameras().copy() 98 | randidx = randint(0, len(viewpoint_stack)-1) 99 | # if iteration > propagated_iteration_begin and iteration < propagated_iteration_after and after_propagated: 100 | # randidx = propagated_view_index 101 | viewpoint_cam = viewpoint_stack[randidx] 102 | 103 | if opt.depth_loss: 104 | if opt.dataset == '360': 105 | src_idxs = pairs[randidx] 106 | else: 107 | # intervals = [-6, -3, 3, 6] 108 | if opt.dataset == 'waymo': 109 | intervals = [-2, -1, 1, 2] 110 | elif opt.dataset == 'scannet': 111 | intervals = [-10, -5, 5, 10] 112 | elif opt.dataset == 'free': 113 | intervals = [-2, -1, 1, 2] 114 | src_idxs = [randidx+itv for itv in intervals if ((itv + randidx > 0) and (itv + randidx < len(viewpoint_stack)))] 115 | 116 | #propagate the gaussians first 117 | with torch.no_grad(): 118 | if opt.depth_loss and iteration > propagated_iteration_begin and iteration < propagated_iteration_after and (iteration % opt.propagation_interval == 0 and not propagation_dict[viewpoint_cam.image_name]): 119 | # if opt.depth_loss and iteration > propagated_iteration_begin and iteration < propagated_iteration_after and (iteration % opt.propagation_interval == 0): 120 | propagation_dict[viewpoint_cam.image_name] = True 121 | 122 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg, 123 | return_normal=opt.normal_loss, return_opacity=False, return_depth=opt.depth_loss or opt.depth2normal_loss) 124 | 125 | projected_depth = render_pkg["render_depth"] 126 | 127 | # get the opacity that less than the threshold, propagate depth in these region 128 | if viewpoint_cam.sky_mask is not None: 129 | sky_mask = viewpoint_cam.sky_mask.to(opacity_mask.device).to(torch.bool) 130 | else: 131 | sky_mask = None 132 | torchvision.utils.save_image(viewpoint_cam.original_image, "cost/"+viewpoint_cam.image_name+"_"+str(iteration)+"gt.png") 133 | 134 | # get the propagated depth 135 | propagated_depth, normal = depth_propagation(viewpoint_cam, projected_depth, viewpoint_stack, src_idxs, opt.dataset, opt.patch_size) 136 | 137 | # cache the propagated_depth 138 | viewpoint_cam.depth = propagated_depth 139 | 140 | #transform normal to camera coordinate 141 | R_w2c = torch.tensor(viewpoint_cam.R.T).cuda().to(torch.float32) 142 | # R_w2c[:, 1:] *= -1 143 | normal = (R_w2c @ normal.view(-1, 3).permute(1, 0)).view(3, viewpoint_cam.image_height, viewpoint_cam.image_width) 144 | valid_mask = propagated_depth != 300 145 | 146 | # calculate the abs rel depth error of the propagated depth and rendered depth & render color error 147 | render_depth = render_pkg['render_depth'] 148 | abs_rel_error = torch.abs(propagated_depth - render_depth) / propagated_depth 149 | abs_rel_error_threshold = opt.depth_error_max_threshold - (opt.depth_error_max_threshold - opt.depth_error_min_threshold) * (iteration - propagated_iteration_begin) / (propagated_iteration_after - propagated_iteration_begin) 150 | # color error 151 | render_color = render_pkg['render'] 152 | torchvision.utils.save_image(render_color, "cost/"+viewpoint_cam.image_name+"_"+str(iteration)+"color.png") 153 | 154 | color_error = torch.abs(render_color - viewpoint_cam.original_image) 155 | color_error = color_error.mean(dim=0).squeeze() 156 | error_mask = (abs_rel_error > abs_rel_error_threshold) 157 | 158 | # # calculate the photometric consistency 159 | ref_K = viewpoint_cam.K 160 | #c2w 161 | ref_pose = viewpoint_cam.world_view_transform.transpose(0, 1).inverse() 162 | 163 | # calculate the geometric consistency 164 | geometric_counts = None 165 | for idx, src_idx in enumerate(src_idxs): 166 | src_viewpoint = viewpoint_stack[src_idx] 167 | #c2w 168 | src_pose = src_viewpoint.world_view_transform.transpose(0, 1).inverse() 169 | src_K = src_viewpoint.K 170 | 171 | if src_viewpoint.depth is None: 172 | src_render_pkg = render(src_viewpoint, gaussians, pipe, bg, 173 | return_normal=opt.normal_loss, return_opacity=False, return_depth=opt.depth_loss or opt.depth2normal_loss) 174 | src_projected_depth = src_render_pkg['render_depth'] 175 | 176 | #get the src_depth first 177 | src_depth, src_normal = depth_propagation(src_viewpoint, src_projected_depth, viewpoint_stack, src_idxs, opt.dataset, opt.patch_size) 178 | src_viewpoint.depth = src_depth 179 | else: 180 | src_depth = src_viewpoint.depth 181 | 182 | mask, depth_reprojected, x2d_src, y2d_src, relative_depth_diff = check_geometric_consistency(propagated_depth.unsqueeze(0), ref_K.unsqueeze(0), 183 | ref_pose.unsqueeze(0), src_depth.unsqueeze(0), 184 | src_K.unsqueeze(0), src_pose.unsqueeze(0), thre1=2, thre2=0.01) 185 | 186 | if geometric_counts is None: 187 | geometric_counts = mask.to(torch.uint8) 188 | else: 189 | geometric_counts += mask.to(torch.uint8) 190 | 191 | cost = geometric_counts.squeeze() 192 | cost_mask = cost >= 2 193 | 194 | normal[~(cost_mask.unsqueeze(0).repeat(3, 1, 1))] = -10 195 | viewpoint_cam.normal = normal 196 | 197 | propagated_mask = valid_mask & error_mask & cost_mask 198 | if sky_mask is not None: 199 | propagated_mask = propagated_mask & sky_mask 200 | 201 | propagated_depth[~cost_mask] = 300 202 | # propagated_mask = propagated_mask & edge_mask 203 | propagated_depth[~propagated_mask] = 300 204 | 205 | if propagated_mask.sum() > 100: 206 | gaussians.densify_from_depth_propagation(viewpoint_cam, propagated_depth, propagated_mask.to(torch.bool), gt_image) 207 | 208 | # Render 209 | if (iteration - 1) == debug_from: 210 | pipe.debug = True 211 | 212 | bg = torch.rand((3), device="cuda") if opt.random_background else background 213 | 214 | #render_pkg = render(viewpoint_cam, gaussians, pipe, bg, return_normal=args.normal_loss) 215 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg, 216 | return_normal=opt.normal_loss, return_opacity=True, return_depth=opt.depth_loss or opt.depth2normal_loss) 217 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 218 | 219 | # opacity mask 220 | if iteration < opt.propagated_iteration_begin and opt.depth_loss: 221 | opacity_mask = render_pkg['render_opacity'] > 0.999 222 | opacity_mask = opacity_mask.unsqueeze(0).repeat(3, 1, 1) 223 | else: 224 | opacity_mask = render_pkg['render_opacity'] > 0.0 225 | opacity_mask = opacity_mask.unsqueeze(0).repeat(3, 1, 1) 226 | 227 | # Loss 228 | gt_image = viewpoint_cam.original_image.cuda() 229 | Ll1 = l1_loss(image[opacity_mask], gt_image[opacity_mask]) 230 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image, mask=opacity_mask)) 231 | 232 | # flatten loss 233 | if opt.flatten_loss: 234 | scales = gaussians.get_scaling 235 | min_scale, _ = torch.min(scales, dim=1) 236 | min_scale = torch.clamp(min_scale, 0, 30) 237 | flatten_loss = torch.abs(min_scale).mean() 238 | loss += opt.lambda_flatten * flatten_loss 239 | 240 | # opacity loss 241 | if opt.sparse_loss: 242 | opacity = gaussians.get_opacity 243 | opacity = opacity.clamp(1e-6, 1-1e-6) 244 | log_opacity = opacity * torch.log(opacity) 245 | log_one_minus_opacity = (1-opacity) * torch.log(1 - opacity) 246 | sparse_loss = -1 * (log_opacity + log_one_minus_opacity)[visibility_filter].mean() 247 | loss += opt.lambda_sparse * sparse_loss 248 | 249 | if opt.normal_loss: 250 | rendered_normal = render_pkg['render_normal'] 251 | if viewpoint_cam.normal is not None: 252 | normal_gt = viewpoint_cam.normal.cuda() 253 | if viewpoint_cam.sky_mask is not None: 254 | filter_mask = viewpoint_cam.sky_mask.to(normal_gt.device).to(torch.bool) 255 | normal_gt[~(filter_mask.unsqueeze(0).repeat(3, 1, 1))] = -10 256 | filter_mask = (normal_gt != -10)[0, :, :].to(torch.bool) 257 | 258 | l1_normal = torch.abs(rendered_normal - normal_gt).sum(dim=0)[filter_mask].mean() 259 | cos_normal = (1. - torch.sum(rendered_normal * normal_gt, dim = 0))[filter_mask].mean() 260 | loss += opt.lambda_l1_normal * l1_normal + opt.lambda_cos_normal * cos_normal 261 | 262 | loss.backward() 263 | iter_end.record() 264 | 265 | with torch.no_grad(): 266 | # Progress bar 267 | if not torch.isnan(loss): 268 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 269 | if iteration % 10 == 0: 270 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 271 | progress_bar.update(10) 272 | if iteration == opt.iterations: 273 | progress_bar.close() 274 | 275 | # Log and save 276 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 277 | if (iteration in saving_iterations): 278 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 279 | scene.save(iteration) 280 | 281 | # Densification 282 | if iteration < opt.densify_until_iter: 283 | # Keep track of max radii in image-space for pruning 284 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 285 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 286 | 287 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 288 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 289 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 290 | 291 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 292 | gaussians.reset_opacity() 293 | 294 | # Optimizer step 295 | if iteration < opt.iterations: 296 | gaussians.optimizer.step() 297 | gaussians.optimizer.zero_grad(set_to_none = True) 298 | 299 | if (iteration in checkpoint_iterations): 300 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 301 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 302 | 303 | def prepare_output_and_logger(args): 304 | if not args.model_path: 305 | if os.getenv('OAR_JOB_ID'): 306 | unique_str=os.getenv('OAR_JOB_ID') 307 | else: 308 | unique_str = str(uuid.uuid4()) 309 | args.model_path = os.path.join("./output/", unique_str[0:10]) 310 | 311 | # Set up output folder 312 | print("Output folder: {}".format(args.model_path)) 313 | os.makedirs(args.model_path, exist_ok = True) 314 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 315 | cfg_log_f.write(str(Namespace(**vars(args)))) 316 | 317 | # Create Tensorboard writer 318 | tb_writer = None 319 | if TENSORBOARD_FOUND: 320 | tb_writer = SummaryWriter(args.model_path) 321 | else: 322 | print("Tensorboard not available: not logging progress") 323 | return tb_writer 324 | 325 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 326 | if tb_writer: 327 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 328 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 329 | tb_writer.add_scalar('iter_time', elapsed, iteration) 330 | 331 | # Report test and samples of training set 332 | if iteration in testing_iterations: 333 | torch.cuda.empty_cache() 334 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 335 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 336 | 337 | for config in validation_configs: 338 | if config['cameras'] and len(config['cameras']) > 0: 339 | l1_test = 0.0 340 | psnr_test = 0.0 341 | for idx, viewpoint in enumerate(config['cameras']): 342 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 343 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 344 | if tb_writer and (idx < 5): 345 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 346 | if iteration == testing_iterations[0]: 347 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 348 | l1_test += l1_loss(image, gt_image).mean().double() 349 | psnr_test += psnr(image, gt_image).mean().double() 350 | psnr_test /= len(config['cameras']) 351 | l1_test /= len(config['cameras']) 352 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 353 | if tb_writer: 354 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 355 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 356 | 357 | if tb_writer: 358 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 359 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 360 | torch.cuda.empty_cache() 361 | 362 | if __name__ == "__main__": 363 | # Set up command line argument parser 364 | parser = ArgumentParser(description="Training script parameters") 365 | lp = ModelParams(parser) 366 | op = OptimizationParams(parser) 367 | pp = PipelineParams(parser) 368 | parser.add_argument('--ip', type=str, default="127.0.0.1") 369 | parser.add_argument('--port', type=int, default=6009) 370 | parser.add_argument('--debug_from', type=int, default=-1) 371 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 372 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 2000, 7000, 20000, 50000]) 373 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 7000, 20000, 50000]) 374 | parser.add_argument("--quiet", action="store_true") 375 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 376 | parser.add_argument("--start_checkpoint", type=str, default = None) 377 | 378 | args = parser.parse_args(sys.argv[1:]) 379 | args.save_iterations.append(args.iterations) 380 | 381 | print("Optimizing " + args.model_path) 382 | 383 | # Initialize system state (RNG) 384 | safe_state(args.quiet) 385 | 386 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 387 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 388 | 389 | # All done 390 | print("\nTraining complete.") 391 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | import cv2 17 | import torch 18 | 19 | WARNED = False 20 | 21 | def loadCam(args, id, cam_info, resolution_scale): 22 | orig_w, orig_h = cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 26 | K = cam_info.K / (resolution_scale * args.resolution) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 1600: 30 | global WARNED 31 | if not WARNED: 32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 34 | WARNED = True 35 | global_down = orig_w / 1600 36 | else: 37 | global_down = 1 38 | else: 39 | global_down = orig_w / args.resolution 40 | 41 | scale = float(global_down) * float(resolution_scale) 42 | resolution = (int(orig_w / scale), int(orig_h / scale)) 43 | K = cam_info.K / scale 44 | 45 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 46 | if cam_info.sky_mask is not None: 47 | resized_sky_mask = torch.tensor(cv2.resize(cam_info.sky_mask, resolution, interpolation=cv2.INTER_NEAREST)).to(resized_image_rgb.device) 48 | else: 49 | resized_sky_mask = None 50 | if cam_info.normal is not None: 51 | resized_normal = torch.tensor(cv2.resize(cam_info.normal.transpose((1, 2, 0)), resolution, interpolation=cv2.INTER_NEAREST)).to(resized_image_rgb.device) 52 | resized_normal = resized_normal.permute((2, 0, 1)) 53 | else: 54 | resized_normal = None 55 | 56 | if cam_info.depth is not None: 57 | resized_depth = torch.tensor(cv2.resize(cam_info.depth.squeeze(), resolution, interpolation=cv2.INTER_NEAREST)).to(resized_image_rgb.device) 58 | else: 59 | resized_depth = None 60 | 61 | gt_image = resized_image_rgb[:3, ...] 62 | loaded_mask = None 63 | 64 | if resized_image_rgb.shape[1] == 4: 65 | loaded_mask = resized_image_rgb[3:4, ...] 66 | 67 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 68 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 69 | image=gt_image, gt_alpha_mask=loaded_mask, 70 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, K=K, 71 | sky_mask=resized_sky_mask, normal=resized_normal, depth=resized_depth) 72 | 73 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 74 | camera_list = [] 75 | 76 | for id, c in enumerate(cam_infos): 77 | camera_list.append(loadCam(args, id, c, resolution_scale)) 78 | 79 | return camera_list 80 | 81 | def camera_to_JSON(id, camera : Camera): 82 | Rt = np.zeros((4, 4)) 83 | Rt[:3, :3] = camera.R.transpose() 84 | Rt[:3, 3] = camera.T 85 | Rt[3, 3] = 1.0 86 | 87 | W2C = np.linalg.inv(Rt) 88 | pos = W2C[:3, 3] 89 | rot = W2C[:3, :3] 90 | serializable_array_2d = [x.tolist() for x in rot] 91 | camera_entry = { 92 | 'id' : id, 93 | 'img_name' : camera.image_name, 94 | 'width' : camera.width, 95 | 'height' : camera.height, 96 | 'position': pos.tolist(), 97 | 'rotation': serializable_array_2d, 98 | 'fy' : fov2focal(camera.FovY, camera.height), 99 | 'fx' : fov2focal(camera.FovX, camera.width) 100 | } 101 | return camera_entry 102 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | import cv2 18 | import matplotlib.pyplot as plt 19 | from matplotlib import cm 20 | import matplotlib as mpl 21 | import struct 22 | import os 23 | 24 | def inverse_sigmoid(x): 25 | return torch.log(x/(1-x)) 26 | 27 | def PILtoTorch(pil_image, resolution): 28 | resized_image_PIL = pil_image.resize(resolution) 29 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 30 | if len(resized_image.shape) == 3: 31 | return resized_image.permute(2, 0, 1) 32 | else: 33 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 34 | 35 | def get_expon_lr_func( 36 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 37 | ): 38 | """ 39 | Copied from Plenoxels 40 | 41 | Continuous learning rate decay function. Adapted from JaxNeRF 42 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 43 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 44 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 45 | function of lr_delay_mult, such that the initial learning rate is 46 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 47 | to the normal learning rate when steps>lr_delay_steps. 48 | :param conf: config subtree 'lr' or similar 49 | :param max_steps: int, the number of steps during optimization. 50 | :return HoF which takes step as input 51 | """ 52 | 53 | def helper(step): 54 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 55 | # Disable this parameter 56 | return 0.0 57 | if lr_delay_steps > 0: 58 | # A kind of reverse cosine decay. 59 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 60 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 61 | ) 62 | else: 63 | delay_rate = 1.0 64 | t = np.clip(step / max_steps, 0, 1) 65 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 66 | return delay_rate * log_lerp 67 | 68 | return helper 69 | 70 | def strip_lowerdiag(L): 71 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 72 | 73 | uncertainty[:, 0] = L[:, 0, 0] 74 | uncertainty[:, 1] = L[:, 0, 1] 75 | uncertainty[:, 2] = L[:, 0, 2] 76 | uncertainty[:, 3] = L[:, 1, 1] 77 | uncertainty[:, 4] = L[:, 1, 2] 78 | uncertainty[:, 5] = L[:, 2, 2] 79 | return uncertainty 80 | 81 | def strip_symmetric(sym): 82 | return strip_lowerdiag(sym) 83 | 84 | def build_rotation(r): 85 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 86 | 87 | q = r / norm[:, None] 88 | 89 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 90 | 91 | r = q[:, 0] 92 | x = q[:, 1] 93 | y = q[:, 2] 94 | z = q[:, 3] 95 | 96 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 97 | R[:, 0, 1] = 2 * (x*y - r*z) 98 | R[:, 0, 2] = 2 * (x*z + r*y) 99 | R[:, 1, 0] = 2 * (x*y + r*z) 100 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 101 | R[:, 1, 2] = 2 * (y*z - r*x) 102 | R[:, 2, 0] = 2 * (x*z - r*y) 103 | R[:, 2, 1] = 2 * (y*z + r*x) 104 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 105 | return R 106 | 107 | def build_scaling_rotation(s, r): 108 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 109 | R = build_rotation(r) 110 | 111 | L[:,0,0] = s[:,0] 112 | L[:,1,1] = s[:,1] 113 | L[:,2,2] = s[:,2] 114 | 115 | L = R @ L 116 | return L 117 | 118 | def safe_state(silent): 119 | old_f = sys.stdout 120 | class F: 121 | def __init__(self, silent): 122 | self.silent = silent 123 | 124 | def write(self, x): 125 | if not self.silent: 126 | if x.endswith("\n"): 127 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 128 | else: 129 | old_f.write(x) 130 | 131 | def flush(self): 132 | old_f.flush() 133 | 134 | sys.stdout = F(silent) 135 | 136 | random.seed(0) 137 | np.random.seed(0) 138 | torch.manual_seed(0) 139 | torch.cuda.set_device(torch.device("cuda:0")) 140 | 141 | def vis_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET, constant_max=120): 142 | """ 143 | depth: (H, W) 144 | """ 145 | depthmap = np.nan_to_num(depth) # change nan to 0 146 | 147 | # x_ = (255 - x)[:,:,None].repeat(3,axis=-1) 148 | 149 | # x = x[] 150 | # threshold 151 | # constant_max = np.percentile(depthmap, 90) 152 | depthmap_valid_count = (depthmap < 300).sum() 153 | constant_max = np.percentile(depthmap[depthmap<300], 99) if depthmap_valid_count > 10 else 60 154 | # constant_max = 1 155 | # constant_min = 0 156 | constant_min = np.percentile(depthmap, 1) if np.percentile(depthmap, 1) < constant_max else 0 157 | normalizer = mpl.colors.Normalize(vmin=constant_min, vmax=constant_max) 158 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma_r') 159 | depth_vis_color = (mapper.to_rgba(depthmap)[:, :, :3] * 255).astype(np.uint8) 160 | # all_white = np.ones_like(x_) * 255 161 | # x_ = x_ * (1-bg_mask)[:,:,None] + all_white * bg_mask[:,:,None] 162 | # x_ = x_.astype(np.uint8) 163 | # x_ = cv2.cvtColor(x_, cv2.COLOR_BGR2RGB) 164 | return depth_vis_color, [np.percentile(depthmap, 0), np.percentile(depthmap, 99)] 165 | 166 | def vis_depth1(depth, minmax=None, cmap=cv2.COLORMAP_JET, constant_max=120): 167 | """ 168 | depth: (H, W) 169 | """ 170 | depthmap = np.nan_to_num(depth) # change nan to 0 171 | 172 | # x_ = (255 - x)[:,:,None].repeat(3,axis=-1) 173 | 174 | # x = x[] 175 | # threshold 176 | # constant_max = np.percentile(depthmap, 90) 177 | depthmap_valid_count = (depthmap < 300).sum() 178 | # constant_max = np.percentile(depthmap[depthmap<300], 99) if depthmap_valid_count > 10 else 60 179 | constant_max = 10 180 | constant_min = 0.5 181 | # constant_min = np.percentile(depthmap, 1) if np.percentile(depthmap, 1) < constant_max else 0 182 | normalizer = mpl.colors.Normalize(vmin=constant_min, vmax=constant_max) 183 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma_r') 184 | depth_vis_color = (mapper.to_rgba(depthmap)[:, :, :3] * 255).astype(np.uint8) 185 | # all_white = np.ones_like(x_) * 255 186 | # x_ = x_ * (1-bg_mask)[:,:,None] + all_white * bg_mask[:,:,None] 187 | # x_ = x_.astype(np.uint8) 188 | # x_ = cv2.cvtColor(x_, cv2.COLOR_BGR2RGB) 189 | return depth_vis_color, [np.percentile(depthmap, 0), np.percentile(depthmap, 99)] 190 | 191 | def readDepthDmb(file_path): 192 | inimage = open(file_path, "rb") 193 | if not inimage: 194 | print("Error opening file", file_path) 195 | return -1 196 | 197 | type = -1 198 | 199 | type = struct.unpack("i", inimage.read(4))[0] 200 | h = struct.unpack("i", inimage.read(4))[0] 201 | w = struct.unpack("i", inimage.read(4))[0] 202 | nb = struct.unpack("i", inimage.read(4))[0] 203 | 204 | if type != 1: 205 | inimage.close() 206 | return -1 207 | 208 | dataSize = h * w * nb 209 | 210 | depth = np.zeros((h, w), dtype=np.float32) 211 | depth_data = np.frombuffer(inimage.read(dataSize * 4), dtype=np.float32) 212 | depth_data = depth_data.reshape((h, w)) 213 | np.copyto(depth, depth_data) 214 | 215 | inimage.close() 216 | return depth 217 | 218 | def readNormalDmb(file_path): 219 | try: 220 | with open(file_path, 'rb') as inimage: 221 | type = np.fromfile(inimage, dtype=np.int32, count=1)[0] 222 | h = np.fromfile(inimage, dtype=np.int32, count=1)[0] 223 | w = np.fromfile(inimage, dtype=np.int32, count=1)[0] 224 | nb = np.fromfile(inimage, dtype=np.int32, count=1)[0] 225 | 226 | if type != 1: 227 | print("Error: Invalid file type") 228 | return -1 229 | 230 | dataSize = h * w * nb 231 | 232 | normal = np.zeros((h, w, 3), dtype=np.float32) 233 | normal_data = np.fromfile(inimage, dtype=np.float32, count=dataSize) 234 | normal_data = normal_data.reshape((h, w, nb)) 235 | normal[:, :, :] = normal_data[:, :, :3] 236 | 237 | return normal 238 | 239 | except IOError: 240 | print("Error opening file", file_path) 241 | return -1 242 | 243 | def read_propagted_depth(path): 244 | cost = readDepthDmb(os.path.join(path, 'costs.dmb')) 245 | cost[cost==np.nan] = 2 246 | cost[cost < 0] = 2 247 | # mask = cost > 0.5 248 | 249 | depth = readDepthDmb(os.path.join(path, 'depths.dmb')) 250 | # depth[mask] = 300 251 | depth[np.isnan(depth)] = 300 252 | depth[depth < 0] = 300 253 | depth[depth > 300] = 300 254 | 255 | normal = readNormalDmb(os.path.join(path, 'normals.dmb')) 256 | 257 | return depth, cost, normal 258 | 259 | def load_pairs_relation(path): 260 | pairs_relation = [] 261 | num = 0 262 | with open(path, 'r') as file: 263 | num_images = int(file.readline()) 264 | for i in range(num_images): 265 | 266 | ref_image_id = int(file.readline()) 267 | if i != ref_image_id: 268 | print(ref_image_id) 269 | print(i) 270 | 271 | src_images_infos = file.readline().split() 272 | num_src_images = int(src_images_infos[0]) 273 | src_images_infos = src_images_infos[1:] 274 | 275 | pairs = [] 276 | #only fetch the first 4 src images 277 | for j in range(num_src_images): 278 | id, score = int(src_images_infos[2*j]), int(src_images_infos[2*j+1]) 279 | #the idx needs to align to the training images 280 | if score <= 0.0 or id % 8 == 0: 281 | continue 282 | id = (id // 8) * 7 + (id % 8) - 1 283 | pairs.append(id) 284 | 285 | if len(pairs) > 3: 286 | break 287 | 288 | if ref_image_id % 8 != 0: 289 | #only load the training images 290 | pairs_relation.append(pairs) 291 | else: 292 | num = num + 1 293 | 294 | return pairs_relation -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | import cv2 17 | import os 18 | from gaussianpro import propagate 19 | 20 | class BasicPointCloud(NamedTuple): 21 | points : np.array 22 | colors : np.array 23 | normals : np.array 24 | 25 | def write_cam_txt(cam_path, K, w2c, depth_range): 26 | with open(cam_path, "w") as file: 27 | file.write("extrinsic\n") 28 | for row in w2c: 29 | file.write(" ".join(str(element) for element in row)) 30 | file.write("\n") 31 | 32 | file.write("\nintrinsic\n") 33 | for row in K: 34 | file.write(" ".join(str(element) for element in row)) 35 | file.write("\n") 36 | 37 | file.write("\n") 38 | 39 | file.write(" ".join(str(element) for element in depth_range)) 40 | file.write("\n") 41 | 42 | def geom_transform_points(points, transf_matrix): 43 | P, _ = points.shape 44 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 45 | points_hom = torch.cat([points, ones], dim=1) 46 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 47 | 48 | denom = points_out[..., 3:] + 0.0000001 49 | return (points_out[..., :3] / denom).squeeze(dim=0) 50 | 51 | def getWorld2View(R, t): 52 | Rt = np.zeros((4, 4)) 53 | Rt[:3, :3] = R.transpose() 54 | Rt[:3, 3] = t 55 | Rt[3, 3] = 1.0 56 | return np.float32(Rt) 57 | 58 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 59 | Rt = np.zeros((4, 4)) 60 | Rt[:3, :3] = R.transpose() 61 | Rt[:3, 3] = t 62 | Rt[3, 3] = 1.0 63 | 64 | C2W = np.linalg.inv(Rt) 65 | cam_center = C2W[:3, 3] 66 | cam_center = (cam_center + translate) * scale 67 | C2W[:3, 3] = cam_center 68 | 69 | Rt = np.linalg.inv(C2W) 70 | return np.float32(Rt) 71 | 72 | def getProjectionMatrix(znear, zfar, fovX, fovY): 73 | tanHalfFovY = math.tan((fovY / 2)) 74 | tanHalfFovX = math.tan((fovX / 2)) 75 | 76 | top = tanHalfFovY * znear 77 | bottom = -top 78 | right = tanHalfFovX * znear 79 | left = -right 80 | 81 | P = torch.zeros(4, 4) 82 | 83 | z_sign = 1.0 84 | 85 | P[0, 0] = 2.0 * znear / (right - left) 86 | P[1, 1] = 2.0 * znear / (top - bottom) 87 | P[0, 2] = (right + left) / (right - left) 88 | P[1, 2] = (top + bottom) / (top - bottom) 89 | P[3, 2] = z_sign 90 | P[2, 2] = z_sign * zfar / (zfar - znear) 91 | P[2, 3] = -(zfar * znear) / (zfar - znear) 92 | return P 93 | 94 | def fov2focal(fov, pixels): 95 | return pixels / (2 * math.tan(fov / 2)) 96 | 97 | def focal2fov(focal, pixels): 98 | return 2*math.atan(pixels/(2*focal)) 99 | 100 | def init_image_coord(height, width): 101 | x_row = np.arange(0, width) 102 | x = np.tile(x_row, (height, 1)) 103 | x = x[np.newaxis, :, :] 104 | x = x.astype(np.float32) 105 | x = torch.from_numpy(x.copy()).cuda() 106 | u_u0 = x - width/2.0 107 | 108 | y_col = np.arange(0, height) # y_col = np.arange(0, height) 109 | y = np.tile(y_col, (width, 1)).T 110 | y = y[np.newaxis, :, :] 111 | y = y.astype(np.float32) 112 | y = torch.from_numpy(y.copy()).cuda() 113 | v_v0 = y - height/2.0 114 | return u_u0, v_v0 115 | 116 | def depth_to_xyz(depth, intrinsic): 117 | b, c, h, w = depth.shape 118 | u_u0, v_v0 = init_image_coord(h, w) 119 | x = (u_u0 - intrinsic[0][2]) * depth / intrinsic[0][0] 120 | y = (v_v0 - intrinsic[1][2]) * depth / intrinsic[1][1] 121 | z = depth 122 | pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] 123 | return pw 124 | 125 | def get_surface_normalv2(xyz, patch_size=5): 126 | """ 127 | xyz: xyz coordinates 128 | patch: [p1, p2, p3, 129 | p4, p5, p6, 130 | p7, p8, p9] 131 | surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] 132 | return: normal [h, w, 3, b] 133 | """ 134 | b, h, w, c = xyz.shape 135 | half_patch = patch_size // 2 136 | xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) 137 | xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz 138 | 139 | # xyz_left_top = xyz_pad[:, :h, :w, :] # p1 140 | # xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 141 | # xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 142 | # xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 143 | # xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 144 | # xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 145 | 146 | xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 147 | xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 148 | xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 149 | xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 150 | xyz_horizon = xyz_left - xyz_right # p4p6 151 | xyz_vertical = xyz_top - xyz_bottom # p2p8 152 | 153 | xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 154 | xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 155 | xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 156 | xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 157 | xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 158 | xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 159 | 160 | n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) 161 | n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) 162 | 163 | # re-orient normals consistently 164 | orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 165 | n_img_1[orient_mask] *= -1 166 | orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 167 | n_img_2[orient_mask] *= -1 168 | 169 | n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) 170 | n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) 171 | 172 | n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) 173 | n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) 174 | 175 | # average 2 norms 176 | n_img_aver = n_img1_norm + n_img2_norm 177 | n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) 178 | n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) 179 | # re-orient normals consistently 180 | orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 181 | n_img_aver_norm[orient_mask] *= -1 182 | n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] 183 | 184 | # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() 185 | # plt.imshow(np.abs(a), cmap='rainbow') 186 | # plt.show() 187 | return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) 188 | 189 | def surface_normal_from_depth(depth, intrinsic, valid_mask=None): 190 | # para depth: depth map, [b, c, h, w] 191 | b, c, h, w = depth.shape 192 | # focal_length = focal_length[:, None, None, None] 193 | depth_filter = torch.nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) 194 | depth_filter = torch.nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) 195 | xyz = depth_to_xyz(depth_filter, intrinsic) 196 | sn_batch = [] 197 | for i in range(b): 198 | xyz_i = xyz[i, :][None, :, :, :] 199 | normal = get_surface_normalv2(xyz_i) 200 | sn_batch.append(normal) 201 | sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] 202 | if valid_mask is not None: 203 | mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) 204 | sn_batch[mask_invalid] = 0.0 205 | 206 | return sn_batch 207 | 208 | def img_warping(ref_pose, src_pose, virtual_pose_ref_depth, virtual_intrinsic, src_img): 209 | ref_depth = virtual_pose_ref_depth 210 | ref_pose = ref_pose 211 | src_pose = src_pose 212 | intrinsic = virtual_intrinsic 213 | 214 | mask = ref_depth > 0 215 | 216 | ht, wd = ref_depth.shape 217 | fx, fy, cx, cy = intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2] 218 | 219 | y, x = torch.meshgrid(torch.arange(ht).float(), torch.arange(wd).float()) 220 | y = y.to(ref_depth.device) 221 | x = x.to(ref_depth.device) 222 | 223 | i = torch.ones_like(ref_depth).to(ref_depth.device) 224 | X = (x - cx) / fx 225 | Y = (y - cy) / fy 226 | pts_in_norm = torch.stack([X, Y, i], dim=-1) 227 | pts_in_3D = pts_in_norm * ref_depth.unsqueeze(-1).repeat(1, 1, 3) 228 | 229 | rel_pose = src_pose.inverse() @ ref_pose 230 | 231 | pts_in_3D_tgt = rel_pose[:3, :3] @ pts_in_3D.view(-1, 3).permute(1, 0) + rel_pose[:3, 3].unsqueeze(-1).repeat(1, ht*wd) 232 | pts_in_norm_tgt = pts_in_3D_tgt / pts_in_3D_tgt[2:, :] 233 | 234 | pts_in_tgt = intrinsic @ pts_in_norm_tgt 235 | pts_in_tgt = pts_in_tgt.permute(1, 0).view(ht, wd, 3)[:, :, :2] 236 | 237 | pts_in_tgt[:, :, 0] = (pts_in_tgt[:, :, 0] / wd - 0.5) * 2 238 | pts_in_tgt[:, :, 1] = (pts_in_tgt[:, :, 1] / ht - 0.5) * 2 239 | warped_ref_img = torch.nn.functional.grid_sample(src_img.unsqueeze(0), pts_in_tgt.unsqueeze(0), mode='nearest', padding_mode="zeros") 240 | 241 | return warped_ref_img 242 | 243 | def get_proj_matrix(K, image_size, znear=.01, zfar=1000.): 244 | fx = K[0,0] 245 | fy = K[1,1] 246 | cx = K[0,2] 247 | cy = K[1,2] 248 | width, height = image_size 249 | m = np.zeros((4, 4)) 250 | m[0][0] = 2.0 * fx / width 251 | m[0][1] = 0.0 252 | m[0][2] = 0.0 253 | m[0][3] = 0.0 254 | 255 | m[1][0] = 0.0 256 | m[1][1] = 2.0 * fy / height 257 | m[1][2] = 0.0 258 | m[1][3] = 0.0 259 | 260 | m[2][0] = 1.0 - 2.0 * cx / width 261 | m[2][1] = 2.0 * cy / height - 1.0 262 | m[2][2] = (zfar + znear) / (znear - zfar) 263 | m[2][3] = -1.0 264 | 265 | m[3][0] = 0.0 266 | m[3][1] = 0.0 267 | m[3][2] = 2.0 * zfar * znear / (znear - zfar) 268 | m[3][3] = 0.0 269 | 270 | return m.T 271 | 272 | 273 | def bilinear_sampler(img, coords, mask=False): 274 | """ Wrapper for grid_sample, uses pixel coordinates """ 275 | H, W = img.shape[-2:] 276 | xgrid, ygrid = coords.split([1,1], dim=-1) 277 | xgrid = 2*xgrid/(W-1) - 1 278 | ygrid = 2*ygrid/(H-1) - 1 279 | 280 | grid = torch.cat([xgrid, ygrid], dim=-1) 281 | img = torch.nn.functional.grid_sample(img, grid, align_corners=True) 282 | 283 | if mask: 284 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 285 | return img, mask.float() 286 | 287 | return img 288 | 289 | 290 | # def sparse_depth_from_projection(gaussians, viewpoint_cam): 291 | # pc = gaussians.get_xyz.contiguous() 292 | # K = viewpoint_cam.K 293 | # img_height = viewpoint_cam.image_height 294 | # img_width = viewpoint_cam.image_width 295 | # znear = 0.1 296 | # zfar = 1000 297 | # proj_matrix = get_proj_matrix(K, (img_width, img_height), znear, zfar) 298 | # proj_matrix = torch.tensor(proj_matrix).cuda().to(torch.float32) 299 | # w2c = viewpoint_cam.world_view_transform.transpose(0, 1) 300 | # c2w = w2c.inverse() 301 | # c2w = c2w @ torch.tensor(np.diag([1., -1., -1., 1.]).astype(np.float32)).cuda() 302 | # w2c = c2w.inverse() 303 | # total_m = proj_matrix @ w2c 304 | # index_buffer, _ = pcpr.forward(pc, total_m.unsqueeze(0), img_width, img_height, 512) 305 | # sh = index_buffer.shape 306 | # ind = index_buffer.view(-1).long().cuda() 307 | 308 | # xyz = pc.unsqueeze(0).permute(2,0,1) 309 | # xyz = xyz.view(xyz.shape[0],-1) 310 | # proj_xyz_world = torch.index_select(xyz, 1, ind) 311 | # Rot, Trans = w2c[:3, :3], w2c[:3, 3][..., None] 312 | 313 | # proj_xyz_cam = Rot @ proj_xyz_world + Trans 314 | # proj_depth = proj_xyz_cam[2,:][None,] 315 | # proj_depth = proj_depth.view(proj_depth.shape[0], sh[0], sh[1], sh[2]) #[1, 4, 256, 256] 316 | # proj_depth = proj_depth.permute(1, 0, 2, 3) 317 | # proj_depth *= -1 318 | 319 | # ##mask获取 320 | # mask = ind.clone() 321 | # mask[mask>0] = 1 322 | # mask = mask.view(1, sh[0], sh[1], sh[2]) 323 | # mask = mask.permute(1,0,2,3) 324 | 325 | # proj_depth = proj_depth * mask 326 | 327 | # return proj_depth.squeeze() 328 | 329 | # project the reference point cloud into the source view, then project back 330 | #extrinsics here refers c2w 331 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 332 | batch, height, width = depth_ref.shape 333 | 334 | ## step1. project reference pixels to the source view 335 | # reference view x, y 336 | y_ref, x_ref = torch.meshgrid(torch.arange(0, height).to(depth_ref.device), torch.arange(0, width).to(depth_ref.device)) 337 | x_ref = x_ref.unsqueeze(0).repeat(batch, 1, 1) 338 | y_ref = y_ref.unsqueeze(0).repeat(batch, 1, 1) 339 | x_ref, y_ref = x_ref.reshape(batch, -1), y_ref.reshape(batch, -1) 340 | # reference 3D space 341 | 342 | A = torch.inverse(intrinsics_ref) 343 | B = torch.stack((x_ref, y_ref, torch.ones_like(x_ref).to(x_ref.device)), dim=1) * depth_ref.reshape(batch, 1, -1) 344 | xyz_ref = torch.matmul(A, B) 345 | 346 | # source 3D space 347 | xyz_src = torch.matmul(torch.matmul(torch.inverse(extrinsics_src), extrinsics_ref), 348 | torch.cat((xyz_ref, torch.ones_like(x_ref).to(x_ref.device).unsqueeze(1)), dim=1))[:, :3] 349 | # source view x, y 350 | K_xyz_src = torch.matmul(intrinsics_src, xyz_src) 351 | xy_src = K_xyz_src[:, :2] / K_xyz_src[:, 2:3] 352 | 353 | ## step2. reproject the source view points with source view depth estimation 354 | # find the depth estimation of the source view 355 | x_src = xy_src[:, 0].reshape([batch, height, width]).float() 356 | y_src = xy_src[:, 1].reshape([batch, height, width]).float() 357 | 358 | # print(x_src, y_src) 359 | sampled_depth_src = bilinear_sampler(depth_src.view(batch, 1, height, width), torch.stack((x_src, y_src), dim=-1).view(batch, height, width, 2)) 360 | 361 | # source 3D space 362 | # NOTE that we should use sampled source-view depth_here to project back 363 | xyz_src = torch.matmul(torch.inverse(intrinsics_src), 364 | torch.cat((xy_src, torch.ones_like(x_ref).to(x_ref.device).unsqueeze(1)), dim=1) * sampled_depth_src.reshape(batch, 1, -1)) 365 | # reference 3D space 366 | xyz_reprojected = torch.matmul(torch.matmul(torch.inverse(extrinsics_ref), extrinsics_src), 367 | torch.cat((xyz_src, torch.ones_like(x_ref).to(x_ref.device).unsqueeze(1)), dim=1))[:, :3] 368 | # source view x, y, depth 369 | depth_reprojected = xyz_reprojected[:, 2].reshape([batch, height, width]).float() 370 | K_xyz_reprojected = torch.matmul(intrinsics_ref, xyz_reprojected) 371 | xy_reprojected = K_xyz_reprojected[:, :2] / K_xyz_reprojected[:, 2:3] 372 | x_reprojected = xy_reprojected[:, 0].reshape([batch, height, width]).float() 373 | y_reprojected = xy_reprojected[:, 1].reshape([batch, height, width]).float() 374 | 375 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 376 | 377 | 378 | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src, thre1=1, thre2=0.01): 379 | batch, height, width = depth_ref.shape 380 | y_ref, x_ref = torch.meshgrid(torch.arange(0, height).to(depth_ref.device), torch.arange(0, width).to(depth_ref.device)) 381 | x_ref = x_ref.unsqueeze(0).repeat(batch, 1, 1) 382 | y_ref = y_ref.unsqueeze(0).repeat(batch, 1, 1) 383 | inputs = [depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src] 384 | outputs = reproject_with_depth(*inputs) 385 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = outputs 386 | # check |p_reproj-p_1| < 1 387 | dist = torch.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 388 | 389 | # check |d_reproj-d_1| / d_1 < 0.01 390 | depth_diff = torch.abs(depth_reprojected - depth_ref) 391 | relative_depth_diff = depth_diff / depth_ref 392 | 393 | mask = torch.logical_and(dist < thre1, relative_depth_diff < thre2) 394 | depth_reprojected[~mask] = 0 395 | 396 | return mask, depth_reprojected, x2d_src, y2d_src, relative_depth_diff 397 | 398 | def depth_propagation(viewpoint_cam, rendered_depth, viewpoint_stack, src_idxs, dataset, patch_size): 399 | 400 | depth_min = 0.1 401 | if dataset == 'waymo': 402 | depth_max = 80 403 | elif dataset == '360': 404 | depth_max = 20 405 | else: 406 | depth_max = 20 407 | 408 | images = list() 409 | intrinsics = list() 410 | poses = list() 411 | depth_intervals = list() 412 | 413 | images.append((viewpoint_cam.original_image * 255).permute((1, 2, 0)).to(torch.uint8)) 414 | intrinsics.append(viewpoint_cam.K) 415 | poses.append(viewpoint_cam.world_view_transform.transpose(0, 1)) 416 | depth_interval = torch.tensor([depth_min, (depth_max-depth_min)/192.0, 192.0, depth_max]) 417 | depth_intervals.append(depth_interval) 418 | 419 | depth = rendered_depth.unsqueeze(-1) 420 | normal = torch.zeros_like(depth) 421 | 422 | for idx, src_idx in enumerate(src_idxs): 423 | src_viewpoint = viewpoint_stack[src_idx] 424 | images.append((src_viewpoint.original_image * 255).permute((1, 2, 0)).to(torch.uint8)) 425 | intrinsics.append(src_viewpoint.K) 426 | poses.append(src_viewpoint.world_view_transform.transpose(0, 1)) 427 | depth_intervals.append(depth_interval) 428 | 429 | images = torch.stack(images) 430 | intrinsics = torch.stack(intrinsics) 431 | poses = torch.stack(poses) 432 | depth_intervals = torch.stack(depth_intervals) 433 | 434 | results = propagate(images, intrinsics, poses, depth, normal, depth_intervals, patch_size) 435 | propagated_depth = results[0].to(rendered_depth.device) 436 | propagated_normal = results[1:4].to(rendered_depth.device).permute(1, 2, 0) 437 | 438 | return propagated_depth, propagated_normal 439 | 440 | 441 | def generate_edge_mask(propagated_depth, patch_size): 442 | # img gradient 443 | x_conv = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).float().cuda() 444 | y_conv = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).float().cuda() 445 | gradient_x = torch.abs(torch.nn.functional.conv2d(propagated_depth.unsqueeze(0).unsqueeze(0), x_conv, padding=1)) 446 | gradient_y = torch.abs(torch.nn.functional.conv2d(propagated_depth.unsqueeze(0).unsqueeze(0), y_conv, padding=1)) 447 | gradient = gradient_x + gradient_y 448 | 449 | # edge mask 450 | edge_mask = (gradient > 5).float() 451 | 452 | # dilation 453 | kernel = torch.ones(1, 1, patch_size, patch_size).float().cuda() 454 | dilated_mask = torch.nn.functional.conv2d(edge_mask, kernel, padding=(patch_size-1)//2) 455 | dilated_mask = torch.round(dilated_mask).squeeze().to(torch.bool) 456 | dilated_mask = ~dilated_mask 457 | 458 | return dilated_mask -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True, mask=None): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average, mask=mask) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True, mask=None): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if mask is not None: 61 | ssim_map = ssim_map[mask] 62 | 63 | if size_average: 64 | return ssim_map.mean() 65 | else: 66 | return ssim_map.mean(1).mean(1).mean(1) 67 | 68 | def reduction_batch_based(image_loss, M): 69 | # average of all valid pixels of the batch 70 | 71 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 72 | divisor = torch.sum(M) 73 | 74 | if divisor == 0: 75 | return 0 76 | else: 77 | return torch.sum(image_loss) / divisor 78 | 79 | def reduction_image_based(image_loss, M): 80 | # mean of average of valid pixels of an image 81 | 82 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 83 | valid = M.nonzero() 84 | 85 | image_loss[valid] = image_loss[valid] / M[valid] 86 | 87 | return torch.mean(image_loss) 88 | 89 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based): 90 | 91 | M = torch.sum(mask, (1, 2)) 92 | res = prediction - target 93 | image_loss = torch.sum(mask * res * res, (1, 2)) 94 | 95 | return reduction(image_loss, 2 * M) 96 | 97 | class MSELoss(torch.nn.Module): 98 | def __init__(self, reduction='image-based'): 99 | super().__init__() 100 | 101 | if reduction == 'image-based': 102 | self.__reduction = reduction_batch_based 103 | else: 104 | self.__reduction = reduction_image_based 105 | 106 | def forward(self, prediction, target, mask): 107 | return mse_loss(prediction, target, mask, reduction=self.__reduction) 108 | 109 | class GradientLoss(torch.nn.Module): 110 | def __init__(self, scales=4, reduction='image-based'): 111 | super().__init__() 112 | 113 | if reduction == 'image-based': 114 | self.__reduction = reduction_batch_based 115 | else: 116 | self.__reduction = reduction_image_based 117 | 118 | self.__scales = scales 119 | 120 | def forward(self, prediction, target, mask): 121 | total = 0 122 | 123 | for scale in range(self.__scales): 124 | step = pow(2, scale) 125 | 126 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 127 | mask[:, ::step, ::step], reduction=self.__reduction) 128 | 129 | return total 130 | 131 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 132 | 133 | M = torch.sum(mask, (1, 2)) 134 | 135 | diff = prediction - target 136 | diff = torch.mul(mask, diff) 137 | 138 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 139 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 140 | grad_x = torch.mul(mask_x, grad_x) 141 | 142 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 143 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 144 | grad_y = torch.mul(mask_y, grad_y) 145 | 146 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 147 | 148 | return reduction(image_loss, M) 149 | 150 | # copy from MiDaS 151 | def compute_scale_and_shift(prediction, target, mask): 152 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 153 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 154 | a_01 = torch.sum(mask * prediction, (1, 2)) 155 | a_11 = torch.sum(mask, (1, 2)) 156 | 157 | # right hand side: b = [b_0, b_1] 158 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 159 | b_1 = torch.sum(mask * target, (1, 2)) 160 | 161 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 162 | x_0 = torch.zeros_like(b_0) 163 | x_1 = torch.zeros_like(b_1) 164 | 165 | det = a_00 * a_11 - a_01 * a_01 166 | valid = det.nonzero() 167 | 168 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] 169 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] 170 | 171 | return x_0, x_1 172 | 173 | 174 | class ScaleAndShiftInvariantLoss(torch.nn.Module): 175 | def __init__(self, alpha=0.5, scales=4, reduction='image-based'): 176 | super().__init__() 177 | 178 | self.__data_loss = MSELoss(reduction=reduction) 179 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) 180 | self.__alpha = alpha 181 | 182 | self.__prediction_ssi = None 183 | 184 | def forward(self, prediction, target, mask): 185 | 186 | scale, shift = compute_scale_and_shift(prediction, target, mask) 187 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) 188 | 189 | total = self.__data_loss(self.__prediction_ssi, target, mask) 190 | if self.__alpha > 0: 191 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) 192 | 193 | return total 194 | 195 | def __get_prediction_ssi(self): 196 | return self.__prediction_ssi 197 | 198 | prediction_ssi = property(__get_prediction_ssi) 199 | 200 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------