├── .gitmodules
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── arguments
└── __init__.py
├── convert.py
├── docker
├── Dockerfile
├── build_gaussian_pro_docker.sh
├── entrypoint.sh
├── environment.yml
└── run_gaussian_pro_docker.sh
├── environment.yml
├── figs
├── comparison.gif
├── effel_tower.mp4
├── jianzhu_final_demo.mp4
├── jiaotang_final_demo.mp4
├── motivation.png
├── output.gif
├── output1.gif
├── output2.gif
└── pipeline.png
├── gaussian_renderer
├── __init__.py
└── network_gui.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── results
├── DeepBlending
│ ├── drjohnson.csv
│ └── playroom.csv
├── Eth3D
│ ├── delivery_area.csv
│ ├── electro.csv
│ ├── kicker.csv
│ ├── meadow.csv
│ ├── office.csv
│ ├── playground.csv
│ ├── relief.csv
│ ├── relief2.csv
│ └── terrace.csv
├── MipNeRF360
│ ├── bicycle.csv
│ ├── bonsai.csv
│ ├── counter.csv
│ ├── flowers.csv
│ ├── garden.csv
│ ├── kitchen.csv
│ ├── room.csv
│ ├── stump.csv
│ └── treehill.csv
└── TanksAndTemples
│ ├── train.csv
│ └── truck.csv
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
└── gaussian_model.py
├── scripts
├── demo.sh
└── waymo.sh
├── submodules
└── Propagation
│ ├── PatchMatch.cpp
│ ├── PatchMatch.h
│ ├── Propagation.cu
│ ├── main.h
│ ├── pro.cpp
│ └── setup.py
├── train.py
└── utils
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── sh_utils.py
└── system_utils.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/diff-gaussian-rasterization"]
2 | path = submodules/diff-gaussian-rasterization
3 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization
4 | [submodule "submodules/simple-knn"]
5 | path = submodules/simple-knn
6 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
7 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "workbench.colorCustomizations": {
3 | "activityBar.background": "#22312D",
4 | "titleBar.activeBackground": "#30443F",
5 | "titleBar.activeForeground": "#F8FAF9"
6 | }
7 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 kcheng1021.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
GaussianPro: 3D Gaussian Splatting with Progressive Propagation
4 | ICML 2024
5 |
6 | ### [Project Page](https://kcheng1021.github.io/gaussianpro.github.io/) | [Paper](https://arxiv.org/abs/2402.14650)
7 |
8 |
9 |
10 |
11 |

