├── .gitignore
├── .gitmodules
├── Dockerfile
├── LICENSE.md
├── README.md
├── arguments
└── __init__.py
├── convert.py
├── distill_train.py
├── environment.yml
├── full_eval.py
├── gaussian_renderer
├── __init__.py
├── gaussian_count.py
└── network_gui.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── prune.py
├── prune_finetune.py
├── render.py
├── render_video.py
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
└── gaussian_model.py
├── scripts
├── run_distill_finetune.sh
├── run_prune_finetune.sh
├── run_prune_pt_finetune.sh
├── run_train_densify_prune.sh
└── run_vectree_quantize.sh
├── static
├── prune_ratio_vs_ssim.svg
└── table5.png
├── submodules
└── simple-knn
│ ├── ext.cpp
│ ├── setup.py
│ ├── simple_knn.cu
│ ├── simple_knn.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
│ ├── simple_knn.h
│ ├── simple_knn
│ └── .gitkeep
│ ├── spatial.cu
│ └── spatial.h
├── train_densify_prune.py
├── utils
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── image.py
├── image_utils.py
├── logger_utils.py
├── loss_utils.py
├── pose_utils.py
├── save_imp_score.py
├── sh_utils.py
├── system_utils.py
├── tracker_utils.py
└── vgg.py
└── vectree
├── utils.py
├── vectree.py
└── vq.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | output
4 | build
5 | diff_rasterization/diff_rast.egg-info
6 | diff_rasterization/dist
7 | tensorboard_3d
8 | screenshots
9 | logs_train
10 | vectree/pruned_distilled
11 | vectree/output
12 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/simple-knn"]
2 | path = submodules/simple-knn
3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
4 |
5 |
6 | [submodule "submodules/compress-diff-gaussian-rasterization"]
7 | path = submodules/compress-diff-gaussian-rasterization
8 | url = https://github.com/Kevin-2017/compress-diff-gaussian-rasterization.git
9 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:22.04-py3
2 | RUN conda env create --file environment.yml
3 | RUN bash -c "conda init bash"
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LightGaussian: Unbounded 3D Gaussian Compression with 15x Reduction and 200+ FPS
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |

13 |
14 |
15 | ## User Guidance
16 | #### Gaussian Prune Ratio, Vector Quantization Ratio vs. FPS, SSIM
17 |
18 |

19 |
20 |
21 | #### Mild Compression Ratio, with Minimum Accuracy Degradation
22 |
23 |

24 |
25 |
26 |
27 | ## Setup
28 | #### Local Setup
29 | The codebase is based on [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting)
30 |
31 | The used datasets, MipNeRF360 and Tank & Temple, are hosted by the paper authors [here](https://jonbarron.info/mipnerf360/).
32 |
33 | For installation:
34 | ```
35 | git clone --recursive https://github.com/VITA-Group/LightGaussian.git
36 | cd LightGaussian
37 | # if you have already cloned LightGaussian:
38 | # git submodule update --init --recursive
39 | ```
40 | ```shell
41 | conda env create --file environment.yml
42 | conda activate lightgaussian
43 | ```
44 | note: we modified the "diff-gaussian-rasterization" in the submodule to get the Global Significant Score.
45 |
46 |
47 | ## Compress to Compact Representation
48 |
49 | Lightgaussian includes **3 ways** to make the 3D Gaussians be compact
50 |
51 |
52 |
53 | #### Option 1 Prune & Recovery
54 | Users can directly prune a trained 3D-GS checkpoint using the following command (default setting):
55 | ```
56 | bash scripts/run_prune_finetune.sh
57 | ```
58 |
59 | Users can also train from scratch and jointly prune redundant Gaussians in training using the following command (different setting from the paper):
60 | ```
61 | bash scripts/run_train_densify_prune.sh
62 | ```
63 | note: 3D-GS is trained for 20,000 iterations and then prune it. The resulting ply file is approximately 35% of the size of the original 3D-GS while ensuring a comparable quality level.
64 |
65 |
66 | #### Option 2 SH distillation
67 | Users can distill 3D-GS checkpoint using the following command (default setting):
68 | ```
69 | bash scripts/run_distill_finetune.sh
70 | ```
71 |
72 | #### Option 3 VecTree Quantization
73 | Users can quantize a pruned and distilled 3D-GS checkpoint using the following command (default setting):
74 | ```
75 | bash scripts/run_vectree_quantize.sh
76 | ```
77 |
78 |
79 | ## Render
80 | Render with trajectory. By default ellipse, you can change it to spiral or others trajectory by changing to corresponding function.
81 | ```
82 | python render_video.py --source_path PATH/TO/DATASET --model_path PATH/TO/MODEL --skip_train --skip_test --video
83 | ```
84 | For render after the Vectree Quantization stage, you could render them through
85 | ```
86 | python render_video.py --load_vq
87 | ```
88 |
89 |
90 | ## Example
91 | An example ckpt for room scene can be downloaded [here](), which mainly includes the following several parts:
92 |
93 | - point_cloud.ply —— Pruned, distilled and quantized 3D-GS checkpoint.
94 | - extreme_saving —— Relevant files obtained after vectree quantization.
95 | - imp_score.npz —— Global significance used in vectree quantization.
96 |
97 |
98 |
99 | ## TODO List
100 | - [x] Upload module 1: Prune & recovery
101 | - [x] Upload module 2: SH distillation
102 | - [x] Upload module 3: Vectree Quantization
103 | - [ ] Upload docker image
104 |
105 | ## Acknowledgements
106 | We would like to express our gratitude to [Yueyu Hu](https://huzi96.github.io/) from NYU for the invaluable discussion on our project.
107 |
108 |
109 | ## BibTeX
110 | If you find our work useful for your project, please consider citing the following paper.
111 |
112 |
113 | ```
114 | @misc{fan2023lightgaussian,
115 | title={LightGaussian: Unbounded 3D Gaussian Compression with 15x Reduction and 200+ FPS},
116 | author={Zhiwen Fan and Kevin Wang and Kairun Wen and Zehao Zhu and Dejia Xu and Zhangyang Wang},
117 | year={2023},
118 | eprint={2311.17245},
119 | archivePrefix={arXiv},
120 | primaryClass={cs.CV} }
121 | ```
122 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 |
17 | class GroupParams:
18 | pass
19 |
20 |
21 | class ParamGroup:
22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
23 | group = parser.add_argument_group(name)
24 | for key, value in vars(self).items():
25 | shorthand = False
26 | if key.startswith("_"):
27 | shorthand = True
28 | key = key[1:]
29 | t = type(value)
30 | value = value if not fill_none else None
31 | if shorthand:
32 | if t == bool:
33 | group.add_argument(
34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true"
35 | )
36 | else:
37 | group.add_argument(
38 | "--" + key, ("-" + key[0:1]), default=value, type=t
39 | )
40 | else:
41 | if t == bool:
42 | group.add_argument("--" + key, default=value, action="store_true")
43 | else:
44 | group.add_argument("--" + key, default=value, type=t)
45 |
46 | def extract(self, args):
47 | group = GroupParams()
48 | for arg in vars(args).items():
49 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
50 | setattr(group, arg[0], arg[1])
51 | return group
52 |
53 |
54 | class ModelParams(ParamGroup):
55 | def __init__(self, parser, sentinel=False):
56 | self.sh_degree = 3
57 | self._source_path = ""
58 | self._model_path = ""
59 | self._images = "images"
60 | self._resolution = -1
61 | self._white_background = False
62 | self.data_device = "cuda"
63 | self.eval = False
64 | super().__init__(parser, "Loading Parameters", sentinel)
65 |
66 | def extract(self, args):
67 | g = super().extract(args)
68 | g.source_path = os.path.abspath(g.source_path)
69 | return g
70 |
71 |
72 | class PipelineParams(ParamGroup):
73 | def __init__(self, parser):
74 | self.convert_SHs_python = False
75 | self.compute_cov3D_python = False
76 | self.debug = False
77 | super().__init__(parser, "Pipeline Parameters")
78 |
79 |
80 | class OptimizationParams(ParamGroup):
81 | def __init__(self, parser):
82 | self.iterations = 30_000
83 | self.position_lr_init = 0.00016
84 | self.position_lr_final = 0.0000016
85 | self.position_lr_delay_mult = 0.01
86 | self.position_lr_max_steps = 30_000
87 | self.feature_lr = 0.0025
88 | self.opacity_lr = 0.05
89 | self.scaling_lr = 0.005
90 | self.rotation_lr = 0.001
91 | self.percent_dense = 0.01
92 | self.lambda_dssim = 0.2
93 | self.densification_interval = 100
94 | self.opacity_reset_interval = 3000
95 | self.densify_from_iter = 500
96 | self.densify_until_iter = 15_000
97 | self.densify_grad_threshold = 0.0002
98 | super().__init__(parser, "Optimization Parameters")
99 |
100 |
101 | def get_combined_args(parser: ArgumentParser):
102 | cmdlne_string = sys.argv[1:]
103 | cfgfile_string = "Namespace()"
104 | args_cmdline = parser.parse_args(cmdlne_string)
105 |
106 | try:
107 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
108 | print("Looking for config file in", cfgfilepath)
109 | with open(cfgfilepath) as cfg_file:
110 | print("Config file found: {}".format(cfgfilepath))
111 | cfgfile_string = cfg_file.read()
112 | except TypeError:
113 | print("Config file not found at")
114 | pass
115 | args_cfgfile = eval(cfgfile_string)
116 |
117 | merged_dict = vars(args_cfgfile).copy()
118 | for k, v in vars(args_cmdline).items():
119 | if v != None:
120 | merged_dict[k] = v
121 | return Namespace(**merged_dict)
122 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import logging
14 | from argparse import ArgumentParser
15 | import shutil
16 |
17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
18 | parser = ArgumentParser("Colmap converter")
19 | parser.add_argument("--no_gpu", action='store_true')
20 | parser.add_argument("--skip_matching", action='store_true')
21 | parser.add_argument("--source_path", "-s", required=True, type=str)
22 | parser.add_argument("--camera", default="OPENCV", type=str)
23 | parser.add_argument("--colmap_executable", default="", type=str)
24 | parser.add_argument("--resize", action="store_true")
25 | parser.add_argument("--magick_executable", default="", type=str)
26 | args = parser.parse_args()
27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
29 | use_gpu = 1 if not args.no_gpu else 0
30 |
31 | if not args.skip_matching:
32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
33 |
34 | ## Feature extraction
35 | feat_extracton_cmd = colmap_command + " feature_extractor "\
36 | "--database_path " + args.source_path + "/distorted/database.db \
37 | --image_path " + args.source_path + "/input \
38 | --ImageReader.single_camera 1 \
39 | --ImageReader.camera_model " + args.camera + " \
40 | --SiftExtraction.use_gpu " + str(use_gpu)
41 | exit_code = os.system(feat_extracton_cmd)
42 | if exit_code != 0:
43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
44 | exit(exit_code)
45 |
46 | ## Feature matching
47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \
48 | --database_path " + args.source_path + "/distorted/database.db \
49 | --SiftMatching.use_gpu " + str(use_gpu)
50 | exit_code = os.system(feat_matching_cmd)
51 | if exit_code != 0:
52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
53 | exit(exit_code)
54 |
55 | ### Bundle adjustment
56 | # The default Mapper tolerance is unnecessarily large,
57 | # decreasing it speeds up bundle adjustment steps.
58 | mapper_cmd = (colmap_command + " mapper \
59 | --database_path " + args.source_path + "/distorted/database.db \
60 | --image_path " + args.source_path + "/input \
61 | --output_path " + args.source_path + "/distorted/sparse \
62 | --Mapper.ba_global_function_tolerance=0.000001")
63 | exit_code = os.system(mapper_cmd)
64 | if exit_code != 0:
65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
66 | exit(exit_code)
67 |
68 | ### Image undistortion
69 | ## We need to undistort our images into ideal pinhole intrinsics.
70 | img_undist_cmd = (colmap_command + " image_undistorter \
71 | --image_path " + args.source_path + "/input \
72 | --input_path " + args.source_path + "/distorted/sparse/0 \
73 | --output_path " + args.source_path + "\
74 | --output_type COLMAP")
75 | exit_code = os.system(img_undist_cmd)
76 | if exit_code != 0:
77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
78 | exit(exit_code)
79 |
80 | files = os.listdir(args.source_path + "/sparse")
81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
82 | # Copy each file from the source directory to the destination directory
83 | for file in files:
84 | if file == '0':
85 | continue
86 | source_file = os.path.join(args.source_path, "sparse", file)
87 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
88 | shutil.move(source_file, destination_file)
89 |
90 | if(args.resize):
91 | print("Copying and resizing...")
92 |
93 | # Resize images.
94 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
95 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
96 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
97 | # Get the list of files in the source directory
98 | files = os.listdir(args.source_path + "/images")
99 | # Copy each file from the source directory to the destination directory
100 | for file in files:
101 | source_file = os.path.join(args.source_path, "images", file)
102 |
103 | destination_file = os.path.join(args.source_path, "images_2", file)
104 | shutil.copy2(source_file, destination_file)
105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
106 | if exit_code != 0:
107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.")
108 | exit(exit_code)
109 |
110 | destination_file = os.path.join(args.source_path, "images_4", file)
111 | shutil.copy2(source_file, destination_file)
112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
113 | if exit_code != 0:
114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.")
115 | exit(exit_code)
116 |
117 | destination_file = os.path.join(args.source_path, "images_8", file)
118 | shutil.copy2(source_file, destination_file)
119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
120 | if exit_code != 0:
121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
122 | exit(exit_code)
123 |
124 | print("Done.")
125 |
--------------------------------------------------------------------------------
/distill_train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from random import randint
15 | from utils.loss_utils import l1_loss, ssim
16 | from gaussian_renderer import render, network_gui
17 | import sys
18 | from scene import Scene, GaussianModel
19 | from utils.general_utils import safe_state
20 | import uuid
21 | from os import makedirs
22 | from tqdm import tqdm
23 | from utils.image_utils import psnr
24 | from argparse import ArgumentParser, Namespace
25 | from arguments import ModelParams, PipelineParams, OptimizationParams
26 | from utils.graphics_utils import getWorld2View2
27 | from utils.pose_utils import gaussian_poses
28 | from icecream import ic
29 | import random
30 | import copy
31 | import json
32 | import numpy as np
33 | from utils.logger_utils import prepare_output_and_logger, training_report
34 | from torch.optim.lr_scheduler import ExponentialLR
35 | from prune import prune_list, calculate_v_imp_score
36 |
37 |
38 | try:
39 | from torch.utils.tensorboard import SummaryWriter
40 | TENSORBOARD_FOUND = True
41 | except ImportError:
42 | TENSORBOARD_FOUND = False
43 |
44 | class NumpyArrayEncoder(json.JSONEncoder):
45 | def default(self, obj):
46 | if isinstance(obj, np.integer):
47 | return int(obj)
48 | elif isinstance(obj, np.floating):
49 | return float(obj)
50 | elif isinstance(obj, np.ndarray):
51 | return obj.tolist()
52 | else:
53 | return super(NumpyArrayEncoder, self).default(obj)
54 |
55 |
56 | to_tensor = lambda x: x.to("cuda") if isinstance(
57 | x, torch.Tensor) else torch.Tensor(x).to("cuda")
58 | img2mse = lambda x, y: torch.mean((x - y)**2)
59 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(to_tensor([10.]))
60 |
61 | def training(args, dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, new_max_sh):
62 | first_iter = 0
63 | old_sh_degree = dataset.sh_degree
64 | dataset.sh_degree = new_max_sh
65 | tb_writer = prepare_output_and_logger(dataset)
66 | with torch.no_grad():
67 | teacher_gaussians = GaussianModel(old_sh_degree)
68 | # teacher_gaussians.training_setup(opt)
69 |
70 | student_gaussians = GaussianModel(old_sh_degree)
71 | student_scene = Scene(dataset, student_gaussians)
72 |
73 | if checkpoint:
74 | (teacher_model_params, _) = torch.load(args.teacher_model)
75 | (model_params, first_iter) = torch.load(checkpoint)
76 | teacher_gaussians.restore(teacher_model_params, copy.deepcopy(opt))
77 | student_gaussians.restore(model_params, opt)
78 | student_gaussians.max_sh_degree = new_max_sh
79 | student_gaussians.onedownSHdegree()
80 | student_gaussians.training_setup(opt)
81 | student_gaussians.scheduler = ExponentialLR(student_gaussians.optimizer, gamma=0.90)
82 | # if !args.enable
83 | if (not args.enable_covariance):
84 | student_gaussians._scaling.requires_grad = False
85 | student_gaussians._rotation.requires_grad = False
86 | if (not args.enable_opacity):
87 | student_gaussians._opacity.requires_grad = False
88 |
89 | teacher_gaussians.optimizer = None
90 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
91 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
92 | iter_start = torch.cuda.Event(enable_timing = True)
93 | iter_end = torch.cuda.Event(enable_timing = True)
94 | viewpoint_stack = None
95 | ema_loss_for_log = 0.0
96 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
97 | first_iter += 1
98 |
99 | # os.makedirs(student_scene.model_path + "/vis_data", exist_ok=True)
100 | for iteration in range(first_iter, opt.iterations + 1):
101 | if network_gui.conn == None:
102 | network_gui.try_connect()
103 | while network_gui.conn != None:
104 | try:
105 | net_image_bytes = None
106 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
107 | if custom_cam != None:
108 | net_image = render(custom_cam, student_gaussians, pipe, background, scaling_modifer)["render"]
109 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
110 | network_gui.send(net_image_bytes, dataset.source_path)
111 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
112 | break
113 | except Exception as e:
114 | network_gui.conn = None
115 |
116 | iter_start.record()
117 | student_gaussians.update_learning_rate(iteration)
118 |
119 | # Every 500 iterations step in scheduler
120 | if iteration % 500 == 0:
121 | # student_gaussians.oneupSHdegree()
122 | student_gaussians.scheduler.step()
123 |
124 | if not viewpoint_stack:
125 | viewpoint_stack = student_scene.getTrainCameras().copy()
126 | viewpoint_cam_org = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
127 | viewpoint_cam = copy.deepcopy(viewpoint_cam_org)
128 |
129 | if (iteration - 1) == debug_from:
130 | pipe.debug = True
131 |
132 | if args.augmented_view and iteration%3:
133 | viewpoint_cam = gaussian_poses(viewpoint_cam, mean= 0, std_dev_translation=0.05, std_dev_rotation=0)
134 | student_render_pkg = render(viewpoint_cam, student_gaussians, pipe, background)
135 | student_image = student_render_pkg["render"]
136 | teacher_render_pkg = render(viewpoint_cam, teacher_gaussians, pipe, background)
137 | teacher_image = teacher_render_pkg["render"].detach()
138 | else:
139 | render_pkg = render(viewpoint_cam, student_gaussians, pipe, background)
140 | student_image = render_pkg["render"]
141 | teacher_image = render(viewpoint_cam, teacher_gaussians, pipe, background)["render"].detach()
142 | Ll1 = l1_loss(student_image, teacher_image)
143 | # Ll1 = img2mse(student_image, teacher_image)
144 |
145 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(student_image, teacher_image))
146 | loss.backward()
147 | iter_end.record()
148 | with torch.no_grad():
149 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
150 | if iteration % 10 == 0:
151 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
152 | progress_bar.update(10)
153 | if iteration == opt.iterations:
154 | progress_bar.close()
155 |
156 | if (iteration in saving_iterations):
157 | print("\n[ITER {}] Saving Gaussians".format(iteration))
158 | ic(student_gaussians._features_rest.detach().shape)
159 | student_scene.save(iteration)
160 |
161 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, student_scene, render, (pipe, background))
162 |
163 | # Optimizer step
164 | if iteration < opt.iterations:
165 | student_gaussians.optimizer.step()
166 | student_gaussians.optimizer.zero_grad(set_to_none = True)
167 |
168 | if (iteration in checkpoint_iterations):
169 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
170 | if not os.path.exists(student_scene.model_path):
171 | os.makedirs(student_scene.model_path)
172 | torch.save((student_gaussians.capture(), iteration), student_scene.model_path + "/chkpnt" + str(iteration) + ".pth")
173 |
174 | if iteration == checkpoint_iterations[-1]:
175 | print("Saving Imp_score")
176 | gaussian_list, imp_list = prune_list(
177 | student_gaussians, student_scene, pipe, background
178 | )
179 | v_list = calculate_v_imp_score(student_gaussians, imp_list, 0.1)
180 | np.savez(
181 | os.path.join(student_scene.model_path, "imp_score"),
182 | v_list.cpu().detach().numpy(),
183 | )
184 |
185 |
186 | if __name__ == "__main__":
187 | # Set up command line argument parser
188 | parser = ArgumentParser(description="Training script parameters")
189 | lp = ModelParams(parser)
190 | op = OptimizationParams(parser)
191 | pp = PipelineParams(parser)
192 | parser.add_argument('--ip', type=str, default="127.0.0.1")
193 | parser.add_argument('--port', type=int, default=6009)
194 | parser.add_argument('--debug_from', type=int, default=-1)
195 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
196 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[35_001, 40_000])
197 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[40_000])
198 | parser.add_argument("--quiet", action="store_true")
199 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[40_000])
200 | parser.add_argument("--start_checkpoint", type=str, default = None)
201 | parser.add_argument("--new_max_sh", type=int, default = 2)
202 | parser.add_argument("--augmented_view", action="store_true")
203 | parser.add_argument("--enable_covariance", action="store_true")
204 | parser.add_argument("--enable_opacity", action="store_true")
205 | parser.add_argument("--opacity_prune", type=float, default = 0)
206 | parser.add_argument("--teacher_model", type=str)
207 |
208 | args = parser.parse_args(sys.argv[1:])
209 | args.save_iterations.append(args.iterations)
210 |
211 | print("Optimizing " + args.model_path)
212 |
213 | # Initialize system state (RNG)
214 | safe_state(args.quiet)
215 |
216 | # Start GUI server, configure and run training
217 | network_gui.init(args.ip, args.port)
218 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
219 | training(args, lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.new_max_sh)
220 |
221 | # All done
222 | print("\nTraining complete.")
223 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: lightgaussian
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.6
8 | - plyfile=0.8.1
9 | - python=3.9
10 | - pip=22.3.1
11 | - pytorch=1.12.1
12 | - torchaudio=0.12.1
13 | - torchvision=0.13.1
14 | - setuptools=69.5.1
15 | - tqdm
16 | - icecream
17 | - pip:
18 | - submodules/compress-diff-gaussian-rasterization
19 | - submodules/simple-knn
20 |
--------------------------------------------------------------------------------
/full_eval.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | from argparse import ArgumentParser
14 |
15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
17 | tanks_and_temples_scenes = ["truck", "train"]
18 | deep_blending_scenes = ["drjohnson", "playroom"]
19 |
20 | parser = ArgumentParser(description="Full evaluation script parameters")
21 | parser.add_argument("--skip_training", action="store_true")
22 | parser.add_argument("--skip_rendering", action="store_true")
23 | parser.add_argument("--skip_metrics", action="store_true")
24 | parser.add_argument("--output_path", default="./eval")
25 | args, _ = parser.parse_known_args()
26 |
27 | all_scenes = []
28 | all_scenes.extend(mipnerf360_outdoor_scenes)
29 | all_scenes.extend(mipnerf360_indoor_scenes)
30 | all_scenes.extend(tanks_and_temples_scenes)
31 | all_scenes.extend(deep_blending_scenes)
32 |
33 | if not args.skip_training or not args.skip_rendering:
34 | parser.add_argument("--mipnerf360", "-m360", required=True, type=str)
35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
36 | parser.add_argument("--deepblending", "-db", required=True, type=str)
37 | args = parser.parse_args()
38 |
39 | if not args.skip_training:
40 | common_args = " --quiet --eval --test_iterations -1 "
41 | for scene in mipnerf360_outdoor_scenes:
42 | source = args.mipnerf360 + "/" + scene
43 | os.system(
44 | "python train.py -s "
45 | + source
46 | + " -i images_4 -m "
47 | + args.output_path
48 | + "/"
49 | + scene
50 | + common_args
51 | )
52 | for scene in mipnerf360_indoor_scenes:
53 | source = args.mipnerf360 + "/" + scene
54 | os.system(
55 | "python train.py -s "
56 | + source
57 | + " -i images_2 -m "
58 | + args.output_path
59 | + "/"
60 | + scene
61 | + common_args
62 | )
63 | for scene in tanks_and_temples_scenes:
64 | source = args.tanksandtemples + "/" + scene
65 | os.system(
66 | "python train.py -s "
67 | + source
68 | + " -m "
69 | + args.output_path
70 | + "/"
71 | + scene
72 | + common_args
73 | )
74 | for scene in deep_blending_scenes:
75 | source = args.deepblending + "/" + scene
76 | os.system(
77 | "python train.py -s "
78 | + source
79 | + " -m "
80 | + args.output_path
81 | + "/"
82 | + scene
83 | + common_args
84 | )
85 |
86 | if not args.skip_rendering:
87 | all_sources = []
88 | for scene in mipnerf360_outdoor_scenes:
89 | all_sources.append(args.mipnerf360 + "/" + scene)
90 | for scene in mipnerf360_indoor_scenes:
91 | all_sources.append(args.mipnerf360 + "/" + scene)
92 | for scene in tanks_and_temples_scenes:
93 | all_sources.append(args.tanksandtemples + "/" + scene)
94 | for scene in deep_blending_scenes:
95 | all_sources.append(args.deepblending + "/" + scene)
96 |
97 | common_args = " --quiet --eval --skip_train"
98 | for scene, source in zip(all_scenes, all_sources):
99 | os.system(
100 | "python render.py --iteration 7000 -s "
101 | + source
102 | + " -m "
103 | + args.output_path
104 | + "/"
105 | + scene
106 | + common_args
107 | )
108 | os.system(
109 | "python render.py --iteration 30000 -s "
110 | + source
111 | + " -m "
112 | + args.output_path
113 | + "/"
114 | + scene
115 | + common_args
116 | )
117 |
118 | if not args.skip_metrics:
119 | scenes_string = ""
120 | for scene in all_scenes:
121 | scenes_string += '"' + args.output_path + "/" + scene + '" '
122 |
123 | os.system("python metrics.py -m " + scenes_string)
124 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import (
15 | GaussianRasterizationSettings,
16 | GaussianRasterizer,
17 | )
18 | from scene.gaussian_model import GaussianModel
19 | from utils.sh_utils import eval_sh
20 |
21 |
22 | def render(
23 | viewpoint_camera,
24 | pc: GaussianModel,
25 | pipe,
26 | bg_color: torch.Tensor,
27 | scaling_modifier=1.0,
28 | override_color=None,
29 | ):
30 | """
31 | Render the scene.
32 |
33 | Background tensor (bg_color) must be on GPU!
34 | """
35 |
36 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
37 | screenspace_points = (
38 | torch.zeros_like(
39 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
40 | )
41 | + 0
42 | )
43 | try:
44 | screenspace_points.retain_grad()
45 | except:
46 | pass
47 |
48 | # Set up rasterization configuration
49 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
50 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
51 |
52 | raster_settings = GaussianRasterizationSettings(
53 | image_height=int(viewpoint_camera.image_height),
54 | image_width=int(viewpoint_camera.image_width),
55 | tanfovx=tanfovx,
56 | tanfovy=tanfovy,
57 | bg=bg_color,
58 | scale_modifier=scaling_modifier,
59 | viewmatrix=viewpoint_camera.world_view_transform,
60 | projmatrix=viewpoint_camera.full_proj_transform,
61 | sh_degree=pc.active_sh_degree,
62 | campos=viewpoint_camera.camera_center,
63 | prefiltered=False,
64 | debug=pipe.debug,
65 | f_count=False,
66 | )
67 |
68 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
69 |
70 | means3D = pc.get_xyz
71 | means2D = screenspace_points
72 | opacity = pc.get_opacity
73 |
74 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
75 | # scaling / rotation by the rasterizer.
76 | scales = None
77 | rotations = None
78 | cov3D_precomp = None
79 | if pipe.compute_cov3D_python:
80 | cov3D_precomp = pc.get_covariance(scaling_modifier)
81 | else:
82 | scales = pc.get_scaling
83 | rotations = pc.get_rotation
84 |
85 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
86 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
87 | shs = None
88 | colors_precomp = None
89 | if override_color is None:
90 | if pipe.convert_SHs_python:
91 | shs_view = pc.get_features.transpose(1, 2).view(
92 | -1, 3, (pc.max_sh_degree + 1) ** 2
93 | )
94 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
95 | pc.get_features.shape[0], 1
96 | )
97 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
98 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
99 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
100 | else:
101 | shs = pc.get_features
102 | else:
103 | colors_precomp = override_color
104 |
105 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
106 | rendered_image, radii = rasterizer(
107 | means3D=means3D,
108 | means2D=means2D,
109 | shs=shs,
110 | colors_precomp=colors_precomp,
111 | opacities=opacity,
112 | scales=scales,
113 | rotations=rotations,
114 | cov3D_precomp=cov3D_precomp,
115 | )
116 |
117 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
118 | # They will be excluded from value updates used in the splitting criteria.
119 | return {
120 | "render": rendered_image,
121 | "viewspace_points": screenspace_points,
122 | "visibility_filter": radii > 0,
123 | "radii": radii,
124 | }
125 |
126 |
127 | def count_render(
128 | viewpoint_camera,
129 | pc: GaussianModel,
130 | pipe,
131 | bg_color: torch.Tensor,
132 | scaling_modifier=1.0,
133 | override_color=None,
134 | ):
135 | """
136 | Render the scene.
137 |
138 | Background tensor (bg_color) must be on GPU!
139 | """
140 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
141 | screenspace_points = (
142 | torch.zeros_like(
143 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
144 | )
145 | + 0
146 | )
147 | try:
148 | screenspace_points.retain_grad()
149 | except:
150 | pass
151 |
152 | # Set up rasterization configuration
153 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
154 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
155 |
156 | raster_settings = GaussianRasterizationSettings(
157 | image_height=int(viewpoint_camera.image_height),
158 | image_width=int(viewpoint_camera.image_width),
159 | tanfovx=tanfovx,
160 | tanfovy=tanfovy,
161 | bg=bg_color,
162 | scale_modifier=scaling_modifier,
163 | viewmatrix=viewpoint_camera.world_view_transform,
164 | projmatrix=viewpoint_camera.full_proj_transform,
165 | sh_degree=pc.active_sh_degree,
166 | campos=viewpoint_camera.camera_center,
167 | prefiltered=False,
168 | debug=pipe.debug,
169 | f_count=True,
170 | )
171 |
172 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
173 | means3D = pc.get_xyz
174 | means2D = screenspace_points
175 | opacity = pc.get_opacity
176 |
177 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
178 | # scaling / rotation by the rasterizer.
179 | scales = None
180 | rotations = None
181 | cov3D_precomp = None
182 | if pipe.compute_cov3D_python:
183 | cov3D_precomp = pc.get_covariance(scaling_modifier)
184 | else:
185 | scales = pc.get_scaling
186 | rotations = pc.get_rotation
187 |
188 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
189 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
190 | shs = None
191 | colors_precomp = None
192 | if override_color is None:
193 | if pipe.convert_SHs_python:
194 | shs_view = pc.get_features.transpose(1, 2).view(
195 | -1, 3, (pc.max_sh_degree + 1) ** 2
196 | )
197 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
198 | pc.get_features.shape[0], 1
199 | )
200 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
201 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
202 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
203 | else:
204 | shs = pc.get_features
205 | else:
206 | colors_precomp = override_color
207 |
208 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
209 | gaussians_count, important_score, rendered_image, radii = rasterizer(
210 | means3D=means3D,
211 | means2D=means2D,
212 | shs=shs,
213 | colors_precomp=colors_precomp,
214 | opacities=opacity,
215 | scales=scales,
216 | rotations=rotations,
217 | cov3D_precomp=cov3D_precomp,
218 | )
219 |
220 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
221 | # They will be excluded from value updates used in the splitting criteria.
222 | return {
223 | "render": rendered_image,
224 | "viewspace_points": screenspace_points,
225 | "visibility_filter": radii > 0,
226 | "radii": radii,
227 | "gaussians_count": gaussians_count,
228 | "important_score": important_score,
229 | }
230 |
--------------------------------------------------------------------------------
/gaussian_renderer/gaussian_count.py:
--------------------------------------------------------------------------------
1 | # base on __ini__.render
2 |
3 | #
4 | # Copyright (C) 2023, Inria
5 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
6 | # All rights reserved.
7 | #
8 | # This software is free for non-commercial, research and evaluation use
9 | # under the terms of the LICENSE.md file.
10 | #
11 | # For inquiries contact george.drettakis@inria.fr
12 | #
13 |
14 | import torch
15 | import math
16 | from diff_gaussian_rasterization import (
17 | GaussianRasterizationSettings,
18 | GaussianRasterizer,
19 | )
20 | from scene.gaussian_model import GaussianModel
21 | from utils.sh_utils import eval_sh
22 |
23 |
24 | def count_render(
25 | viewpoint_camera,
26 | pc: GaussianModel,
27 | pipe,
28 | bg_color: torch.Tensor,
29 | scaling_modifier=1.0,
30 | override_color=None,
31 | ):
32 | """
33 | Render the scene.
34 |
35 | Background tensor (bg_color) must be on GPU!
36 | """
37 |
38 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
39 | screenspace_points = (
40 | torch.zeros_like(
41 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
42 | )
43 | + 0
44 | )
45 | try:
46 | screenspace_points.retain_grad()
47 | except:
48 | pass
49 |
50 | # Set up rasterization configuration
51 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
52 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
53 |
54 | raster_settings = GaussianRasterizationSettings(
55 | image_height=int(viewpoint_camera.image_height),
56 | image_width=int(viewpoint_camera.image_width),
57 | tanfovx=tanfovx,
58 | tanfovy=tanfovy,
59 | bg=bg_color,
60 | scale_modifier=scaling_modifier,
61 | viewmatrix=viewpoint_camera.world_view_transform,
62 | projmatrix=viewpoint_camera.full_proj_transform,
63 | sh_degree=pc.active_sh_degree,
64 | campos=viewpoint_camera.camera_center,
65 | prefiltered=False,
66 | debug=pipe.debug,
67 | )
68 |
69 | rasterizer = GaussianRasterizer(raster_settings=raster_settings, f_count=True)
70 |
71 | means3D = pc.get_xyz
72 | means2D = screenspace_points
73 | opacity = pc.get_opacity
74 |
75 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
76 | # scaling / rotation by the rasterizer.
77 | scales = None
78 | rotations = None
79 | cov3D_precomp = None
80 | if pipe.compute_cov3D_python:
81 | cov3D_precomp = pc.get_covariance(scaling_modifier)
82 | else:
83 | scales = pc.get_scaling
84 | rotations = pc.get_rotation
85 |
86 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
87 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
88 | shs = None
89 | colors_precomp = None
90 | if override_color is None:
91 | if pipe.convert_SHs_python:
92 | shs_view = pc.get_features.transpose(1, 2).view(
93 | -1, 3, (pc.max_sh_degree + 1) ** 2
94 | )
95 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
96 | pc.get_features.shape[0], 1
97 | )
98 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
99 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
100 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
101 | else:
102 | shs = pc.get_features
103 | else:
104 | colors_precomp = override_color
105 |
106 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
107 | (
108 | gaussians_count,
109 | important_score,
110 | rendered_image,
111 | radii,
112 | ) = rasterizer.forward_counter(
113 | means3D=means3D,
114 | means2D=means2D,
115 | shs=shs,
116 | colors_precomp=colors_precomp,
117 | opacities=opacity,
118 | scales=scales,
119 | rotations=rotations,
120 | cov3D_precomp=cov3D_precomp,
121 | )
122 |
123 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
124 | # They will be excluded from value updates used in the splitting criteria.
125 | return {
126 | "render": rendered_image,
127 | "viewspace_points": screenspace_points,
128 | "visibility_filter": radii > 0,
129 | "radii": radii,
130 | "gaussians_count": gaussians_count,
131 | "important_score": important_score,
132 | }
133 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 |
27 | def init(wish_host, wish_port):
28 | global host, port, listener
29 | host = wish_host
30 | port = wish_port
31 | listener.bind((host, port))
32 | listener.listen()
33 | listener.settimeout(0)
34 |
35 |
36 | def try_connect():
37 | global conn, addr, listener
38 | try:
39 | conn, addr = listener.accept()
40 | print(f"\nConnected by {addr}")
41 | conn.settimeout(None)
42 | except Exception as inst:
43 | pass
44 |
45 |
46 | def read():
47 | global conn
48 | messageLength = conn.recv(4)
49 | messageLength = int.from_bytes(messageLength, "little")
50 | message = conn.recv(messageLength)
51 | return json.loads(message.decode("utf-8"))
52 |
53 |
54 | def send(message_bytes, verify):
55 | global conn
56 | if message_bytes != None:
57 | conn.sendall(message_bytes)
58 | conn.sendall(len(verify).to_bytes(4, "little"))
59 | conn.sendall(bytes(verify, "ascii"))
60 |
61 |
62 | def receive():
63 | message = read()
64 |
65 | width = message["resolution_x"]
66 | height = message["resolution_y"]
67 |
68 | if width != 0 and height != 0:
69 | try:
70 | do_training = bool(message["train"])
71 | fovy = message["fov_y"]
72 | fovx = message["fov_x"]
73 | znear = message["z_near"]
74 | zfar = message["z_far"]
75 | do_shs_python = bool(message["shs_python"])
76 | do_rot_scale_python = bool(message["rot_scale_python"])
77 | keep_alive = bool(message["keep_alive"])
78 | scaling_modifier = message["scaling_modifier"]
79 | world_view_transform = torch.reshape(
80 | torch.tensor(message["view_matrix"]), (4, 4)
81 | ).cuda()
82 | world_view_transform[:, 1] = -world_view_transform[:, 1]
83 | world_view_transform[:, 2] = -world_view_transform[:, 2]
84 | full_proj_transform = torch.reshape(
85 | torch.tensor(message["view_projection_matrix"]), (4, 4)
86 | ).cuda()
87 | full_proj_transform[:, 1] = -full_proj_transform[:, 1]
88 | custom_cam = MiniCam(
89 | width,
90 | height,
91 | fovy,
92 | fovx,
93 | znear,
94 | zfar,
95 | world_view_transform,
96 | full_proj_transform,
97 | )
98 | except Exception as e:
99 | print("")
100 | traceback.print_exc()
101 | raise e
102 | return (
103 | custom_cam,
104 | do_training,
105 | do_shs_python,
106 | do_rot_scale_python,
107 | keep_alive,
108 | scaling_modifier,
109 | )
110 | else:
111 | return None, None, None, None, None, None
112 |
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(
7 | x: torch.Tensor, y: torch.Tensor, net_type: str = "alex", version: str = "0.1"
8 | ):
9 | r"""Function that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | x, y (torch.Tensor): the input tensors to compare.
14 | net_type (str): the network type to compare the features:
15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
16 | version (str): the version of LPIPS. Default: 0.1.
17 | """
18 | device = x.device
19 | criterion = LPIPS(net_type, version).to(device)
20 | return criterion(x, y)
21 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 |
18 | def __init__(self, net_type: str = "alex", version: str = "0.1"):
19 | assert version in ["0.1"], "v0.1 is only supported now"
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == "alex":
14 | return AlexNet()
15 | elif net_type == "squeeze":
16 | return SqueezeNet()
17 | elif net_type == "vgg":
18 | return VGG16()
19 | else:
20 | raise NotImplementedError("choose net_type from [alex, squeeze, vgg].")
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__(
26 | [
27 | nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False))
28 | for nc in n_channels_list
29 | ]
30 | )
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
43 | )
44 | self.register_buffer(
45 | "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
46 | )
47 |
48 | def set_requires_grad(self, state: bool):
49 | for param in chain(self.parameters(), self.buffers()):
50 | param.requires_grad = state
51 |
52 | def z_score(self, x: torch.Tensor):
53 | return (x - self.mean) / self.std
54 |
55 | def forward(self, x: torch.Tensor):
56 | x = self.z_score(x)
57 |
58 | output = []
59 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
60 | x = layer(x)
61 | if i in self.target_layers:
62 | output.append(normalize_activation(x))
63 | if len(output) == len(self.target_layers):
64 | break
65 | return output
66 |
67 |
68 | class SqueezeNet(BaseNet):
69 | def __init__(self):
70 | super(SqueezeNet, self).__init__()
71 |
72 | self.layers = models.squeezenet1_1(True).features
73 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
74 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
75 |
76 | self.set_requires_grad(False)
77 |
78 |
79 | class AlexNet(BaseNet):
80 | def __init__(self):
81 | super(AlexNet, self).__init__()
82 |
83 | self.layers = models.alexnet(True).features
84 | self.target_layers = [2, 5, 8, 10, 12]
85 | self.n_channels_list = [64, 192, 384, 256, 256]
86 |
87 | self.set_requires_grad(False)
88 |
89 |
90 | class VGG16(BaseNet):
91 | def __init__(self):
92 | super(VGG16, self).__init__()
93 |
94 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
95 | self.target_layers = [4, 9, 16, 23, 30]
96 | self.n_channels_list = [64, 128, 256, 512, 512]
97 |
98 | self.set_requires_grad(False)
99 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = "alex", version: str = "0.1"):
12 | # build url
13 | url = (
14 | "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/"
15 | + f"master/lpips/weights/v{version}/{net_type}.pth"
16 | )
17 |
18 | # download
19 | old_state_dict = torch.hub.load_state_dict_from_url(
20 | url,
21 | progress=True,
22 | map_location=None if torch.cuda.is_available() else torch.device("cpu"),
23 | )
24 |
25 | # rename keys
26 | new_state_dict = OrderedDict()
27 | for key, val in old_state_dict.items():
28 | new_key = key
29 | new_key = new_key.replace("lin", "")
30 | new_key = new_key.replace("model.", "")
31 | new_state_dict[new_key] = val
32 |
33 | return new_state_dict
34 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | from lpipsPyTorch import lpips
19 | import json
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser
23 |
24 |
25 | def readImages(renders_dir, gt_dir):
26 | renders = []
27 | gts = []
28 | image_names = []
29 | for fname in os.listdir(renders_dir):
30 | render = Image.open(renders_dir / fname)
31 | gt = Image.open(gt_dir / fname)
32 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
33 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
34 | image_names.append(fname)
35 | return renders, gts, image_names
36 |
37 |
38 | def evaluate(model_paths):
39 | full_dict = {}
40 | per_view_dict = {}
41 | full_dict_polytopeonly = {}
42 | per_view_dict_polytopeonly = {}
43 | print("")
44 |
45 | for scene_dir in model_paths:
46 | try:
47 | print("Scene:", scene_dir)
48 | full_dict[scene_dir] = {}
49 | per_view_dict[scene_dir] = {}
50 | full_dict_polytopeonly[scene_dir] = {}
51 | per_view_dict_polytopeonly[scene_dir] = {}
52 |
53 | test_dir = Path(scene_dir) / "test"
54 |
55 | for method in os.listdir(test_dir):
56 | print("Method:", method)
57 |
58 | full_dict[scene_dir][method] = {}
59 | per_view_dict[scene_dir][method] = {}
60 | full_dict_polytopeonly[scene_dir][method] = {}
61 | per_view_dict_polytopeonly[scene_dir][method] = {}
62 |
63 | method_dir = test_dir / method
64 | gt_dir = method_dir / "gt"
65 | renders_dir = method_dir / "renders"
66 | renders, gts, image_names = readImages(renders_dir, gt_dir)
67 |
68 | ssims = []
69 | psnrs = []
70 | lpipss = []
71 |
72 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
73 | ssims.append(ssim(renders[idx], gts[idx]))
74 | psnrs.append(psnr(renders[idx], gts[idx]))
75 | lpipss.append(lpips(renders[idx], gts[idx], net_type="vgg"))
76 |
77 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
78 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
79 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
80 | print("")
81 |
82 | full_dict[scene_dir][method].update(
83 | {
84 | "SSIM": torch.tensor(ssims).mean().item(),
85 | "PSNR": torch.tensor(psnrs).mean().item(),
86 | "LPIPS": torch.tensor(lpipss).mean().item(),
87 | }
88 | )
89 | per_view_dict[scene_dir][method].update(
90 | {
91 | "SSIM": {
92 | name: ssim
93 | for ssim, name in zip(
94 | torch.tensor(ssims).tolist(), image_names
95 | )
96 | },
97 | "PSNR": {
98 | name: psnr
99 | for psnr, name in zip(
100 | torch.tensor(psnrs).tolist(), image_names
101 | )
102 | },
103 | "LPIPS": {
104 | name: lp
105 | for lp, name in zip(
106 | torch.tensor(lpipss).tolist(), image_names
107 | )
108 | },
109 | }
110 | )
111 |
112 | with open(scene_dir + "/results.json", "w") as fp:
113 | json.dump(full_dict[scene_dir], fp, indent=True)
114 | with open(scene_dir + "/per_view.json", "w") as fp:
115 | json.dump(per_view_dict[scene_dir], fp, indent=True)
116 | except:
117 | print("Unable to compute metrics for model", scene_dir)
118 |
119 |
120 | if __name__ == "__main__":
121 | device = torch.device("cuda:0")
122 | torch.cuda.set_device(device)
123 |
124 | # Set up command line argument parser
125 | parser = ArgumentParser(description="Training script parameters")
126 | parser.add_argument(
127 | "--model_paths", "-m", required=True, nargs="+", type=str, default=[]
128 | )
129 | args = parser.parse_args()
130 | evaluate(args.model_paths)
131 |
--------------------------------------------------------------------------------
/prune.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from random import randint
15 | from gaussian_renderer import render, count_render
16 | import sys
17 | from scene import Scene, GaussianModel
18 | from utils.general_utils import safe_state
19 | import uuid
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser, Namespace
23 | from arguments import ModelParams, PipelineParams, OptimizationParams
24 | from utils.graphics_utils import getWorld2View2
25 | from icecream import ic
26 | import random
27 | import copy
28 | import gc
29 | import numpy as np
30 | from collections import defaultdict
31 |
32 | # from cuml.cluster import HDBSCAN
33 |
34 |
35 | # def HDBSCAN_prune(gaussians, score_list, prune_percent):
36 | # # Ensure the tensor is on the GPU and detached from the graph
37 | # s, d = gaussians.get_xyz.shape
38 | # X_gpu = cp.asarray(gaussians.get_xyz.detach().cuda())
39 |
40 | # scores_gpu = cp.asarray(score_list.detach().cuda())
41 | # hdbscan = HDBSCAN(min_cluster_size = 100)
42 | # cluster_labels = hdbscan.fit_predict(X_gpu)
43 | # points_by_centroid = {}
44 | # ic("cluster_labels")
45 | # ic(cluster_labels.shape)
46 | # ic(cluster_labels)
47 | # for i, label in enumerate(cluster_labels):
48 | # if label not in points_by_centroid:
49 | # points_by_centroid[label] = []
50 | # points_by_centroid[label].append(i)
51 | # points_to_prune = []
52 |
53 | # for centroid_idx, point_indices in points_by_centroid.items():
54 | # # Skip noise points with label -1
55 | # if centroid_idx == -1:
56 | # continue
57 | # num_to_prune = int(cp.ceil(prune_percent * len(point_indices)))
58 | # if num_to_prune <= 3:
59 | # continue
60 | # point_indices_cp = cp.array(point_indices)
61 | # distances = scores_gpu[point_indices_cp].squeeze()
62 | # indices_to_prune = point_indices_cp[cp.argsort(distances)[:num_to_prune]]
63 | # points_to_prune.extend(indices_to_prune)
64 | # points_to_prune = np.array(points_to_prune)
65 | # mask = np.zeros(s, dtype=bool)
66 | # mask[points_to_prune] = True
67 | # # points_to_prune now contains the indices of the points to be pruned
68 | # return mask
69 |
70 |
71 | # def uniform_prune(gaussians, k, score_list, prune_percent, sample = "k_mean"):
72 | # # get the farthest_point
73 | # D, I = None, None
74 | # s, d = gaussians.get_xyz.shape
75 |
76 | # if sample == "k_mean":
77 | # ic("k_mean")
78 | # n_iter = 200
79 | # verbose = False
80 | # kmeans = faiss.Kmeans(d, k=k, niter=n_iter, verbose=verbose, gpu=True)
81 | # kmeans.train(gaussians.get_xyz.detach().cpu().numpy())
82 | # # The cluster centroids can be accessed as follows
83 | # centroids = kmeans.centroids
84 | # D, I = kmeans.index.search(gaussians.get_xyz.detach().cpu().numpy(), 1)
85 | # else:
86 | # point_idx = farthest_point_sampler(torch.unsqueeze(gaussians.get_xyz, 0), k)
87 | # centroids = gaussians.get_xyz[point_idx,: ]
88 | # centroids = centroids.squeeze(0)
89 | # index = faiss.IndexFlatL2(d)
90 | # index.add(centroids.detach().cpu().numpy())
91 | # D, I = index.search(gaussians.get_xyz.detach().cpu().numpy(), 1)
92 | # points_to_prune = []
93 | # points_by_centroid = defaultdict(list)
94 | # for point_idx, centroid_idx in enumerate(I.flatten()):
95 | # points_by_centroid[centroid_idx.item()].append(point_idx)
96 | # for centroid_idx in points_by_centroid:
97 | # points_by_centroid[centroid_idx] = np.array(points_by_centroid[centroid_idx])
98 | # for centroid_idx, point_indices in points_by_centroid.items():
99 | # # Find the number of points to prune
100 | # num_to_prune = int(np.ceil(prune_percent * len(point_indices)))
101 | # if num_to_prune <= 3:
102 | # continue
103 | # distances = score_list[point_indices].squeeze().cpu().detach().numpy()
104 | # indices_to_prune = point_indices[np.argsort(distances)[:num_to_prune]]
105 | # points_to_prune.extend(indices_to_prune)
106 | # # Convert the list to an array
107 | # points_to_prune = np.array(points_to_prune)
108 | # mask = np.zeros(s, dtype=bool)
109 | # mask[points_to_prune] = True
110 | # return mask
111 |
112 | def calculate_v_imp_score(gaussians, imp_list, v_pow):
113 | """
114 | :param gaussians: A data structure containing Gaussian components with a get_scaling method.
115 | :param imp_list: The importance scores for each Gaussian component.
116 | :param v_pow: The power to which the volume ratios are raised.
117 | :return: A list of adjusted values (v_list) used for pruning.
118 | """
119 | # Calculate the volume of each Gaussian component
120 | volume = torch.prod(gaussians.get_scaling, dim=1)
121 | # Determine the kth_percent_largest value
122 | index = int(len(volume) * 0.9)
123 | sorted_volume, _ = torch.sort(volume, descending=True)
124 | kth_percent_largest = sorted_volume[index]
125 | # Calculate v_list
126 | v_list = torch.pow(volume / kth_percent_largest, v_pow)
127 | v_list = v_list * imp_list
128 | return v_list
129 |
130 |
131 |
132 |
133 | def prune_list(gaussians, scene, pipe, background):
134 | viewpoint_stack = scene.getTrainCameras().copy()
135 | gaussian_list, imp_list = None, None
136 | viewpoint_cam = viewpoint_stack.pop()
137 | render_pkg = count_render(viewpoint_cam, gaussians, pipe, background)
138 | gaussian_list, imp_list = (
139 | render_pkg["gaussians_count"],
140 | render_pkg["important_score"],
141 | )
142 |
143 | # ic(dataset.model_path)
144 | for iteration in range(len(viewpoint_stack)):
145 | # Pick a random Camera
146 | # prunning
147 | viewpoint_cam = viewpoint_stack.pop()
148 | render_pkg = count_render(viewpoint_cam, gaussians, pipe, background)
149 | # image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
150 | gaussians_count, important_score = (
151 | render_pkg["gaussians_count"].detach(),
152 | render_pkg["important_score"].detach(),
153 | )
154 | gaussian_list += gaussians_count
155 | imp_list += important_score
156 | gc.collect()
157 | return gaussian_list, imp_list
158 |
--------------------------------------------------------------------------------
/prune_finetune.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from random import randint
15 | from utils.loss_utils import l1_loss, ssim
16 | from lpipsPyTorch import lpips
17 | from gaussian_renderer import render, network_gui, count_render
18 | import sys
19 | from scene import Scene, GaussianModel
20 | from utils.general_utils import safe_state
21 | import uuid
22 | from tqdm import tqdm
23 | from utils.image_utils import psnr
24 | from argparse import ArgumentParser, Namespace
25 | from arguments import ModelParams, PipelineParams, OptimizationParams
26 | import numpy as np
27 |
28 | try:
29 | from torch.utils.tensorboard import SummaryWriter
30 |
31 | TENSORBOARD_FOUND = True
32 | except ImportError:
33 | TENSORBOARD_FOUND = False
34 | from icecream import ic
35 | import random
36 | import copy
37 | import gc
38 | from os import makedirs
39 | from prune import prune_list, calculate_v_imp_score
40 | import torchvision
41 | from torch.optim.lr_scheduler import ExponentialLR
42 | import csv
43 | from utils.logger_utils import training_report, prepare_output_and_logger
44 |
45 |
46 | to_tensor = (
47 | lambda x: x.to("cuda")
48 | if isinstance(x, torch.Tensor)
49 | else torch.Tensor(x).to("cuda")
50 | )
51 | img2mse = lambda x, y: torch.mean((x - y) ** 2)
52 | mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(to_tensor([10.0]))
53 |
54 |
55 | def training(
56 | dataset,
57 | opt,
58 | pipe,
59 | testing_iterations,
60 | saving_iterations,
61 | checkpoint_iterations,
62 | checkpoint,
63 | debug_from,
64 | args,
65 | ):
66 | first_iter = 0
67 | tb_writer = prepare_output_and_logger(dataset)
68 | gaussians = GaussianModel(dataset.sh_degree)
69 | scene = Scene(dataset, gaussians)
70 | if checkpoint:
71 | gaussians.training_setup(opt)
72 | (model_params, first_iter) = torch.load(checkpoint)
73 | gaussians.restore(model_params, opt)
74 | elif args.start_pointcloud:
75 | gaussians.load_ply(args.start_pointcloud)
76 | ic(gaussians.get_xyz.shape)
77 | # ic(gaussians.optimizer.param_groups["xyz"].shape)
78 | gaussians.training_setup(opt)
79 | gaussians.max_radii2D = torch.zeros((gaussians.get_xyz.shape[0]), device="cuda")
80 |
81 | else:
82 | raise ValueError("A checkpoint file or a pointcloud is required to proceed.")
83 |
84 |
85 |
86 |
87 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
88 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
89 |
90 | iter_start = torch.cuda.Event(enable_timing=True)
91 | iter_end = torch.cuda.Event(enable_timing=True)
92 |
93 | viewpoint_stack = None
94 | ema_loss_for_log = 0.0
95 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
96 | first_iter += 1
97 | gaussians.scheduler = ExponentialLR(gaussians.optimizer, gamma=0.95)
98 |
99 | for iteration in range(first_iter, opt.iterations + 1):
100 | if network_gui.conn == None:
101 | network_gui.try_connect()
102 | while network_gui.conn != None:
103 | try:
104 | net_image_bytes = None
105 | (
106 | custom_cam,
107 | do_training,
108 | pipe.convert_SHs_python,
109 | pipe.compute_cov3D_python,
110 | keep_alive,
111 | scaling_modifer,
112 | ) = network_gui.receive()
113 | if custom_cam != None:
114 | net_image = render(
115 | custom_cam, gaussians, pipe, background, scaling_modifer
116 | )["render"]
117 | net_image_bytes = memoryview(
118 | (torch.clamp(net_image, min=0, max=1.0) * 255)
119 | .byte()
120 | .permute(1, 2, 0)
121 | .contiguous()
122 | .cpu()
123 | .numpy()
124 | )
125 | network_gui.send(net_image_bytes, dataset.source_path)
126 | if do_training and (
127 | (iteration < int(opt.iterations)) or not keep_alive
128 | ):
129 | break
130 | except Exception as e:
131 | network_gui.conn = None
132 |
133 | iter_start.record()
134 |
135 | gaussians.update_learning_rate(iteration)
136 |
137 | # Every 1000 its we increase the levels of SH up to a maximum degree
138 | if iteration % 1000 == 0:
139 | gaussians.oneupSHdegree()
140 | if iteration % 400 == 0:
141 | gaussians.scheduler.step()
142 |
143 | # Pick a random Camera
144 | if not viewpoint_stack:
145 | viewpoint_stack = scene.getTrainCameras().copy()
146 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
147 |
148 | # Render
149 | if (iteration - 1) == debug_from:
150 | pipe.debug = True
151 | render_pkg = render(viewpoint_cam, gaussians, pipe, background)
152 | image, viewspace_point_tensor, visibility_filter, radii = (
153 | render_pkg["render"],
154 | render_pkg["viewspace_points"],
155 | render_pkg["visibility_filter"],
156 | render_pkg["radii"],
157 | )
158 |
159 | # Loss
160 | gt_image = viewpoint_cam.original_image.cuda()
161 | Ll1 = l1_loss(image, gt_image)
162 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
163 | 1.0 - ssim(image, gt_image)
164 | )
165 |
166 | loss.backward()
167 |
168 | iter_end.record()
169 |
170 | with torch.no_grad():
171 | # Progress bar
172 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
173 | if iteration % 1000 == 0:
174 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
175 | progress_bar.update(1000)
176 | if iteration == opt.iterations:
177 | progress_bar.close()
178 |
179 | # Log and save
180 |
181 | if iteration in saving_iterations:
182 | print("\n[ITER {}] Saving Gaussians".format(iteration))
183 | scene.save(iteration)
184 |
185 | if iteration in checkpoint_iterations:
186 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
187 | if not os.path.exists(scene.model_path):
188 | os.makedirs(scene.model_path)
189 | torch.save(
190 | (gaussians.capture(), iteration),
191 | scene.model_path + "/chkpnt" + str(iteration) + ".pth",
192 | )
193 |
194 | if iteration == checkpoint_iterations[-1]:
195 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background)
196 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow)
197 | np.savez(os.path.join(scene.model_path,"imp_score"), v_list.cpu().detach().numpy())
198 |
199 |
200 | training_report(
201 | tb_writer,
202 | iteration,
203 | Ll1,
204 | loss,
205 | l1_loss,
206 | iter_start.elapsed_time(iter_end),
207 | testing_iterations,
208 | scene,
209 | render,
210 | (pipe, background),
211 | )
212 |
213 | if iteration in args.prune_iterations:
214 | ic("Before prune iteration, number of gaussians: " + str(len(gaussians.get_xyz)))
215 | i = args.prune_iterations.index(iteration)
216 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background)
217 |
218 | if args.prune_type == "important_score":
219 | gaussians.prune_gaussians(
220 | (args.prune_decay**i) * args.prune_percent, imp_list
221 | )
222 | elif args.prune_type == "v_important_score":
223 | # normalize scale
224 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow)
225 | gaussians.prune_gaussians(
226 | (args.prune_decay**i) * args.prune_percent, v_list
227 | )
228 | elif args.prune_type == "max_v_important_score":
229 | v_list = imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
230 | gaussians.prune_gaussians(
231 | (args.prune_decay**i) * args.prune_percent, v_list
232 | )
233 | elif args.prune_type == "count":
234 | gaussians.prune_gaussians(
235 | (args.prune_decay**i) * args.prune_percent, gaussian_list
236 | )
237 | elif args.prune_type == "opacity":
238 | gaussians.prune_gaussians(
239 | (args.prune_decay**i) * args.prune_percent,
240 | gaussians.get_opacity.detach(),
241 | )
242 | # TODO(release different pruning method)
243 | # elif args.prune_type == "HDBSCAN":
244 | # masks = HDBSCAN_prune(gaussians, imp_list, (args.prune_decay**i)*args.prune_percent)
245 | # gaussians.prune_points(masks)
246 | # # elif args.prune_type == "v_important_score":
247 | # # imp_list *
248 | # elif args.prune_type == "two_step":
249 | # if i == 0:
250 | # volume = torch.prod(gaussians.get_scaling, dim = 1)
251 | # index = int(len(volume) * 0.9)
252 | # sorted_volume, sorted_indices = torch.sort(volume, descending=True, dim=0)
253 | # kth_percent_largest = sorted_volume[index]
254 | # v_list = torch.pow(volume/kth_percent_largest, args.v_pow)
255 | # v_list = v_list * imp_list
256 | # gaussians.prune_gaussians((args.prune_decay**i)*args.prune_percent, v_list)
257 | # else:
258 | # k = 5^(1*i) * 100
259 | # masks = uniform_prune(gaussians, k, imp_list, 0.3, "k_mean")
260 | # gaussians.prune_points(masks)
261 | # else:
262 | # k = len(gaussians.get_xyz)//500 * i
263 | # masks = uniform_prune(gaussians, k, imp_list, (args.prune_decay**i)*args.prune_percent, args.prune_type)
264 | # gaussians.prune_points(masks)
265 | # gaussians.prune_gaussians(args.prune_percent, imp_list)
266 | # gaussians.optimizer.zero_grad(set_to_none = True) #hachy way to maintain grad
267 | # if (iteration in args.opacity_prune_iterations):
268 | # gaussians.prune_opacity(0.05)
269 | else:
270 | raise Exception("Unsupportive pruning method")
271 |
272 | ic("After prune iteration, number of gaussians: " + str(len(gaussians.get_xyz)))
273 |
274 | # if iteration in args.densify_iteration:
275 | # gaussians.max_radii2D[visibility_filter] = torch.max(
276 | # gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
277 | # )
278 | # gaussians.add_densification_stats(
279 | # viewspace_point_tensor, visibility_filter
280 | # )
281 | # gaussians.densify(opt.densify_grad_threshold, scene.cameras_extent)
282 |
283 | ic("after")
284 | ic(gaussians.get_xyz.shape)
285 | ic(len(gaussians.optimizer.param_groups[0]['params'][0]))
286 |
287 | if iteration < opt.iterations:
288 | gaussians.optimizer.step()
289 | gaussians.optimizer.zero_grad(set_to_none=True)
290 |
291 |
292 | if __name__ == "__main__":
293 | # Set up command line argument parser
294 | parser = ArgumentParser(description="Training script parameters")
295 | lp = ModelParams(parser)
296 | op = OptimizationParams(parser)
297 | pp = PipelineParams(parser)
298 | parser.add_argument("--ip", type=str, default="127.0.0.1")
299 | parser.add_argument("--port", type=int, default=6009)
300 | parser.add_argument("--debug_from", type=int, default=-1)
301 | parser.add_argument("--detect_anomaly", action="store_true", default=False)
302 | parser.add_argument(
303 | "--test_iterations", nargs="+", type=int, default=[30_001, 30_002, 35_000]
304 | )
305 | parser.add_argument(
306 | "--save_iterations", nargs="+", type=int, default=[35_000]
307 | )
308 | parser.add_argument("--quiet", action="store_true")
309 | parser.add_argument(
310 | "--checkpoint_iterations", nargs="+", type=int, default=[35_000]
311 | )
312 |
313 | parser.add_argument("--prune_iterations", nargs="+", type=int, default=[30_001])
314 | parser.add_argument("--start_checkpoint", type=str, default=None)
315 | parser.add_argument("--start_pointcloud", type=str, default=None)
316 | parser.add_argument("--prune_percent", type=float, default=0.1)
317 | parser.add_argument("--prune_decay", type=float, default=1)
318 | parser.add_argument(
319 | "--prune_type", type=str, default="important_score"
320 | ) # k_mean, farther_point_sample, important_score
321 | parser.add_argument("--v_pow", type=float, default=0.1)
322 | parser.add_argument("--densify_iteration", nargs="+", type=int, default=[-1])
323 | args = parser.parse_args(sys.argv[1:])
324 | args.save_iterations.append(args.iterations)
325 |
326 | print("Optimizing " + args.model_path)
327 |
328 | # Initialize system state (RNG)
329 | safe_state(args.quiet)
330 |
331 | # Start GUI server, configure and run training
332 | network_gui.init(args.ip, args.port)
333 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
334 | training(
335 | lp.extract(args),
336 | op.extract(args),
337 | pp.extract(args),
338 | args.test_iterations,
339 | args.save_iterations,
340 | args.checkpoint_iterations,
341 | args.start_checkpoint,
342 | args.debug_from,
343 | args,
344 | )
345 |
346 | # All done
347 | print("\nTraining complete.")
348 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene
14 | import os
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | import torchvision
19 | from utils.general_utils import safe_state
20 | from argparse import ArgumentParser
21 | from arguments import ModelParams, PipelineParams, get_combined_args
22 | from gaussian_renderer import GaussianModel
23 |
24 |
25 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
26 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
27 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
28 |
29 | makedirs(render_path, exist_ok=True)
30 | makedirs(gts_path, exist_ok=True)
31 |
32 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
33 | rendering = render(view, gaussians, pipeline, background)["render"]
34 | gt = view.original_image[0:3, :, :]
35 | torchvision.utils.save_image(
36 | rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
37 | )
38 | torchvision.utils.save_image(
39 | gt, os.path.join(gts_path, "{0:05d}".format(idx) + ".png")
40 | )
41 |
42 |
43 | def render_sets(
44 | dataset: ModelParams,
45 | iteration: int,
46 | pipeline: PipelineParams,
47 | skip_train: bool,
48 | skip_test: bool,
49 | load_vq: bool,
50 | ):
51 | with torch.no_grad():
52 | gaussians = GaussianModel(dataset.sh_degree)
53 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False, load_vq= load_vq)
54 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
55 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
56 |
57 | if not skip_train:
58 | render_set(
59 | dataset.model_path,
60 | "train",
61 | scene.loaded_iter,
62 | scene.getTrainCameras(),
63 | gaussians,
64 | pipeline,
65 | background,
66 | )
67 |
68 | if not skip_test:
69 | render_set(
70 | dataset.model_path,
71 | "test",
72 | scene.loaded_iter,
73 | scene.getTestCameras(),
74 | gaussians,
75 | pipeline,
76 | background,
77 | )
78 |
79 |
80 | if __name__ == "__main__":
81 | # Set up command line argument parser
82 | parser = ArgumentParser(description="Testing script parameters")
83 | model = ModelParams(parser, sentinel=True)
84 | pipeline = PipelineParams(parser)
85 | parser.add_argument("--iteration", default=-1, type=int)
86 | parser.add_argument("--skip_train", action="store_true")
87 | parser.add_argument("--skip_test", action="store_true")
88 | parser.add_argument("--load_vq", action="store_true")
89 | parser.add_argument("--quiet", action="store_true")
90 | args = get_combined_args(parser)
91 | print("Rendering " + args.model_path)
92 |
93 | # Initialize system state (RNG)
94 | safe_state(args.quiet)
95 |
96 | render_sets(
97 | model.extract(args),
98 | args.iteration,
99 | pipeline.extract(args),
100 | args.skip_train,
101 | args.skip_test,
102 | args.load_vq
103 | )
104 |
--------------------------------------------------------------------------------
/render_video.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene import Scene
14 | import os
15 | from tqdm import tqdm
16 | import numpy as np
17 | from os import makedirs
18 | from gaussian_renderer import render
19 | import torchvision
20 | from utils.general_utils import safe_state
21 | from argparse import ArgumentParser
22 | from arguments import ModelParams, PipelineParams, get_combined_args
23 | from gaussian_renderer import GaussianModel
24 | from icecream import ic
25 | import copy
26 |
27 | from utils.graphics_utils import getWorld2View2
28 | from utils.pose_utils import generate_ellipse_path, generate_spherical_sample_path, generate_spiral_path, generate_spherify_path, gaussian_poses, circular_poses
29 | # import stepfun
30 |
31 |
32 |
33 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
34 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
35 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
36 |
37 | makedirs(render_path, exist_ok=True)
38 | makedirs(gts_path, exist_ok=True)
39 |
40 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
41 | rendering = render(view, gaussians, pipeline, background)["render"]
42 | gt = view.original_image[0:3, :, :]
43 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
44 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
45 |
46 |
47 | # def normalize(x):
48 | # return x / np.linalg.norm(x)
49 |
50 | # def viewmatrix(z, up, pos):
51 | # vec2 = normalize(z)
52 | # vec0 = normalize(np.cross(up, vec2))
53 | # vec1 = normalize(np.cross(vec2, vec0))
54 | # m = np.stack([vec0, vec1, vec2, pos], 1)
55 | # return m
56 |
57 | # def poses_avg(poses):
58 | # hwf = poses[0, :3, -1:]
59 | # center = poses[:, :3, 3].mean(0)
60 | # vec2 = normalize(poses[:, :3, 2].sum(0))
61 | # up = poses[:, :3, 1].sum(0)
62 | # c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
63 | # return c2w
64 |
65 | # def get_focal(camera):
66 | # focal = camera.FoVx
67 | # return focal
68 |
69 | # def poses_avg_fixed_center(poses):
70 | # hwf = poses[0, :3, -1:]
71 | # center = poses[:, :3, 3].mean(0)
72 | # vec2 = [1, 0, 0]
73 | # up = [0, 0, 1]
74 | # c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
75 | # return c2w
76 |
77 | # def focus_point_fn(poses):
78 | # """Calculate nearest point to all focal axes in poses."""
79 | # directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
80 | # m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
81 | # mt_m = np.transpose(m, [0, 2, 1]) @ m
82 | # focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
83 | # return focus_pt
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 | # xy circular
93 | def render_circular_video(model_path, iteration, views, gaussians, pipeline, background, radius=0.5, n_frames=240):
94 | render_path = os.path.join(model_path, 'circular', "ours_{}".format(iteration))
95 | os.makedirs(render_path, exist_ok=True)
96 | makedirs(render_path, exist_ok=True)
97 | # view = views[0]
98 | for idx in range(n_frames):
99 | view = copy.deepcopy(views[13])
100 | angle = 2 * np.pi * idx / n_frames
101 | cam = circular_poses(view, radius, angle)
102 | rendering = render(cam, gaussians, pipeline, background)["render"]
103 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
104 |
105 |
106 |
107 | def render_video(model_path, iteration, views, gaussians, pipeline, background):
108 | render_path = os.path.join(model_path, 'video', "ours_{}".format(iteration))
109 | makedirs(render_path, exist_ok=True)
110 | view = views[0]
111 | # render_path_spiral
112 | # render_path_spherical
113 | for idx, pose in enumerate(tqdm(generate_ellipse_path(views,n_frames=600), desc="Rendering progress")):
114 | view.world_view_transform = torch.tensor(getWorld2View2(pose[:3, :3].T, pose[:3, 3], view.trans, view.scale)).transpose(0, 1).cuda()
115 | view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
116 | view.camera_center = view.world_view_transform.inverse()[3, :3]
117 | rendering = render(view, gaussians, pipeline, background)["render"]
118 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
119 |
120 |
121 |
122 |
123 | def gaussian_render(model_path, iteration, views, gaussians, pipeline, background, args):
124 | views = views[:10] #take the first 10 views and check gaussian view point
125 | render_path = os.path.join(model_path, 'video', "gaussians_{}_std{}".format(iteration, args.std))
126 | makedirs(render_path, exist_ok=True)
127 |
128 | for i, view in enumerate(views):
129 | rendering = render(view, gaussians, pipeline, background)["render"]
130 | sub_path = os.path.join(render_path,"view_"+str(i))
131 | makedirs(sub_path ,exist_ok=True)
132 | torchvision.utils.save_image(rendering, os.path.join(sub_path, "gt"+'{0:05d}'.format(i) + ".png"))
133 | for j in range(10):
134 | n_view = copy.deepcopy(view)
135 | g_view = gaussian_poses(n_view, args.mean, args.std)
136 | rendering = render(g_view, gaussians, pipeline, background)["render"]
137 | torchvision.utils.save_image(rendering, os.path.join(sub_path, '{0:05d}'.format(j) + ".png"))
138 |
139 |
140 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, video: bool, circular:bool, radius: float, args):
141 | with torch.no_grad():
142 | gaussians = GaussianModel(dataset.sh_degree)
143 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False, load_vq= args.load_vq)
144 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
145 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
146 |
147 | if not skip_train:
148 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
149 |
150 | if not skip_test:
151 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
152 | if circular:
153 | render_circular_video(dataset.model_path, scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background,radius)
154 | # by default generate ellipse path, other options include spiral, circular, or other generate_xxx_path function from utils.pose_utils
155 | # Modify trajectory function in render_video's enumerate
156 | if video:
157 | render_video(dataset.model_path, scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
158 | #sample virtual view
159 | if args.gaussians:
160 | gaussian_render(dataset.model_path, scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, args)
161 |
162 |
163 | if __name__ == "__main__":
164 | # Set up command line argument parser
165 | parser = ArgumentParser(description="Testing script parameters")
166 | model = ModelParams(parser, sentinel=True)
167 | pipeline = PipelineParams(parser)
168 | parser.add_argument("--iteration", default=-1, type=int)
169 | parser.add_argument("--skip_train", action="store_true")
170 | parser.add_argument("--skip_test", action="store_true")
171 | parser.add_argument("--quiet", action="store_true")
172 | parser.add_argument("--video", action="store_true")
173 | parser.add_argument("--circular", action="store_true")
174 | parser.add_argument("--radius", default=5, type=float)
175 | parser.add_argument("--gaussians", action="store_true")
176 | parser.add_argument("--mean", default=0, type=float)
177 | parser.add_argument("--std", default=0.03, type=float)
178 | parser.add_argument("--load_vq", action="store_true")
179 | args = get_combined_args(parser)
180 | print("Rendering " + args.model_path)
181 |
182 | # Initialize system state (RNG)
183 | safe_state(args.quiet)
184 |
185 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.video, args.circular, args.radius, args)
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 |
21 |
22 | class Scene:
23 | gaussians: GaussianModel
24 | # modified
25 | def __init__(
26 | self,
27 | args: ModelParams,
28 | gaussians: GaussianModel,
29 | load_iteration=None,
30 | shuffle=True,
31 | resolution_scales=[1.0],
32 | new_sh=0,
33 | load_vq=False
34 | ):
35 | """b
36 | :param path: Path to colmap scene main folder.
37 | """
38 | self.model_path = args.model_path
39 | self.loaded_iter = None
40 | self.gaussians = gaussians
41 |
42 | if load_iteration:
43 | if load_iteration == -1:
44 | self.loaded_iter = searchForMaxIteration(
45 | os.path.join(self.model_path, "point_cloud")
46 | )
47 | else:
48 | self.loaded_iter = load_iteration
49 | print("Loading trained model at iteration {}".format(self.loaded_iter))
50 |
51 | self.train_cameras = {}
52 | self.test_cameras = {}
53 | print(args.source_path)
54 | if os.path.exists(os.path.join(args.source_path, "sparse")):
55 | scene_info = sceneLoadTypeCallbacks["Colmap"](
56 | args.source_path, args.images, args.eval
57 | )
58 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
59 | print("Found transforms_train.json file, assuming Blender data set!")
60 | scene_info = sceneLoadTypeCallbacks["Blender"](
61 | args.source_path, args.white_background, args.eval
62 | )
63 | else:
64 | assert False, "Could not recognize scene type!"
65 |
66 | if not self.loaded_iter:
67 | with open(scene_info.ply_path, "rb") as src_file, open(
68 | os.path.join(self.model_path, "input.ply"), "wb"
69 | ) as dest_file:
70 | dest_file.write(src_file.read())
71 | json_cams = []
72 | camlist = []
73 | if scene_info.test_cameras:
74 | camlist.extend(scene_info.test_cameras)
75 | if scene_info.train_cameras:
76 | camlist.extend(scene_info.train_cameras)
77 | for id, cam in enumerate(camlist):
78 | json_cams.append(camera_to_JSON(id, cam))
79 | with open(os.path.join(self.model_path, "cameras.json"), "w") as file:
80 | json.dump(json_cams, file)
81 |
82 | if shuffle:
83 | random.shuffle(
84 | scene_info.train_cameras
85 | ) # Multi-res consistent random shuffling
86 | random.shuffle(
87 | scene_info.test_cameras
88 | ) # Multi-res consistent random shuffling
89 |
90 | self.cameras_extent = scene_info.nerf_normalization["radius"]
91 |
92 | for resolution_scale in resolution_scales:
93 | # temp comment out
94 | print("Loading Training Cameras")
95 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(
96 | scene_info.train_cameras, resolution_scale, args
97 | )
98 | print("Loading Test Cameras")
99 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(
100 | scene_info.test_cameras, resolution_scale, args
101 | )
102 | if load_vq:
103 | self.gaussians.load_vq(self.model_path)
104 |
105 | elif new_sh != 0 and self.loaded_iter:
106 | self.gaussians.load_ply_sh(
107 | os.path.join(
108 | self.model_path,
109 | "point_cloud",
110 | "iteration_" + str(self.loaded_iter),
111 | "point_cloud.ply",
112 | ),
113 | new_sh,
114 | )
115 | elif self.loaded_iter:
116 | self.gaussians.load_ply(
117 | os.path.join(
118 | self.model_path,
119 | "point_cloud",
120 | "iteration_" + str(self.loaded_iter),
121 | "point_cloud.ply",
122 | )
123 | )
124 | else:
125 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
126 |
127 | def save(self, iteration):
128 | point_cloud_path = os.path.join(
129 | self.model_path, "point_cloud/iteration_{}".format(iteration)
130 | )
131 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
132 |
133 | def getTrainCameras(self, scale=1.0):
134 | return self.train_cameras[scale]
135 |
136 | def getTestCameras(self, scale=1.0):
137 | return self.test_cameras[scale]
138 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 |
17 |
18 | class Camera(nn.Module):
19 | def __init__(
20 | self,
21 | colmap_id,
22 | R,
23 | T,
24 | FoVx,
25 | FoVy,
26 | image,
27 | gt_alpha_mask,
28 | image_name,
29 | uid,
30 | trans=np.array([0.0, 0.0, 0.0]),
31 | scale=1.0,
32 | data_device="cuda",
33 | ):
34 | super(Camera, self).__init__()
35 |
36 | self.uid = uid
37 | self.colmap_id = colmap_id
38 | self.R = R
39 | self.T = T
40 | self.FoVx = FoVx
41 | self.FoVy = FoVy
42 | self.image_name = image_name
43 |
44 | try:
45 | self.data_device = torch.device(data_device)
46 | except Exception as e:
47 | print(e)
48 | print(
49 | f"[Warning] Custom device {data_device} failed, fallback to default cuda device"
50 | )
51 | self.data_device = torch.device("cuda")
52 |
53 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
54 | self.image_width = self.original_image.shape[2]
55 | self.image_height = self.original_image.shape[1]
56 |
57 | if gt_alpha_mask is not None:
58 | self.original_image *= gt_alpha_mask.to(self.data_device)
59 | else:
60 | self.original_image *= torch.ones(
61 | (1, self.image_height, self.image_width), device=self.data_device
62 | )
63 |
64 | self.zfar = 100.0
65 | self.znear = 0.01
66 |
67 | self.trans = trans
68 | self.scale = scale
69 |
70 | self.world_view_transform = (
71 | torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
72 | )
73 | self.projection_matrix = (
74 | getProjectionMatrix(
75 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
76 | )
77 | .transpose(0, 1)
78 | .cuda()
79 | )
80 | self.full_proj_transform = (
81 | self.world_view_transform.unsqueeze(0).bmm(
82 | self.projection_matrix.unsqueeze(0)
83 | )
84 | ).squeeze(0)
85 | self.camera_center = self.world_view_transform.inverse()[3, :3]
86 |
87 |
88 | class MiniCam:
89 | def __init__(
90 | self,
91 | width,
92 | height,
93 | fovy,
94 | fovx,
95 | znear,
96 | zfar,
97 | world_view_transform,
98 | full_proj_transform,
99 | ):
100 | self.image_width = width
101 | self.image_height = height
102 | self.FoVy = fovy
103 | self.FoVx = fovx
104 | self.znear = znear
105 | self.zfar = zfar
106 | self.world_view_transform = world_view_transform
107 | self.full_proj_transform = full_proj_transform
108 | view_inv = torch.inverse(self.world_view_transform)
109 | self.camera_center = view_inv[3][:3]
110 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"]
18 | )
19 | Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
22 | )
23 | Point3D = collections.namedtuple(
24 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
25 | )
26 | CAMERA_MODELS = {
27 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
28 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
29 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
30 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
31 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
32 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
33 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
34 | CameraModel(model_id=7, model_name="FOV", num_params=5),
35 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
36 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
37 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
38 | }
39 | CAMERA_MODEL_IDS = dict(
40 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
41 | )
42 | CAMERA_MODEL_NAMES = dict(
43 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
44 | )
45 |
46 |
47 | def qvec2rotmat(qvec):
48 | return np.array(
49 | [
50 | [
51 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
52 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
53 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
54 | ],
55 | [
56 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
57 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
58 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
59 | ],
60 | [
61 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
62 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
63 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
64 | ],
65 | ]
66 | )
67 |
68 |
69 | def rotmat2qvec(R):
70 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
71 | K = (
72 | np.array(
73 | [
74 | [Rxx - Ryy - Rzz, 0, 0, 0],
75 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
76 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
77 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
78 | ]
79 | )
80 | / 3.0
81 | )
82 | eigvals, eigvecs = np.linalg.eigh(K)
83 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
84 | if qvec[0] < 0:
85 | qvec *= -1
86 | return qvec
87 |
88 |
89 | class Image(BaseImage):
90 | def qvec2rotmat(self):
91 | return qvec2rotmat(self.qvec)
92 |
93 |
94 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
95 | """Read and unpack the next bytes from a binary file.
96 | :param fid:
97 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
98 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
99 | :param endian_character: Any of {@, =, <, >, !}
100 | :return: Tuple of read and unpacked values.
101 | """
102 | data = fid.read(num_bytes)
103 | return struct.unpack(endian_character + format_char_sequence, data)
104 |
105 |
106 | def read_points3D_text(path):
107 | """
108 | see: src/base/reconstruction.cc
109 | void Reconstruction::ReadPoints3DText(const std::string& path)
110 | void Reconstruction::WritePoints3DText(const std::string& path)
111 | """
112 | xyzs = None
113 | rgbs = None
114 | errors = None
115 | num_points = 0
116 | with open(path, "r") as fid:
117 | while True:
118 | line = fid.readline()
119 | if not line:
120 | break
121 | line = line.strip()
122 | if len(line) > 0 and line[0] != "#":
123 | num_points += 1
124 |
125 | xyzs = np.empty((num_points, 3))
126 | rgbs = np.empty((num_points, 3))
127 | errors = np.empty((num_points, 1))
128 | count = 0
129 | with open(path, "r") as fid:
130 | while True:
131 | line = fid.readline()
132 | if not line:
133 | break
134 | line = line.strip()
135 | if len(line) > 0 and line[0] != "#":
136 | elems = line.split()
137 | xyz = np.array(tuple(map(float, elems[1:4])))
138 | rgb = np.array(tuple(map(int, elems[4:7])))
139 | error = np.array(float(elems[7]))
140 | xyzs[count] = xyz
141 | rgbs[count] = rgb
142 | errors[count] = error
143 | count += 1
144 |
145 | return xyzs, rgbs, errors
146 |
147 |
148 | def read_points3D_binary(path_to_model_file):
149 | """
150 | see: src/base/reconstruction.cc
151 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
152 | void Reconstruction::WritePoints3DBinary(const std::string& path)
153 | """
154 |
155 | with open(path_to_model_file, "rb") as fid:
156 | num_points = read_next_bytes(fid, 8, "Q")[0]
157 |
158 | xyzs = np.empty((num_points, 3))
159 | rgbs = np.empty((num_points, 3))
160 | errors = np.empty((num_points, 1))
161 |
162 | for p_id in range(num_points):
163 | binary_point_line_properties = read_next_bytes(
164 | fid, num_bytes=43, format_char_sequence="QdddBBBd"
165 | )
166 | xyz = np.array(binary_point_line_properties[1:4])
167 | rgb = np.array(binary_point_line_properties[4:7])
168 | error = np.array(binary_point_line_properties[7])
169 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
170 | 0
171 | ]
172 | track_elems = read_next_bytes(
173 | fid,
174 | num_bytes=8 * track_length,
175 | format_char_sequence="ii" * track_length,
176 | )
177 | xyzs[p_id] = xyz
178 | rgbs[p_id] = rgb
179 | errors[p_id] = error
180 | return xyzs, rgbs, errors
181 |
182 |
183 | def read_intrinsics_text(path):
184 | """
185 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
186 | """
187 | cameras = {}
188 | with open(path, "r") as fid:
189 | while True:
190 | line = fid.readline()
191 | if not line:
192 | break
193 | line = line.strip()
194 | if len(line) > 0 and line[0] != "#":
195 | elems = line.split()
196 | camera_id = int(elems[0])
197 | model = elems[1]
198 | assert (
199 | model == "PINHOLE"
200 | ), "While the loader support other types, the rest of the code assumes PINHOLE"
201 | width = int(elems[2])
202 | height = int(elems[3])
203 | params = np.array(tuple(map(float, elems[4:])))
204 | cameras[camera_id] = Camera(
205 | id=camera_id, model=model, width=width, height=height, params=params
206 | )
207 | return cameras
208 |
209 |
210 | def read_extrinsics_binary(path_to_model_file):
211 | """
212 | see: src/base/reconstruction.cc
213 | void Reconstruction::ReadImagesBinary(const std::string& path)
214 | void Reconstruction::WriteImagesBinary(const std::string& path)
215 | """
216 | images = {}
217 | with open(path_to_model_file, "rb") as fid:
218 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
219 | for _ in range(num_reg_images):
220 | binary_image_properties = read_next_bytes(
221 | fid, num_bytes=64, format_char_sequence="idddddddi"
222 | )
223 | image_id = binary_image_properties[0]
224 | qvec = np.array(binary_image_properties[1:5])
225 | tvec = np.array(binary_image_properties[5:8])
226 | camera_id = binary_image_properties[8]
227 | image_name = ""
228 | current_char = read_next_bytes(fid, 1, "c")[0]
229 | while current_char != b"\x00": # look for the ASCII 0 entry
230 | image_name += current_char.decode("utf-8")
231 | current_char = read_next_bytes(fid, 1, "c")[0]
232 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
233 | 0
234 | ]
235 | x_y_id_s = read_next_bytes(
236 | fid,
237 | num_bytes=24 * num_points2D,
238 | format_char_sequence="ddq" * num_points2D,
239 | )
240 | xys = np.column_stack(
241 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]
242 | )
243 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
244 | images[image_id] = Image(
245 | id=image_id,
246 | qvec=qvec,
247 | tvec=tvec,
248 | camera_id=camera_id,
249 | name=image_name,
250 | xys=xys,
251 | point3D_ids=point3D_ids,
252 | )
253 | return images
254 |
255 |
256 | def read_intrinsics_binary(path_to_model_file):
257 | """
258 | see: src/base/reconstruction.cc
259 | void Reconstruction::WriteCamerasBinary(const std::string& path)
260 | void Reconstruction::ReadCamerasBinary(const std::string& path)
261 | """
262 | cameras = {}
263 | with open(path_to_model_file, "rb") as fid:
264 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
265 | for _ in range(num_cameras):
266 | camera_properties = read_next_bytes(
267 | fid, num_bytes=24, format_char_sequence="iiQQ"
268 | )
269 | camera_id = camera_properties[0]
270 | model_id = camera_properties[1]
271 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
272 | width = camera_properties[2]
273 | height = camera_properties[3]
274 | num_params = CAMERA_MODEL_IDS[model_id].num_params
275 | params = read_next_bytes(
276 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params
277 | )
278 | cameras[camera_id] = Camera(
279 | id=camera_id,
280 | model=model_name,
281 | width=width,
282 | height=height,
283 | params=np.array(params),
284 | )
285 | assert len(cameras) == num_cameras
286 | return cameras
287 |
288 |
289 | def read_extrinsics_text(path):
290 | """
291 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
292 | """
293 | images = {}
294 | with open(path, "r") as fid:
295 | while True:
296 | line = fid.readline()
297 | if not line:
298 | break
299 | line = line.strip()
300 | if len(line) > 0 and line[0] != "#":
301 | elems = line.split()
302 | image_id = int(elems[0])
303 | qvec = np.array(tuple(map(float, elems[1:5])))
304 | tvec = np.array(tuple(map(float, elems[5:8])))
305 | camera_id = int(elems[8])
306 | image_name = elems[9]
307 | elems = fid.readline().split()
308 | xys = np.column_stack(
309 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]
310 | )
311 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
312 | images[image_id] = Image(
313 | id=image_id,
314 | qvec=qvec,
315 | tvec=tvec,
316 | camera_id=camera_id,
317 | name=image_name,
318 | xys=xys,
319 | point3D_ids=point3D_ids,
320 | )
321 | return images
322 |
323 |
324 | def read_colmap_bin_array(path):
325 | """
326 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
327 |
328 | :param path: path to the colmap binary file.
329 | :return: nd array with the floating point values in the value
330 | """
331 | with open(path, "rb") as fid:
332 | width, height, channels = np.genfromtxt(
333 | fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int
334 | )
335 | fid.seek(0)
336 | num_delimiter = 0
337 | byte = fid.read(1)
338 | while True:
339 | if byte == b"&":
340 | num_delimiter += 1
341 | if num_delimiter >= 3:
342 | break
343 | byte = fid.read(1)
344 | array = np.fromfile(fid, np.float32)
345 | array = array.reshape((width, height, channels), order="F")
346 | return np.transpose(array, (1, 0, 2)).squeeze()
347 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from typing import NamedTuple
16 | from scene.colmap_loader import (
17 | read_extrinsics_text,
18 | read_intrinsics_text,
19 | qvec2rotmat,
20 | read_extrinsics_binary,
21 | read_intrinsics_binary,
22 | read_points3D_binary,
23 | read_points3D_text,
24 | )
25 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
26 | import numpy as np
27 | import json
28 | from pathlib import Path
29 | from plyfile import PlyData, PlyElement
30 | from utils.sh_utils import SH2RGB
31 | from scene.gaussian_model import BasicPointCloud
32 |
33 |
34 | class CameraInfo(NamedTuple):
35 | uid: int
36 | R: np.array
37 | T: np.array
38 | FovY: np.array
39 | FovX: np.array
40 | image: np.array
41 | image_path: str
42 | image_name: str
43 | width: int
44 | height: int
45 |
46 |
47 | class SceneInfo(NamedTuple):
48 | point_cloud: BasicPointCloud
49 | train_cameras: list
50 | test_cameras: list
51 | nerf_normalization: dict
52 | ply_path: str
53 |
54 |
55 | def getNerfppNorm(cam_info):
56 | def get_center_and_diag(cam_centers):
57 | cam_centers = np.hstack(cam_centers)
58 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
59 | center = avg_cam_center
60 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
61 | diagonal = np.max(dist)
62 | return center.flatten(), diagonal
63 |
64 | cam_centers = []
65 |
66 | for cam in cam_info:
67 | W2C = getWorld2View2(cam.R, cam.T)
68 | C2W = np.linalg.inv(W2C)
69 | cam_centers.append(C2W[:3, 3:4])
70 |
71 | center, diagonal = get_center_and_diag(cam_centers)
72 | radius = diagonal * 1.1
73 |
74 | translate = -center
75 |
76 | return {"translate": translate, "radius": radius}
77 |
78 |
79 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
80 | cam_infos = []
81 | for idx, key in enumerate(cam_extrinsics):
82 | sys.stdout.write("\r")
83 | # the exact output you're looking for:
84 | sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics)))
85 | sys.stdout.flush()
86 |
87 | extr = cam_extrinsics[key]
88 | intr = cam_intrinsics[extr.camera_id]
89 | height = intr.height
90 | width = intr.width
91 |
92 | uid = intr.id
93 | R = np.transpose(qvec2rotmat(extr.qvec))
94 | T = np.array(extr.tvec)
95 |
96 | if intr.model == "SIMPLE_PINHOLE":
97 | focal_length_x = intr.params[0]
98 | FovY = focal2fov(focal_length_x, height)
99 | FovX = focal2fov(focal_length_x, width)
100 | elif intr.model == "PINHOLE":
101 | focal_length_x = intr.params[0]
102 | focal_length_y = intr.params[1]
103 | FovY = focal2fov(focal_length_y, height)
104 | FovX = focal2fov(focal_length_x, width)
105 | else:
106 | assert (
107 | False
108 | ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
109 |
110 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
111 | image_name = os.path.basename(image_path).split(".")[0]
112 | image = Image.open(image_path)
113 |
114 | cam_info = CameraInfo(
115 | uid=uid,
116 | R=R,
117 | T=T,
118 | FovY=FovY,
119 | FovX=FovX,
120 | image=image,
121 | image_path=image_path,
122 | image_name=image_name,
123 | width=width,
124 | height=height,
125 | )
126 | cam_infos.append(cam_info)
127 | sys.stdout.write("\n")
128 | return cam_infos
129 |
130 |
131 | def fetchPly(path):
132 | plydata = PlyData.read(path)
133 | vertices = plydata["vertex"]
134 | positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
135 | colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0
136 | normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T
137 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
138 |
139 |
140 | def storePly(path, xyz, rgb):
141 | # Define the dtype for the structured array
142 | dtype = [
143 | ("x", "f4"),
144 | ("y", "f4"),
145 | ("z", "f4"),
146 | ("nx", "f4"),
147 | ("ny", "f4"),
148 | ("nz", "f4"),
149 | ("red", "u1"),
150 | ("green", "u1"),
151 | ("blue", "u1"),
152 | ]
153 |
154 | normals = np.zeros_like(xyz)
155 |
156 | elements = np.empty(xyz.shape[0], dtype=dtype)
157 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
158 | elements[:] = list(map(tuple, attributes))
159 |
160 | # Create the PlyData object and write to file
161 | vertex_element = PlyElement.describe(elements, "vertex")
162 | ply_data = PlyData([vertex_element])
163 | ply_data.write(path)
164 |
165 |
166 | def readColmapSceneInfo(path, images, eval, llffhold=8):
167 | try:
168 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
169 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
170 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
171 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
172 | except:
173 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
174 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
175 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
176 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
177 |
178 | reading_dir = "images" if images == None else images
179 | cam_infos_unsorted = readColmapCameras(
180 | cam_extrinsics=cam_extrinsics,
181 | cam_intrinsics=cam_intrinsics,
182 | images_folder=os.path.join(path, reading_dir),
183 | )
184 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)
185 |
186 | if eval:
187 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
188 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
189 | else:
190 | train_cam_infos = cam_infos
191 | test_cam_infos = []
192 |
193 | nerf_normalization = getNerfppNorm(train_cam_infos)
194 |
195 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
196 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
197 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
198 | if not os.path.exists(ply_path):
199 | print(
200 | "Converting point3d.bin to .ply, will happen only the first time you open the scene."
201 | )
202 | try:
203 | xyz, rgb, _ = read_points3D_binary(bin_path)
204 | except:
205 | xyz, rgb, _ = read_points3D_text(txt_path)
206 | storePly(ply_path, xyz, rgb)
207 | try:
208 | pcd = fetchPly(ply_path)
209 | except:
210 | pcd = None
211 |
212 | scene_info = SceneInfo(
213 | point_cloud=pcd,
214 | train_cameras=train_cam_infos,
215 | test_cameras=test_cam_infos,
216 | nerf_normalization=nerf_normalization,
217 | ply_path=ply_path,
218 | )
219 | return scene_info
220 |
221 |
222 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
223 | cam_infos = []
224 |
225 | with open(os.path.join(path, transformsfile)) as json_file:
226 | contents = json.load(json_file)
227 | fovx = contents["camera_angle_x"]
228 |
229 | frames = contents["frames"]
230 | for idx, frame in enumerate(frames):
231 | cam_name = os.path.join(path, frame["file_path"] + extension)
232 |
233 | # NeRF 'transform_matrix' is a camera-to-world transform
234 | c2w = np.array(frame["transform_matrix"])
235 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
236 | c2w[:3, 1:3] *= -1
237 |
238 | # get the world-to-camera transform and set R, T
239 | w2c = np.linalg.inv(c2w)
240 | R = np.transpose(
241 | w2c[:3, :3]
242 | ) # R is stored transposed due to 'glm' in CUDA code
243 | T = w2c[:3, 3]
244 |
245 | image_path = os.path.join(path, cam_name)
246 | image_name = Path(cam_name).stem
247 | image = Image.open(image_path)
248 |
249 | im_data = np.array(image.convert("RGBA"))
250 |
251 | bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])
252 |
253 | norm_data = im_data / 255.0
254 | arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (
255 | 1 - norm_data[:, :, 3:4]
256 | )
257 | image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")
258 |
259 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
260 | FovY = fovy
261 | FovX = fovx
262 |
263 | cam_infos.append(
264 | CameraInfo(
265 | uid=idx,
266 | R=R,
267 | T=T,
268 | FovY=FovY,
269 | FovX=FovX,
270 | image=image,
271 | image_path=image_path,
272 | image_name=image_name,
273 | width=image.size[0],
274 | height=image.size[1],
275 | )
276 | )
277 |
278 | return cam_infos
279 |
280 |
281 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
282 | print("Reading Training Transforms")
283 | train_cam_infos = readCamerasFromTransforms(
284 | path, "transforms_train.json", white_background, extension
285 | )
286 | print("Reading Test Transforms")
287 | test_cam_infos = readCamerasFromTransforms(
288 | path, "transforms_test.json", white_background, extension
289 | )
290 |
291 | if not eval:
292 | train_cam_infos.extend(test_cam_infos)
293 | test_cam_infos = []
294 |
295 | nerf_normalization = getNerfppNorm(train_cam_infos)
296 |
297 | ply_path = os.path.join(path, "points3d.ply")
298 | if not os.path.exists(ply_path):
299 | # Since this data set has no colmap data, we start with random points
300 | num_pts = 100_000
301 | print(f"Generating random point cloud ({num_pts})...")
302 |
303 | # We create random points inside the bounds of the synthetic Blender scenes
304 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
305 | shs = np.random.random((num_pts, 3)) / 255.0
306 | pcd = BasicPointCloud(
307 | points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
308 | )
309 |
310 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
311 | try:
312 | pcd = fetchPly(ply_path)
313 | except:
314 | pcd = None
315 |
316 | scene_info = SceneInfo(
317 | point_cloud=pcd,
318 | train_cameras=train_cam_infos,
319 | test_cameras=test_cam_infos,
320 | nerf_normalization=nerf_normalization,
321 | ply_path=ply_path,
322 | )
323 | return scene_info
324 |
325 |
326 | sceneLoadTypeCallbacks = {
327 | "Colmap": readColmapSceneInfo,
328 | "Blender": readNerfSyntheticInfo,
329 | }
330 |
--------------------------------------------------------------------------------
/scripts/run_distill_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to get the id of an available GPU
4 | get_available_gpu() {
5 | local mem_threshold=500
6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # Initial port number
12 | port=6025
13 |
14 | # Datasets
15 | declare -a run_args=(
16 | "bicycle"
17 | "bonsai"
18 | "counter"
19 | "kitchen"
20 | "room"
21 | "stump"
22 | "garden"
23 | "train"
24 | "truck"
25 | )
26 |
27 |
28 | # activate psudo view, else using train view for distillation
29 | declare -a virtue_view_arg=(
30 | "--augmented_view"
31 | )
32 | # compress_gaussian/output5_prune_final_result/bicycle_v_important_score_oneshot_prune_densify0.67_vpow0.1_try3_decay1
33 | # compress_gaussian/output2
34 | for arg in "${run_args[@]}"; do
35 | for view in "${virtue_view_arg[@]}"; do
36 | # Wait for an available GPU
37 | while true; do
38 | gpu_id=$(get_available_gpu)
39 | if [[ -n $gpu_id ]]; then
40 | echo "GPU $gpu_id is available. Starting distill_train.py with dataset '$arg' and options '$view' on port $port"
41 | CUDA_VISIBLE_DEVICES=$gpu_id nohup python distill_train.py \
42 | -s "PATH/TO/DATASET/$arg" \
43 | -m "OUTPUT/PATH/${arg}_${prune_percent}" \
44 | --start_checkpoint "PATH/TO/CHECKPOINT/$arg/chkpnt30000.pth" \
45 | --iteration 40000 \
46 | --eval \
47 | --teacher_model "PATH/TO/TEACHER_CHECKPOINT/${arg}/chkpnt30000.pth" \
48 | --new_max_sh 2 \
49 | --position_lr_max_steps 40000 \
50 | --enable_covariance \
51 | $view \
52 | --port $port > "logs/distill_${arg}${view}.log" 2>&1 &
53 |
54 | # Increment the port number for the next run
55 | ((port++))
56 | # Allow some time for the process to initialize and potentially use GPU memory
57 | sleep 60
58 | break
59 | else
60 | echo "No GPU available at the moment. Retrying in 1 minute."
61 | sleep 60
62 | fi
63 | done
64 | done
65 | done
66 | wait
67 | echo "All distill_train.py runs completed."
68 |
--------------------------------------------------------------------------------
/scripts/run_prune_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to get the id of an available GPU
4 | get_available_gpu() {
5 | local mem_threshold=10000
6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # Initial port number
12 | port=6041
13 |
14 | # Only one dataset specified here, but you could run multiple
15 | declare -a run_args=(
16 | "bicycle"
17 | # "bonsai"
18 | # "counter"
19 | # "kitchen"
20 | # "room"
21 | # "stump"
22 | # "garden"
23 | # "train"
24 | # "truck"
25 | # "chair"
26 | # "drums"
27 | # "ficus"
28 | # "hotdog"
29 | # "lego"
30 | # "mic"
31 | # "materials"
32 | # "ship"
33 | )
34 |
35 |
36 | # Prune percentages and corresponding decays, volume power
37 | declare -a prune_percents=(0.66)
38 | # decay rate for the following prune. The 2nd prune would prune out 0.5 x 0.6 = 0.3 of the remaining gaussian
39 | declare -a prune_decays=(1)
40 | # The volumetric importance power. The higher it is the more weight the volume is in the Global significant
41 | declare -a v_pow=(0.1)
42 |
43 | # prune type, by default the Global significant listed in the paper, but there are other option that you can play with
44 | declare -a prune_types=(
45 | "v_important_score"
46 | # "important_score"
47 | # "count"
48 | )
49 |
50 |
51 | # Check that prune_percents, prune_decays, and v_pow arrays have the same length
52 | if [ "${#prune_percents[@]}" -ne "${#prune_decays[@]}" ] || [ "${#prune_percents[@]}" -ne "${#v_pow[@]}" ]; then
53 | echo "The lengths of prune_percents, prune_decays, and v_pow arrays do not match."
54 | exit 1
55 | fi
56 |
57 | # Loop over the arguments array
58 | for arg in "${run_args[@]}"; do
59 | for i in "${!prune_percents[@]}"; do
60 | prune_percent="${prune_percents[i]}"
61 | prune_decay="${prune_decays[i]}"
62 | vp="${v_pow[i]}"
63 |
64 | for prune_type in "${prune_types[@]}"; do
65 | # Wait for an available GPU
66 | while true; do
67 | gpu_id=$(get_available_gpu)
68 | if [[ -n $gpu_id ]]; then
69 | echo "GPU $gpu_id is available. Starting prune_finetune.py with dataset '$arg', prune_percent '$prune_percent', prune_type '$prune_type', prune_decay '$prune_decay', and v_pow '$vp' on port $port"
70 |
71 | CUDA_VISIBLE_DEVICES=$gpu_id nohup python prune_finetune.py \
72 | -s "PATH/TO/DATASET/$arg" \
73 | -m "OUTPUT/PATH/${arg}_${prune_percent}" \
74 | --eval \
75 | --port $port \
76 | --start_checkpoint "PATH/TO/CHECKPOINT/$arg/chkpnt30000.pth" \
77 | --iteration 35000 \
78 | --prune_percent $prune_percent \
79 | --prune_type $prune_type \
80 | --prune_decay $prune_decay \
81 | --position_lr_max_steps 35000 \
82 | --v_pow $vp > "logs_prune/${arg}${prune_percent}prunned.log" 2>&1 &
83 |
84 | # Increment the port number for the next run
85 | ((port++))
86 | # Allow some time for the process to initialize and potentially use GPU memory
87 | sleep 60
88 | break
89 | else
90 | echo "No GPU available at the moment. Retrying in 1 minute."
91 | sleep 60
92 | fi
93 | done
94 | done
95 | done
96 | done
97 | wait
98 | echo "All prune_finetune.py runs completed."
99 |
--------------------------------------------------------------------------------
/scripts/run_prune_pt_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to get the id of an available GPU
4 | get_available_gpu() {
5 | local mem_threshold=10000
6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk -v threshold="$mem_threshold" -F', ' '
7 | $2 < threshold { print $1; exit }
8 | '
9 | }
10 |
11 | # Initial port number
12 | port=6045
13 | # This is an example script to load from ply file.
14 | # Only one dataset specified here, but you could run multiple
15 | declare -a run_args=(
16 | "bicycle"
17 | # "bonsai"
18 | # "counter"
19 | # "kitchen"
20 | # "room"
21 | # "stump"
22 | # "garden"
23 | # "train"
24 | # "truck"
25 | # "chair"
26 | # "drums"
27 | # "ficus"
28 | # "hotdog"
29 | # "lego"
30 | # "mic"
31 | # "materials"
32 | # "ship"
33 | )
34 |
35 |
36 | # Prune percentages and corresponding decays, volume power
37 | declare -a prune_percents=(0.66)
38 | # decay rate for the following prune
39 | declare -a prune_decays=(1)
40 | # The volumetric importance power. The higher it is the more weight the volume is in the Global significant
41 | declare -a v_pow=(0.1)
42 |
43 | # prune type, by default the Global significant listed in the paper, but there are other option that you can play with
44 | declare -a prune_types=(
45 | "v_important_score"
46 | # "important_score"
47 | # "count"
48 | )
49 |
50 |
51 | # Check that prune_percents, prune_decays, and v_pow arrays have the same length
52 | if [ "${#prune_percents[@]}" -ne "${#prune_decays[@]}" ] || [ "${#prune_percents[@]}" -ne "${#v_pow[@]}" ]; then
53 | echo "The lengths of prune_percents, prune_decays, and v_pow arrays do not match."
54 | exit 1
55 | fi
56 | # /ssd1/zhiwen/projects/compress_gaussian/output2/bicycle/point_cloud/iteration_30000/point_cloud.ply
57 | # Loop over the arguments array
58 | for arg in "${run_args[@]}"; do
59 | for i in "${!prune_percents[@]}"; do
60 | prune_percent="${prune_percents[i]}"
61 | prune_decay="${prune_decays[i]}"
62 | vp="${v_pow[i]}"
63 |
64 | for prune_type in "${prune_types[@]}"; do
65 | # Wait for an available GPU
66 | while true; do
67 | gpu_id=$(get_available_gpu)
68 | if [[ -n $gpu_id ]]; then
69 | echo "GPU $gpu_id is available. Starting prune_finetune.py with dataset '$arg', prune_percent '$prune_percent', prune_type '$prune_type', prune_decay '$prune_decay', and v_pow '$vp' on port $port"
70 |
71 | CUDA_VISIBLE_DEVICES=$gpu_id python prune_finetune.py \
72 | -s "PATH/TO/DATASET/$arg" \
73 | -m "OUTPUT/PATH/${arg}_${prune_percent}" \
74 | --eval \
75 | --port $port \
76 | --start_pointcloud "PATH/TO/CHECKPOINT/$arg/point_cloud/iteration_30000/point_cloud.ply" \
77 | --iteration 5000 \
78 | --test_iterations 5000 \
79 | --save_iterations 5000 \
80 | --prune_iterations 2 \
81 | --prune_percent $prune_percent \
82 | --prune_type $prune_type \
83 | --prune_decay $prune_decay \
84 | --position_lr_init 0.000005 \
85 | --position_lr_max_steps 5000 \
86 | --v_pow $vp > "logs_prune/${arg}${prune_percent}_ply_prune2.log" 2>&1 &
87 |
88 | # Increment the port number for the next run
89 | ((port++))
90 | # Allow some time for the process to initialize and potentially use GPU memory
91 | sleep 60
92 | break
93 | else
94 | echo "No GPU available at the moment. Retrying in 1 minute."
95 | sleep 60
96 | fi
97 | done
98 | done
99 | done
100 | done
101 | wait
102 | echo "All prune_finetune.py runs completed."
103 |
--------------------------------------------------------------------------------
/scripts/run_train_densify_prune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to get the id of an available GPU
4 | get_available_gpu() {
5 | local mem_threshold=5000
6 | nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | \
7 | awk -v threshold="$mem_threshold" -F', ' '
8 | $2 < threshold { print $1; exit }
9 | '
10 | }
11 |
12 | port=6035
13 |
14 | # Only one dataset specified here
15 | declare -a run_args=(
16 | "bicycle"
17 | # "bonsai"
18 | # "counter"
19 | # "kitchen"
20 | # "room"
21 | # "stump"
22 | # "garden"
23 | # "train"
24 | # "truck"
25 | )
26 |
27 | # prune percentage for the first prune
28 | declare -a prune_percents=(0.6)
29 |
30 | # decay rate for the following prune
31 | declare -a prune_decays=(0.6)
32 |
33 | # The volumetric importance power
34 | declare -a v_pow=(0.1)
35 |
36 | # Prune types
37 | declare -a prune_types=(
38 | "v_important_score"
39 | )
40 |
41 | # Check that prune_percents and prune_decays arrays have the same length
42 | if [ "${#prune_percents[@]}" -ne "${#prune_decays[@]}" ]; then
43 | echo "The number of prune_percents does not match the number of prune_decays."
44 | exit 1
45 | fi
46 |
47 | # Loop over datasets
48 | for arg in "${run_args[@]}"; do
49 | # Loop over each index in prune_percents/decays/v_pow
50 | for i in "${!prune_percents[@]}"; do
51 | prune_percent="${prune_percents[i]}"
52 | prune_decay="${prune_decays[i]}"
53 | vp="${v_pow[i]}"
54 |
55 | # Loop over each prune type
56 | for prune_type in "${prune_types[@]}"; do
57 |
58 | # Wait for an available GPU
59 | while true; do
60 | gpu_id=$(get_available_gpu)
61 | if [[ -n $gpu_id ]]; then
62 | echo "GPU $gpu_id is available. Starting train_densify_prune.py with dataset '$arg', prune_percent '$prune_percent', prune_type '$prune_type', prune_decay '$prune_decay', and v_pow '$vp' on port $port"
63 |
64 | CUDA_VISIBLE_DEVICES=$gpu_id nohup python train_densify_prune.py \
65 | -s "PATH/TO/DATASET/$arg" \
66 | -m "OUTPUT/PATH/${arg}" \
67 | --prune_percent "$prune_percent" \
68 | --prune_decay "$prune_decay" \
69 | --prune_iterations 20000 \
70 | --v_pow "$vp" \
71 | --eval \
72 | --port "$port" \
73 | > "logs/train_${arg}.log" 2>&1 &
74 |
75 | # you need to create the log folder first if it doesn't exist
76 | ((port++))
77 |
78 | # Give the process time to start using GPU memory
79 | sleep 60
80 | break
81 | else
82 | echo "No GPU available at the moment. Retrying in 1 minute."
83 | sleep 60
84 | fi
85 | done
86 |
87 | done # end for prune_type
88 | done # end for i
89 | done # end for arg
90 |
91 | # Wait for all background processes to finish
92 | wait
93 | echo "All train_densify_prune.py runs completed."
94 |
--------------------------------------------------------------------------------
/scripts/run_vectree_quantize.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # SCENES=(bicycle bonsai counter garden kitchen room stump train truck)
4 | SCENES=(room)
5 | VQ_RATIO=0.6
6 | CODEBOOK_SIZE=8192
7 |
8 | for SCENE in "${SCENES[@]}" # Add more scenes as needed
9 | do
10 | IMP_PATH=./vectree/pruned_distilled/${SCENE}
11 | INPUT_PLY_PATH=./vectree/pruned_distilled/${SCENE}/iteration_40000/point_cloud.ply
12 | SAVE_PATH=./vectree/output/${SCENE}
13 |
14 | CMD="CUDA_VISIBLE_DEVICES=0 python vectree/vectree.py \
15 | --important_score_npz_path ${IMP_PATH} \
16 | --input_path ${INPUT_PLY_PATH} \
17 | --save_path ${SAVE_PATH} \
18 | --vq_ratio ${VQ_RATIO} \
19 | --codebook_size ${CODEBOOK_SIZE} \
20 | "
21 | eval $CMD
22 | done
--------------------------------------------------------------------------------
/static/table5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VITA-Group/LightGaussian/6676b983e77baadd909effc56a6aaadafa964dcc/static/table5.png
--------------------------------------------------------------------------------
/submodules/simple-knn/ext.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include
13 | #include "spatial.h"
14 |
15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16 | m.def("distCUDA2", &distCUDA2);
17 | }
18 |
--------------------------------------------------------------------------------
/submodules/simple-knn/setup.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from setuptools import setup
13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14 | import os
15 |
16 | cxx_compiler_flags = []
17 |
18 | if os.name == "nt":
19 | cxx_compiler_flags.append("/wd4624")
20 |
21 | setup(
22 | name="simple_knn",
23 | ext_modules=[
24 | CUDAExtension(
25 | name="simple_knn._C",
26 | sources=["spatial.cu", "simple_knn.cu", "ext.cpp"],
27 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags},
28 | )
29 | ],
30 | cmdclass={"build_ext": BuildExtension},
31 | )
32 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #define BOX_SIZE 1024
13 |
14 | #include "cuda_runtime.h"
15 | #include "device_launch_parameters.h"
16 | #include "simple_knn.h"
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #define __CUDACC__
24 | #include
25 | #include
26 |
27 | namespace cg = cooperative_groups;
28 |
29 | struct CustomMin
30 | {
31 | __device__ __forceinline__
32 | float3 operator()(const float3& a, const float3& b) const {
33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) };
34 | }
35 | };
36 |
37 | struct CustomMax
38 | {
39 | __device__ __forceinline__
40 | float3 operator()(const float3& a, const float3& b) const {
41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) };
42 | }
43 | };
44 |
45 | __host__ __device__ uint32_t prepMorton(uint32_t x)
46 | {
47 | x = (x | (x << 16)) & 0x030000FF;
48 | x = (x | (x << 8)) & 0x0300F00F;
49 | x = (x | (x << 4)) & 0x030C30C3;
50 | x = (x | (x << 2)) & 0x09249249;
51 | return x;
52 | }
53 |
54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx)
55 | {
56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1));
57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1));
58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1));
59 |
60 | return x | (y << 1) | (z << 2);
61 | }
62 |
63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes)
64 | {
65 | auto idx = cg::this_grid().thread_rank();
66 | if (idx >= P)
67 | return;
68 |
69 | codes[idx] = coord2Morton(points[idx], minn, maxx);
70 | }
71 |
72 | struct MinMax
73 | {
74 | float3 minn;
75 | float3 maxx;
76 | };
77 |
78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes)
79 | {
80 | auto idx = cg::this_grid().thread_rank();
81 |
82 | MinMax me;
83 | if (idx < P)
84 | {
85 | me.minn = points[indices[idx]];
86 | me.maxx = points[indices[idx]];
87 | }
88 | else
89 | {
90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX };
91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX };
92 | }
93 |
94 | __shared__ MinMax redResult[BOX_SIZE];
95 |
96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2)
97 | {
98 | if (threadIdx.x < 2 * off)
99 | redResult[threadIdx.x] = me;
100 | __syncthreads();
101 |
102 | if (threadIdx.x < off)
103 | {
104 | MinMax other = redResult[threadIdx.x + off];
105 | me.minn.x = min(me.minn.x, other.minn.x);
106 | me.minn.y = min(me.minn.y, other.minn.y);
107 | me.minn.z = min(me.minn.z, other.minn.z);
108 | me.maxx.x = max(me.maxx.x, other.maxx.x);
109 | me.maxx.y = max(me.maxx.y, other.maxx.y);
110 | me.maxx.z = max(me.maxx.z, other.maxx.z);
111 | }
112 | __syncthreads();
113 | }
114 |
115 | if (threadIdx.x == 0)
116 | boxes[blockIdx.x] = me;
117 | }
118 |
119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p)
120 | {
121 | float3 diff = { 0, 0, 0 };
122 | if (p.x < box.minn.x || p.x > box.maxx.x)
123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x));
124 | if (p.y < box.minn.y || p.y > box.maxx.y)
125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y));
126 | if (p.z < box.minn.z || p.z > box.maxx.z)
127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z));
128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z;
129 | }
130 |
131 | template
132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn)
133 | {
134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z };
135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z;
136 | for (int j = 0; j < K; j++)
137 | {
138 | if (knn[j] > dist)
139 | {
140 | float t = knn[j];
141 | knn[j] = dist;
142 | dist = t;
143 | }
144 | }
145 | }
146 |
147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists)
148 | {
149 | int idx = cg::this_grid().thread_rank();
150 | if (idx >= P)
151 | return;
152 |
153 | float3 point = points[indices[idx]];
154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX };
155 |
156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++)
157 | {
158 | if (i == idx)
159 | continue;
160 | updateKBest<3>(point, points[indices[i]], best);
161 | }
162 |
163 | float reject = best[2];
164 | best[0] = FLT_MAX;
165 | best[1] = FLT_MAX;
166 | best[2] = FLT_MAX;
167 |
168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++)
169 | {
170 | MinMax box = boxes[b];
171 | float dist = distBoxPoint(box, point);
172 | if (dist > reject || dist > best[2])
173 | continue;
174 |
175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++)
176 | {
177 | if (i == idx)
178 | continue;
179 | updateKBest<3>(point, points[indices[i]], best);
180 | }
181 | }
182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f;
183 | }
184 |
185 | void SimpleKNN::knn(int P, float3* points, float* meanDists)
186 | {
187 | float3* result;
188 | cudaMalloc(&result, sizeof(float3));
189 | size_t temp_storage_bytes;
190 |
191 | float3 init = { 0, 0, 0 }, minn, maxx;
192 |
193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init);
194 | thrust::device_vector temp_storage(temp_storage_bytes);
195 |
196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init);
197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost);
198 |
199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init);
200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost);
201 |
202 | thrust::device_vector morton(P);
203 | thrust::device_vector morton_sorted(P);
204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get());
205 |
206 | thrust::device_vector indices(P);
207 | thrust::sequence(indices.begin(), indices.end());
208 | thrust::device_vector indices_sorted(P);
209 |
210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P);
211 | temp_storage.resize(temp_storage_bytes);
212 |
213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P);
214 |
215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE;
216 | thrust::device_vector boxes(num_boxes);
217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get());
218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists);
219 |
220 | cudaFree(result);
221 | }
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: simple-knn
3 | Version: 0.0.0
4 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | ext.cpp
2 | setup.py
3 | simple_knn.cu
4 | spatial.cu
5 | simple_knn.egg-info/PKG-INFO
6 | simple_knn.egg-info/SOURCES.txt
7 | simple_knn.egg-info/dependency_links.txt
8 | simple_knn.egg-info/top_level.txt
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | simple_knn
2 |
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #ifndef SIMPLEKNN_H_INCLUDED
13 | #define SIMPLEKNN_H_INCLUDED
14 |
15 | class SimpleKNN
16 | {
17 | public:
18 | static void knn(int P, float3* points, float* meanDists);
19 | };
20 |
21 | #endif
--------------------------------------------------------------------------------
/submodules/simple-knn/simple_knn/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VITA-Group/LightGaussian/6676b983e77baadd909effc56a6aaadafa964dcc/submodules/simple-knn/simple_knn/.gitkeep
--------------------------------------------------------------------------------
/submodules/simple-knn/spatial.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include "spatial.h"
13 | #include "simple_knn.h"
14 |
15 | torch::Tensor
16 | distCUDA2(const torch::Tensor& points)
17 | {
18 | const int P = points.size(0);
19 |
20 | auto float_opts = points.options().dtype(torch::kFloat32);
21 | torch::Tensor means = torch::full({P}, 0.0, float_opts);
22 |
23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data());
24 |
25 | return means;
26 | }
--------------------------------------------------------------------------------
/submodules/simple-knn/spatial.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2023, Inria
3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | * All rights reserved.
5 | *
6 | * This software is free for non-commercial, research and evaluation use
7 | * under the terms of the LICENSE.md file.
8 | *
9 | * For inquiries contact george.drettakis@inria.fr
10 | */
11 |
12 | #include
13 |
14 | torch::Tensor distCUDA2(const torch::Tensor& points);
--------------------------------------------------------------------------------
/train_densify_prune.py:
--------------------------------------------------------------------------------
1 | #
2 | # This software is free for non-commercial, research and evaluation use
3 | # under the terms of the LICENSE.md file.
4 | #
5 | # For inquiries contact george.drettakis@inria.fr
6 | #
7 | import os
8 | import torch
9 | from random import randint
10 | from utils.loss_utils import l1_loss, ssim
11 | from gaussian_renderer import render, network_gui
12 | import sys
13 | from lpipsPyTorch import lpips
14 |
15 | from scene import Scene, GaussianModel
16 | from utils.general_utils import safe_state
17 | from utils.logger_utils import training_report, prepare_output_and_logger
18 |
19 | import uuid
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser, Namespace
23 | from arguments import ModelParams, PipelineParams, OptimizationParams
24 |
25 | # from prune_train import prepare_output_and_logger, training_report
26 | from icecream import ic
27 | from os import makedirs
28 | from prune import prune_list, calculate_v_imp_score
29 | import torchvision
30 | from torch.optim.lr_scheduler import ExponentialLR
31 | import csv
32 | import numpy as np
33 |
34 |
35 | try:
36 | from torch.utils.tensorboard import SummaryWriter
37 |
38 | TENSORBOARD_FOUND = True
39 | except ImportError:
40 | TENSORBOARD_FOUND = False
41 |
42 |
43 | def training(
44 | dataset,
45 | opt,
46 | pipe,
47 | testing_iterations,
48 | saving_iterations,
49 | checkpoint_iterations,
50 | checkpoint,
51 | debug_from,
52 | args,
53 | ):
54 | first_iter = 0
55 | tb_writer = prepare_output_and_logger(dataset)
56 | gaussians = GaussianModel(dataset.sh_degree)
57 | scene = Scene(dataset, gaussians)
58 | gaussians.training_setup(opt)
59 | if checkpoint:
60 | (model_params, first_iter) = torch.load(checkpoint)
61 | gaussians.restore(model_params, opt)
62 |
63 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
64 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
65 |
66 | iter_start = torch.cuda.Event(enable_timing=True)
67 | iter_end = torch.cuda.Event(enable_timing=True)
68 |
69 | viewpoint_stack = None
70 | ema_loss_for_log = 0.0
71 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
72 | first_iter += 1
73 | gaussians.scheduler = ExponentialLR(gaussians.optimizer, gamma=0.97)
74 | for iteration in range(first_iter, opt.iterations + 1):
75 | if network_gui.conn == None:
76 | network_gui.try_connect()
77 | while network_gui.conn != None:
78 | try:
79 | net_image_bytes = None
80 | (
81 | custom_cam,
82 | do_training,
83 | pipe.convert_SHs_python,
84 | pipe.compute_cov3D_python,
85 | keep_alive,
86 | scaling_modifer,
87 | ) = network_gui.receive()
88 | if custom_cam != None:
89 | net_image = render(
90 | custom_cam, gaussians, pipe, background, scaling_modifer
91 | )["render"]
92 | net_image_bytes = memoryview(
93 | (torch.clamp(net_image, min=0, max=1.0) * 255)
94 | .byte()
95 | .permute(1, 2, 0)
96 | .contiguous()
97 | .cpu()
98 | .numpy()
99 | )
100 | network_gui.send(net_image_bytes, dataset.source_path)
101 | if do_training and (
102 | (iteration < int(opt.iterations)) or not keep_alive
103 | ):
104 | break
105 | except Exception as e:
106 | network_gui.conn = None
107 |
108 | iter_start.record()
109 |
110 | gaussians.update_learning_rate(iteration)
111 |
112 | # Every 1000 its we increase the levels of SH up to a maximum degree
113 | if iteration % 1000 == 0:
114 | gaussians.oneupSHdegree()
115 | gaussians.scheduler.step()
116 |
117 | # Pick a random Camera
118 | if not viewpoint_stack:
119 | viewpoint_stack = scene.getTrainCameras().copy()
120 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
121 |
122 | # Render
123 | if (iteration - 1) == debug_from:
124 | pipe.debug = True
125 | render_pkg = render(viewpoint_cam, gaussians, pipe, background)
126 | image, viewspace_point_tensor, visibility_filter, radii = (
127 | render_pkg["render"],
128 | render_pkg["viewspace_points"],
129 | render_pkg["visibility_filter"],
130 | render_pkg["radii"],
131 | )
132 |
133 | # Loss
134 | gt_image = viewpoint_cam.original_image.cuda()
135 | Ll1 = l1_loss(image, gt_image)
136 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
137 | 1.0 - ssim(image, gt_image)
138 | )
139 | loss.backward()
140 |
141 | iter_end.record()
142 |
143 | with torch.no_grad():
144 | # Progress bar
145 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
146 | if iteration % 10 == 0:
147 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
148 | progress_bar.update(10)
149 | if iteration == opt.iterations:
150 | progress_bar.close()
151 |
152 | # Log and save
153 | if iteration in saving_iterations:
154 | print("\n[ITER {}] Saving Gaussians".format(iteration))
155 | scene.save(iteration)
156 | training_report(
157 | tb_writer,
158 | iteration,
159 | Ll1,
160 | loss,
161 | l1_loss,
162 | iter_start.elapsed_time(iter_end),
163 | testing_iterations,
164 | scene,
165 | render,
166 | (pipe, background),
167 | )
168 |
169 | # Densification
170 | if iteration < opt.densify_until_iter:
171 | # Keep track of max radii in image-space for pruning
172 | gaussians.max_radii2D[visibility_filter] = torch.max(
173 | gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
174 | )
175 | gaussians.add_densification_stats(
176 | viewspace_point_tensor, visibility_filter
177 | )
178 |
179 | if (
180 | iteration > opt.densify_from_iter
181 | and iteration % opt.densification_interval == 0
182 | ):
183 | size_threshold = (
184 | 20 if iteration > opt.opacity_reset_interval else None
185 | )
186 | gaussians.densify_and_prune(
187 | opt.densify_grad_threshold,
188 | 0.005,
189 | scene.cameras_extent,
190 | size_threshold,
191 | )
192 |
193 | if iteration % opt.opacity_reset_interval == 0 or (
194 | dataset.white_background and iteration == opt.densify_from_iter
195 | ):
196 | gaussians.reset_opacity()
197 |
198 | if iteration in args.prune_iterations:
199 | # TODO Add prunning types
200 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background)
201 | i = args.prune_iterations.index(iteration)
202 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow)
203 | gaussians.prune_gaussians(
204 | (args.prune_decay**i) * args.prune_percent, v_list
205 | )
206 |
207 |
208 |
209 | # Optimizer step
210 | if iteration < opt.iterations:
211 | gaussians.optimizer.step()
212 | gaussians.optimizer.zero_grad(set_to_none=True)
213 |
214 | if iteration in checkpoint_iterations:
215 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
216 | if not os.path.exists(scene.model_path):
217 | os.makedirs(scene.model_path)
218 | torch.save(
219 | (gaussians.capture(), iteration),
220 | scene.model_path + "/chkpnt" + str(iteration) + ".pth",
221 | )
222 | if iteration == checkpoint_iterations[-1]:
223 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background)
224 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow)
225 | np.savez(os.path.join(scene.model_path,"imp_score"), v_list.cpu().detach().numpy())
226 |
227 |
228 | if __name__ == "__main__":
229 | # Set up command line argument parser
230 | parser = ArgumentParser(description="Training script parameters")
231 | lp = ModelParams(parser)
232 | op = OptimizationParams(parser)
233 | pp = PipelineParams(parser)
234 | parser.add_argument("--ip", type=str, default="127.0.0.1")
235 | parser.add_argument("--port", type=int, default=6009)
236 | parser.add_argument("--debug_from", type=int, default=-1)
237 | parser.add_argument("--detect_anomaly", action="store_true", default=False)
238 | parser.add_argument(
239 | "--test_iterations",
240 | nargs="+",
241 | type=int,
242 | default=[7_000, 30_000],
243 | )
244 | parser.add_argument(
245 | "--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
246 | )
247 | parser.add_argument("--quiet", action="store_true")
248 | parser.add_argument(
249 | "--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000]
250 | )
251 | parser.add_argument("--start_checkpoint", type=str, default=None)
252 |
253 | parser.add_argument(
254 | "--prune_iterations", nargs="+", type=int, default=[16_000, 24_000]
255 | )
256 | parser.add_argument("--prune_percent", type=float, default=0.5)
257 | parser.add_argument("--v_pow", type=float, default=0.1)
258 | parser.add_argument("--prune_decay", type=float, default=0.8)
259 | args = parser.parse_args(sys.argv[1:])
260 | args.save_iterations.append(args.iterations)
261 |
262 | print("Optimizing " + args.model_path)
263 | # Initialize system state (RNG)
264 | safe_state(args.quiet)
265 | # Start GUI server, configure and run training
266 | network_gui.init(args.ip, args.port)
267 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
268 | training(
269 | lp.extract(args),
270 | op.extract(args),
271 | pp.extract(args),
272 | args.test_iterations,
273 | args.save_iterations,
274 | args.checkpoint_iterations,
275 | args.start_checkpoint,
276 | args.debug_from,
277 | args,
278 | )
279 |
280 | # All done
281 | print("\nTraining complete.")
282 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from scene.cameras import Camera
13 | import numpy as np
14 | from utils.general_utils import PILtoTorch
15 | from utils.graphics_utils import fov2focal
16 |
17 | WARNED = False
18 |
19 |
20 | def loadCam(args, id, cam_info, resolution_scale):
21 | orig_w, orig_h = cam_info.image.size
22 |
23 | if args.resolution in [1, 2, 4, 8]:
24 | resolution = round(orig_w / (resolution_scale * args.resolution)), round(
25 | orig_h / (resolution_scale * args.resolution)
26 | )
27 | else: # should be a type that converts to float
28 | if args.resolution == -1:
29 | if orig_w > 1600:
30 | global WARNED
31 | if not WARNED:
32 | print(
33 | "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
34 | "If this is not desired, please explicitly specify '--resolution/-r' as 1"
35 | )
36 | WARNED = True
37 | global_down = orig_w / 1600
38 | else:
39 | global_down = 1
40 | else:
41 | global_down = orig_w / args.resolution
42 |
43 | scale = float(global_down) * float(resolution_scale)
44 | resolution = (int(orig_w / scale), int(orig_h / scale))
45 |
46 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
47 |
48 | gt_image = resized_image_rgb[:3, ...]
49 | loaded_mask = None
50 |
51 | if resized_image_rgb.shape[1] == 4:
52 | loaded_mask = resized_image_rgb[3:4, ...]
53 |
54 | return Camera(
55 | colmap_id=cam_info.uid,
56 | R=cam_info.R,
57 | T=cam_info.T,
58 | FoVx=cam_info.FovX,
59 | FoVy=cam_info.FovY,
60 | image=gt_image,
61 | gt_alpha_mask=loaded_mask,
62 | image_name=cam_info.image_name,
63 | uid=id,
64 | data_device=args.data_device,
65 | )
66 |
67 |
68 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
69 | camera_list = []
70 |
71 | for id, c in enumerate(cam_infos):
72 | camera_list.append(loadCam(args, id, c, resolution_scale))
73 |
74 | return camera_list
75 |
76 |
77 | def camera_to_JSON(id, camera: Camera):
78 | Rt = np.zeros((4, 4))
79 | Rt[:3, :3] = camera.R.transpose()
80 | Rt[:3, 3] = camera.T
81 | Rt[3, 3] = 1.0
82 |
83 | W2C = np.linalg.inv(Rt)
84 | pos = W2C[:3, 3]
85 | rot = W2C[:3, :3]
86 | serializable_array_2d = [x.tolist() for x in rot]
87 | camera_entry = {
88 | "id": id,
89 | "img_name": camera.image_name,
90 | "width": camera.width,
91 | "height": camera.height,
92 | "position": pos.tolist(),
93 | "rotation": serializable_array_2d,
94 | "fy": fov2focal(camera.FovY, camera.height),
95 | "fx": fov2focal(camera.FovX, camera.width),
96 | }
97 | return camera_entry
98 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 |
18 |
19 | def inverse_sigmoid(x):
20 | return torch.log(x / (1 - x))
21 |
22 |
23 | def PILtoTorch(pil_image, resolution):
24 | resized_image_PIL = pil_image.resize(resolution)
25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
26 | if len(resized_image.shape) == 3:
27 | return resized_image.permute(2, 0, 1)
28 | else:
29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
30 |
31 |
32 | def get_expon_lr_func(
33 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
34 | ):
35 | """
36 | Copied from Plenoxels
37 |
38 | Continuous learning rate decay function. Adapted from JaxNeRF
39 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
40 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
41 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
42 | function of lr_delay_mult, such that the initial learning rate is
43 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
44 | to the normal learning rate when steps>lr_delay_steps.
45 | :param conf: config subtree 'lr' or similar
46 | :param max_steps: int, the number of steps during optimization.
47 | :return HoF which takes step as input
48 | """
49 |
50 | def helper(step):
51 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
52 | # Disable this parameter
53 | return 0.0
54 | if lr_delay_steps > 0:
55 | # A kind of reverse cosine decay.
56 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
57 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
58 | )
59 | else:
60 | delay_rate = 1.0
61 | t = np.clip(step / max_steps, 0, 1)
62 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
63 | return delay_rate * log_lerp
64 |
65 | return helper
66 |
67 |
68 | def strip_lowerdiag(L):
69 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
70 |
71 | uncertainty[:, 0] = L[:, 0, 0]
72 | uncertainty[:, 1] = L[:, 0, 1]
73 | uncertainty[:, 2] = L[:, 0, 2]
74 | uncertainty[:, 3] = L[:, 1, 1]
75 | uncertainty[:, 4] = L[:, 1, 2]
76 | uncertainty[:, 5] = L[:, 2, 2]
77 | return uncertainty
78 |
79 |
80 | def strip_symmetric(sym):
81 | return strip_lowerdiag(sym)
82 |
83 |
84 | def build_rotation(r):
85 | norm = torch.sqrt(
86 | r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
87 | )
88 |
89 | q = r / norm[:, None]
90 |
91 | R = torch.zeros((q.size(0), 3, 3), device="cuda")
92 |
93 | r = q[:, 0]
94 | x = q[:, 1]
95 | y = q[:, 2]
96 | z = q[:, 3]
97 |
98 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
99 | R[:, 0, 1] = 2 * (x * y - r * z)
100 | R[:, 0, 2] = 2 * (x * z + r * y)
101 | R[:, 1, 0] = 2 * (x * y + r * z)
102 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
103 | R[:, 1, 2] = 2 * (y * z - r * x)
104 | R[:, 2, 0] = 2 * (x * z - r * y)
105 | R[:, 2, 1] = 2 * (y * z + r * x)
106 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
107 | return R
108 |
109 |
110 | def build_scaling_rotation(s, r):
111 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
112 | R = build_rotation(r)
113 |
114 | L[:, 0, 0] = s[:, 0]
115 | L[:, 1, 1] = s[:, 1]
116 | L[:, 2, 2] = s[:, 2]
117 |
118 | L = R @ L
119 | return L
120 |
121 |
122 | def safe_state(silent):
123 | old_f = sys.stdout
124 |
125 | class F:
126 | def __init__(self, silent):
127 | self.silent = silent
128 |
129 | def write(self, x):
130 | if not self.silent:
131 | if x.endswith("\n"):
132 | old_f.write(
133 | x.replace(
134 | "\n",
135 | " [{}]\n".format(
136 | str(datetime.now().strftime("%d/%m %H:%M:%S"))
137 | ),
138 | )
139 | )
140 | else:
141 | old_f.write(x)
142 |
143 | def flush(self):
144 | old_f.flush()
145 |
146 | sys.stdout = F(silent)
147 |
148 | random.seed(0)
149 | np.random.seed(0)
150 | torch.manual_seed(0)
151 | torch.cuda.set_device(torch.device("cuda:0"))
152 |
153 |
154 | class CircularTensor:
155 | def __init__(self, max_size):
156 | self.buffer = torch.empty(max_size)
157 | self.max_size = max_size
158 | self.current_pos = 0
159 | self.current_size = 0 # Tracks the number of elements added
160 |
161 | def add(self, element):
162 | self.buffer[self.current_pos] = element
163 | self.current_pos = (self.current_pos + 1) % self.max_size
164 | if self.current_size < self.max_size:
165 | self.current_size += 1
166 |
167 | def get(self, index):
168 | if index >= self.current_size:
169 | raise IndexError("Index out of bounds")
170 | return self.buffer[index]
171 |
172 | def size(self):
173 | return self.current_size
174 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 |
18 | class BasicPointCloud(NamedTuple):
19 | points: np.array
20 | colors: np.array
21 | normals: np.array
22 |
23 |
24 | def geom_transform_points(points, transf_matrix):
25 | P, _ = points.shape
26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
27 | points_hom = torch.cat([points, ones], dim=1)
28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
29 |
30 | denom = points_out[..., 3:] + 0.0000001
31 | return (points_out[..., :3] / denom).squeeze(dim=0)
32 |
33 |
34 | def getWorld2View(R, t):
35 | Rt = np.zeros((4, 4))
36 | Rt[:3, :3] = R.transpose()
37 | Rt[:3, 3] = t
38 | Rt[3, 3] = 1.0
39 | return np.float32(Rt)
40 |
41 |
42 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
43 | Rt = np.zeros((4, 4))
44 | Rt[:3, :3] = R.transpose()
45 | Rt[:3, 3] = t
46 | Rt[3, 3] = 1.0
47 |
48 | C2W = np.linalg.inv(Rt)
49 | cam_center = C2W[:3, 3]
50 | cam_center = (cam_center + translate) * scale
51 | C2W[:3, 3] = cam_center
52 | Rt = np.linalg.inv(C2W)
53 | return np.float32(Rt)
54 |
55 |
56 | def getProjectionMatrix(znear, zfar, fovX, fovY):
57 | tanHalfFovY = math.tan((fovY / 2))
58 | tanHalfFovX = math.tan((fovX / 2))
59 |
60 | top = tanHalfFovY * znear
61 | bottom = -top
62 | right = tanHalfFovX * znear
63 | left = -right
64 |
65 | P = torch.zeros(4, 4)
66 |
67 | z_sign = 1.0
68 |
69 | P[0, 0] = 2.0 * znear / (right - left)
70 | P[1, 1] = 2.0 * znear / (top - bottom)
71 | P[0, 2] = (right + left) / (right - left)
72 | P[1, 2] = (top + bottom) / (top - bottom)
73 | P[3, 2] = z_sign
74 | P[2, 2] = z_sign * zfar / (zfar - znear)
75 | P[2, 3] = -(zfar * znear) / (zfar - znear)
76 | return P
77 |
78 |
79 | def fov2focal(fov, pixels):
80 | return pixels / (2 * math.tan(fov / 2))
81 |
82 |
83 | def focal2fov(focal, pixels):
84 | return 2 * math.atan(pixels / (2 * focal))
85 |
--------------------------------------------------------------------------------
/utils/image.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import math, random, time
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import imageio
10 | from pdb import set_trace as st
11 |
12 |
13 | mse2psnr = (
14 | lambda x: -10.0 * torch.log(x) / torch.log(torch.tensor([10.0], device=x.device))
15 | )
16 |
17 |
18 | def img2mse(x, y, mask=None):
19 | if mask is None:
20 | return torch.mean((x - y) ** 2)
21 | else:
22 | return torch.sum((x * mask - y * mask) ** 2) / (torch.sum(mask) + 1e-5)
23 |
24 |
25 | def img2mae(x, y, mask=None):
26 | if mask is None:
27 | return torch.mean(torch.abs(x - y))
28 | else:
29 | return torch.sum(torch.abs(x * mask - y * mask)) / (torch.sum(mask) + 1e-5)
30 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 |
15 | def mse(img1, img2):
16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
17 |
18 |
19 | def psnr(img1, img2):
20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
21 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
22 |
--------------------------------------------------------------------------------
/utils/logger_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from random import randint
4 | from utils.loss_utils import l1_loss, ssim
5 | from lpipsPyTorch import lpips
6 | from scene import Scene, GaussianModel
7 | from utils.general_utils import safe_state
8 | import uuid
9 | from utils.image_utils import psnr
10 | from argparse import Namespace
11 | from icecream import ic
12 | import csv
13 |
14 | try:
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | TENSORBOARD_FOUND = True
18 | except ImportError:
19 | TENSORBOARD_FOUND = False
20 |
21 |
22 | def prepare_output_and_logger(args):
23 | if not args.model_path:
24 | if os.getenv("OAR_JOB_ID"):
25 | unique_str = os.getenv("OAR_JOB_ID")
26 | else:
27 | unique_str = str(uuid.uuid4())
28 | args.model_path = os.path.join("./output/", unique_str[0:10])
29 |
30 | # Set up output folder
31 | print("Output folder: {}".format(args.model_path))
32 | os.makedirs(args.model_path, exist_ok=True)
33 | with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
34 | cfg_log_f.write(str(Namespace(**vars(args))))
35 |
36 | # Create Tensorboard writer
37 | tb_writer = None
38 | if TENSORBOARD_FOUND:
39 | tb_writer = SummaryWriter(args.model_path)
40 | else:
41 | print("Tensorboard not available: not logging progress")
42 | return tb_writer
43 |
44 |
45 | def training_report(
46 | tb_writer,
47 | iteration,
48 | Ll1,
49 | loss,
50 | l1_loss,
51 | elapsed,
52 | testing_iterations,
53 | scene: Scene,
54 | renderFunc,
55 | renderArgs,
56 | ):
57 | if tb_writer:
58 | tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration)
59 | tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration)
60 | tb_writer.add_scalar("iter_time", elapsed, iteration)
61 |
62 | # Report test and samples of training set
63 | if iteration in testing_iterations:
64 | ic("report")
65 | headers = [
66 | "iteration",
67 | "set",
68 | "l1_loss",
69 | "psnr",
70 | "ssim",
71 | "lpips",
72 | "file_size",
73 | "elapsed",
74 | ]
75 | csv_path = os.path.join(scene.model_path, "metric.csv")
76 | # Check if the CSV file exists, if not, create it and write the header
77 | file_exists = os.path.isfile(csv_path)
78 | save_path = os.path.join(
79 | scene.model_path,
80 | "point_cloud/iteration_" + str(iteration),
81 | "point_cloud.ply",
82 | )
83 | # Check if the file exists
84 | if os.path.exists(save_path):
85 | # Get the size of the file
86 | file_size = os.path.getsize(save_path)
87 | file_size_mb = file_size / 1024 / 1024 # Convert bytes to kilobytes
88 | else:
89 | file_size_mb = None
90 |
91 | with open(csv_path, "a", newline="") as csvfile:
92 | writer = csv.DictWriter(csvfile, fieldnames=headers)
93 | if not file_exists:
94 | writer.writeheader() # file doesn't exist yet, write a header
95 |
96 | torch.cuda.empty_cache()
97 | validation_configs = ({"name": "test", "cameras": scene.getTestCameras()},)
98 | # {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
99 |
100 | for config in validation_configs:
101 | if config["cameras"] and len(config["cameras"]) > 0:
102 | l1_test = 0.0
103 | psnr_test = 0.0
104 | ssim_test = 0.0
105 | lpips_test = 0.0
106 | for idx, viewpoint in enumerate(config["cameras"]):
107 | image = torch.clamp(
108 | renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"],
109 | 0.0,
110 | 1.0,
111 | )
112 | gt_image = torch.clamp(
113 | viewpoint.original_image.to("cuda"), 0.0, 1.0
114 | )
115 | if tb_writer and (idx < 5):
116 | tb_writer.add_images(
117 | config["name"]
118 | + "_view_{}/render".format(viewpoint.image_name),
119 | image[None],
120 | global_step=iteration,
121 | )
122 | if iteration == testing_iterations[0]:
123 | tb_writer.add_images(
124 | config["name"]
125 | + "_view_{}/ground_truth".format(viewpoint.image_name),
126 | gt_image[None],
127 | global_step=iteration,
128 | )
129 | l1_test += l1_loss(image, gt_image).mean().double()
130 | psnr_test += psnr(image, gt_image).mean().double()
131 | ssim_test += ssim(image, gt_image).mean().double()
132 | lpips_test += lpips(image, gt_image, net_type="vgg").mean().double()
133 |
134 | psnr_test /= len(config["cameras"])
135 | l1_test /= len(config["cameras"])
136 | ssim_test /= len(config["cameras"])
137 | lpips_test /= len(config["cameras"])
138 | # sys.stderr.write(f"Iteration {iteration} Evaluating {config['name']}: L1 {l1_test} PSNR {psnr_test} SSIM {ssim_test} LPIPS {lpips_test}\n")
139 | # sys.stderr.flush()
140 | print(
141 | "\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(
142 | iteration,
143 | config["name"],
144 | l1_test,
145 | psnr_test,
146 | ssim_test,
147 | lpips_test,
148 | )
149 | )
150 | if tb_writer:
151 | tb_writer.add_scalar(
152 | config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration
153 | )
154 | tb_writer.add_scalar(
155 | config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration
156 | )
157 | tb_writer.add_scalar(
158 | config["name"] + "/loss_viewpoint - ssim", ssim_test, iteration
159 | )
160 | tb_writer.add_scalar(
161 | config["name"] + "/loss_viewpoint - lpips",
162 | lpips_test,
163 | iteration,
164 | )
165 | if config["name"] == "test":
166 | with open(csv_path, "a", newline="") as csvfile:
167 | writer = csv.DictWriter(csvfile, fieldnames=headers)
168 | writer.writerow(
169 | {
170 | "iteration": iteration,
171 | "set": config["name"],
172 | "l1_loss": l1_test.item(),
173 | "psnr": psnr_test.item(),
174 | "ssim": ssim_test.item(),
175 | "lpips": lpips_test.item(),
176 | "file_size": file_size_mb,
177 | "elapsed": elapsed,
178 | }
179 | )
180 |
181 | if tb_writer:
182 | tb_writer.add_histogram(
183 | "scene/opacity_histogram", scene.gaussians.get_opacity, iteration
184 | )
185 | tb_writer.add_scalar(
186 | "total_points", scene.gaussians.get_xyz.shape[0], iteration
187 | )
188 |
189 | torch.cuda.empty_cache()
190 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 |
18 | def l1_loss(network_output, gt):
19 | return torch.abs((network_output - gt)).mean()
20 |
21 |
22 | def l2_loss(network_output, gt):
23 | return ((network_output - gt) ** 2).mean()
24 |
25 |
26 | def gaussian(window_size, sigma):
27 | gauss = torch.Tensor(
28 | [
29 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
30 | for x in range(window_size)
31 | ]
32 | )
33 | return gauss / gauss.sum()
34 |
35 |
36 | def create_window(window_size, channel):
37 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
38 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
39 | window = Variable(
40 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()
41 | )
42 | return window
43 |
44 |
45 | def ssim(img1, img2, window_size=11, size_average=True):
46 | channel = img1.size(-3)
47 | window = create_window(window_size, channel)
48 |
49 | if img1.is_cuda:
50 | window = window.cuda(img1.get_device())
51 | window = window.type_as(img1)
52 |
53 | return _ssim(img1, img2, window, window_size, channel, size_average)
54 |
55 |
56 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
59 |
60 | mu1_sq = mu1.pow(2)
61 | mu2_sq = mu2.pow(2)
62 | mu1_mu2 = mu1 * mu2
63 |
64 | sigma1_sq = (
65 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
66 | )
67 | sigma2_sq = (
68 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
69 | )
70 | sigma12 = (
71 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
72 | - mu1_mu2
73 | )
74 |
75 | C1 = 0.01**2
76 | C2 = 0.03**2
77 |
78 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
79 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
80 | )
81 |
82 | if size_average:
83 | return ssim_map.mean()
84 | else:
85 | return ssim_map.mean(1).mean(1).mean(1)
86 |
87 |
88 | def img2mse(x, y, mask=None):
89 | if mask is None:
90 | return torch.mean((x - y) ** 2)
91 | else:
92 | return torch.sum((x * mask - y * mask) ** 2) / (torch.sum(mask) + 1e-5)
93 |
94 |
95 | def img2mae(x, y, mask=None):
96 | if mask is None:
97 | return torch.mean(torch.abs(x - y))
98 | else:
99 | return torch.sum(torch.abs(x * mask - y * mask)) / (torch.sum(mask) + 1e-5)
100 |
--------------------------------------------------------------------------------
/utils/save_imp_score.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from random import randint
15 | import sys
16 | from scene import Scene, GaussianModel
17 | from utils.general_utils import safe_state
18 | from tqdm import tqdm
19 | from utils.image_utils import psnr
20 | from argparse import ArgumentParser, Namespace
21 | from arguments import ModelParams, PipelineParams, OptimizationParams
22 | try:
23 | from torch.utils.tensorboard import SummaryWriter
24 | TENSORBOARD_FOUND = True
25 | except ImportError:
26 | TENSORBOARD_FOUND = False
27 | from icecream import ic
28 | import random
29 | import copy
30 | import gc
31 | from os import makedirs
32 | from prune import prune_list, calculate_v_imp_score
33 | import csv
34 | import numpy as np
35 |
36 | def save_imp_score(dataset, opt, pipe, checkpoint, args):
37 | gaussians = GaussianModel(dataset.sh_degree)
38 | scene = Scene(dataset, gaussians)
39 | gaussians.training_setup(opt)
40 | if checkpoint:
41 | (model_params, first_iter) = torch.load(checkpoint)
42 | gaussians.restore(model_params, opt)
43 |
44 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
45 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
46 | gaussian_list, imp_list = prune_list(gaussians, scene, pipe, background)
47 | v_list = calculate_v_imp_score(gaussians, imp_list, args.v_pow)
48 | np.savez(os.path.join(scene.model_path,"imp_score"), v_list)
49 |
50 | # If you want to print the imp_score:
51 | if args.show_imp_score:
52 | data = np.load(os.path.join(scene.model_path,"imp_score.npz"))
53 | lst = data.files
54 | for item in lst:
55 | ic(item)
56 | ic(data[item].shape)
57 |
58 |
59 |
60 | if __name__ == "__main__":
61 | # Set up command line argument parser
62 | parser = ArgumentParser(description="Training script parameters")
63 | lp = ModelParams(parser)
64 | op = OptimizationParams(parser)
65 | pp = PipelineParams(parser)
66 | parser.add_argument('--debug_from', type=int, default=-1)
67 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
68 | parser.add_argument("--start_checkpoint", type=str, default = None)
69 | parser.add_argument("--show_imp_score", action='store_true', default=False)
70 | parser.add_argument("--get_fps",action='store_true', default=False)
71 | parser.add_argument("--quiet", action="store_true")
72 | parser.add_argument("--v_pow", type=float, default=0.1)
73 |
74 |
75 | args = parser.parse_args(sys.argv[1:])
76 |
77 | print("Optimizing " + args.model_path)
78 |
79 | # Initialize system state (RNG)
80 | safe_state(args.quiet)
81 | save_imp_score(lp.extract(args), op.extract(args), pp.extract(args), args.start_checkpoint, args)
82 | # All done
83 | print("\nTraining complete.")
84 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396,
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435,
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (
78 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]
79 | )
80 |
81 | if deg > 1:
82 | xx, yy, zz = x * x, y * y, z * z
83 | xy, yz, xz = x * y, y * z, x * z
84 | result = (
85 | result
86 | + C2[0] * xy * sh[..., 4]
87 | + C2[1] * yz * sh[..., 5]
88 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]
89 | + C2[3] * xz * sh[..., 7]
90 | + C2[4] * (xx - yy) * sh[..., 8]
91 | )
92 |
93 | if deg > 2:
94 | result = (
95 | result
96 | + C3[0] * y * (3 * xx - yy) * sh[..., 9]
97 | + C3[1] * xy * z * sh[..., 10]
98 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]
99 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]
100 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]
101 | + C3[5] * z * (xx - yy) * sh[..., 14]
102 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15]
103 | )
104 |
105 | if deg > 3:
106 | result = (
107 | result
108 | + C4[0] * xy * (xx - yy) * sh[..., 16]
109 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17]
110 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18]
111 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19]
112 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]
113 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21]
114 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]
115 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]
116 | + C4[8]
117 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
118 | * sh[..., 24]
119 | )
120 | return result
121 |
122 |
123 | def RGB2SH(rgb):
124 | return (rgb - 0.5) / C0
125 |
126 |
127 | def SH2RGB(sh):
128 | return sh * C0 + 0.5
129 |
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 |
17 | def mkdir_p(folder_path):
18 | # Creates a directory. equivalent to using mkdir -p on the command line
19 | try:
20 | makedirs(folder_path)
21 | except OSError as exc: # Python >2.5
22 | if exc.errno == EEXIST and path.isdir(folder_path):
23 | pass
24 | else:
25 | raise
26 |
27 |
28 | def searchForMaxIteration(folder):
29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
30 | return max(saved_iters)
31 |
--------------------------------------------------------------------------------
/utils/tracker_utils.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import random
3 |
4 | class HardestExamplesTracker:
5 | def __init__(self, max_size=10):
6 | self.max_size = max_size
7 | self.heap = []
8 | self.total_added = 0
9 |
10 | def add(self, loss, example, label):
11 | # Ensure the label is either "virtual" or "gt"
12 | # assert label in ["virtual", "gt"], "Label must be 'virtual' or 'gt'"
13 |
14 | if len(self.heap) < self.max_size:
15 | heapq.heappush(self.heap, (loss, example, label))
16 | self.total_added += 1
17 | elif loss > self.heap[0][0]:
18 | heapq.heappushpop(self.heap, (loss, example, label))
19 |
20 | def get_hardest_examples(self):
21 | # Sort by loss and return examples with their labels
22 | return [(example, label) for loss, example, label in sorted(self.heap, reverse=True)]
23 |
24 | def get_random_example(self):
25 | if not self.heap:
26 | return None
27 | _, example, label = random.choice(self.heap)
28 | return example, label
29 |
30 | def get_hardest_example(self):
31 | if not self.heap:
32 | return None
33 | _, example, label = max(self.heap, key=lambda x: x[0])
34 | return example, label
35 |
36 | def get_size(self):
37 | return self.total_added
38 |
39 |
40 |
--------------------------------------------------------------------------------
/utils/vgg.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | from torchvision import models
5 |
6 |
7 | class Vgg16(torch.nn.Module):
8 | def __init__(self, requires_grad=False):
9 | super(Vgg16, self).__init__()
10 | vgg_pretrained_features = models.vgg16(weights="VGG16_Weights.DEFAULT").features
11 | self.slice1 = torch.nn.Sequential()
12 | self.slice2 = torch.nn.Sequential()
13 | self.slice3 = torch.nn.Sequential()
14 | self.slice4 = torch.nn.Sequential()
15 | for x in range(4):
16 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
17 | for x in range(4, 9):
18 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
19 | for x in range(9, 16):
20 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
21 | for x in range(16, 23):
22 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
23 | if not requires_grad:
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def forward(self, X):
28 | h = self.slice1(X)
29 | h_relu1_2 = h
30 | h = self.slice2(h)
31 | h_relu2_2 = h
32 | h = self.slice3(h)
33 | h_relu3_3 = h
34 | h = self.slice4(h)
35 | h_relu4_3 = h
36 | # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
37 | # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
38 | out = {}
39 | out["relu1_2"] = h_relu1_2
40 | out["relu2_2"] = h_relu2_2
41 | out["relu3_3"] = h_relu3_3
42 | out["relu4_3"] = h_relu4_3
43 | return out
44 |
--------------------------------------------------------------------------------
/vectree/utils.py:
--------------------------------------------------------------------------------
1 | import os, math, torch
2 | import numpy as np
3 | from plyfile import PlyData, PlyElement
4 |
5 | def load_vqgaussian(path, device='cuda'):
6 | def load_f(name, allow_pickle=False,array_name='arr_0'):
7 | return np.load(os.path.join(path,name),allow_pickle=allow_pickle)[array_name]
8 |
9 | metadata = load_f('metadata.npz',allow_pickle=True,array_name='metadata')
10 | metadata = metadata.item()
11 |
12 | ## load basic info
13 | codebook_size = metadata['codebook_size']
14 | codebook_dim = metadata['codebook_dim']
15 | bit_length = int(math.log2(codebook_size)) # log_2_K
16 | input_pc_num = metadata['input_pc_num'] # feats.shape[0]
17 | input_pc_dim = metadata['input_pc_dim'] # feats.shape[1]
18 |
19 | # ===================================================== load vq_SH ============================================
20 | ## loading the two masks
21 | non_vq_mask = load_f('non_vq_mask.npz')
22 | non_vq_mask = np.unpackbits(non_vq_mask)
23 | non_vq_mask = non_vq_mask[:input_pc_num]
24 | non_vq_mask = torch.from_numpy(non_vq_mask).bool().to(device) # non_vq_mask
25 | all_one_mask = torch.ones_like(non_vq_mask).bool().to(device) # all_one_mask
26 |
27 | ## loading codebook and vq indexes
28 | codebook = load_f('codebook.npz')
29 | codebook = torch.from_numpy(codebook).float().to(device)
30 | vq_mask = torch.logical_xor(non_vq_mask, all_one_mask) # vq_mask
31 | vq_elements = vq_mask.sum()
32 |
33 | vq_indexs = load_f('vq_indexs.npz')
34 | vq_indexs = np.unpackbits(vq_indexs)
35 | vq_indexs = vq_indexs[:vq_elements*bit_length].reshape(vq_elements,bit_length)
36 | vq_indexs = torch.from_numpy(vq_indexs).float()
37 | vq_indexs = bin2dec(vq_indexs, bits=bit_length)
38 | vq_indexs = vq_indexs.long().to(device) # vq_indexs
39 |
40 | # ===================================================== load non_vq_SH ==========================================
41 | non_vq_feats = load_f('non_vq_feats.npz')
42 | non_vq_feats = torch.from_numpy(non_vq_feats).float().to(device)
43 |
44 | # =========================================== load xyz & other attr(opacity + 3*scale + 4*rot) ===============
45 | other_attribute = load_f('other_attribute.npz')
46 | other_attribute = torch.from_numpy(other_attribute).float().to(device)
47 |
48 | xyz = load_f('xyz.npz')
49 | xyz = torch.from_numpy(xyz).float().to(device)
50 | # =========================================== build full features =============================================
51 | full_feats = torch.zeros(input_pc_num, input_pc_dim).to(device)
52 | # --- xyz & other attr---
53 | full_feats[:, 0:3] = xyz
54 | full_feats[:, -8:] = other_attribute
55 |
56 | # --- nx==ny==nz==0
57 |
58 | # --- vq_SH ---
59 | full_feats[vq_mask, 6:6+codebook_dim] = codebook[vq_indexs]
60 |
61 | # --- non_vq_SH ---
62 | # non_vq_mask = torch.logical_xor(vq_mask, all_one_mask)
63 | full_feats[non_vq_mask, 6:6+codebook_dim] = non_vq_feats
64 |
65 | return full_feats
66 |
67 |
68 |
69 | def read_ply_data(input_file):
70 | ply_data = PlyData.read(input_file)
71 | i = 0
72 | vertex = ply_data['vertex']
73 | for prop in vertex._property_lookup:
74 | tmp = vertex.data[prop].reshape(-1,1)
75 | if i == 0:
76 | data = tmp
77 | i += 1
78 | else:
79 | data = np.concatenate((data, tmp), axis=1)
80 | return data
81 |
82 |
83 | def write_ply_data(feats, save_ply_path, sh_dim):
84 | def construct_list_of_attributes():
85 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
86 | # All channels except the 3 DC
87 | for i in range(3):
88 | l.append('f_dc_{}'.format(i))
89 | for i in range(sh_dim-3-8 if sh_dim==24+3+8 else sh_dim-3):
90 | l.append('f_rest_{}'.format(i))
91 | l.append('opacity')
92 | for i in range(3):
93 | l.append('scale_{}'.format(i))
94 | for i in range(4):
95 | l.append('rot_{}'.format(i))
96 | return l
97 |
98 | path= save_ply_path+'/point_cloud.ply'
99 | dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()] # f4:float32,f2:float16
100 | elements = np.empty(feats.shape[0], dtype=dtype_full)
101 | elements[:] = list(map(tuple, feats))
102 | el = PlyElement.describe(elements, 'vertex')
103 | PlyData([el]).write(path)
104 |
105 | def dec2bin(x, bits):
106 | mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype)
107 | return x.unsqueeze(-1).bitwise_and(mask).ne(0).float()
108 |
109 | def bin2dec(b, bits):
110 | mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype)
111 | return torch.sum(mask * b, -1)
112 |
113 |
--------------------------------------------------------------------------------
/vectree/vectree.py:
--------------------------------------------------------------------------------
1 | import os, torch, argparse, math
2 | import numpy as np
3 | from copy import deepcopy
4 | from tqdm import tqdm, trange
5 |
6 | from vq import VectorQuantize
7 | from utils import read_ply_data, write_ply_data, load_vqgaussian
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description="vectree quantization")
12 | parser.add_argument("--important_score_npz_path", type=str, default='room')
13 | parser.add_argument("--input_path", type=str, default='room/iteration_40000/point_cloud.ply')
14 |
15 | parser.add_argument("--save_path", type=str, default='./output/room')
16 | parser.add_argument("--no_load_data", type=bool, default=False)
17 | parser.add_argument("--no_save_ply", type=bool, default=False)
18 | parser.add_argument("--sh_degree", type=int, default=2)
19 |
20 | parser.add_argument("--iteration_num", type=float, default=1000)
21 | parser.add_argument("--vq_ratio", type=float, default=0.6)
22 | parser.add_argument("--codebook_size", type=int, default=2**13) # 2**13 = 8192
23 | parser.add_argument("--no_IS", type=bool, default=False)
24 | parser.add_argument("--vq_way", type=str, default='half')
25 | opt = parser.parse_args()
26 | return opt
27 |
28 |
29 | class Quantization():
30 | def __init__(self, opt):
31 |
32 | # ----- load ply data -----
33 | if opt.sh_degree == 3:
34 | self.sh_dim = 3+45
35 | elif opt.sh_degree == 2:
36 | self.sh_dim = 3+24
37 |
38 | self.feats = read_ply_data(opt.input_path)
39 | self.feats = torch.tensor(self.feats)
40 | self.feats_bak = self.feats.clone()
41 | self.feats = self.feats[:, 6:6+self.sh_dim]
42 |
43 | # ----- define model -----
44 | self.model_vq = VectorQuantize(
45 | dim = self.feats.shape[1],
46 | codebook_size = opt.codebook_size,
47 | decay = 0.8,
48 | commitment_weight = 1.0,
49 | use_cosine_sim = False,
50 | threshold_ema_dead_code=0,
51 | ).to(device)
52 |
53 | # ----- other -----
54 | self.save_path = opt.save_path
55 | self.ply_path = opt.save_path
56 | self.imp_path = opt.important_score_npz_path
57 | self.high = None
58 | self.VQ_CHUNK = 80000
59 | self.k_expire = 10
60 | self.vq_ratio = opt.vq_ratio
61 |
62 | self.no_IS = opt.no_IS
63 | self.no_load_data = opt.no_load_data
64 | self.no_save_ply = opt.no_save_ply
65 |
66 | self.codebook_size = opt.codebook_size
67 | self.iteration_num = opt.iteration_num
68 | self.vq_way = opt.vq_way
69 |
70 | # ----- print info -----
71 | print("\n================== Print Info ================== ")
72 | print("Input_feats_shape: ", self.feats_bak.shape)
73 | print("VQ_feats_shape: ", self.feats.shape)
74 | print("SH_degree: ", opt.sh_degree)
75 | print("Quantization_ratio: ", opt.vq_ratio)
76 | print("Add_important_score: ", opt.no_IS==False)
77 | print("Codebook_size: ", opt.codebook_size)
78 | print("================================================ ")
79 |
80 | @torch.no_grad()
81 | def calc_vector_quantized_feature(self):
82 | """
83 | apply vector quantize on gaussian attributes and return vq indexes
84 | """
85 | CHUNK = 8192
86 | feat_list = []
87 | indice_list = []
88 | self.model_vq.eval()
89 | self.model_vq._codebook.embed.half().float() #
90 | for i in tqdm(range(0, self.feats.shape[0], CHUNK)):
91 | feat, indices, commit = self.model_vq(self.feats[i:i+CHUNK,:].unsqueeze(0).to(device))
92 | indice_list.append(indices[0])
93 | feat_list.append(feat[0])
94 | self.model_vq.train()
95 | all_feat = torch.cat(feat_list).half().float() # [num_elements, feats_dim]
96 | all_indice = torch.cat(indice_list) # [num_elements, 1]
97 | return all_feat, all_indice
98 |
99 |
100 | @torch.no_grad()
101 | def fully_vq_reformat(self):
102 |
103 | print("\n=============== Start vector quantize ===============")
104 | all_feat, all_indice = self.calc_vector_quantized_feature()
105 |
106 | if self.save_path is not None:
107 | save_path = self.save_path
108 | os.makedirs(f'{save_path}/extreme_saving', exist_ok=True)
109 |
110 | # ----- save basic info -----
111 | metadata = dict()
112 | metadata['input_pc_num'] = self.feats_bak.shape[0]
113 | metadata['input_pc_dim'] = self.feats_bak.shape[1]
114 | metadata['codebook_size'] = self.codebook_size
115 | metadata['codebook_dim'] = self.sh_dim
116 | np.savez_compressed(f'{save_path}/extreme_saving/metadata.npz', metadata=metadata)
117 |
118 | # ===================================================== save vq_SH =============================================
119 | # ----- save mapping_index (vq_index) -----
120 | def dec2bin(x, bits):
121 | mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype)
122 | return x.unsqueeze(-1).bitwise_and(mask).ne(0).float()
123 | # vq indice was saved in according to the bit length
124 | self.codebook_vq_index = all_indice[torch.logical_xor(self.all_one_mask,self.non_vq_mask)] # vq_index
125 | bin_indices = dec2bin(self.codebook_vq_index, int(math.log2(self.codebook_size))).bool().cpu().numpy() # mapping_index
126 | np.savez_compressed(f'{save_path}/extreme_saving/vq_indexs.npz',np.packbits(bin_indices.reshape(-1)))
127 |
128 | # ----- save codebook -----
129 | codebook = self.model_vq._codebook.embed.cpu().half().numpy().squeeze(0)
130 | np.savez_compressed(f'{save_path}/extreme_saving/codebook.npz', codebook)
131 |
132 | # ----- save keep mask (non_vq_feats_index)-----
133 | np.savez_compressed(f'{save_path}/extreme_saving/non_vq_mask.npz',np.packbits(self.non_vq_mask.reshape(-1).cpu().numpy()))
134 |
135 | # ===================================================== save non_vq_SH =============================================
136 | non_vq_feats = self.feats_bak[self.non_vq_mask, 6:6+self.sh_dim]
137 | wage_non_vq_feats = self.wage_vq(non_vq_feats)
138 | np.savez_compressed(f'{save_path}/extreme_saving/non_vq_feats.npz', wage_non_vq_feats)
139 |
140 | # =========================================== save xyz & other attr(opacity + 3*scale + 4*rot) ====================================
141 | other_attribute = self.feats_bak[:, -8:]
142 | wage_other_attribute = self.wage_vq(other_attribute)
143 | np.savez_compressed(f'{save_path}/extreme_saving/other_attribute.npz', wage_other_attribute)
144 |
145 | xyz = self.feats_bak[:, 0:3]
146 | np.savez_compressed(f'{save_path}/extreme_saving/xyz.npz', xyz)
147 |
148 |
149 | # zip everything together to get final size
150 | os.system(f"zip -r {save_path}/extreme_saving.zip {save_path}/extreme_saving")
151 | size = os.path.getsize(f'{save_path}/extreme_saving.zip')
152 | size_MB = size / 1024.0 / 1024.0
153 | print("Size = {:.2f} MB".format(size_MB))
154 |
155 | return all_feat, all_indice
156 |
157 | def load_f(self, path, name, allow_pickle=False,array_name='arr_0'):
158 | return np.load(os.path.join(path, name),allow_pickle=allow_pickle)[array_name]
159 |
160 | def wage_vq(self, feats):
161 | if self.vq_way == 'half':
162 | return feats.half()
163 | else:
164 | return feats
165 |
166 | def quantize(self):
167 | if self.no_IS: # no important score
168 | importance = np.ones((self.feats.shape[0]))
169 | else:
170 | importance = self.load_f(self.imp_path, 'imp_score.npz')
171 |
172 | ###################################################
173 | only_vq_some_vector = True
174 | if only_vq_some_vector:
175 | tensor_importance = torch.tensor(importance)
176 | large_val, large_index = torch.topk(tensor_importance, k=int(tensor_importance.shape[0] * (1-self.vq_ratio)), largest=True)
177 | self.all_one_mask = torch.ones_like(tensor_importance).bool()
178 | self.non_vq_mask = torch.zeros_like(tensor_importance).bool()
179 | self.non_vq_mask[large_index] = True
180 | self.non_vq_index = large_index
181 |
182 | IS_non_vq_point = large_val.sum()
183 | IS_all_point = tensor_importance.sum()
184 | IS_percent = IS_non_vq_point/IS_all_point
185 | print("IS_percent: ", IS_percent)
186 |
187 | #=================== Codebook initialization & Update codebook ====================
188 | self.model_vq.train()
189 | with torch.no_grad():
190 | self.vq_mask = torch.logical_xor(self.all_one_mask, self.non_vq_mask)
191 | feats_needs_vq = self.feats[self.vq_mask].clone()
192 | imp = tensor_importance[self.vq_mask].float()
193 | k = self.k_expire
194 | if k > self.model_vq.codebook_size:
195 | k = 0
196 | for i in trange(self.iteration_num):
197 | indexes = torch.randint(low=0, high=feats_needs_vq.shape[0], size=[self.VQ_CHUNK])
198 | vq_weight = imp[indexes].to(device)
199 | vq_feature = feats_needs_vq[indexes,:].to(device)
200 | quantize, embed, loss = self.model_vq(vq_feature.unsqueeze(0), weight=vq_weight.reshape(1,-1,1))
201 |
202 | replace_val, replace_index = torch.topk(self.model_vq._codebook.cluster_size, k=k, largest=False)
203 | _, most_important_index = torch.topk(vq_weight, k=k, largest=True)
204 | self.model_vq._codebook.embed[:,replace_index,:] = vq_feature[most_important_index,:]
205 |
206 | #=================== Apply vector quantization ====================
207 | all_feat, all_indices = self.fully_vq_reformat()
208 |
209 | def dequantize(self):
210 | print("\n==================== Load saved data & Dequantize ==================== ")
211 | dequantized_feats = load_vqgaussian(os.path.join(self.save_path,'extreme_saving'), device=device)
212 |
213 | if self.no_save_ply == False:
214 | os.makedirs(f'{self.ply_path}/', exist_ok=True)
215 | write_ply_data(dequantized_feats.cpu().numpy(), self.ply_path, self.sh_dim)
216 |
217 |
218 | if __name__=='__main__':
219 | opt = parse_args()
220 | device = torch.device('cuda')
221 | vq = Quantization(opt)
222 |
223 | vq.quantize()
224 | vq.dequantize()
225 |
226 | print("All done!")
227 |
228 |
--------------------------------------------------------------------------------