12 |
13 |
14 | ## 📖 Abstract
15 |
16 | The advent of 3D Gaussian Splatting (3DGS) has recently brought about a revolution in the field of neural rendering, facilitating high-quality renderings at real-time speed. However, 3DGS heavily depends on the initialized point cloud produced by Structure-from-Motion (SfM) techniques.
17 | When tackling with large-scale scenes that unavoidably contain texture-less surfaces, the SfM techniques always fail to produce enough points in these surfaces and cannot provide good initialization for 3DGS. As a result, 3DGS suffers from difficult optimization and low-quality renderings.
18 | In this paper, inspired by classical multi-view stereo (MVS) techniques, we propose **GaussianPro**, a novel method that applies a progressive propagation strategy to guide the densification of the 3D Gaussians.
19 | Compared to the simple split and clone strategies used in 3DGS, our method leverages the priors of the existing reconstructed geometries of the scene and patch matching techniques to produce new Gaussians with accurate positions and orientations.
20 | Experiments on both large-scale and small-scale scenes validate the effectiveness of our method, where our method significantly surpasses 3DGS on the Waymo dataset, exhibiting an improvement of 1.15dB in terms of PSNR.
21 |
22 | ## 🗓️ News
23 |
24 | [2024.10.10] Many thanks to [Caio Viturino](https://github.com/caioviturinofs), the project provides the docker environment!
25 |
26 | [2024.9.28] Many thanks to [Chongjie Ye](https://github.com/hugoycj), the project avoid the dependency on Opencv C++ libraray, making it more convenient to install!
27 |
28 | Some amazing enhancements will also come out this year.
29 |
30 | ## 🗓️ TODO
31 | - [✔] Code pre-release -- Beta version.
32 | - [✔] Demo Scenes.
33 | - [✔] Pybinding & CUDA acceleration.
34 | - [ ] Support for unordered set of images.
35 |
36 | Some amazing enhancements are under development. We are warmly welcome anyone to collaborate in improving this repository. Please send me an email if you are interested!
37 |
38 | ## 🚀 Pipeline
39 |
40 |
41 |

42 |
43 |
44 |
45 | ## 🚀 Setup
46 | #### Tested Environment
47 | Ubuntu 20.04.1 LTS, GeForce 3090, CUDA 11.3 (tested) / 11.7 (tested), C++17
48 |
49 | #### Clone the repo.
50 | ```
51 | git clone https://github.com/kcheng1021/GaussianPro.git --recursive
52 | ```
53 |
54 | #### Environment setup
55 | ```
56 | conda env create --file environment.yml
57 |
58 | # install the propagation package
59 | # The gpu compute architecture is specified as sm_86 in setup.py. Please replace it with a version that is suitable for your GPU.
60 | # Replace the opencv and CUDA include/lib path with your own (Ignore in the latest version)
61 | # the C++ opencv is better installed in conda environment by conda install -c conda-forge opencv (Ignore in the latest version)
62 | pip install ./submodules/Propagation
63 |
64 | ```
65 |
66 | #### Docker install (Alternative)
67 |
68 | To build the GaussianPro using docker, execute the following commands:
69 | ```bash
70 | sh docker/build_gaussian_pro_docker.sh
71 | ```
72 |
73 | To execute the container, run:
74 | ```bash
75 | # Please remember to substitute the dataset path to your desired path
76 | sh docker/run_gaussian_pro_docker.sh
77 | ```
78 |
79 | #### Download the Waymo scenes: Segment-102751,100613,132384,144248,148697,405841,164701,150623,113792
80 | ```
81 | wget https://drive.google.com/file/d/1DXQRBcUIrnIC33WNq8pVLKZ_W1VwON3k/view?usp=sharing
82 | https://drive.google.com/file/d/1DEDt8sNshAlmcwbp_KleeNYf_Jq0fy4u/view?usp=sharing
83 | https://drive.google.com/file/d/1J7_IA2w4-u51lCmtmMA5CDxXR4Dbkeoq/view?usp=sharing
84 | https://rec.ustc.edu.cn/share/d34a0370-2bb2-11f0-b128-73c0ccb2577f password:ux3p
85 | ```
86 |
87 | #### Besides the public datasets, we also test GaussianPro from random selected Youtube videos and find consistent improvement. The processed data is provided below.
88 |
89 | ```
90 | #youtube01: Park.
91 | wget https://drive.google.com/file/d/1iHYTnI76Zx9VTKbMu1zUE7gVKP4UpBan/view?usp=sharing
92 |
93 | #youtube02: Church
94 | wget https://drive.google.com/file/d/1i2ReAJYkeLHBBbs_8Zn560Tke2F8yR1X/view?usp=sharing
95 |
96 | #youtube03: The forbidden city.
97 | wget https://drive.google.com/file/d/1PZ_917Oq0Y45_5dJ504RxRmpUnewYmyn/view?usp=sharing
98 |
99 | #youtube04: Eiffel tower.
100 | wget https://drive.google.com/file/d/1JoYyfAu3RNnj12C2gPvfljHLUKlUsSr1/view?usp=sharing
101 | ```
102 |
103 | 
104 | 
105 |
106 | #### Run the codes
107 | ```
108 | # Run the 3DGS, we modify the defaulting parameters in 3DGS to better learn large scenes. The description of parameters in GaussianPro will come out later.
109 |
110 | # To run the Waymo scenes (3DGS and GaussianPro)
111 | bash scripts/waymo.sh
112 |
113 | # Run the Youtube scenes above
114 | bash scripts/demo.sh
115 | ```
116 |
117 | To ensure the reproducibility, we present a reference of the results in the provided demo scenes based on the current code.
118 | | | Waymo-1002751 | Youtube-01 | Youtube-02 | Youtube-03 | Youtube-04 |
119 | | :--- | :---: | :---: | :---: | :---: | :---: |
120 | | 3DGS | 35.22,0.950,0.234 | 34.40,0.964,0.092 | 34.67,0.954,0.072 | 37.81,0.971,0.081 | 33.05,0.950,0.079 |
121 | | GaussianPro | **35.97,0.959,0.207** | **35.29,0.969,0.076** | **35.08,0.959,0.064** | **38.27,0.974,0.072** | **33.66,0.956,0.072** |
122 |
123 | ### Try your scenes
124 |
125 | **If you want to try your scenes, make sure your images are sorted in the time order, i.e. video data. The current version does not support unordered image sets, but it
126 | will be updated in the next version. Then you can try the commands in demo.sh to run your own scenes.**
127 |
128 | **Please ensure that your neighboring images have sufficient overlap.**
129 |
130 | ## 🎫 License
131 |
132 | For non-commercial use, this code is released under the [LICENSE](LICENSE).
133 | For commercial use, please contact Xuejin Chen.
134 |
135 | ## 🎫 Acknowledgment
136 | This project largely references [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [ACMH/ACMM](https://github.com/GhiXu/ACMH). Thanks for their amazing works!
137 |
138 | ## 🖊️ Citation
139 |
140 |
141 | If you find this project useful in your research, please consider cite:
142 |
143 |
144 | ```BibTeX
145 | @article{cheng2024gaussianpro,
146 | title={GaussianPro: 3D Gaussian Splatting with Progressive Propagation},
147 | author={Cheng, Kai and Long, Xiaoxiao and Yang, Kaizhi and Yao, Yao and Yin, Wei and Ma, Yuexin and Wang, Wenping and Chen, Xuejin},
148 | journal={arXiv preprint arXiv:2402.14650},
149 | year={2024}
150 | }
151 | ```
152 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self._source_path = ""
51 | self._model_path = ""
52 | self._images = "images"
53 | self._resolution = -1
54 | self._white_background = False
55 | self.data_device = "cuda"
56 | self.sky_seg = False
57 | self.load_normal = False
58 | self.load_depth = False
59 | self.eval = False
60 | super().__init__(parser, "Loading Parameters", sentinel)
61 |
62 | def extract(self, args):
63 | g = super().extract(args)
64 | g.source_path = os.path.abspath(g.source_path)
65 | return g
66 |
67 | class PipelineParams(ParamGroup):
68 | def __init__(self, parser):
69 | self.convert_SHs_python = False
70 | self.compute_cov3D_python = False
71 | self.debug = False
72 | super().__init__(parser, "Pipeline Parameters")
73 |
74 | class OptimizationParams(ParamGroup):
75 | def __init__(self, parser):
76 | self.iterations = 30_000
77 | self.position_lr_init = 0.00016
78 | self.position_lr_final = 0.0000016
79 | self.position_lr_delay_mult = 0.01
80 | self.position_lr_max_steps = 30_000
81 | self.feature_lr = 0.0025
82 | self.opacity_lr = 0.05
83 | self.scaling_lr = 0.005
84 | self.rotation_lr = 0.001
85 | self.percent_dense = 0.01
86 | self.normal_loss = False
87 | self.sparse_loss = False
88 | self.flatten_loss = False
89 | self.depth_loss = False
90 | self.depth2normal_loss = False
91 | self.lambda_l1_normal = 0.01
92 | self.lambda_cos_normal = 0.01
93 | self.lambda_flatten = 100.0
94 | self.lambda_dssim = 0.2
95 | self.lambda_sparse = 0.001
96 | self.lambda_depth = 0.1
97 | self.lambda_depth2normal = 0.05
98 | self.densification_interval = 100
99 | self.opacity_reset_interval = 3000
100 | self.densify_from_iter = 500
101 | self.densify_until_iter = 15_000
102 | self.densify_grad_threshold = 0.0002
103 | self.random_background = False
104 |
105 | #propagation parameters
106 | self.dataset = 'waymo'
107 | self.propagation_interval = 20
108 | self.depth_error_min_threshold = 1.0
109 | self.depth_error_max_threshold = 1.0
110 | self.propagated_iteration_begin = 1000
111 | self.propagated_iteration_after = 12000
112 | self.patch_size = 20
113 | self.pair_path = ''
114 | super().__init__(parser, "Optimization Parameters")
115 |
116 | def get_combined_args(parser : ArgumentParser):
117 | cmdlne_string = sys.argv[1:]
118 | cfgfile_string = "Namespace()"
119 | args_cmdline = parser.parse_args(cmdlne_string)
120 |
121 | try:
122 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
123 | print("Looking for config file in", cfgfilepath)
124 | with open(cfgfilepath) as cfg_file:
125 | print("Config file found: {}".format(cfgfilepath))
126 | cfgfile_string = cfg_file.read()
127 | except TypeError:
128 | print("Config file not found at")
129 | pass
130 | args_cfgfile = eval(cfgfile_string)
131 |
132 | merged_dict = vars(args_cfgfile).copy()
133 | for k,v in vars(args_cmdline).items():
134 | if v != None:
135 | merged_dict[k] = v
136 | return Namespace(**merged_dict)
137 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import logging
14 | from argparse import ArgumentParser
15 | import shutil
16 |
17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
18 | parser = ArgumentParser("Colmap converter")
19 | parser.add_argument("--no_gpu", action='store_true')
20 | parser.add_argument("--skip_matching", action='store_true')
21 | parser.add_argument("--source_path", "-s", required=True, type=str)
22 | parser.add_argument("--camera", default="OPENCV", type=str)
23 | parser.add_argument("--colmap_executable", default="", type=str)
24 | parser.add_argument("--resize", action="store_true")
25 | parser.add_argument("--magick_executable", default="", type=str)
26 | args = parser.parse_args()
27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
29 | use_gpu = 1 if not args.no_gpu else 0
30 |
31 | if not args.skip_matching:
32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
33 |
34 | ## Feature extraction
35 | feat_extracton_cmd = colmap_command + " feature_extractor "\
36 | "--database_path " + args.source_path + "/distorted/database.db \
37 | --image_path " + args.source_path + "/input \
38 | --ImageReader.single_camera 1 \
39 | --ImageReader.camera_model " + args.camera + " \
40 | --ImageReader.mask_path " + args.source_path + "/mask" + " \
41 | --SiftExtraction.use_gpu " + str(use_gpu)
42 | exit_code = os.system(feat_extracton_cmd)
43 | if exit_code != 0:
44 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
45 | exit(exit_code)
46 |
47 | ## Feature matching
48 | feat_matching_cmd = colmap_command + " exhaustive_matcher \
49 | --database_path " + args.source_path + "/distorted/database.db \
50 | --SiftMatching.use_gpu " + str(use_gpu)
51 | exit_code = os.system(feat_matching_cmd)
52 | if exit_code != 0:
53 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
54 | exit(exit_code)
55 |
56 | ### Bundle adjustment
57 | # The default Mapper tolerance is unnecessarily large,
58 | # decreasing it speeds up bundle adjustment steps.
59 | mapper_cmd = (colmap_command + " mapper \
60 | --database_path " + args.source_path + "/distorted/database.db \
61 | --image_path " + args.source_path + "/input \
62 | --output_path " + args.source_path + "/distorted/sparse \
63 | --Mapper.ba_global_function_tolerance=0.000001")
64 | exit_code = os.system(mapper_cmd)
65 | if exit_code != 0:
66 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
67 | exit(exit_code)
68 |
69 | ### Image undistortion
70 | ## We need to undistort our images into ideal pinhole intrinsics.
71 | img_undist_cmd = (colmap_command + " image_undistorter \
72 | --image_path " + args.source_path + "/input \
73 | --input_path " + args.source_path + "/distorted/sparse/0 \
74 | --output_path " + args.source_path + "\
75 | --output_type COLMAP")
76 | exit_code = os.system(img_undist_cmd)
77 | if exit_code != 0:
78 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
79 | exit(exit_code)
80 |
81 | files = os.listdir(args.source_path + "/sparse")
82 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
83 | # Copy each file from the source directory to the destination directory
84 | for file in files:
85 | if file == '0':
86 | continue
87 | source_file = os.path.join(args.source_path, "sparse", file)
88 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
89 | shutil.move(source_file, destination_file)
90 |
91 | if(args.resize):
92 | print("Copying and resizing...")
93 |
94 | # Resize images.
95 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
96 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
97 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
98 | # Get the list of files in the source directory
99 | files = os.listdir(args.source_path + "/images")
100 | # Copy each file from the source directory to the destination directory
101 | for file in files:
102 | source_file = os.path.join(args.source_path, "images", file)
103 |
104 | destination_file = os.path.join(args.source_path, "images_2", file)
105 | shutil.copy2(source_file, destination_file)
106 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
107 | if exit_code != 0:
108 | logging.error(f"50% resize failed with code {exit_code}. Exiting.")
109 | exit(exit_code)
110 |
111 | destination_file = os.path.join(args.source_path, "images_4", file)
112 | shutil.copy2(source_file, destination_file)
113 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
114 | if exit_code != 0:
115 | logging.error(f"25% resize failed with code {exit_code}. Exiting.")
116 | exit(exit_code)
117 |
118 | destination_file = os.path.join(args.source_path, "images_8", file)
119 | shutil.copy2(source_file, destination_file)
120 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
121 | if exit_code != 0:
122 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
123 | exit(exit_code)
124 |
125 | print("Done.")
126 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use an official CUDA runtime as the base image
2 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
3 |
4 | # Set environment variables
5 | ENV DEBIAN_FRONTEND=noninteractive \
6 | CONDA_DIR=/opt/conda \
7 | CUDA_HOME=/usr/local/cuda \
8 | TORCH_CUDA_ARCH_LIST="7.5"
9 |
10 | # Add Conda to PATH
11 | ENV PATH=$CONDA_DIR/bin:$PATH
12 |
13 | # Install system dependencies
14 | RUN apt-get update && apt-get install -y \
15 | git \
16 | wget \
17 | build-essential \
18 | libgl1-mesa-glx \
19 | libglib2.0-0 \
20 | python3-dev \
21 | python3-pip \
22 | && rm -rf /var/lib/apt/lists/*
23 |
24 | # Install Miniconda
25 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \
26 | bash /tmp/miniconda.sh -b -p $CONDA_DIR && \
27 | rm /tmp/miniconda.sh && \
28 | $CONDA_DIR/bin/conda clean -afy
29 |
30 | # Initialize Conda
31 | RUN $CONDA_DIR/bin/conda init bash
32 |
33 | # Clone the GaussianPro repository
34 | RUN git clone https://github.com/kcheng1021/GaussianPro.git --recursive
35 |
36 | # Set the working directory
37 | WORKDIR /GaussianPro
38 |
39 | # Copy environment.yml into the Docker image
40 | COPY environment.yml .
41 |
42 | # Create the Conda environment and install additional packages
43 | # RUN /opt/conda/bin/conda env create -f environment.yml
44 |
45 | # Activate the environment
46 | RUN echo "source /opt/conda/etc/profile.d/conda.sh" >> /root/.bashrc && \
47 | echo "conda activate gaussianpro" >> /root/.bashrc
48 |
49 | # Create the Conda environment, install packages, and clean up in one RUN command
50 | RUN /bin/bash -c "source $CONDA_DIR/etc/profile.d/conda.sh && \
51 | conda env create -f environment.yml && \
52 | conda activate gaussianpro && \
53 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.6 -c pytorch -c conda-forge && \
54 | pip install --upgrade pip && \
55 | pip install ./submodules/Propagation && \
56 | pip install ./submodules/diff-gaussian-rasterization && \
57 | pip install ./submodules/simple-knn && \
58 | conda clean -afy"
59 |
60 | # Copy the entrypoint script into the Docker image
61 | COPY entrypoint.sh /entrypoint.sh
62 |
63 | # Make the entrypoint script executable
64 | RUN chmod +x /entrypoint.sh
65 |
66 | # Set the entrypoint to the script that activates the Conda environment
67 | ENTRYPOINT ["/entrypoint.sh"]
68 |
69 | # Set the default command to bash to keep the container running
70 | CMD ["/bin/bash"]
71 |
--------------------------------------------------------------------------------
/docker/build_gaussian_pro_docker.sh:
--------------------------------------------------------------------------------
1 | docker build --no-cache -t gaussian-pro .
--------------------------------------------------------------------------------
/docker/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | echo "Initializing Conda..."
5 | source /opt/conda/etc/profile.d/conda.sh
6 |
7 | echo "Activating Conda environment 'gaussianpro'..."
8 | conda activate gaussianpro
9 |
10 | echo "Conda environment activated: $(conda info --envs | grep '*' )"
11 |
12 | if [ "$#" -gt 0 ]; then
13 | echo "Executing command: $@"
14 | exec "$@"
15 | else
16 | echo "Starting interactive bash shell..."
17 | exec bash
18 | fi
19 |
--------------------------------------------------------------------------------
/docker/environment.yml:
--------------------------------------------------------------------------------
1 | name: gaussianpro
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.7
8 | - plyfile=0.8.1
9 | - python=3.7.13
10 | - pip=22.3.1
11 | - tqdm
12 | - ninja
13 | - opencv-python=4.10.0.84
14 | - matplotlib=3.5.3
15 | - open3d=0.17.0
16 | - imageio=2.31.2
17 |
--------------------------------------------------------------------------------
/docker/run_gaussian_pro_docker.sh:
--------------------------------------------------------------------------------
1 | docker run --rm --gpus all -it \
2 | --entrypoint /bin/bash \
3 | -v your_dataset_path:/GaussianPro/datasets \
4 | gaussian-pro
5 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gaussianpro
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.6
8 | - plyfile=0.8.1
9 | - python=3.7.13
10 | - pip=22.3.1
11 | - pytorch=1.12.1
12 | - torchaudio=0.12.1
13 | - torchvision=0.13.1
14 | - tqdm
15 | - pip:
16 | - submodules/diff-gaussian-rasterization
17 | - submodules/simple-knn
18 |
--------------------------------------------------------------------------------
/figs/comparison.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/comparison.gif
--------------------------------------------------------------------------------
/figs/effel_tower.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/effel_tower.mp4
--------------------------------------------------------------------------------
/figs/jianzhu_final_demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/jianzhu_final_demo.mp4
--------------------------------------------------------------------------------
/figs/jiaotang_final_demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/jiaotang_final_demo.mp4
--------------------------------------------------------------------------------
/figs/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/motivation.png
--------------------------------------------------------------------------------
/figs/output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/output.gif
--------------------------------------------------------------------------------
/figs/output1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/output1.gif
--------------------------------------------------------------------------------
/figs/output2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/output2.gif
--------------------------------------------------------------------------------
/figs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcheng1021/GaussianPro/b13a32329551d34219cd1be0375907a0954a898b/figs/pipeline.png
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | from scene.gaussian_model import GaussianModel
16 | from utils.sh_utils import eval_sh
17 | from utils.general_utils import build_rotation
18 | import torch.nn.functional as F
19 |
20 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None,
21 | return_depth = False, return_normal = False, return_opacity = False):
22 | """
23 | Render the scene.
24 |
25 | Background tensor (bg_color) must be on GPU!
26 | """
27 |
28 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
29 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
30 | try:
31 | screenspace_points.retain_grad()
32 | except:
33 | pass
34 |
35 | # Set up rasterization configuration
36 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
37 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
38 |
39 | raster_settings = GaussianRasterizationSettings(
40 | image_height=int(viewpoint_camera.image_height),
41 | image_width=int(viewpoint_camera.image_width),
42 | tanfovx=tanfovx,
43 | tanfovy=tanfovy,
44 | bg=bg_color,
45 | scale_modifier=scaling_modifier,
46 | viewmatrix=viewpoint_camera.world_view_transform,
47 | projmatrix=viewpoint_camera.full_proj_transform,
48 | sh_degree=pc.active_sh_degree,
49 | campos=viewpoint_camera.camera_center,
50 | prefiltered=False,
51 | debug=pipe.debug
52 | )
53 |
54 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
55 |
56 | means3D = pc.get_xyz
57 | means2D = screenspace_points
58 | opacity = pc.get_opacity
59 |
60 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
61 | # scaling / rotation by the rasterizer.
62 | scales = None
63 | rotations = None
64 | cov3D_precomp = None
65 | if pipe.compute_cov3D_python:
66 | cov3D_precomp = pc.get_covariance(scaling_modifier)
67 | else:
68 | scales = pc.get_scaling
69 | rotations = pc.get_rotation
70 |
71 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
72 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
73 | shs = None
74 | colors_precomp = None
75 | if override_color is None:
76 | if pipe.convert_SHs_python:
77 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
78 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
79 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
80 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
81 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
82 | else:
83 | shs = pc.get_features
84 | else:
85 | colors_precomp = override_color
86 |
87 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
88 | rendered_image, radii = rasterizer(
89 | means3D = means3D,
90 | means2D = means2D,
91 | shs = shs,
92 | colors_precomp = colors_precomp,
93 | opacities = opacity,
94 | scales = scales,
95 | rotations = rotations,
96 | cov3D_precomp = cov3D_precomp)
97 |
98 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
99 | # They will be excluded from value updates used in the splitting criteria.
100 | return_dict = {"render": rendered_image,
101 | "viewspace_points": screenspace_points,
102 | "visibility_filter" : radii > 0,
103 | "radii": radii}
104 |
105 | if return_depth:
106 | projvect1 = viewpoint_camera.world_view_transform[:,2][:3].detach()
107 | projvect2 = viewpoint_camera.world_view_transform[:,2][-1].detach()
108 | means3D_depth = (means3D * projvect1.unsqueeze(0)).sum(dim=-1,keepdim=True) + projvect2
109 | means3D_depth = means3D_depth.repeat(1,3)
110 | render_depth, _ = rasterizer(
111 | means3D = means3D,
112 | means2D = means2D,
113 | shs = None,
114 | colors_precomp = means3D_depth,
115 | opacities = opacity,
116 | scales = scales,
117 | rotations = rotations,
118 | cov3D_precomp = cov3D_precomp)
119 | render_depth = render_depth.mean(dim=0)
120 | return_dict.update({'render_depth': render_depth})
121 |
122 | if return_normal:
123 | rotations_mat = build_rotation(rotations)
124 | scales = pc.get_scaling
125 | min_scales = torch.argmin(scales, dim=1)
126 | indices = torch.arange(min_scales.shape[0])
127 | normal = rotations_mat[indices, :, min_scales]
128 |
129 | # convert normal direction to the camera; calculate the normal in the camera coordinate
130 | view_dir = means3D - viewpoint_camera.camera_center
131 | normal = normal * ((((view_dir * normal).sum(dim=-1) < 0) * 1 - 0.5) * 2)[...,None]
132 |
133 | R_w2c = torch.tensor(viewpoint_camera.R.T).cuda().to(torch.float32)
134 | normal = (R_w2c @ normal.transpose(0, 1)).transpose(0, 1)
135 |
136 | render_normal, _ = rasterizer(
137 | means3D = means3D,
138 | means2D = means2D,
139 | shs = None,
140 | colors_precomp = normal,
141 | opacities = opacity,
142 | scales = scales,
143 | rotations = rotations,
144 | cov3D_precomp = cov3D_precomp)
145 | render_normal = F.normalize(render_normal, dim = 0)
146 | return_dict.update({'render_normal': render_normal})
147 |
148 | if return_opacity:
149 | density = torch.ones_like(means3D)
150 |
151 | render_opacity, _ = rasterizer(
152 | means3D = means3D,
153 | means2D = means2D,
154 | shs = None,
155 | colors_precomp = density,
156 | opacities = opacity,
157 | scales = scales,
158 | rotations = rotations,
159 | cov3D_precomp = cov3D_precomp)
160 | return_dict.update({'render_opacity': render_opacity.mean(dim=0)})
161 |
162 | return return_dict
163 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 | def init(wish_host, wish_port):
27 | global host, port, listener
28 | host = wish_host
29 | port = wish_port
30 | listener.bind((host, port))
31 | listener.listen()
32 | listener.settimeout(0)
33 |
34 | def try_connect():
35 | global conn, addr, listener
36 | try:
37 | conn, addr = listener.accept()
38 | print(f"\nConnected by {addr}")
39 | conn.settimeout(None)
40 | except Exception as inst:
41 | pass
42 |
43 | def read():
44 | global conn
45 | messageLength = conn.recv(4)
46 | messageLength = int.from_bytes(messageLength, 'little')
47 | message = conn.recv(messageLength)
48 | return json.loads(message.decode("utf-8"))
49 |
50 | def send(message_bytes, verify):
51 | global conn
52 | if message_bytes != None:
53 | conn.sendall(message_bytes)
54 | conn.sendall(len(verify).to_bytes(4, 'little'))
55 | conn.sendall(bytes(verify, 'ascii'))
56 |
57 | def receive():
58 | message = read()
59 |
60 | width = message["resolution_x"]
61 | height = message["resolution_y"]
62 |
63 | if width != 0 and height != 0:
64 | try:
65 | do_training = bool(message["train"])
66 | fovy = message["fov_y"]
67 | fovx = message["fov_x"]
68 | znear = message["z_near"]
69 | zfar = message["z_far"]
70 | do_shs_python = bool(message["shs_python"])
71 | do_rot_scale_python = bool(message["rot_scale_python"])
72 | keep_alive = bool(message["keep_alive"])
73 | scaling_modifier = message["scaling_modifier"]
74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75 | world_view_transform[:,1] = -world_view_transform[:,1]
76 | world_view_transform[:,2] = -world_view_transform[:,2]
77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78 | full_proj_transform[:,1] = -full_proj_transform[:,1]
79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80 | except Exception as e:
81 | print("")
82 | traceback.print_exc()
83 | raise e
84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85 | else:
86 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | from lpipsPyTorch import lpips
19 | import json
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser
23 |
24 | def readImages(renders_dir, gt_dir):
25 | renders = []
26 | gts = []
27 | image_names = []
28 | for fname in os.listdir(renders_dir):
29 | render = Image.open(renders_dir / fname)
30 | gt = Image.open(gt_dir / fname)
31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
33 | image_names.append(fname)
34 | return renders, gts, image_names
35 |
36 | def evaluate(model_paths):
37 |
38 | full_dict = {}
39 | per_view_dict = {}
40 | full_dict_polytopeonly = {}
41 | per_view_dict_polytopeonly = {}
42 | print("")
43 |
44 | for scene_dir in model_paths:
45 | # try:
46 | print("Scene:", scene_dir)
47 | full_dict[scene_dir] = {}
48 | per_view_dict[scene_dir] = {}
49 | full_dict_polytopeonly[scene_dir] = {}
50 | per_view_dict_polytopeonly[scene_dir] = {}
51 |
52 | test_dir = Path(scene_dir) / "test"
53 |
54 | for method in os.listdir(test_dir):
55 | print("Method:", method)
56 |
57 | full_dict[scene_dir][method] = {}
58 | per_view_dict[scene_dir][method] = {}
59 | full_dict_polytopeonly[scene_dir][method] = {}
60 | per_view_dict_polytopeonly[scene_dir][method] = {}
61 |
62 | method_dir = test_dir / method
63 | gt_dir = method_dir/ "gt"
64 | renders_dir = method_dir / "renders"
65 | renders, gts, image_names = readImages(renders_dir, gt_dir)
66 |
67 | ssims = []
68 | psnrs = []
69 | lpipss = []
70 |
71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
72 | ssims.append(ssim(renders[idx], gts[idx]))
73 | psnrs.append(psnr(renders[idx], gts[idx]))
74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
75 |
76 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
77 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
78 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
79 | print("")
80 |
81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
82 | "PSNR": torch.tensor(psnrs).mean().item(),
83 | "LPIPS": torch.tensor(lpipss).mean().item()})
84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
87 |
88 | with open(scene_dir + "/results.json", 'w') as fp:
89 | json.dump(full_dict[scene_dir], fp, indent=True)
90 | with open(scene_dir + "/per_view.json", 'w') as fp:
91 | json.dump(per_view_dict[scene_dir], fp, indent=True)
92 | # except:
93 | # print("Unable to compute metrics for model", scene_dir)
94 |
95 | if __name__ == "__main__":
96 | device = torch.device("cuda:0")
97 | torch.cuda.set_device(device)
98 |
99 | # Set up command line argument parser
100 | parser = ArgumentParser(description="Training script parameters")
101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
102 | args = parser.parse_args()
103 | evaluate(args.model_paths)
104 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene
14 | import os
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | import torchvision
19 | from utils.general_utils import safe_state, vis_depth
20 | from argparse import ArgumentParser
21 | from arguments import ModelParams, PipelineParams, get_combined_args
22 | from gaussian_renderer import GaussianModel
23 | import imageio
24 | import numpy as np
25 |
26 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
27 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
28 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
29 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "render_depth")
30 | normal_path = os.path.join(model_path, name, "ours_{}".format(iteration), "render_normal")
31 |
32 | makedirs(render_path, exist_ok=True)
33 | makedirs(gts_path, exist_ok=True)
34 | makedirs(depth_path , exist_ok=True)
35 | makedirs(normal_path, exist_ok=True)
36 |
37 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
38 | renders = render(view, gaussians, pipeline, background, return_depth=True, return_normal=True)
39 | rendering = renders["render"]
40 | gt = view.original_image[0:3, :, :]
41 |
42 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
43 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
44 |
45 | render_depth = renders["render_depth"]
46 | if view.sky_mask is not None:
47 | render_depth[~(view.sky_mask.to(render_depth.device).to(torch.bool))] = 300
48 | render_depth = vis_depth(render_depth.detach().cpu().numpy())[0]
49 | imageio.imwrite(os.path.join(depth_path , '{0:05d}'.format(idx) + ".png"), render_depth)
50 |
51 | render_normal = (renders["render_normal"] + 1.0) / 2.0
52 | if view.sky_mask is not None:
53 | render_normal[~(view.sky_mask.to(rendering.device).to(torch.bool).unsqueeze(0).repeat(3, 1, 1))] = -10
54 | # render_normal = renders["render_normal"]
55 | np.save(os.path.join(normal_path, '{0:05d}'.format(idx) + ".png"), renders["render_normal"].detach().cpu().numpy())
56 | torchvision.utils.save_image(render_normal, os.path.join(normal_path, '{0:05d}'.format(idx) + ".png"))
57 | # normal_gt = torch.nn.functional.normalize(view.normal, p=2, dim=0)
58 | # render_normal_gt = (normal_gt + 1.0) / 2.0
59 | # torchvision.utils.save_image(render_normal_gt, os.path.join(normal_path, '{0:05d}'.format(idx) + "_normalgt.png"))
60 | # exit()
61 |
62 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
63 | with torch.no_grad():
64 | gaussians = GaussianModel(dataset.sh_degree)
65 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
66 |
67 | # gaussians._scaling[:, 0] = 0.001
68 | # gaussians._scaling[:, 1] = 0.0005
69 | # gaussians._scaling[:, 2] = -10000.0
70 | # gaussians._rotation[:, 0] = 1
71 | # gaussians._rotation[:, 1:] = 0
72 | scales = gaussians.get_scaling
73 |
74 | # min_scale, _ = torch.min(scales, dim=1)
75 | # max_scale, _ = torch.max(scales, dim=1)
76 | # median_scale, _ = torch.median(scales, dim=1)
77 | # print(min_scale)
78 | # print(max_scale)
79 |
80 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
81 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
82 |
83 | if not skip_train:
84 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
85 |
86 | if not skip_test:
87 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
88 |
89 | if __name__ == "__main__":
90 | # Set up command line argument parser
91 | parser = ArgumentParser(description="Testing script parameters")
92 | model = ModelParams(parser, sentinel=True)
93 | pipeline = PipelineParams(parser)
94 | parser.add_argument("--iteration", default=-1, type=int)
95 | parser.add_argument("--skip_train", action="store_true")
96 | parser.add_argument("--skip_test", action="store_true")
97 | parser.add_argument("--quiet", action="store_true")
98 | args = get_combined_args(parser)
99 | print("Rendering " + args.model_path)
100 |
101 | # Initialize system state (RNG)
102 | safe_state(args.quiet)
103 |
104 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
--------------------------------------------------------------------------------
/results/DeepBlending/drjohnson.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,29.25110626,0.907489359,0.223799944,784796221.4,3164513
3 | ,29.27627754,0.906062722,0.227002591,749050265.6,3020344
4 |
--------------------------------------------------------------------------------
/results/DeepBlending/playroom.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,30.33006096,0.918174982,0.219789669,496238592,2000956
3 | ,30.40084267,0.918608427,0.222872108,466092032,1879412
4 |
--------------------------------------------------------------------------------
/results/Eth3D/delivery_area.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,17.53661728,0.779447019,0.347055882,884998144,3568545
3 | ,19.40202713,0.817175806,0.3138583,844963512.3,3407113
4 |
--------------------------------------------------------------------------------
/results/Eth3D/electro.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,15.60790539,0.703940392,0.398457259,835337584.6,3368290
3 | ,16.02513695,0.704342365,0.393627644,849703075.8,3426235
4 |
--------------------------------------------------------------------------------
/results/Eth3D/kicker.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,22.60378647,0.765782118,0.37088266,341647032.3,1377615
3 | ,22.70046806,0.761263132,0.372863382,344551587.8,1389313
4 |
--------------------------------------------------------------------------------
/results/Eth3D/meadow.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,14.37556458,0.396413594,0.495654911,800986234.9,3229780
3 | ,14.46836758,0.379970253,0.487953991,763195555.8,3077388
4 |
--------------------------------------------------------------------------------
/results/Eth3D/office.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,17.37435532,0.805293322,0.322832406,173151354.9,698192
3 | ,17.59784698,0.809933305,0.325274765,175080734.7,705952
4 |
--------------------------------------------------------------------------------
/results/Eth3D/playground.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,15.42733574,0.490227938,0.435589731,1236292076,4985058
3 | ,15.45620441,0.499574512,0.431878805,1214292951,4896345
4 |
--------------------------------------------------------------------------------
/results/Eth3D/relief.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,26.46071815,0.829667032,0.286081493,355635036.2,1434025
3 | ,26.71376228,0.833274424,0.281712294,361140060.2,1456216
4 |
--------------------------------------------------------------------------------
/results/Eth3D/relief2.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,27.4475174,0.876014948,0.25743258,309319434.2,1247246
3 | ,27.19592476,0.872983813,0.258035332,313052364.8,1262298
4 |
--------------------------------------------------------------------------------
/results/Eth3D/terrace.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,20.75988007,0.780296981,0.276560545,409951273,1653027
3 | ,20.40294075,0.776602566,0.272293329,407738777.6,1644092
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/bicycle.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,25.06400108,0.747040689,0.240248889,1556464271,6276054
3 | ,25.12158394,0.747914374,0.244340554,1408059310,5677644
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/bonsai.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,32.31475449,0.946509063,0.18023923,319847137.3,1289692
3 | ,32.47714996,0.95010221,0.16658856,329525493.8,1328746
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/counter.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,29.10391808,0.915232599,0.181373119,299924193.3,1209366
3 | ,29.15148544,0.917220533,0.174561828,292699504.6,1180219
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/flowers.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,21.44571877,0.5914886,0.351740956,943477227.5,3804355
3 | ,21.35391426,0.589819193,0.352616757,808902983.7,3261679
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/garden.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,27.23822975,0.855030119,0.121553876,1459837993,5886426
3 | ,27.12190247,0.854436398,0.119427539,1228364841,4953084
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/kitchen.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,31.22276497,0.931057751,0.117309771,429842759.7,1733247
3 | ,31.52673721,0.932370305,0.115075439,431960883.2,1741790
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/room.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,31.28110695,0.922523379,0.200879946,377676103.7,1522889
3 | ,31.69297218,0.924869835,0.197009131,329515008,1328679
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/stump.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,26.64076614,0.772574306,0.234375477,1183664046,4772826
3 | ,26.60910416,0.770324588,0.239168838,1112067277,4484153
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/treehill.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,22.56472206,0.634812593,0.339756131,1025916273,4136754
3 | ,22.11814117,0.629052699,0.346725404,814082949.1,3282588
4 |
--------------------------------------------------------------------------------
/results/TanksAndTemples/train.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,21.91499519,0.823102057,0.230635419,266768220,1075666
3 | ,21.49312401,0.815916955,0.239080429,259522560,1046465
4 |
--------------------------------------------------------------------------------
/results/TanksAndTemples/truck.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | Baseline,26.26991844,0.901777327,0.138400272,448087982.1,1806784
3 | ,25.92964554,0.897741735,0.132989123,439835689,1773540
4 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 |
21 | class Scene:
22 |
23 | gaussians : GaussianModel
24 |
25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
26 | """b
27 | :param path: Path to colmap scene main folder.
28 | """
29 | self.model_path = args.model_path
30 | self.loaded_iter = None
31 | self.gaussians = gaussians
32 |
33 | if load_iteration:
34 | if load_iteration == -1:
35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
36 | else:
37 | self.loaded_iter = load_iteration
38 | print("Loading trained model at iteration {}".format(self.loaded_iter))
39 |
40 | self.train_cameras = {}
41 | self.test_cameras = {}
42 |
43 | if os.path.exists(os.path.join(args.source_path, "sparse")):
44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval,
45 | sky_seg=args.sky_seg, load_normal=args.load_normal, load_depth=args.load_depth)
46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
47 | print("Found transforms_train.json file, assuming Blender data set!")
48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
49 | else:
50 | assert False, "Could not recognize scene type!"
51 |
52 | if not self.loaded_iter:
53 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
54 | dest_file.write(src_file.read())
55 | json_cams = []
56 | camlist = []
57 | if scene_info.test_cameras:
58 | camlist.extend(scene_info.test_cameras)
59 | if scene_info.train_cameras:
60 | camlist.extend(scene_info.train_cameras)
61 | for id, cam in enumerate(camlist):
62 | json_cams.append(camera_to_JSON(id, cam))
63 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
64 | json.dump(json_cams, file)
65 |
66 | # if shuffle:
67 | # random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
68 | # random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
69 |
70 | self.cameras_extent = scene_info.nerf_normalization["radius"]
71 |
72 | for resolution_scale in resolution_scales:
73 | print("Loading Training Cameras")
74 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
75 | print("Loading Test Cameras")
76 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
77 |
78 | if self.loaded_iter:
79 | self.gaussians.load_ply(os.path.join(self.model_path,
80 | "point_cloud",
81 | "iteration_" + str(self.loaded_iter),
82 | "point_cloud.ply"))
83 | else:
84 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
85 |
86 | def save(self, iteration):
87 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
88 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
89 |
90 | def getTrainCameras(self, scale=1.0):
91 | return self.train_cameras[scale]
92 |
93 | def getTestCameras(self, scale=1.0):
94 | return self.test_cameras[scale]
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 |
17 | class Camera(nn.Module):
18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
19 | image_name, uid,
20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", K=None,
21 | sky_mask=None, normal=None, depth=None
22 | ):
23 | super(Camera, self).__init__()
24 |
25 | self.uid = uid
26 | self.colmap_id = colmap_id
27 | self.R = R
28 | self.T = T
29 | self.FoVx = FoVx
30 | self.FoVy = FoVy
31 | self.image_name = image_name
32 | self.sky_mask = sky_mask
33 | self.normal = normal
34 | self.depth = depth
35 |
36 | try:
37 | self.data_device = torch.device(data_device)
38 | except Exception as e:
39 | print(e)
40 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
41 | self.data_device = torch.device("cuda")
42 |
43 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
44 | self.image_width = self.original_image.shape[2]
45 | self.image_height = self.original_image.shape[1]
46 |
47 | if gt_alpha_mask is not None:
48 | self.original_image *= gt_alpha_mask.to(self.data_device)
49 | else:
50 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
51 |
52 | self.K = torch.tensor([[K[0], 0, K[2]],
53 | [0, K[1], K[3]],
54 | [0, 0, 1]]).to(self.data_device).to(torch.float32)
55 |
56 | self.zfar = 100.0
57 | self.znear = 0.01
58 |
59 | self.trans = trans
60 | self.scale = scale
61 |
62 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
63 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
64 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
65 | self.camera_center = self.world_view_transform.inverse()[3, :3]
66 |
67 | class MiniCam:
68 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
69 | self.image_width = width
70 | self.image_height = height
71 | self.FoVy = fovy
72 | self.FoVx = fovx
73 | self.znear = znear
74 | self.zfar = zfar
75 | self.world_view_transform = world_view_transform
76 | self.full_proj_transform = full_proj_transform
77 | view_inv = torch.inverse(self.world_view_transform)
78 | self.camera_center = view_inv[3][:3]
79 |
80 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54 |
55 | def rotmat2qvec(R):
56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57 | K = np.array([
58 | [Rxx - Ryy - Rzz, 0, 0, 0],
59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62 | eigvals, eigvecs = np.linalg.eigh(K)
63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64 | if qvec[0] < 0:
65 | qvec *= -1
66 | return qvec
67 |
68 | class Image(BaseImage):
69 | def qvec2rotmat(self):
70 | return qvec2rotmat(self.qvec)
71 |
72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73 | """Read and unpack the next bytes from a binary file.
74 | :param fid:
75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77 | :param endian_character: Any of {@, =, <, >, !}
78 | :return: Tuple of read and unpacked values.
79 | """
80 | data = fid.read(num_bytes)
81 | return struct.unpack(endian_character + format_char_sequence, data)
82 |
83 | def read_points3D_text(path):
84 | """
85 | see: src/base/reconstruction.cc
86 | void Reconstruction::ReadPoints3DText(const std::string& path)
87 | void Reconstruction::WritePoints3DText(const std::string& path)
88 | """
89 | xyzs = None
90 | rgbs = None
91 | errors = None
92 | num_points = 0
93 | with open(path, "r") as fid:
94 | while True:
95 | line = fid.readline()
96 | if not line:
97 | break
98 | line = line.strip()
99 | if len(line) > 0 and line[0] != "#":
100 | num_points += 1
101 |
102 |
103 | xyzs = np.empty((num_points, 3))
104 | rgbs = np.empty((num_points, 3))
105 | errors = np.empty((num_points, 1))
106 | count = 0
107 | with open(path, "r") as fid:
108 | while True:
109 | line = fid.readline()
110 | if not line:
111 | break
112 | line = line.strip()
113 | if len(line) > 0 and line[0] != "#":
114 | elems = line.split()
115 | xyz = np.array(tuple(map(float, elems[1:4])))
116 | rgb = np.array(tuple(map(int, elems[4:7])))
117 | error = np.array(float(elems[7]))
118 | xyzs[count] = xyz
119 | rgbs[count] = rgb
120 | errors[count] = error
121 | count += 1
122 |
123 | return xyzs, rgbs, errors
124 |
125 | def read_points3D_binary(path_to_model_file):
126 | """
127 | see: src/base/reconstruction.cc
128 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
129 | void Reconstruction::WritePoints3DBinary(const std::string& path)
130 | """
131 |
132 |
133 | with open(path_to_model_file, "rb") as fid:
134 | num_points = read_next_bytes(fid, 8, "Q")[0]
135 |
136 | xyzs = np.empty((num_points, 3))
137 | rgbs = np.empty((num_points, 3))
138 | errors = np.empty((num_points, 1))
139 |
140 | for p_id in range(num_points):
141 | binary_point_line_properties = read_next_bytes(
142 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
143 | xyz = np.array(binary_point_line_properties[1:4])
144 | rgb = np.array(binary_point_line_properties[4:7])
145 | error = np.array(binary_point_line_properties[7])
146 | track_length = read_next_bytes(
147 | fid, num_bytes=8, format_char_sequence="Q")[0]
148 | track_elems = read_next_bytes(
149 | fid, num_bytes=8*track_length,
150 | format_char_sequence="ii"*track_length)
151 | xyzs[p_id] = xyz
152 | rgbs[p_id] = rgb
153 | errors[p_id] = error
154 | return xyzs, rgbs, errors
155 |
156 | def read_intrinsics_text(path):
157 | """
158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159 | """
160 | cameras = {}
161 | with open(path, "r") as fid:
162 | while True:
163 | line = fid.readline()
164 | if not line:
165 | break
166 | line = line.strip()
167 | if len(line) > 0 and line[0] != "#":
168 | elems = line.split()
169 | camera_id = int(elems[0])
170 | model = elems[1]
171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172 | width = int(elems[2])
173 | height = int(elems[3])
174 | params = np.array(tuple(map(float, elems[4:])))
175 | cameras[camera_id] = Camera(id=camera_id, model=model,
176 | width=width, height=height,
177 | params=params)
178 | return cameras
179 |
180 | def read_extrinsics_binary(path_to_model_file):
181 | """
182 | see: src/base/reconstruction.cc
183 | void Reconstruction::ReadImagesBinary(const std::string& path)
184 | void Reconstruction::WriteImagesBinary(const std::string& path)
185 | """
186 | images = {}
187 | with open(path_to_model_file, "rb") as fid:
188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189 | for _ in range(num_reg_images):
190 | binary_image_properties = read_next_bytes(
191 | fid, num_bytes=64, format_char_sequence="idddddddi")
192 | image_id = binary_image_properties[0]
193 | qvec = np.array(binary_image_properties[1:5])
194 | tvec = np.array(binary_image_properties[5:8])
195 | camera_id = binary_image_properties[8]
196 | image_name = ""
197 | current_char = read_next_bytes(fid, 1, "c")[0]
198 | while current_char != b"\x00": # look for the ASCII 0 entry
199 | image_name += current_char.decode("utf-8")
200 | current_char = read_next_bytes(fid, 1, "c")[0]
201 | num_points2D = read_next_bytes(fid, num_bytes=8,
202 | format_char_sequence="Q")[0]
203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204 | format_char_sequence="ddq"*num_points2D)
205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206 | tuple(map(float, x_y_id_s[1::3]))])
207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208 | images[image_id] = Image(
209 | id=image_id, qvec=qvec, tvec=tvec,
210 | camera_id=camera_id, name=image_name,
211 | xys=xys, point3D_ids=point3D_ids)
212 | return images
213 |
214 |
215 | def read_intrinsics_binary(path_to_model_file):
216 | """
217 | see: src/base/reconstruction.cc
218 | void Reconstruction::WriteCamerasBinary(const std::string& path)
219 | void Reconstruction::ReadCamerasBinary(const std::string& path)
220 | """
221 | cameras = {}
222 | with open(path_to_model_file, "rb") as fid:
223 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
224 | for _ in range(num_cameras):
225 | camera_properties = read_next_bytes(
226 | fid, num_bytes=24, format_char_sequence="iiQQ")
227 | camera_id = camera_properties[0]
228 | model_id = camera_properties[1]
229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230 | width = camera_properties[2]
231 | height = camera_properties[3]
232 | num_params = CAMERA_MODEL_IDS[model_id].num_params
233 | params = read_next_bytes(fid, num_bytes=8*num_params,
234 | format_char_sequence="d"*num_params)
235 | cameras[camera_id] = Camera(id=camera_id,
236 | model=model_name,
237 | width=width,
238 | height=height,
239 | params=np.array(params))
240 | assert len(cameras) == num_cameras
241 | return cameras
242 |
243 |
244 | def read_extrinsics_text(path):
245 | """
246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247 | """
248 | images = {}
249 | with open(path, "r") as fid:
250 | while True:
251 | line = fid.readline()
252 | if not line:
253 | break
254 | line = line.strip()
255 | if len(line) > 0 and line[0] != "#":
256 | elems = line.split()
257 | image_id = int(elems[0])
258 | qvec = np.array(tuple(map(float, elems[1:5])))
259 | tvec = np.array(tuple(map(float, elems[5:8])))
260 | camera_id = int(elems[8])
261 | image_name = elems[9]
262 | elems = fid.readline().split()
263 | xys = np.column_stack([tuple(map(float, elems[0::3])),
264 | tuple(map(float, elems[1::3]))])
265 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
266 | images[image_id] = Image(
267 | id=image_id, qvec=qvec, tvec=tvec,
268 | camera_id=camera_id, name=image_name,
269 | xys=xys, point3D_ids=point3D_ids)
270 | return images
271 |
272 |
273 | def read_colmap_bin_array(path):
274 | """
275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276 |
277 | :param path: path to the colmap binary file.
278 | :return: nd array with the floating point values in the value
279 | """
280 | with open(path, "rb") as fid:
281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282 | usecols=(0, 1, 2), dtype=int)
283 | fid.seek(0)
284 | num_delimiter = 0
285 | byte = fid.read(1)
286 | while True:
287 | if byte == b"&":
288 | num_delimiter += 1
289 | if num_delimiter >= 3:
290 | break
291 | byte = fid.read(1)
292 | array = np.fromfile(fid, np.float32)
293 | array = array.reshape((width, height, channels), order="F")
294 | return np.transpose(array, (1, 0, 2)).squeeze()
295 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from typing import NamedTuple
16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19 | import numpy as np
20 | import json
21 | from pathlib import Path
22 | from plyfile import PlyData, PlyElement
23 | from utils.sh_utils import SH2RGB
24 | from scene.gaussian_model import BasicPointCloud
25 | import open3d as o3d
26 |
27 | class CameraInfo(NamedTuple):
28 | uid: int
29 | R: np.array
30 | T: np.array
31 | FovY: np.array
32 | FovX: np.array
33 | image: np.array
34 | image_path: str
35 | image_name: str
36 | width: int
37 | height: int
38 | K: np.array
39 | sky_mask: np.array
40 | normal: np.array
41 | depth: np.array
42 |
43 | class SceneInfo(NamedTuple):
44 | point_cloud: BasicPointCloud
45 | train_cameras: list
46 | test_cameras: list
47 | nerf_normalization: dict
48 | ply_path: str
49 |
50 | def getNerfppNorm(cam_info):
51 | def get_center_and_diag(cam_centers):
52 | cam_centers = np.hstack(cam_centers)
53 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
54 | center = avg_cam_center
55 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
56 | diagonal = np.max(dist)
57 | return center.flatten(), diagonal
58 |
59 | cam_centers = []
60 |
61 | for cam in cam_info:
62 | W2C = getWorld2View2(cam.R, cam.T)
63 | C2W = np.linalg.inv(W2C)
64 | cam_centers.append(C2W[:3, 3:4])
65 |
66 | center, diagonal = get_center_and_diag(cam_centers)
67 | radius = diagonal * 1.1
68 |
69 | translate = -center
70 |
71 | return {"translate": translate, "radius": radius}
72 |
73 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, sky_seg=False, load_normal=False, load_depth=False):
74 | cam_infos = []
75 | for idx, key in enumerate(cam_extrinsics):
76 | sys.stdout.write('\r')
77 | # the exact output you're looking for:
78 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
79 | sys.stdout.flush()
80 |
81 | extr = cam_extrinsics[key]
82 | intr = cam_intrinsics[extr.camera_id]
83 |
84 | height = intr.height
85 | width = intr.width
86 |
87 | uid = intr.id
88 | R = np.transpose(qvec2rotmat(extr.qvec))
89 | T = np.array(extr.tvec)
90 |
91 | if intr.model=="SIMPLE_PINHOLE":
92 | focal_length_x = intr.params[0]
93 | FovY = focal2fov(focal_length_x, height)
94 | FovX = focal2fov(focal_length_x, width)
95 | elif intr.model=="PINHOLE":
96 | focal_length_x = intr.params[0]
97 | focal_length_y = intr.params[1]
98 | FovY = focal2fov(focal_length_y, height)
99 | FovX = focal2fov(focal_length_x, width)
100 | else:
101 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
102 |
103 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
104 | image_name = os.path.basename(image_path).split(".")[0]
105 |
106 | image = Image.open(image_path)
107 |
108 | # #sky mask
109 | if sky_seg:
110 | sky_path = image_path.replace("images", "mask")[:-4]+".npy"
111 | sky_mask = np.load(sky_path).astype(np.uint8)
112 | else:
113 | sky_mask = None
114 |
115 | if load_normal:
116 | normal_path = image_path.replace("images", "normals")[:-4]+".npy"
117 | normal = np.load(normal_path).astype(np.float32)
118 | normal = (normal - 0.5) * 2.0
119 | else:
120 | normal = None
121 |
122 | if load_depth:
123 | # depth_path = image_path.replace("images", "monodepth")[:-4]+".npy"
124 | depth_path = image_path.replace("images", "metricdepth")[:-4]+".npy"
125 | depth = np.load(depth_path).astype(np.float32)
126 | else:
127 | depth = None
128 |
129 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
130 | image_path=image_path, image_name=image_name, width=width, height=height,
131 | K=intr.params, sky_mask=sky_mask, normal=normal, depth=depth)
132 | cam_infos.append(cam_info)
133 | sys.stdout.write('\n')
134 | return cam_infos
135 |
136 | def fetchPly(path):
137 | plydata = PlyData.read(path)
138 | vertices = plydata['vertex']
139 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
140 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
141 | # normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
142 | normals = np.zeros_like(positions)
143 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
144 |
145 | def storePly(path, xyz, rgb):
146 | # Define the dtype for the structured array
147 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
148 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
149 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
150 |
151 | normals = np.zeros_like(xyz)
152 |
153 | elements = np.empty(xyz.shape[0], dtype=dtype)
154 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
155 | elements[:] = list(map(tuple, attributes))
156 |
157 | # Create the PlyData object and write to file
158 | vertex_element = PlyElement.describe(elements, 'vertex')
159 | ply_data = PlyData([vertex_element])
160 | ply_data.write(path)
161 |
162 | def readColmapSceneInfo(path, images, eval, llffhold=8, sky_seg=False, load_normal=False, load_depth=False):
163 | try:
164 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
165 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
166 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
167 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
168 | except:
169 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
170 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
171 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
172 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
173 |
174 | reading_dir = "images" if images == None else images
175 |
176 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir),
177 | sky_seg=sky_seg, load_normal=load_normal, load_depth=load_depth)
178 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
179 |
180 | if eval:
181 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
182 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
183 | if 'waymo' in path:
184 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != (llffhold-1)]
185 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == (llffhold-1)]
186 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % (llffhold * 3) >= 3]
187 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % (llffhold * 3) < 3]
188 | else:
189 | train_cam_infos = cam_infos
190 | test_cam_infos = []
191 |
192 | nerf_normalization = getNerfppNorm(train_cam_infos)
193 |
194 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
195 |
196 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
197 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
198 | if not os.path.exists(ply_path):
199 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
200 | try:
201 | xyz, rgb, _ = read_points3D_binary(bin_path)
202 | except:
203 | xyz, rgb, _ = read_points3D_text(txt_path)
204 | storePly(ply_path, xyz, rgb)
205 | try:
206 | pcd = fetchPly(ply_path)
207 | except:
208 | pcd = None
209 |
210 | scene_info = SceneInfo(point_cloud=pcd,
211 | train_cameras=train_cam_infos,
212 | test_cameras=test_cam_infos,
213 | nerf_normalization=nerf_normalization,
214 | ply_path=ply_path)
215 | return scene_info
216 |
217 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png", is_train=True):
218 | cam_infos = []
219 |
220 | with open(os.path.join(path, transformsfile)) as json_file:
221 | contents = json.load(json_file)
222 | fovx = contents["camera_angle_x"]
223 |
224 | frames = contents["frames"]
225 | for idx, frame in enumerate(frames):
226 | cam_name = os.path.join(path, frame["file_path"] + extension)
227 |
228 | # NeRF 'transform_matrix' is a camera-to-world transform
229 | c2w = np.array(frame["transform_matrix"])
230 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
231 | c2w[:3, 1:3] *= -1
232 |
233 | # get the world-to-camera transform and set R, T
234 | w2c = np.linalg.inv(c2w)
235 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
236 | T = w2c[:3, 3]
237 |
238 | image_path = os.path.join(path, cam_name)
239 | image_name = Path(cam_name).stem
240 | image = Image.open(image_path)
241 |
242 | im_data = np.array(image.convert("RGBA"))
243 |
244 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
245 |
246 | norm_data = im_data / 255.0
247 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
248 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
249 |
250 | sky_mask = np.ones_like(image)[:, :, 0].astype(np.uint8)
251 |
252 | if is_train:
253 | normal_path = image_path.replace("train", "normals")[:-4]+".npy"
254 | normal = np.load(normal_path).astype(np.float32)
255 | normal = (normal - 0.5) * 2.0
256 | # normal[2, :, :] *= -1
257 | else:
258 | normal = np.zeros_like(image).transpose(2, 0, 1)
259 |
260 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
261 | FovY = fovy
262 | FovX = fovx
263 |
264 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
265 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1],
266 | K=np.array([1, 2, 3, 4]), sky_mask=sky_mask, normal=normal))
267 |
268 | return cam_infos
269 |
270 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
271 | print("Reading Training Transforms")
272 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
273 | print("Reading Test Transforms")
274 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension, is_train=False)
275 |
276 | if not eval:
277 | train_cam_infos.extend(test_cam_infos)
278 | test_cam_infos = []
279 |
280 | nerf_normalization = getNerfppNorm(train_cam_infos)
281 |
282 | ply_path = os.path.join(path, "points3d.ply")
283 | if not os.path.exists(ply_path):
284 | # Since this data set has no colmap data, we start with random points
285 | num_pts = 100_000
286 | print(f"Generating random point cloud ({num_pts})...")
287 |
288 | # We create random points inside the bounds of the synthetic Blender scenes
289 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
290 | shs = np.random.random((num_pts, 3)) / 255.0
291 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
292 |
293 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
294 | try:
295 | pcd = fetchPly(ply_path)
296 | except:
297 | pcd = None
298 |
299 | scene_info = SceneInfo(point_cloud=pcd,
300 | train_cameras=train_cam_infos,
301 | test_cameras=test_cam_infos,
302 | nerf_normalization=nerf_normalization,
303 | ply_path=ply_path)
304 | return scene_info
305 |
306 | sceneLoadTypeCallbacks = {
307 | "Colmap": readColmapSceneInfo,
308 | "Blender" : readNerfSyntheticInfo
309 | }
--------------------------------------------------------------------------------
/scene/gaussian_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import numpy as np
14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
15 | from torch import nn
16 | import os
17 | from utils.system_utils import mkdir_p
18 | from plyfile import PlyData, PlyElement
19 | from utils.sh_utils import RGB2SH
20 | from simple_knn._C import distCUDA2
21 | from utils.graphics_utils import BasicPointCloud, getWorld2View2
22 | from utils.general_utils import strip_symmetric, build_scaling_rotation
23 |
24 | class GaussianModel:
25 |
26 | def setup_functions(self):
27 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
28 | L = build_scaling_rotation(scaling_modifier * scaling, rotation)
29 | actual_covariance = L @ L.transpose(1, 2)
30 | symm = strip_symmetric(actual_covariance)
31 | return symm
32 |
33 | self.scaling_activation = torch.exp
34 | self.scaling_inverse_activation = torch.log
35 |
36 | self.covariance_activation = build_covariance_from_scaling_rotation
37 |
38 | self.opacity_activation = torch.sigmoid
39 | self.inverse_opacity_activation = inverse_sigmoid
40 |
41 | self.rotation_activation = torch.nn.functional.normalize
42 |
43 |
44 | def __init__(self, sh_degree : int):
45 | self.active_sh_degree = 0
46 | self.max_sh_degree = sh_degree
47 | self._xyz = torch.empty(0)
48 | self._features_dc = torch.empty(0)
49 | self._features_rest = torch.empty(0)
50 | self._scaling = torch.empty(0)
51 | self._rotation = torch.empty(0)
52 | self._opacity = torch.empty(0)
53 | self.max_radii2D = torch.empty(0)
54 | self.xyz_gradient_accum = torch.empty(0)
55 | self.denom = torch.empty(0)
56 | self.optimizer = None
57 | self.percent_dense = 0
58 | self.spatial_lr_scale = 0
59 | self.setup_functions()
60 |
61 | def capture(self):
62 | return (
63 | self.active_sh_degree,
64 | self._xyz,
65 | self._features_dc,
66 | self._features_rest,
67 | self._scaling,
68 | self._rotation,
69 | self._opacity,
70 | self.max_radii2D,
71 | self.xyz_gradient_accum,
72 | self.denom,
73 | self.optimizer.state_dict(),
74 | self.spatial_lr_scale,
75 | )
76 |
77 | def restore(self, model_args, training_args):
78 | (self.active_sh_degree,
79 | self._xyz,
80 | self._features_dc,
81 | self._features_rest,
82 | self._scaling,
83 | self._rotation,
84 | self._opacity,
85 | self.max_radii2D,
86 | xyz_gradient_accum,
87 | denom,
88 | opt_dict,
89 | self.spatial_lr_scale) = model_args
90 | self.training_setup(training_args)
91 | self.xyz_gradient_accum = xyz_gradient_accum
92 | self.denom = denom
93 | self.optimizer.load_state_dict(opt_dict)
94 |
95 | @property
96 | def get_scaling(self):
97 | return self.scaling_activation(self._scaling)
98 |
99 | @property
100 | def get_rotation(self):
101 | return self.rotation_activation(self._rotation)
102 |
103 | @property
104 | def get_xyz(self):
105 | return self._xyz
106 |
107 | @property
108 | def get_features(self):
109 | features_dc = self._features_dc
110 | features_rest = self._features_rest
111 | return torch.cat((features_dc, features_rest), dim=1)
112 |
113 | @property
114 | def get_opacity(self):
115 | return self.opacity_activation(self._opacity)
116 |
117 | def get_covariance(self, scaling_modifier = 1):
118 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
119 |
120 | def oneupSHdegree(self):
121 | if self.active_sh_degree < self.max_sh_degree:
122 | self.active_sh_degree += 1
123 |
124 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
125 | self.spatial_lr_scale = spatial_lr_scale
126 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
127 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
128 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
129 | features[:, :3, 0 ] = fused_color
130 | features[:, 3:, 1:] = 0.0
131 |
132 | print("Number of points at initialisation : ", fused_point_cloud.shape[0])
133 |
134 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
135 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
136 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
137 | rots[:, 0] = 1
138 |
139 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
140 |
141 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
142 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
143 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
144 | self._scaling = nn.Parameter(scales.requires_grad_(True))
145 | self._rotation = nn.Parameter(rots.requires_grad_(True))
146 | self._opacity = nn.Parameter(opacities.requires_grad_(True))
147 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
148 |
149 | def training_setup(self, training_args):
150 | self.percent_dense = training_args.percent_dense
151 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
152 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
153 |
154 | l = [
155 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
156 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
157 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
158 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
159 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
160 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
161 | ]
162 |
163 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
164 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
165 | lr_final=training_args.position_lr_final*self.spatial_lr_scale,
166 | lr_delay_mult=training_args.position_lr_delay_mult,
167 | max_steps=training_args.position_lr_max_steps)
168 |
169 | def update_learning_rate(self, iteration):
170 | ''' Learning rate scheduling per step '''
171 | for param_group in self.optimizer.param_groups:
172 | if param_group["name"] == "xyz":
173 | lr = self.xyz_scheduler_args(iteration)
174 | param_group['lr'] = lr
175 | return lr
176 |
177 | def construct_list_of_attributes(self):
178 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
179 | # All channels except the 3 DC
180 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
181 | l.append('f_dc_{}'.format(i))
182 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
183 | l.append('f_rest_{}'.format(i))
184 | l.append('opacity')
185 | for i in range(self._scaling.shape[1]):
186 | l.append('scale_{}'.format(i))
187 | for i in range(self._rotation.shape[1]):
188 | l.append('rot_{}'.format(i))
189 | return l
190 |
191 | def save_ply(self, path):
192 | mkdir_p(os.path.dirname(path))
193 |
194 | xyz = self._xyz.detach().cpu().numpy()
195 | normals = np.zeros_like(xyz)
196 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
197 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
198 | opacities = self._opacity.detach().cpu().numpy()
199 | scale = self._scaling.detach().cpu().numpy()
200 | rotation = self._rotation.detach().cpu().numpy()
201 |
202 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
203 |
204 | elements = np.empty(xyz.shape[0], dtype=dtype_full)
205 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
206 | elements[:] = list(map(tuple, attributes))
207 | el = PlyElement.describe(elements, 'vertex')
208 | PlyData([el]).write(path)
209 |
210 | def reset_opacity(self):
211 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
212 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
213 | self._opacity = optimizable_tensors["opacity"]
214 |
215 | def load_ply(self, path):
216 | plydata = PlyData.read(path)
217 |
218 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
219 | np.asarray(plydata.elements[0]["y"]),
220 | np.asarray(plydata.elements[0]["z"])), axis=1)
221 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
222 |
223 | features_dc = np.zeros((xyz.shape[0], 3, 1))
224 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
225 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
226 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
227 |
228 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
229 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
230 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
231 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
232 | for idx, attr_name in enumerate(extra_f_names):
233 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
234 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
235 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
236 |
237 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
238 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
239 | scales = np.zeros((xyz.shape[0], len(scale_names)))
240 | for idx, attr_name in enumerate(scale_names):
241 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
242 |
243 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
244 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
245 | rots = np.zeros((xyz.shape[0], len(rot_names)))
246 | for idx, attr_name in enumerate(rot_names):
247 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
248 |
249 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
250 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
251 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
252 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
253 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
254 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
255 |
256 | self.active_sh_degree = self.max_sh_degree
257 |
258 | def replace_tensor_to_optimizer(self, tensor, name):
259 | optimizable_tensors = {}
260 | for group in self.optimizer.param_groups:
261 | if group["name"] == name:
262 | stored_state = self.optimizer.state.get(group['params'][0], None)
263 | stored_state["exp_avg"] = torch.zeros_like(tensor)
264 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
265 |
266 | del self.optimizer.state[group['params'][0]]
267 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
268 | self.optimizer.state[group['params'][0]] = stored_state
269 |
270 | optimizable_tensors[group["name"]] = group["params"][0]
271 | return optimizable_tensors
272 |
273 | def _prune_optimizer(self, mask):
274 | optimizable_tensors = {}
275 | for group in self.optimizer.param_groups:
276 | stored_state = self.optimizer.state.get(group['params'][0], None)
277 | if stored_state is not None:
278 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
279 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
280 |
281 | del self.optimizer.state[group['params'][0]]
282 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
283 | self.optimizer.state[group['params'][0]] = stored_state
284 |
285 | optimizable_tensors[group["name"]] = group["params"][0]
286 | else:
287 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
288 | optimizable_tensors[group["name"]] = group["params"][0]
289 | return optimizable_tensors
290 |
291 | def prune_points(self, mask):
292 | valid_points_mask = ~mask
293 | optimizable_tensors = self._prune_optimizer(valid_points_mask)
294 |
295 | self._xyz = optimizable_tensors["xyz"]
296 | self._features_dc = optimizable_tensors["f_dc"]
297 | self._features_rest = optimizable_tensors["f_rest"]
298 | self._opacity = optimizable_tensors["opacity"]
299 | self._scaling = optimizable_tensors["scaling"]
300 | self._rotation = optimizable_tensors["rotation"]
301 |
302 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
303 |
304 | self.denom = self.denom[valid_points_mask]
305 | self.max_radii2D = self.max_radii2D[valid_points_mask]
306 |
307 | def cat_tensors_to_optimizer(self, tensors_dict):
308 | optimizable_tensors = {}
309 | for group in self.optimizer.param_groups:
310 | assert len(group["params"]) == 1
311 | extension_tensor = tensors_dict[group["name"]]
312 | stored_state = self.optimizer.state.get(group['params'][0], None)
313 | if stored_state is not None:
314 |
315 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
316 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
317 |
318 | del self.optimizer.state[group['params'][0]]
319 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
320 | self.optimizer.state[group['params'][0]] = stored_state
321 |
322 | optimizable_tensors[group["name"]] = group["params"][0]
323 | else:
324 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
325 | optimizable_tensors[group["name"]] = group["params"][0]
326 |
327 | return optimizable_tensors
328 |
329 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
330 | d = {"xyz": new_xyz,
331 | "f_dc": new_features_dc,
332 | "f_rest": new_features_rest,
333 | "opacity": new_opacities,
334 | "scaling" : new_scaling,
335 | "rotation" : new_rotation}
336 |
337 | optimizable_tensors = self.cat_tensors_to_optimizer(d)
338 | self._xyz = optimizable_tensors["xyz"]
339 | self._features_dc = optimizable_tensors["f_dc"]
340 | self._features_rest = optimizable_tensors["f_rest"]
341 | self._opacity = optimizable_tensors["opacity"]
342 | self._scaling = optimizable_tensors["scaling"]
343 | self._rotation = optimizable_tensors["rotation"]
344 |
345 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
346 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
347 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
348 |
349 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
350 | n_init_points = self.get_xyz.shape[0]
351 | # Extract points that satisfy the gradient condition
352 | padded_grad = torch.zeros((n_init_points), device="cuda")
353 | padded_grad[:grads.shape[0]] = grads.squeeze()
354 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
355 | selected_pts_mask = torch.logical_and(selected_pts_mask,
356 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
357 |
358 | stds = self.get_scaling[selected_pts_mask].repeat(N,1)
359 | means =torch.zeros((stds.size(0), 3),device="cuda")
360 | samples = torch.normal(mean=means, std=stds)
361 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
362 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
363 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
364 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
365 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
366 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
367 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
368 |
369 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
370 |
371 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
372 | self.prune_points(prune_filter)
373 |
374 | def densify_and_clone(self, grads, grad_threshold, scene_extent):
375 | # Extract points that satisfy the gradient condition
376 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
377 | selected_pts_mask = torch.logical_and(selected_pts_mask,
378 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
379 |
380 | new_xyz = self._xyz[selected_pts_mask]
381 | new_features_dc = self._features_dc[selected_pts_mask]
382 | new_features_rest = self._features_rest[selected_pts_mask]
383 | new_opacities = self._opacity[selected_pts_mask]
384 | new_scaling = self._scaling[selected_pts_mask]
385 | new_rotation = self._rotation[selected_pts_mask]
386 |
387 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
388 |
389 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
390 | grads = self.xyz_gradient_accum / self.denom
391 | grads[grads.isnan()] = 0.0
392 |
393 | self.densify_and_clone(grads, max_grad, extent)
394 | self.densify_and_split(grads, max_grad, extent)
395 |
396 | prune_mask = (self.get_opacity < min_opacity).squeeze()
397 | if max_screen_size:
398 | big_points_vs = self.max_radii2D > max_screen_size
399 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
400 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
401 | self.prune_points(prune_mask)
402 |
403 | torch.cuda.empty_cache()
404 |
405 | def add_densification_stats(self, viewspace_point_tensor, update_filter):
406 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
407 | self.denom[update_filter] += 1
408 |
409 | def densify_from_depth_propagation(self, viewpoint_cam, propagated_depth, filter_mask, gt_image):
410 | # inverse project pixels into 3D scenes
411 | K = viewpoint_cam.K
412 | cam2world = viewpoint_cam.world_view_transform.transpose(0, 1).inverse()
413 |
414 | # Get the shape of the depth image
415 | height, width = propagated_depth.shape
416 | # Create a grid of 2D pixel coordinates
417 | y, x = torch.meshgrid(torch.arange(0, height), torch.arange(0, width))
418 | # Stack the 2D and depth coordinates to create 3D homogeneous coordinates
419 | coordinates = torch.stack([x.to(propagated_depth.device), y.to(propagated_depth.device), torch.ones_like(propagated_depth)], dim=-1)
420 | # Reshape the coordinates to (height * width, 3)
421 | coordinates = coordinates.view(-1, 3).to(K.device).to(torch.float32)
422 | # Reproject the 2D coordinates to 3D coordinates
423 | coordinates_3D = (K.inverse() @ coordinates.T).T
424 |
425 | # Multiply by depth
426 | coordinates_3D *= propagated_depth.view(-1, 1)
427 |
428 | # convert to the world coordinate
429 | world_coordinates_3D = (cam2world[:3, :3] @ coordinates_3D.T).T + cam2world[:3, 3]
430 |
431 | # import open3d as o3d
432 | # point_cloud = o3d.geometry.PointCloud()
433 | # point_cloud.points = o3d.utility.Vector3dVector(world_coordinates_3D.detach().cpu().numpy())
434 | # o3d.io.write_point_cloud("partpc.ply", point_cloud)
435 | # exit()
436 |
437 | #mask the points below the confidence threshold
438 | #downsample the pixels; 1/4
439 | world_coordinates_3D = world_coordinates_3D.view(height, width, 3)
440 | world_coordinates_3D_downsampled = world_coordinates_3D[::8, ::8]
441 | filter_mask_downsampled = filter_mask[::8, ::8]
442 | gt_image_downsampled = gt_image.permute(1, 2, 0)[::8, ::8]
443 |
444 | world_coordinates_3D_downsampled = world_coordinates_3D_downsampled[filter_mask_downsampled]
445 | color_downsampled = gt_image_downsampled[filter_mask_downsampled]
446 |
447 | # initialize gaussians
448 | fused_point_cloud = world_coordinates_3D_downsampled
449 | fused_color = RGB2SH(color_downsampled)
450 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).to(fused_color.device)
451 | features[:, :3, 0 ] = fused_color
452 | features[:, 3:, 1:] = 0.0
453 |
454 | original_point_cloud = self.get_xyz
455 | # initialize the scale from the mode, if using the distance to calculate, there are outliers, if using the whole gaussians, it is memory consuming
456 | # quantile_scale = torch.quantile(self.get_scaling, 0.5, dim=0)
457 | # scales = self.scaling_inverse_activation(quantile_scale.unsqueeze(0).repeat(fused_point_cloud.shape[0], 1))
458 | fused_shape = fused_point_cloud.shape[0]
459 | all_point_cloud = torch.concat([fused_point_cloud, original_point_cloud], dim=0)
460 | all_dist2 = torch.clamp_min(distCUDA2(all_point_cloud), 0.0000001)
461 | dist2 = all_dist2[:fused_shape]
462 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
463 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
464 | rots[:, 0] = 1
465 |
466 | opacities = inverse_sigmoid(1.0 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
467 |
468 | new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
469 | new_features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
470 | new_features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
471 | new_scaling = nn.Parameter(scales.requires_grad_(True))
472 | new_rotation = nn.Parameter(rots.requires_grad_(True))
473 | new_opacity = nn.Parameter(opacities.requires_grad_(True))
474 |
475 | #update gaussians
476 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
--------------------------------------------------------------------------------
/scripts/demo.sh:
--------------------------------------------------------------------------------
1 | python train.py -s $path/to/data$ -m $save_path$ \
2 | --eval --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021
3 |
4 | python render.py -m $save_path$
5 | python metrics.py -m $save_path$
6 |
7 | python train.py -s $path/to/data$ -m $save_path$ \
8 | --eval --flatten_loss --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 \
9 | --normal_loss --depth_loss --propagation_interval 50 --depth_error_min_threshold 0.8 --depth_error_max_threshold 1.0 \
10 | --propagated_iteration_begin 1000 --propagated_iteration_after 6000 --patch_size 20 --lambda_l1_normal 0.001 --lambda_cos_normal 0.001
11 |
12 | python render.py -m $save_path$
13 | python metrics.py -m $save_path$
14 |
15 | # normal_loss -- whether using planar-constrained loss
16 | # depth_loss -- whether using propagation
17 | # propagation_interval -- the frequency for activating propagation
18 | # depth_error_min_threshold -- the final threshold of relative depth error between rendered depth and propagated depth for initializing new gaussians
19 | # depth_error_max_threshold -- the initial threshold of relative depth error between rendered depth and propagated depth for initializing new gaussians
20 | # patch size for patchmatching, make it bigger if your scenes are consisted of many large textureless planes, smaller otherwise
21 | # lambda_xx_normal normal loss weight
22 |
--------------------------------------------------------------------------------
/scripts/waymo.sh:
--------------------------------------------------------------------------------
1 | python train.py -s $path/to/data$ -m $save_path$ \
2 | --eval --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 --dataset waymo
3 |
4 | python render.py -m $save_path$
5 | python metrics.py -m $save_path$
6 |
7 | python train.py -s $path/to/data$ -m $save_path$ \
8 | --eval --flatten_loss --position_lr_init 0.000016 --scaling_lr 0.001 --percent_dense 0.0005 --port 1021 --dataset waymo \
9 | --sky_seg --normal_loss --depth_loss --propagation_interval 30 --depth_error_min_threshold 0.8 --depth_error_max_threshold 1.0 \
10 | --propagated_iteration_begin 1000 --propagated_iteration_after 12000 --patch_size 20 --lambda_l1_normal 0.001 --lambda_cos_normal 0.001
11 |
12 | python render.py -m $save_path$
13 | python metrics.py -m $save_path$
14 |
--------------------------------------------------------------------------------
/submodules/Propagation/PatchMatch.cpp:
--------------------------------------------------------------------------------
1 | #include "PatchMatch.h"
2 | #include
3 | #include
4 |
5 | #include
6 |
7 | void StringAppendV(std::string* dst, const char* format, va_list ap) {
8 | // First try with a small fixed size buffer.
9 | static const int kFixedBufferSize = 1024;
10 | char fixed_buffer[kFixedBufferSize];
11 |
12 | // It is possible for methods that use a va_list to invalidate
13 | // the data in it upon use. The fix is to make a copy
14 | // of the structure before using it and use that copy instead.
15 | va_list backup_ap;
16 | va_copy(backup_ap, ap);
17 | int result = vsnprintf(fixed_buffer, kFixedBufferSize, format, backup_ap);
18 | va_end(backup_ap);
19 |
20 | if (result < kFixedBufferSize) {
21 | if (result >= 0) {
22 | // Normal case - everything fits.
23 | dst->append(fixed_buffer, result);
24 | return;
25 | }
26 |
27 | #ifdef _MSC_VER
28 | // Error or MSVC running out of space. MSVC 8.0 and higher
29 | // can be asked about space needed with the special idiom below:
30 | va_copy(backup_ap, ap);
31 | result = vsnprintf(nullptr, 0, format, backup_ap);
32 | va_end(backup_ap);
33 | #endif
34 |
35 | if (result < 0) {
36 | // Just an error.
37 | return;
38 | }
39 | }
40 |
41 | // Increase the buffer size to the size requested by vsnprintf,
42 | // plus one for the closing \0.
43 | const int variable_buffer_size = result + 1;
44 | std::unique_ptr variable_buffer(new char[variable_buffer_size]);
45 |
46 | // Restore the va_list before we use it again.
47 | va_copy(backup_ap, ap);
48 | result =
49 | vsnprintf(variable_buffer.get(), variable_buffer_size, format, backup_ap);
50 | va_end(backup_ap);
51 |
52 | if (result >= 0 && result < variable_buffer_size) {
53 | dst->append(variable_buffer.get(), result);
54 | }
55 | }
56 |
57 | std::string StringPrintf(const char* format, ...) {
58 | va_list ap;
59 | va_start(ap, format);
60 | std::string result;
61 | StringAppendV(&result, format, ap);
62 | va_end(ap);
63 | return result;
64 | }
65 |
66 | void CudaSafeCall(const cudaError_t error, const std::string& file,
67 | const int line) {
68 | if (error != cudaSuccess) {
69 | std::cerr << StringPrintf("%s in %s at line %i", cudaGetErrorString(error),
70 | file.c_str(), line)
71 | << std::endl;
72 | exit(EXIT_FAILURE);
73 | }
74 | }
75 |
76 | void CudaCheckError(const char* file, const int line) {
77 | cudaError error = cudaGetLastError();
78 | if (error != cudaSuccess) {
79 | std::cerr << StringPrintf("cudaCheckError() failed at %s:%i : %s", file,
80 | line, cudaGetErrorString(error))
81 | << std::endl;
82 | exit(EXIT_FAILURE);
83 | }
84 |
85 | // More careful checking. However, this will affect performance.
86 | // Comment away if needed.
87 | error = cudaDeviceSynchronize();
88 | if (cudaSuccess != error) {
89 | std::cerr << StringPrintf("cudaCheckError() with sync failed at %s:%i : %s",
90 | file, line, cudaGetErrorString(error))
91 | << std::endl;
92 | std::cerr
93 | << "This error is likely caused by the graphics card timeout "
94 | "detection mechanism of your operating system. Please refer to "
95 | "the FAQ in the documentation on how to solve this problem."
96 | << std::endl;
97 | exit(EXIT_FAILURE);
98 | }
99 | }
100 |
101 | PatchMatch::PatchMatch() {}
102 |
103 | PatchMatch::~PatchMatch()
104 | {
105 | delete[] plane_hypotheses_host;
106 | delete[] costs_host;
107 |
108 | for (int i = 0; i < num_images; ++i) {
109 | cudaDestroyTextureObject(texture_objects_host.images[i]);
110 | cudaFreeArray(cuArray[i]);
111 | }
112 | cudaFree(texture_objects_cuda);
113 | cudaFree(cameras_cuda);
114 | cudaFree(plane_hypotheses_cuda);
115 | cudaFree(costs_cuda);
116 | cudaFree(rand_states_cuda);
117 | cudaFree(selected_views_cuda);
118 | cudaFree(depths_cuda);
119 |
120 | if (params.geom_consistency) {
121 | for (int i = 0; i < num_images; ++i) {
122 | cudaDestroyTextureObject(texture_depths_host.images[i]);
123 | cudaFreeArray(cuDepthArray[i]);
124 | }
125 | cudaFree(texture_depths_cuda);
126 | }
127 | }
128 |
129 | Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval)
130 | {
131 | Camera camera;
132 |
133 | for (int i = 0; i < 3; ++i) {
134 | camera.R[3 * i + 0] = pose[i][0].item();
135 | camera.R[3 * i + 1] = pose[i][1].item();
136 | camera.R[3 * i + 2] = pose[i][2].item();
137 | camera.t[i] = pose[i][3].item();
138 | }
139 |
140 | for (int i = 0; i < 3; ++i) {
141 | camera.K[3 * i + 0] = intrinsic[i][0].item();
142 | camera.K[3 * i + 1] = intrinsic[i][1].item();
143 | camera.K[3 * i + 2] = intrinsic[i][2].item();
144 | }
145 |
146 | camera.depth_min = depth_interval[0].item();
147 | camera.depth_max = depth_interval[3].item();
148 |
149 | return camera;
150 | }
151 |
152 | void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera)
153 | {
154 | const int cols = depth.size(1);
155 | const int rows = depth.size(0);
156 |
157 | if (cols == src.size(1) && rows == src.size(0)) {
158 | dst = src.clone();
159 | return;
160 | }
161 |
162 | const float scale_x = cols / static_cast(src.size(1));
163 | const float scale_y = rows / static_cast(src.size(0));
164 | dst = torch::nn::functional::interpolate(src.unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector({rows, cols})).mode(torch::kBilinear)).squeeze(0);
165 |
166 | camera.K[0] *= scale_x;
167 | camera.K[2] *= scale_x;
168 | camera.K[4] *= scale_y;
169 | camera.K[5] *= scale_y;
170 | camera.width = cols;
171 | camera.height = rows;
172 | }
173 |
174 | float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera)
175 | {
176 | float3 pointX;
177 | float3 tmpX;
178 | // Reprojection
179 | pointX.x = depth * (x - camera.K[2]) / camera.K[0];
180 | pointX.y = depth * (y - camera.K[5]) / camera.K[4];
181 | pointX.z = depth;
182 |
183 | // Rotation
184 | tmpX.x = camera.R[0] * pointX.x + camera.R[3] * pointX.y + camera.R[6] * pointX.z;
185 | tmpX.y = camera.R[1] * pointX.x + camera.R[4] * pointX.y + camera.R[7] * pointX.z;
186 | tmpX.z = camera.R[2] * pointX.x + camera.R[5] * pointX.y + camera.R[8] * pointX.z;
187 |
188 | // Transformation
189 | float3 C;
190 | C.x = -(camera.R[0] * camera.t[0] + camera.R[3] * camera.t[1] + camera.R[6] * camera.t[2]);
191 | C.y = -(camera.R[1] * camera.t[0] + camera.R[4] * camera.t[1] + camera.R[7] * camera.t[2]);
192 | C.z = -(camera.R[2] * camera.t[0] + camera.R[5] * camera.t[1] + camera.R[8] * camera.t[2]);
193 | pointX.x = tmpX.x + C.x;
194 | pointX.y = tmpX.y + C.y;
195 | pointX.z = tmpX.z + C.z;
196 |
197 | return pointX;
198 | }
199 |
200 | void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth)
201 | {
202 | float3 tmp;
203 | tmp.x = camera.R[0] * PointX.x + camera.R[1] * PointX.y + camera.R[2] * PointX.z + camera.t[0];
204 | tmp.y = camera.R[3] * PointX.x + camera.R[4] * PointX.y + camera.R[5] * PointX.z + camera.t[1];
205 | tmp.z = camera.R[6] * PointX.x + camera.R[7] * PointX.y + camera.R[8] * PointX.z + camera.t[2];
206 |
207 | depth = camera.K[6] * tmp.x + camera.K[7] * tmp.y + camera.K[8] * tmp.z;
208 | point.x = (camera.K[0] * tmp.x + camera.K[1] * tmp.y + camera.K[2] * tmp.z) / depth;
209 | point.y = (camera.K[3] * tmp.x + camera.K[4] * tmp.y + camera.K[5] * tmp.z) / depth;
210 | }
211 |
212 | float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2)
213 | {
214 | float dot_product = v1[0].item() * v2[0].item() + v1[1].item() * v2[1].item() + v1[2].item() * v2[2].item();
215 | float angle = acosf(dot_product);
216 | //if angle is not a number the dot product was 1 and thus the two vectors should be identical --> return 0
217 | if ( angle != angle )
218 | return 0.0f;
219 |
220 | return angle;
221 | }
222 |
223 | void StoreColorPlyFileBinaryPointCloud (const std::string &plyFilePath, const std::vector &pc)
224 | {
225 | std::cout << "store 3D points to ply file" << std::endl;
226 |
227 | FILE *outputPly;
228 | outputPly=fopen(plyFilePath.c_str(), "wb");
229 |
230 | /*write header*/
231 | fprintf(outputPly, "ply\n");
232 | fprintf(outputPly, "format binary_little_endian 1.0\n");
233 | fprintf(outputPly, "element vertex %d\n",pc.size());
234 | fprintf(outputPly, "property float x\n");
235 | fprintf(outputPly, "property float y\n");
236 | fprintf(outputPly, "property float z\n");
237 | fprintf(outputPly, "property float nx\n");
238 | fprintf(outputPly, "property float ny\n");
239 | fprintf(outputPly, "property float nz\n");
240 | fprintf(outputPly, "property uchar red\n");
241 | fprintf(outputPly, "property uchar green\n");
242 | fprintf(outputPly, "property uchar blue\n");
243 | fprintf(outputPly, "end_header\n");
244 |
245 | //write data
246 | #pragma omp parallel for
247 | for(size_t i = 0; i < pc.size(); i++) {
248 | const PointList &p = pc[i];
249 | float3 X = p.coord;
250 | const float3 normal = p.normal;
251 | const float3 color = p.color;
252 | const char b_color = (int)color.x;
253 | const char g_color = (int)color.y;
254 | const char r_color = (int)color.z;
255 |
256 | if(!(X.x < FLT_MAX && X.x > -FLT_MAX) || !(X.y < FLT_MAX && X.y > -FLT_MAX) || !(X.z < FLT_MAX && X.z >= -FLT_MAX)){
257 | X.x = 0.0f;
258 | X.y = 0.0f;
259 | X.z = 0.0f;
260 | }
261 | #pragma omp critical
262 | {
263 | fwrite(&X.x, sizeof(X.x), 1, outputPly);
264 | fwrite(&X.y, sizeof(X.y), 1, outputPly);
265 | fwrite(&X.z, sizeof(X.z), 1, outputPly);
266 | fwrite(&normal.x, sizeof(normal.x), 1, outputPly);
267 | fwrite(&normal.y, sizeof(normal.y), 1, outputPly);
268 | fwrite(&normal.z, sizeof(normal.z), 1, outputPly);
269 | fwrite(&r_color, sizeof(char), 1, outputPly);
270 | fwrite(&g_color, sizeof(char), 1, outputPly);
271 | fwrite(&b_color, sizeof(char), 1, outputPly);
272 | }
273 |
274 | }
275 | fclose(outputPly);
276 | }
277 |
278 | static float GetDisparity(const Camera &camera, const int2 &p, const float &depth)
279 | {
280 | float point3D[3];
281 | point3D[0] = depth * (p.x - camera.K[2]) / camera.K[0];
282 | point3D[1] = depth * (p.y - camera.K[5]) / camera.K[4];
283 | point3D[2] = depth;
284 |
285 | return std::sqrt(point3D[0] * point3D[0] + point3D[1] * point3D[1] + point3D[2] * point3D[2]);
286 | }
287 |
288 | void PatchMatch::SetGeomConsistencyParams()
289 | {
290 | params.geom_consistency = true;
291 | params.max_iterations = 2;
292 | }
293 |
294 | void PatchMatch::InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda,
295 | torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals)
296 | {
297 | images.clear();
298 | cameras.clear();
299 |
300 | torch::Tensor image_color = images_cuda[0];
301 | torch::Tensor image_float = torch::mean(image_color, /*dim=*/2, /*keepdim=*/true).squeeze();
302 | image_float = image_float.to(torch::kFloat32);
303 | images.push_back(image_float);
304 |
305 | Camera camera = ReadCamera(intrinsics_cuda[0], poses_cuda[0], depth_intervals[0]);
306 | camera.height = image_float.size(0);
307 | camera.width = image_float.size(1);
308 | cameras.push_back(camera);
309 |
310 | torch::Tensor ref_depth = depth_cuda;
311 | depths.push_back(ref_depth);
312 |
313 | int num_src_images = images_cuda.size(0);
314 | for (int i = 1; i < num_src_images; ++i) {
315 | torch::Tensor src_image_color = images_cuda[i];
316 | torch::Tensor src_image_float = torch::mean(src_image_color, /*dim=*/2, /*keepdim=*/true).squeeze();
317 | src_image_float = src_image_float.to(torch::kFloat32);
318 | images.push_back(src_image_float);
319 |
320 | Camera camera = ReadCamera(intrinsics_cuda[i], poses_cuda[i], depth_intervals[i]);
321 | camera.height = src_image_float.size(0);
322 | camera.width = src_image_float.size(1);
323 | cameras.push_back(camera);
324 | }
325 |
326 | // Scale cameras and images
327 | for (size_t i = 0; i < images.size(); ++i) {
328 | if (images[i].size(1) <= params.max_image_size && images[i].size(0) <= params.max_image_size) {
329 | continue;
330 | }
331 |
332 | const float factor_x = static_cast(params.max_image_size) / images[i].size(1);
333 | const float factor_y = static_cast(params.max_image_size) / images[i].size(0);
334 | const float factor = std::min(factor_x, factor_y);
335 |
336 | const int new_cols = std::round(images[i].size(1) * factor);
337 | const int new_rows = std::round(images[i].size(0) * factor);
338 |
339 | const float scale_x = new_cols / static_cast(images[i].size(1));
340 | const float scale_y = new_rows / static_cast(images[i].size(0));
341 |
342 | torch::Tensor scaled_image_float = torch::nn::functional::interpolate(images[i].unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector({new_rows, new_cols})).mode(torch::kBilinear)).squeeze(0);
343 | images[i] = scaled_image_float.clone();
344 |
345 | cameras[i].K[0] *= scale_x;
346 | cameras[i].K[2] *= scale_x;
347 | cameras[i].K[4] *= scale_y;
348 | cameras[i].K[5] *= scale_y;
349 | cameras[i].height = scaled_image_float.size(0);
350 | cameras[i].width = scaled_image_float.size(1);
351 | }
352 |
353 | params.depth_min = cameras[0].depth_min * 0.6f;
354 | params.depth_max = cameras[0].depth_max * 1.2f;
355 | params.num_images = (int)images.size();
356 | params.disparity_min = cameras[0].K[0] * params.baseline / params.depth_max;
357 | params.disparity_max = cameras[0].K[0] * params.baseline / params.depth_min;
358 |
359 | }
360 |
361 | void PatchMatch::CudaSpaceInitialization()
362 | {
363 | num_images = (int)images.size();
364 |
365 | for (int i = 0; i < num_images; ++i) {
366 | int rows = images[i].size(0);
367 | int cols = images[i].size(1);
368 |
369 | cudaChannelFormatDesc channelDesc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat);
370 | cudaMallocArray(&cuArray[i], &channelDesc, cols, rows);
371 |
372 | cudaMemcpy2DToArray(cuArray[i], 0, 0, images[i].data_ptr(), images[i].stride(0) * sizeof(float), cols * sizeof(float), rows, cudaMemcpyHostToDevice);
373 |
374 | struct cudaResourceDesc resDesc;
375 | memset(&resDesc, 0, sizeof(cudaResourceDesc));
376 | resDesc.resType = cudaResourceTypeArray;
377 | resDesc.res.array.array = cuArray[i];
378 |
379 | struct cudaTextureDesc texDesc;
380 | memset(&texDesc, 0, sizeof(cudaTextureDesc));
381 | texDesc.addressMode[0] = cudaAddressModeWrap;
382 | texDesc.addressMode[1] = cudaAddressModeWrap;
383 | texDesc.filterMode = cudaFilterModeLinear;
384 | texDesc.readMode = cudaReadModeElementType;
385 | texDesc.normalizedCoords = 0;
386 |
387 | cudaCreateTextureObject(&(texture_objects_host.images[i]), &resDesc, &texDesc, NULL);
388 | }
389 |
390 | cudaMalloc((void**)&texture_objects_cuda, sizeof(cudaTextureObjects));
391 | cudaMemcpy(texture_objects_cuda, &texture_objects_host, sizeof(cudaTextureObjects), cudaMemcpyHostToDevice);
392 |
393 | cudaMalloc((void**)&cameras_cuda, sizeof(Camera) * (num_images));
394 | cudaMemcpy(cameras_cuda, &cameras[0], sizeof(Camera) * (num_images), cudaMemcpyHostToDevice);
395 |
396 | int total_pixels = cameras[0].height * cameras[0].width;
397 | plane_hypotheses_host = new float4[total_pixels];
398 | cudaMalloc((void**)&plane_hypotheses_cuda, sizeof(float4) * total_pixels);
399 |
400 | costs_host = new float[cameras[0].height * cameras[0].width];
401 | cudaMalloc((void**)&costs_cuda, sizeof(float) * (cameras[0].height * cameras[0].width));
402 |
403 | cudaMalloc((void**)&rand_states_cuda, sizeof(curandState) * (cameras[0].height * cameras[0].width));
404 | cudaMalloc((void**)&selected_views_cuda, sizeof(unsigned int) * (cameras[0].height * cameras[0].width));
405 |
406 | cudaMalloc((void**)&depths_cuda, sizeof(float) * (cameras[0].height * cameras[0].width));
407 | cudaMemcpy(depths_cuda, depths[0].data_ptr(), sizeof(float) * cameras[0].height * cameras[0].width, cudaMemcpyHostToDevice);
408 | }
409 |
410 | int PatchMatch::GetReferenceImageWidth()
411 | {
412 | return cameras[0].width;
413 | }
414 |
415 | int PatchMatch::GetReferenceImageHeight()
416 | {
417 | return cameras[0].height;
418 | }
419 |
420 | torch::Tensor PatchMatch::GetReferenceImage()
421 | {
422 | return images[0];
423 | }
424 |
425 | float4 PatchMatch::GetPlaneHypothesis(const int index)
426 | {
427 | return plane_hypotheses_host[index];
428 | }
429 |
430 | float4* PatchMatch::GetPlaneHypotheses()
431 | {
432 | return plane_hypotheses_host;
433 | }
434 |
435 | float PatchMatch::GetCost(const int index)
436 | {
437 | return costs_host[index];
438 | }
439 |
440 | void PatchMatch::SetPatchSize(int patch_size)
441 | {
442 | params.patch_size = patch_size;
443 | }
444 |
445 | int PatchMatch::GetPatchSize()
446 | {
447 | return params.patch_size;
448 | }
449 |
450 |
451 |
--------------------------------------------------------------------------------
/submodules/Propagation/PatchMatch.h:
--------------------------------------------------------------------------------
1 | #ifndef _PatchMatch_H_
2 | #define _PatchMatch_H_
3 |
4 | #include "main.h"
5 | #include
6 |
7 | Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval);
8 | void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera);
9 | float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera);
10 | void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth);
11 | float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2);
12 | void StoreColorPlyFileBinaryPointCloud(const std::string &plyFilePath, const std::vector &pc);
13 |
14 | #define CUDA_SAFE_CALL(error) CudaSafeCall(error, __FILE__, __LINE__)
15 | #define CUDA_CHECK_ERROR() CudaCheckError(__FILE__, __LINE__)
16 |
17 | void CudaSafeCall(const cudaError_t error, const std::string& file, const int line);
18 | void CudaCheckError(const char* file, const int line);
19 |
20 | struct cudaTextureObjects {
21 | cudaTextureObject_t images[MAX_IMAGES];
22 | };
23 |
24 | struct PatchMatchParams {
25 | int max_iterations = 6;
26 | int patch_size = 11;
27 | int num_images = 5;
28 | int max_image_size=3200;
29 | int radius_increment = 2;
30 | float sigma_spatial = 5.0f;
31 | float sigma_color = 3.0f;
32 | int top_k = 4;
33 | float baseline = 0.54f;
34 | float depth_min = 0.0f;
35 | float depth_max = 1.0f;
36 | float disparity_min = 0.0f;
37 | float disparity_max = 1.0f;
38 | bool geom_consistency = false;
39 | };
40 |
41 | class PatchMatch {
42 | public:
43 | PatchMatch();
44 | ~PatchMatch();
45 |
46 | void InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals);
47 | void Colmap2MVS(const std::string &dense_folder, std::vector &problems);
48 | void CudaSpaceInitialization();
49 | void RunPatchMatch();
50 | void SetGeomConsistencyParams();
51 | void SetPatchSize(int patch_size);
52 | int GetPatchSize();
53 | int GetReferenceImageWidth();
54 | int GetReferenceImageHeight();
55 | torch::Tensor GetReferenceImage();
56 | float4 GetPlaneHypothesis(const int index);
57 | float GetCost(const int index);
58 | float4* GetPlaneHypotheses();
59 |
60 | private:
61 | int num_images;
62 | std::vector images;
63 | std::vector depths;
64 | std::vector cameras;
65 | cudaTextureObjects texture_objects_host;
66 | cudaTextureObjects texture_depths_host;
67 | float4 *plane_hypotheses_host;
68 | float *costs_host;
69 | PatchMatchParams params;
70 |
71 | Camera *cameras_cuda;
72 | cudaArray *cuArray[MAX_IMAGES];
73 | cudaArray *cuDepthArray[MAX_IMAGES];
74 | cudaTextureObjects *texture_objects_cuda;
75 | cudaTextureObjects *texture_depths_cuda;
76 | float4 *plane_hypotheses_cuda;
77 | float *costs_cuda;
78 | curandState *rand_states_cuda;
79 | unsigned int *selected_views_cuda;
80 | float *depths_cuda;
81 | };
82 |
83 | #endif // _PatchMatch_H_
84 |
--------------------------------------------------------------------------------
/submodules/Propagation/main.h:
--------------------------------------------------------------------------------
1 | #ifndef _MAIN_H_
2 | #define _MAIN_H_
3 |
4 | // Includes CUDA
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include