├── .gitignore ├── .gitmodules ├── LICENSE.txt ├── README.md ├── analyze.py ├── analyze_statistic.py ├── arguments └── __init__.py ├── assets └── teaser.png ├── convert.py ├── densification.py ├── environment.yml ├── examples ├── mip360 │ ├── 1g_1b.sh │ ├── 4g_1b.sh │ ├── 4g_4b.sh │ ├── analyze_results.py │ ├── eval_all_mip360.sh │ ├── render_and_metrics.sh │ └── render_and_metrics_gpus.sh ├── mip360_4k │ ├── 1g_1b_4k.sh │ ├── 4g_1b_4k.sh │ ├── analyze_results.py │ └── eval_mip360_4k.sh └── train_truck_1k │ ├── analyze_results.py │ ├── eval_train_truck_1k.sh │ └── train_truck_1k.sh ├── gaussian_renderer ├── __init__.py ├── distribution_config.py ├── loss_distribution.py ├── network_gui.py └── workload_division.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── train.py ├── train_internal.py └── utils ├── camera_utils.py ├── debug_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py ├── system_utils.py └── timer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | expe* 10 | old_files 11 | # *.sh 12 | *.ipynb 13 | *.txt 14 | *.log 15 | *.out 16 | __pycache__ 17 | *.npy 18 | scripts/ 19 | data 20 | results 21 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = git@github.com:nyu-systems/diff-gaussian-rasterization.git 7 | [submodule "submodules/gsplat"] 8 | path = submodules/gsplat 9 | url = git@github.com:alexis-mmm/gsplat.git 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | from gaussian_renderer.distribution_config import init_image_distribution_config 16 | import utils.general_utils as utils 17 | import diff_gaussian_rasterization 18 | 19 | 20 | class GroupParams: 21 | pass 22 | 23 | 24 | class ParamGroup: 25 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 26 | group = parser.add_argument_group(name) 27 | for key, value in vars(self).items(): 28 | shorthand = False 29 | if key.startswith("_"): 30 | shorthand = True 31 | key = key[1:] 32 | t = type(value) 33 | value = value if not fill_none else None 34 | if shorthand: 35 | if t == bool: 36 | group.add_argument( 37 | "--" + key, ("-" + key[0:1]), default=value, action="store_true" 38 | ) 39 | else: 40 | group.add_argument( 41 | "--" + key, ("-" + key[0:1]), default=value, type=t 42 | ) 43 | else: 44 | if t == bool: 45 | group.add_argument("--" + key, default=value, action="store_true") 46 | elif t == list: 47 | type_to_use = int 48 | if len(value) > 0: 49 | type_to_use = type(value[0]) 50 | group.add_argument( 51 | "--" + key, default=value, nargs="+", type=type_to_use 52 | ) 53 | else: 54 | group.add_argument("--" + key, default=value, type=t) 55 | 56 | def extract(self, args): 57 | group = GroupParams() 58 | for arg in vars(args).items(): 59 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 60 | setattr(group, arg[0], arg[1]) 61 | return group 62 | 63 | 64 | class AuxiliaryParams(ParamGroup): 65 | def __init__(self, parser, sentinel=False): 66 | self.debug_from = -1 67 | self.detect_anomaly = False 68 | self.test_iterations = [7_000, 30_000] 69 | self.save_iterations = [7_000, 30_000] 70 | self.quiet = False 71 | self.checkpoint_iterations = [] 72 | self.start_checkpoint = "" 73 | self.auto_start_checkpoint = False 74 | self.log_folder = "/tmp/gaussian_splatting" 75 | self.log_interval = 250 76 | self.llffhold = 8 77 | self.backend = "default" # "default", "gsplat" 78 | super().__init__(parser, "Loading Parameters", sentinel) 79 | 80 | def extract(self, args): 81 | g = super().extract(args) 82 | return g 83 | 84 | 85 | class ModelParams(ParamGroup): 86 | def __init__(self, parser, sentinel=False): 87 | self.sh_degree = 3 88 | self._source_path = "" 89 | self._model_path = "/tmp/gaussian_splatting" 90 | self._images = "images" 91 | self._white_background = False 92 | self.eval = False 93 | super().__init__(parser, "Loading Parameters", sentinel) 94 | 95 | def extract(self, args): 96 | g = super().extract(args) 97 | g.source_path = os.path.abspath(g.source_path) 98 | return g 99 | 100 | 101 | class PipelineParams(ParamGroup): 102 | def __init__(self, parser): 103 | self.debug = False 104 | super().__init__(parser, "Pipeline Parameters") 105 | 106 | 107 | class OptimizationParams(ParamGroup): 108 | def __init__(self, parser): 109 | self.iterations = 30_000 110 | self.position_lr_init = 0.00016 111 | self.position_lr_final = 0.0000016 112 | self.position_lr_delay_mult = 0.01 113 | self.position_lr_max_steps = 30_000 114 | self.feature_lr = 0.0025 115 | self.opacity_lr = 0.05 116 | self.scaling_lr = 0.005 117 | self.lr_scale_loss = 1.0 118 | self.lr_scale_pos_and_scale = 1.0 119 | self.rotation_lr = 0.001 120 | self.percent_dense = 0.01 121 | self.lambda_dssim = 0.2 122 | self.densification_interval = 100 123 | self.opacity_reset_interval = 3000 124 | self.densify_from_iter = 500 125 | self.densify_until_iter = 15_000 126 | self.densify_grad_threshold = 0.0002 127 | self.densify_memory_limit_percentage = 0.9 128 | self.disable_auto_densification = False 129 | self.opacity_reset_until_iter = -1 130 | self.random_background = False 131 | self.min_opacity = 0.005 132 | self.lr_scale_mode = "sqrt" # can be "linear", "sqrt", or "accumu" 133 | super().__init__(parser, "Optimization Parameters") 134 | 135 | 136 | class DistributionParams(ParamGroup): 137 | def __init__(self, parser): 138 | # Distribution for pixel-wise workloads. 139 | self.image_distribution = True 140 | self.image_distribution_mode = "final" 141 | self.heuristic_decay = 0.0 142 | self.no_heuristics_update = False 143 | self.border_divpos_coeff = 1.0 144 | self.adjust_strategy_warmp_iterations = -1 145 | self.save_strategy_history = False 146 | 147 | # Distribution for 3DGS-wise workloads. 148 | self.gaussians_distribution = True 149 | self.redistribute_gaussians_mode = "random_redistribute" # "no_redistribute" 150 | self.redistribute_gaussians_frequency = ( 151 | 10 # redistribution frequency for 3DGS storage location. 152 | ) 153 | self.redistribute_gaussians_threshold = ( 154 | 1.1 # threshold to apply redistribution for 3DGS storage location 155 | ) 156 | self.sync_grad_mode = "dense" # "dense", "sparse", "fused_dense", "fused_sparse" gradient synchronization. Only use when gaussians_distribution is False. 157 | self.grad_normalization_mode = "none" # "divide_by_visible_count", "square_multiply_by_visible_count", "multiply_by_visible_count", "none" gradient normalization mode. 158 | 159 | # Dataset and Model save 160 | self.bsz = 1 # batch size. 161 | self.distributed_dataset_storage = True # if True, we store dataset only on rank 0 and broadcast to other ranks. 162 | self.distributed_save = True 163 | self.local_sampling = False 164 | self.preload_dataset_to_gpu = ( 165 | False # By default, we do not preload dataset to GPU. 166 | ) 167 | self.preload_dataset_to_gpu_threshold = ( 168 | 10 # unit is GB, by default 10GB memory limit for dataset. 169 | ) 170 | self.multiprocesses_image_loading = True 171 | self.num_train_cameras = -1 172 | self.num_test_cameras = -1 173 | 174 | super().__init__(parser, "Distribution Parameters") 175 | 176 | 177 | class BenchmarkParams(ParamGroup): 178 | def __init__(self, parser): 179 | self.enable_timer = False # Log running time from python side. 180 | self.end2end_time = True # Log end2end training time. 181 | self.zhx_time = False # Log running time from gpu side. 182 | self.check_gpu_memory = False # check gpu memory usage. 183 | self.check_cpu_memory = False # check cpu memory usage. 184 | self.log_memory_summary = False 185 | 186 | super().__init__(parser, "Benchmark Parameters") 187 | 188 | 189 | class DebugParams(ParamGroup): 190 | def __init__(self, parser): 191 | self.zhx_debug = False # log debug information that zhx needs. 192 | self.stop_update_param = ( 193 | False # stop updating parameters. No optimizer.step() will be called. 194 | ) 195 | self.time_image_loading = False # Log image loading time. 196 | 197 | self.nsys_profile = False # profile with nsys. 198 | self.drop_initial_3dgs_p = 0.0 # profile with nsys. 199 | self.drop_duplicate_gaussians_coeff = 1.0 200 | 201 | super().__init__(parser, "Debug Parameters") 202 | 203 | 204 | def get_combined_args(parser: ArgumentParser, auto_find_cfg_args_path=False): 205 | cmdlne_string = sys.argv[1:] 206 | cfgfile_string = "Namespace()" 207 | args_cmdline = parser.parse_args(cmdlne_string) 208 | 209 | try: 210 | if auto_find_cfg_args_path: 211 | if hasattr(args_cmdline, "load_ply_path"): 212 | path = args_cmdline.load_ply_path 213 | while not os.path.exists( 214 | os.path.join(path, "cfg_args") 215 | ) and os.path.exists(path): 216 | path = os.path.join(path, "..") 217 | cfgfilepath = os.path.join(path, "cfg_args") 218 | else: 219 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 220 | print("Looking for config file in", cfgfilepath) 221 | with open(cfgfilepath) as cfg_file: 222 | print("Config file found: {}".format(cfgfilepath)) 223 | cfgfile_string = cfg_file.read() 224 | except TypeError: 225 | print("Config file not found at") 226 | pass 227 | args_cfgfile = eval(cfgfile_string) 228 | 229 | merged_dict = vars(args_cfgfile).copy() 230 | for k, v in vars(args_cmdline).items(): 231 | if v != None: 232 | merged_dict[k] = v 233 | return Namespace(**merged_dict) 234 | 235 | 236 | def print_all_args(args, log_file): 237 | # print all arguments in a readable format, each argument in a line. 238 | log_file.write("arguments:\n") 239 | log_file.write("-" * 30 + "\n") 240 | for arg in vars(args): 241 | log_file.write("{}: {}\n".format(arg, getattr(args, arg))) 242 | log_file.write("-" * 30 + "\n\n") 243 | log_file.write( 244 | "world_size: " 245 | + str(utils.WORLD_SIZE) 246 | + " rank: " 247 | + str(utils.GLOBAL_RANK) 248 | + "; bsz: " 249 | + str(args.bsz) 250 | + "\n" 251 | ) 252 | 253 | # Make sure block size match between python and cuda code. 254 | cuda_block_x, cuda_block_y, one_dim_block_size = ( 255 | diff_gaussian_rasterization._C.get_block_XY() 256 | ) 257 | utils.set_block_size(cuda_block_x, cuda_block_y, one_dim_block_size) 258 | log_file.write( 259 | "cuda_block_x: {}; cuda_block_y: {}; one_dim_block_size: {};\n".format( 260 | cuda_block_x, cuda_block_y, one_dim_block_size 261 | ) 262 | ) 263 | 264 | 265 | def find_latest_checkpoint(log_folder): 266 | checkpoint_folder = os.path.join(log_folder, "checkpoints") 267 | if os.path.exists(checkpoint_folder): 268 | all_sub_folders = os.listdir(checkpoint_folder) 269 | if len(all_sub_folders) > 0: 270 | all_sub_folders.sort(key=lambda x: int(x), reverse=True) 271 | return os.path.join(checkpoint_folder, all_sub_folders[0]) 272 | return "" 273 | 274 | 275 | def init_args(args): 276 | 277 | if args.opacity_reset_until_iter == -1: 278 | args.opacity_reset_until_iter = args.densify_until_iter + args.bsz 279 | 280 | # Logging are saved with where model is saved. 281 | args.log_folder = args.model_path 282 | 283 | if args.auto_start_checkpoint: 284 | args.start_checkpoint = find_latest_checkpoint(args.log_folder) 285 | 286 | if utils.DEFAULT_GROUP.size() == 1: 287 | args.gaussians_distribution = False 288 | args.image_distribution = False 289 | args.image_distribution_mode = "" 290 | args.distributed_dataset_storage = False 291 | args.distributed_save = False 292 | args.local_sampling = False 293 | 294 | if args.preload_dataset_to_gpu: 295 | args.distributed_dataset_storage = False 296 | args.local_sampling = False 297 | # TODO: args.preload_dataset_to_gpu should be independent of args.local_sampling and args.distributed_dataset_storage 298 | # We can distributedly save dataset and preload every shard to GPU at the same time. 299 | 300 | if args.local_sampling: 301 | assert args.distributed_dataset_storage, "local_sampling works only when distributed_dataset_storage==True" 302 | 303 | if not args.gaussians_distribution: 304 | args.distributed_save = False 305 | 306 | # sort test_iterations 307 | args.test_iterations.sort() 308 | args.save_iterations.sort() 309 | if len(args.save_iterations) > 0 and args.iterations not in args.save_iterations: 310 | args.save_iterations.append(args.iterations) 311 | args.checkpoint_iterations.sort() 312 | 313 | # Set up global args 314 | utils.set_args(args) 315 | # TODO: handle the warning: https://github.com/pytorch/pytorch/blob/bae409388cfc20cce656bf7b671e45aaf81dd1c8/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1849-L1852 316 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-systems/Grendel-GS/bff47670c914f7d68990df5e38123ecee19c2a08/assets/teaser.png -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action="store_true") 20 | parser.add_argument("--skip_matching", action="store_true") 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = ( 28 | '"{}"'.format(args.colmap_executable) 29 | if len(args.colmap_executable) > 0 30 | else "colmap" 31 | ) 32 | magick_command = ( 33 | '"{}"'.format(args.magick_executable) 34 | if len(args.magick_executable) > 0 35 | else "magick" 36 | ) 37 | use_gpu = 1 if not args.no_gpu else 0 38 | 39 | if not args.skip_matching: 40 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 41 | 42 | ## Feature extraction 43 | feat_extracton_cmd = ( 44 | colmap_command + " feature_extractor " 45 | "--database_path " 46 | + args.source_path 47 | + "/distorted/database.db \ 48 | --image_path " 49 | + args.source_path 50 | + "/input \ 51 | --ImageReader.single_camera 1 \ 52 | --ImageReader.camera_model " 53 | + args.camera 54 | + " \ 55 | --SiftExtraction.use_gpu " 56 | + str(use_gpu) 57 | ) 58 | exit_code = os.system(feat_extracton_cmd) 59 | if exit_code != 0: 60 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 61 | exit(exit_code) 62 | 63 | ## Feature matching 64 | feat_matching_cmd = ( 65 | colmap_command 66 | + " exhaustive_matcher \ 67 | --database_path " 68 | + args.source_path 69 | + "/distorted/database.db \ 70 | --SiftMatching.use_gpu " 71 | + str(use_gpu) 72 | ) 73 | exit_code = os.system(feat_matching_cmd) 74 | if exit_code != 0: 75 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 76 | exit(exit_code) 77 | 78 | ### Bundle adjustment 79 | # The default Mapper tolerance is unnecessarily large, 80 | # decreasing it speeds up bundle adjustment steps. 81 | mapper_cmd = ( 82 | colmap_command 83 | + " mapper \ 84 | --database_path " 85 | + args.source_path 86 | + "/distorted/database.db \ 87 | --image_path " 88 | + args.source_path 89 | + "/input \ 90 | --output_path " 91 | + args.source_path 92 | + "/distorted/sparse \ 93 | --Mapper.ba_global_function_tolerance=0.000001" 94 | ) 95 | exit_code = os.system(mapper_cmd) 96 | if exit_code != 0: 97 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 98 | exit(exit_code) 99 | 100 | ### Image undistortion 101 | ## We need to undistort our images into ideal pinhole intrinsics. 102 | img_undist_cmd = ( 103 | colmap_command 104 | + " image_undistorter \ 105 | --image_path " 106 | + args.source_path 107 | + "/input \ 108 | --input_path " 109 | + args.source_path 110 | + "/distorted/sparse/0 \ 111 | --output_path " 112 | + args.source_path 113 | + "\ 114 | --output_type COLMAP" 115 | ) 116 | exit_code = os.system(img_undist_cmd) 117 | if exit_code != 0: 118 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 119 | exit(exit_code) 120 | 121 | files = os.listdir(args.source_path + "/sparse") 122 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 123 | # Copy each file from the source directory to the destination directory 124 | for file in files: 125 | if file == "0": 126 | continue 127 | source_file = os.path.join(args.source_path, "sparse", file) 128 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 129 | shutil.move(source_file, destination_file) 130 | 131 | if args.resize: 132 | print("Copying and resizing...") 133 | 134 | # Resize images. 135 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 136 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 137 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 138 | # Get the list of files in the source directory 139 | files = os.listdir(args.source_path + "/images") 140 | # Copy each file from the source directory to the destination directory 141 | for file in files: 142 | source_file = os.path.join(args.source_path, "images", file) 143 | 144 | destination_file = os.path.join(args.source_path, "images_2", file) 145 | shutil.copy2(source_file, destination_file) 146 | exit_code = os.system( 147 | magick_command + " mogrify -resize 50% " + destination_file 148 | ) 149 | if exit_code != 0: 150 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 151 | exit(exit_code) 152 | 153 | destination_file = os.path.join(args.source_path, "images_4", file) 154 | shutil.copy2(source_file, destination_file) 155 | exit_code = os.system( 156 | magick_command + " mogrify -resize 25% " + destination_file 157 | ) 158 | if exit_code != 0: 159 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 160 | exit(exit_code) 161 | 162 | destination_file = os.path.join(args.source_path, "images_8", file) 163 | shutil.copy2(source_file, destination_file) 164 | exit_code = os.system( 165 | magick_command + " mogrify -resize 12.5% " + destination_file 166 | ) 167 | if exit_code != 0: 168 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 169 | exit(exit_code) 170 | 171 | print("Done.") 172 | -------------------------------------------------------------------------------- /densification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils.general_utils as utils 3 | 4 | 5 | def densification(iteration, scene, gaussians, batched_screenspace_pkg): 6 | args = utils.get_args() 7 | timers = utils.get_timers() 8 | log_file = utils.get_log_file() 9 | 10 | # Densification 11 | if not args.disable_auto_densification and iteration <= args.densify_until_iter: 12 | # Keep track of max radii in image-space for pruning 13 | timers.start("densification") 14 | 15 | timers.start("densification_update_stats") 16 | for radii, visibility_filter, screenspace_mean2D in zip( 17 | batched_screenspace_pkg["batched_locally_preprocessed_radii"], 18 | batched_screenspace_pkg["batched_locally_preprocessed_visibility_filter"], 19 | batched_screenspace_pkg["batched_locally_preprocessed_mean2D"], 20 | ): 21 | gaussians.max_radii2D[visibility_filter] = torch.max( 22 | gaussians.max_radii2D[visibility_filter], radii[visibility_filter] 23 | ) 24 | gaussians.add_densification_stats(screenspace_mean2D, visibility_filter) 25 | timers.stop("densification_update_stats") 26 | 27 | if iteration > args.densify_from_iter and utils.check_update_at_this_iter( 28 | iteration, args.bsz, args.densification_interval, 0 29 | ): 30 | assert ( 31 | args.stop_update_param == False 32 | ), "stop_update_param must be false for densification; because it is a flag for debugging." 33 | # utils.print_rank_0("iteration: {}, bsz: {}, update_interval: {}, update_residual: {}".format(iteration, args.bsz, args.densification_interval, 0)) 34 | 35 | timers.start("densify_and_prune") 36 | size_threshold = 20 if iteration > args.opacity_reset_interval else None 37 | gaussians.densify_and_prune( 38 | args.densify_grad_threshold, 39 | args.min_opacity, 40 | scene.cameras_extent, 41 | size_threshold, 42 | ) 43 | timers.stop("densify_and_prune") 44 | 45 | # redistribute after densify_and_prune, because we have new gaussians to distribute evenly. 46 | if utils.get_denfify_iter() % args.redistribute_gaussians_frequency == 0: 47 | num_3dgs_before_redistribute = gaussians.get_xyz.shape[0] 48 | timers.start("redistribute_gaussians") 49 | gaussians.redistribute_gaussians() 50 | timers.stop("redistribute_gaussians") 51 | num_3dgs_after_redistribute = gaussians.get_xyz.shape[0] 52 | 53 | log_file.write( 54 | "iteration[{},{}) redistribute. Now num of 3dgs before redistribute: {}. Now num of 3dgs after redistribute: {}. \n".format( 55 | iteration, 56 | iteration + args.bsz, 57 | num_3dgs_before_redistribute, 58 | num_3dgs_after_redistribute, 59 | ) 60 | ) 61 | 62 | utils.check_memory_usage( 63 | log_file, args, iteration, gaussians, before_densification_stop=True 64 | ) 65 | 66 | utils.inc_densify_iter() 67 | 68 | if ( 69 | utils.check_update_at_this_iter( 70 | iteration, args.bsz, args.opacity_reset_interval, 0 71 | ) 72 | and iteration + args.bsz <= args.opacity_reset_until_iter 73 | ): 74 | timers.start("reset_opacity") 75 | gaussians.reset_opacity() 76 | timers.stop("reset_opacity") 77 | 78 | timers.stop("densification") 79 | else: 80 | if iteration > args.densify_from_iter and utils.check_update_at_this_iter( 81 | iteration, args.bsz, args.densification_interval, 0 82 | ): 83 | utils.check_memory_usage( 84 | log_file, args, iteration, gaussians, before_densification_stop=False 85 | ) 86 | 87 | 88 | def gsplat_densification(iteration, scene, gaussians, batched_screenspace_pkg): 89 | args = utils.get_args() 90 | timers = utils.get_timers() 91 | log_file = utils.get_log_file() 92 | 93 | # Densification 94 | if not args.disable_auto_densification and iteration <= args.densify_until_iter: 95 | # Keep track of max radii in image-space for pruning 96 | timers.start("densification") 97 | 98 | timers.start("densification_update_stats") 99 | image_width = batched_screenspace_pkg["image_width"] 100 | image_height = batched_screenspace_pkg["image_height"] 101 | batched_screenspace_mean2D_grad = batched_screenspace_pkg[ 102 | "batched_locally_preprocessed_mean2D" 103 | ].grad 104 | for i, (radii, visibility_filter) in enumerate( 105 | zip( 106 | batched_screenspace_pkg["batched_locally_preprocessed_radii"], 107 | batched_screenspace_pkg[ 108 | "batched_locally_preprocessed_visibility_filter" 109 | ], 110 | ) 111 | ): 112 | gaussians.max_radii2D[visibility_filter] = torch.max( 113 | gaussians.max_radii2D[visibility_filter], radii[visibility_filter] 114 | ) 115 | gaussians.gsplat_add_densification_stats( 116 | batched_screenspace_mean2D_grad[i], 117 | visibility_filter, 118 | image_width, 119 | image_height, 120 | ) 121 | timers.stop("densification_update_stats") 122 | 123 | if iteration > args.densify_from_iter and utils.check_update_at_this_iter( 124 | iteration, args.bsz, args.densification_interval, 0 125 | ): 126 | assert ( 127 | args.stop_update_param == False 128 | ), "stop_update_param must be false for densification; because it is a flag for debugging." 129 | # utils.print_rank_0("iteration: {}, bsz: {}, update_interval: {}, update_residual: {}".format(iteration, args.bsz, args.densification_interval, 0)) 130 | 131 | timers.start("densify_and_prune") 132 | size_threshold = 20 if iteration > args.opacity_reset_interval else None 133 | gaussians.densify_and_prune( 134 | args.densify_grad_threshold, 135 | args.min_opacity, 136 | scene.cameras_extent, 137 | size_threshold, 138 | ) 139 | timers.stop("densify_and_prune") 140 | 141 | # redistribute after densify_and_prune, because we have new gaussians to distribute evenly. 142 | if utils.get_denfify_iter() % args.redistribute_gaussians_frequency == 0: 143 | num_3dgs_before_redistribute = gaussians.get_xyz.shape[0] 144 | timers.start("redistribute_gaussians") 145 | gaussians.redistribute_gaussians() 146 | timers.stop("redistribute_gaussians") 147 | num_3dgs_after_redistribute = gaussians.get_xyz.shape[0] 148 | 149 | log_file.write( 150 | "iteration[{},{}) redistribute. Now num of 3dgs before redistribute: {}. Now num of 3dgs after redistribute: {}. \n".format( 151 | iteration, 152 | iteration + args.bsz, 153 | num_3dgs_before_redistribute, 154 | num_3dgs_after_redistribute, 155 | ) 156 | ) 157 | 158 | utils.check_memory_usage( 159 | log_file, args, iteration, gaussians, before_densification_stop=True 160 | ) 161 | 162 | utils.inc_densify_iter() 163 | 164 | if ( 165 | utils.check_update_at_this_iter( 166 | iteration, args.bsz, args.opacity_reset_interval, 0 167 | ) 168 | and iteration + args.bsz <= args.opacity_reset_until_iter 169 | ): 170 | timers.start("reset_opacity") 171 | gaussians.reset_opacity() 172 | timers.stop("reset_opacity") 173 | 174 | timers.stop("densification") 175 | else: 176 | if iteration > args.densify_from_iter and utils.check_update_at_this_iter( 177 | iteration, args.bsz, args.densification_interval, 0 178 | ): 179 | utils.check_memory_usage( 180 | log_file, args, iteration, gaussians, before_densification_stop=False 181 | ) 182 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gaussian_splatting 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - plyfile 9 | - python=3.8 10 | - pip=23.0.1 11 | - tqdm 12 | - psutil 13 | - pytorch=2.0.1 14 | - pytorch-cuda=11.7 15 | - torchvision=0.15.2 16 | - setuptools=69.1.1 17 | - mkl==2024.0 18 | - pandas 19 | - pip: 20 | - submodules/diff-gaussian-rasterization 21 | - submodules/simple-knn 22 | - submodules/gsplat -------------------------------------------------------------------------------- /examples/mip360/1g_1b.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments will be saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | # scenes to be trained 13 | SCENE=(counter bicycle stump garden room bonsai kitchen) 14 | # batch size 15 | BSZ=1 16 | 17 | # Monitoring Settings 18 | monitor_opts="--enable_timer \ 19 | --end2end_time \ 20 | --check_gpu_memory \ 21 | --check_cpu_memory" 22 | 23 | for scene in ${SCENE[@]}; do 24 | expe_name="e_${scene}" 25 | 26 | # the following is to match the experiments setting in original gaussian splatting repository 27 | if [ "$scene" = "bicycle" ] || [ "$scene" = "stump" ] || [ "$scene" = "garden" ]; then 28 | image_folder="images_4" 29 | else 30 | image_folder="images_2" 31 | fi 32 | 33 | torchrun --standalone --nnodes=1 --nproc-per-node=1 train.py \ 34 | -s ${dataset_folder}/${scene} \ 35 | --images ${image_folder} \ 36 | --llffhold 8 \ 37 | --iterations 30000 \ 38 | --log_interval 250 \ 39 | --model_path ${expe_folder}/1g_1b/${expe_name} \ 40 | --bsz $BSZ \ 41 | $monitor_opts \ 42 | --test_iterations 7000 15000 30000 \ 43 | --save_iterations 7000 30000 \ 44 | --eval 45 | done 46 | -------------------------------------------------------------------------------- /examples/mip360/4g_1b.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments will be saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | # scenes to be trained 13 | SCENE=(counter bicycle stump garden room bonsai kitchen) 14 | # batch size 15 | BSZ=1 16 | 17 | # Monitoring Settings 18 | monitor_opts="--enable_timer \ 19 | --end2end_time \ 20 | --check_gpu_memory \ 21 | --check_cpu_memory" 22 | 23 | for scene in ${SCENE[@]}; do 24 | expe_name="e_${scene}" 25 | 26 | # the following is to match the experiments setting in original gaussian splatting repository 27 | if [ "$scene" = "bicycle" ] || [ "$scene" = "stump" ] || [ "$scene" = "garden" ]; then 28 | image_folder="images_4" 29 | else 30 | image_folder="images_2" 31 | fi 32 | 33 | torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py \ 34 | -s ${dataset_folder}/${scene} \ 35 | --images ${image_folder} \ 36 | --llffhold 8 \ 37 | --iterations 30000 \ 38 | --log_interval 250 \ 39 | --model_path ${expe_folder}/4g_1b/${expe_name} \ 40 | --bsz $BSZ \ 41 | $monitor_opts \ 42 | --test_iterations 7000 15000 30000 \ 43 | --save_iterations 7000 30000 \ 44 | --eval 45 | done 46 | -------------------------------------------------------------------------------- /examples/mip360/4g_4b.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments will be saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | # scenes to be trained 13 | SCENE=(counter bicycle stump garden room bonsai kitchen) 14 | # batch size 15 | BSZ=4 16 | 17 | # Monitoring Settings 18 | monitor_opts="--enable_timer \ 19 | --end2end_time \ 20 | --check_gpu_memory \ 21 | --check_cpu_memory" 22 | 23 | for scene in ${SCENE[@]}; do 24 | expe_name="e_${scene}" 25 | 26 | # the following is to match the experiments setting in original gaussian splatting repository 27 | if [ "$scene" = "bicycle" ] || [ "$scene" = "stump" ] || [ "$scene" = "garden" ]; then 28 | image_folder="images_4" 29 | else 30 | image_folder="images_2" 31 | fi 32 | 33 | torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py \ 34 | -s ${dataset_folder}/${scene} \ 35 | --images ${image_folder} \ 36 | --llffhold 8 \ 37 | --iterations 30000 \ 38 | --log_interval 250 \ 39 | --model_path ${expe_folder}/4g_4b/${expe_name} \ 40 | --bsz $BSZ \ 41 | $monitor_opts \ 42 | --test_iterations 7000 15000 30000 \ 43 | --save_iterations 7000 30000 \ 44 | --eval 45 | done 46 | -------------------------------------------------------------------------------- /examples/mip360/analyze_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import os 4 | import sys 5 | 6 | def get_suffix_in_folder(folder): 7 | 8 | if not os.path.exists(folder): 9 | return None 10 | 11 | if not folder.endswith("/"): 12 | folder += "/" 13 | 14 | suffix_list_candidates = [] 15 | for ws in [1,2,4,8,16,32]: 16 | for rk in range(ws): 17 | suffix_list_candidates.append(f"ws={ws}_rk={rk}") 18 | 19 | suffix_list = [] 20 | for suffix in suffix_list_candidates: 21 | if os.path.exists(folder + "python_" + suffix + ".log"): 22 | suffix_list.append(suffix) 23 | 24 | return suffix_list 25 | 26 | 27 | def get_running_time_at_iterations(expe_folder, iterations): 28 | a_suffix = get_suffix_in_folder(expe_folder)[0] 29 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 30 | lines = open(a_log_file, "r").readlines() 31 | results = [] 32 | bsz = 1 33 | for line in lines: 34 | # bsz: 1 35 | if "bsz: " in line: 36 | bsz = int(line.split("bsz: ")[1]) 37 | # end2end total_time: 443.026 s, iterations: 7001, throughput 15.80 it/s 38 | if "end2end total_time:" not in line: 39 | continue 40 | iteration = int(line.split("iterations: ")[1].split(",")[0]) 41 | running_time = int(float(line.split("end2end total_time: ")[1].split(" s")[0])) 42 | 43 | for r in range(iteration-bsz, iteration): 44 | if r in iterations: 45 | results.append(running_time) 46 | break 47 | return results 48 | 49 | def get_test_psnr_at_iterations(expe_folder, iterations): 50 | a_suffix = get_suffix_in_folder(expe_folder)[0] 51 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 52 | lines = open(a_log_file, "r").readlines() 53 | results = [] 54 | bsz = 1 55 | for line in lines: 56 | # bsz: 1 57 | if "bsz: " in line: 58 | bsz = int(line.split("bsz: ")[1]) 59 | # [ITER 50000] Evaluating test: L1 0.01809605024755001 PSNR 29.30947494506836 60 | if "Evaluating test:" not in line: 61 | continue 62 | iteration = int(line.split("[ITER ")[1].split("]")[0]) 63 | L1 = float(line.split("L1 ")[1].split(" PSNR")[0]) 64 | PSNR = float(line.split("PSNR ")[1]) 65 | for r in range(iteration, iteration+bsz): 66 | if r in iterations: 67 | results.append(round(PSNR, 2)) 68 | return results 69 | 70 | def get_test_psnr_list_from_logfile(expe_folder): 71 | a_suffix = get_suffix_in_folder(expe_folder)[0] 72 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 73 | lines = open(a_log_file, "r").readlines() 74 | results = [] 75 | for line in lines: 76 | # [ITER 50000] Evaluating test: L1 0.01809605024755001 PSNR 29.30947494506836 77 | if "Evaluating test:" not in line: 78 | continue 79 | iteration = int(line.split("[ITER ")[1].split("]")[0]) 80 | L1 = float(line.split("L1 ")[1].split(" PSNR")[0]) 81 | PSNR = float(line.split("PSNR ")[1]) 82 | results.append({ 83 | "iteration": iteration, 84 | "L1": round(L1, 2), 85 | "PSNR": round(PSNR, 2) 86 | }) 87 | return results 88 | 89 | def extract_from_mip360_all9scene(folder): 90 | # if os.path.exists(os.path.join(folder, "mip360_all9scene.json")): 91 | # print("mip360_all9scene.json already exists for ", folder) 92 | # return 93 | 94 | # counter kitchen room stump bicycle garden bonsai flowers treehill 95 | scene_names = [ 96 | "counter", 97 | "kitchen", 98 | "room", 99 | "stump", 100 | "bicycle", 101 | "garden", 102 | "bonsai", 103 | ] 104 | check_iterations = [7000, 30000] 105 | results = {} 106 | for scene in scene_names: 107 | scene_folder = os.path.join(folder, "e_"+scene) 108 | if not os.path.exists(scene_folder): 109 | continue 110 | running_time_all = get_running_time_at_iterations(scene_folder, check_iterations) 111 | psnr_all = get_test_psnr_at_iterations(scene_folder, check_iterations) 112 | results[scene] = {} 113 | for iteration, running_time, psnr in zip(check_iterations, running_time_all, psnr_all): 114 | results[scene][iteration] = { 115 | "running_time": running_time, 116 | "psnr": psnr, 117 | "throughput": round(iteration/running_time, 2) 118 | } 119 | 120 | json.dump(results, open(os.path.join(folder, "mip360_all9scene.json"), "w"), indent=4) 121 | print("Generated mip360_all9scene.json for ", folder) 122 | 123 | def plot_release_mip360(expe_folder_base): 124 | 125 | expe_sets = [ 126 | "1g_1b", 127 | "4g_1b", 128 | "4g_4b", 129 | ] 130 | for expe_set in expe_sets: 131 | extract_from_mip360_all9scene(f"{expe_folder_base}/{expe_set}/") 132 | 133 | all_scenes = [ 134 | "stump", 135 | "bicycle", 136 | "kitchen", 137 | "room", 138 | "counter", 139 | "garden", 140 | "bonsai", 141 | ] 142 | 143 | compare_iterations = ["7000", "30000"] 144 | unit_map = { 145 | "throughput": "its", 146 | "psnr": "dB", 147 | "running_time": "min" 148 | } 149 | 150 | for metric in ["running_time", "psnr"]: 151 | first_col_name = "30k Train Time(min)" if metric == "running_time" else "30k Test PSNR" 152 | for iter in compare_iterations: 153 | df = pd.DataFrame(columns=[first_col_name]+all_scenes) 154 | for i, expe in enumerate(["1 GPU + Batch Size=1", "4 GPU + Batch Size=1", "4 GPU + Batch Size=4"]): 155 | row = [expe] 156 | expe_folder = os.path.join(expe_folder_base, expe_sets[i]) 157 | results = json.load(open(os.path.join(expe_folder, "mip360_all9scene.json"), "r")) 158 | for scene in all_scenes: 159 | if metric == "running_time": 160 | row.append(round(results[scene][iter]["running_time"]/60, 2)) 161 | else: 162 | row.append(results[scene][iter][metric]) 163 | df.loc[len(df)] = row 164 | df.to_markdown(os.path.join(expe_folder, f"mip360_compare_{metric}_{iter}.md"), index=False) 165 | 166 | if __name__ == "__main__": 167 | # read args from command line 168 | expe_folder_base = sys.argv[1] 169 | print("Expe base folder: ", expe_folder_base) 170 | 171 | plot_release_mip360(expe_folder_base) 172 | 173 | 174 | -------------------------------------------------------------------------------- /examples/mip360/eval_all_mip360.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | expe_folder_base=$1 # specify the folder to save these experiments log and checkpoints. 4 | dataset_folder=$2 # specify the dataset folder 5 | 6 | # Train all scenes on single GPU with batch size 1 7 | bash examples/mip360/1g_1b.sh $expe_folder_base $dataset_folder 8 | 9 | # Train all scenes on 4 GPU distributed with batch size 1 10 | bash examples/mip360/4g_1b.sh $expe_folder_base $dataset_folder 11 | 12 | # Train all scenes on 4 GPU distributed with batch size 4 13 | bash examples/mip360/4g_4b.sh $expe_folder_base $dataset_folder 14 | 15 | # Render and calculate metrics for all experiments 16 | bash examples/mip360/render_and_metrics.sh $expe_folder_base $dataset_folder 17 | 18 | # analyze the results from logs and generate the result table 19 | python examples/mip360/analyze_results.py $expe_folder_base -------------------------------------------------------------------------------- /examples/mip360/render_and_metrics.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments are saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | SCENE=(counter bicycle stump garden room bonsai kitchen) 13 | EXPE_sets=("1g_1b" "4g_1b" "4g_4b") 14 | BSZ_for_expesets=(1 1 4) 15 | 16 | for i in ${!EXPE_sets[@]}; do 17 | expe_set=${EXPE_sets[$i]} 18 | bsz=${BSZ_for_expesets[$i]} 19 | 20 | for scene in ${SCENE[@]}; do 21 | 22 | if [ "$scene" = "bicycle" ] || [ "$scene" = "stump" ] || [ "$scene" = "garden" ]; then 23 | image_folder="images_4" 24 | else 25 | image_folder="images_2" 26 | fi 27 | 28 | EXPE_NAME="${expe_folder}/${expe_set}/e_${scene}" 29 | CHECKPOINTS=(7000 30000) 30 | 31 | for iter in ${CHECKPOINTS[@]}; do 32 | python render.py \ 33 | -s ${dataset_folder}/${scene} \ 34 | --images ${image_folder} \ 35 | --model_path ${EXPE_NAME} \ 36 | --iteration $iter \ 37 | --skip_train \ 38 | --llffhold 8 39 | done 40 | 41 | python metrics.py \ 42 | --mode test \ 43 | --model_paths ${EXPE_NAME} 44 | 45 | sleep 2 46 | done 47 | 48 | done 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /examples/mip360/render_and_metrics_gpus.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments are saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | # SCENE=(counter bicycle stump garden room bonsai kitchen) 13 | # EXPE_sets=("1g_1b" "4g_1b" "4g_4b") 14 | # BSZ_for_expesets=(1 1 4) 15 | 16 | SCENE=(garden) 17 | EXPE_sets=("4g_4b") 18 | # BSZ_for_expesets=(1 1 4) 19 | # NPROC_for_expesets=(1 4 4) 20 | 21 | for i in ${!EXPE_sets[@]}; do 22 | expe_set=${EXPE_sets[$i]} 23 | # bsz=${BSZ_for_expesets[$i]} 24 | # nproc=${NPROC_for_expesets[$i]} 25 | bsz=4 26 | nproc=4 27 | 28 | for scene in ${SCENE[@]}; do 29 | 30 | if [ "$scene" = "bicycle" ] || [ "$scene" = "stump" ] || [ "$scene" = "garden" ]; then 31 | image_folder="images_4" 32 | elif [ "$scene" = "train" ]; then 33 | image_folder="images" 34 | else 35 | image_folder="images_2" 36 | fi 37 | 38 | EXPE_NAME="${expe_folder}/${expe_set}/e_${scene}" 39 | 40 | # you should change the checkpoints as it may depends on the batch size 41 | CHECKPOINTS=(29993) 42 | 43 | for iter in ${CHECKPOINTS[@]}; do 44 | torchrun --standalone --nnodes=1 --nproc-per-node=${nproc} render.py \ 45 | -s ${dataset_folder}/${scene} \ 46 | --images ${image_folder} \ 47 | --model_path ${EXPE_NAME} \ 48 | --iteration $iter \ 49 | --llffhold 8 \ 50 | --skip_train \ 51 | --bsz ${bsz} 52 | done 53 | 54 | python metrics.py \ 55 | --mode test \ 56 | --model_paths ${EXPE_NAME} 57 | 58 | rm -rf ${EXPE_NAME}/test 59 | done 60 | 61 | done 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /examples/mip360_4k/1g_1b_4k.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments will be saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | SCENE=(garden bicycle) 13 | BSZ=1 14 | expe_set_name="1g_1b_4k" 15 | 16 | # Monitoring Settings 17 | monitor_opts="--enable_timer \ 18 | --end2end_time \ 19 | --check_gpu_memory \ 20 | --check_cpu_memory" 21 | 22 | # 4k reconstruction requires even denser sampling 23 | densify_opts="--densify_grad_threshold 0.0001 --percent_dense 0.002" 24 | 25 | for scene in ${SCENE[@]}; do 26 | expe_name="e_${scene}" 27 | image_folder="images" # train original resolution 28 | 29 | torchrun --standalone --nnodes=1 --nproc-per-node=1 train.py \ 30 | -s ${dataset_folder}/${scene} \ 31 | --images ${image_folder} \ 32 | --llffhold 8 \ 33 | --iterations 50000 \ 34 | --log_interval 250 \ 35 | --model_path ${expe_folder}/1g_1b_4k/${expe_name} \ 36 | --bsz $BSZ \ 37 | $monitor_opts \ 38 | --test_iterations 7000 15000 30000 50000 \ 39 | --save_iterations 7000 30000 50000 \ 40 | --preload_dataset_to_gpu_threshold 0 \ 41 | --eval 42 | done 43 | 44 | -------------------------------------------------------------------------------- /examples/mip360_4k/4g_1b_4k.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments will be saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | SCENE=(garden bicycle) 13 | BSZ=1 14 | expe_set_name="4g_1b_4k" 15 | 16 | # Monitoring Settings 17 | monitor_opts="--enable_timer \ 18 | --end2end_time \ 19 | --check_gpu_memory \ 20 | --check_cpu_memory" 21 | 22 | # 4k reconstruction requires even denser sampling 23 | densify_opts="--densify_grad_threshold 0.0001 --percent_dense 0.002" 24 | 25 | for scene in ${SCENE[@]}; do 26 | expe_name="e_${scene}" 27 | image_folder="images" # train original resolution 28 | 29 | torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py \ 30 | -s ${dataset_folder}/${scene} \ 31 | --images ${image_folder} \ 32 | --llffhold 8 \ 33 | --iterations 50000 \ 34 | --log_interval 250 \ 35 | --model_path ${expe_folder}/4g_1b_4k/${expe_name} \ 36 | --bsz $BSZ \ 37 | $monitor_opts \ 38 | --test_iterations 7000 15000 30000 50000 \ 39 | --save_iterations 7000 30000 50000 \ 40 | --preload_dataset_to_gpu_threshold 0 \ 41 | --eval 42 | done 43 | 44 | -------------------------------------------------------------------------------- /examples/mip360_4k/analyze_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import os 4 | import sys 5 | 6 | def get_suffix_in_folder(folder): 7 | 8 | if not os.path.exists(folder): 9 | return None 10 | 11 | if not folder.endswith("/"): 12 | folder += "/" 13 | 14 | suffix_list_candidates = [] 15 | for ws in [1,2,4,8,16,32]: 16 | for rk in range(ws): 17 | suffix_list_candidates.append(f"ws={ws}_rk={rk}") 18 | 19 | suffix_list = [] 20 | for suffix in suffix_list_candidates: 21 | if os.path.exists(folder + "python_" + suffix + ".log"): 22 | suffix_list.append(suffix) 23 | 24 | return suffix_list 25 | 26 | 27 | def get_running_time_at_iterations(expe_folder, iterations): 28 | a_suffix = get_suffix_in_folder(expe_folder)[0] 29 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 30 | lines = open(a_log_file, "r").readlines() 31 | results = [] 32 | bsz = 1 33 | for line in lines: 34 | # bsz: 1 35 | if "bsz: " in line: 36 | bsz = int(line.split("bsz: ")[1]) 37 | # end2end total_time: 443.026 s, iterations: 7001, throughput 15.80 it/s 38 | if "end2end total_time:" not in line: 39 | continue 40 | iteration = int(line.split("iterations: ")[1].split(",")[0]) 41 | running_time = int(float(line.split("end2end total_time: ")[1].split(" s")[0])) 42 | 43 | for r in range(iteration-bsz, iteration): 44 | if r in iterations: 45 | results.append(running_time) 46 | break 47 | return results 48 | 49 | def get_test_psnr_at_iterations(expe_folder, iterations): 50 | a_suffix = get_suffix_in_folder(expe_folder)[0] 51 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 52 | lines = open(a_log_file, "r").readlines() 53 | results = [] 54 | bsz = 1 55 | for line in lines: 56 | # bsz: 1 57 | if "bsz: " in line: 58 | bsz = int(line.split("bsz: ")[1]) 59 | # [ITER 50000] Evaluating test: L1 0.01809605024755001 PSNR 29.30947494506836 60 | if "Evaluating test:" not in line: 61 | continue 62 | iteration = int(line.split("[ITER ")[1].split("]")[0]) 63 | L1 = float(line.split("L1 ")[1].split(" PSNR")[0]) 64 | PSNR = float(line.split("PSNR ")[1]) 65 | for r in range(iteration, iteration+bsz): 66 | if r in iterations: 67 | results.append(round(PSNR, 2)) 68 | return results 69 | 70 | def get_test_psnr_list_from_logfile(expe_folder): 71 | a_suffix = get_suffix_in_folder(expe_folder)[0] 72 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 73 | lines = open(a_log_file, "r").readlines() 74 | results = [] 75 | for line in lines: 76 | # [ITER 50000] Evaluating test: L1 0.01809605024755001 PSNR 29.30947494506836 77 | if "Evaluating test:" not in line: 78 | continue 79 | iteration = int(line.split("[ITER ")[1].split("]")[0]) 80 | L1 = float(line.split("L1 ")[1].split(" PSNR")[0]) 81 | PSNR = float(line.split("PSNR ")[1]) 82 | results.append({ 83 | "iteration": iteration, 84 | "L1": round(L1, 2), 85 | "PSNR": round(PSNR, 2) 86 | }) 87 | return results 88 | 89 | def get_max_memory_at_iterations(expe_folder, check_iterations): 90 | a_suffix = get_suffix_in_folder(expe_folder)[0] 91 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 92 | lines = open(a_log_file, "r").readlines() 93 | results = [] 94 | for line in lines: 95 | # iteration[50000,50001) densify_and_prune. Now num of 3dgs: 1322628. Now Memory usage: 1.4657483100891113 GB. Max Memory usage: 4.159056186676025 GB. Max Reserved Memory: 8.357421875 GB. Now Reserved Memory: 2.44921875 GB. 96 | if "Max Reserved Memory: " not in line: 97 | continue 98 | iteration_start, iteration_end = line.split("iteration[")[1].split(")")[0].split(",") 99 | iteration_start, iteration_end = int(iteration_start), int(iteration_end) 100 | max_memory = float(line.split("Max Reserved Memory: ")[1].split(" GB")[0]) 101 | for r in range(iteration_start, iteration_end): 102 | if r in check_iterations: 103 | results.append(round(max_memory, 2)) 104 | return results 105 | 106 | def extract_from_mip360_all9scene(folder): 107 | # if os.path.exists(os.path.join(folder, "mip360_all9scene.json")): 108 | # print("mip360_all9scene.json already exists for ", folder) 109 | # return 110 | 111 | scene_names = [ 112 | "bicycle", 113 | "garden", 114 | ] 115 | check_iterations = [7000, 30000] 116 | results = {} 117 | for scene in scene_names: 118 | scene_folder = os.path.join(folder, "e_"+scene) 119 | if not os.path.exists(scene_folder): 120 | continue 121 | running_time_all = get_running_time_at_iterations(scene_folder, check_iterations) 122 | psnr_all = get_test_psnr_at_iterations(scene_folder, check_iterations) 123 | memory_all = get_max_memory_at_iterations(scene_folder, check_iterations) 124 | results[scene] = {} 125 | for iteration, running_time, psnr, max_memory in zip(check_iterations, running_time_all, psnr_all, memory_all): 126 | results[scene][iteration] = { 127 | "running_time": running_time, 128 | "psnr": psnr, 129 | "throughput": round(iteration/running_time, 2), 130 | "max_memory": max_memory 131 | } 132 | 133 | json.dump(results, open(os.path.join(folder, "mip360_all9scene.json"), "w"), indent=4) 134 | print("Generated mip360_all9scene.json for ", folder) 135 | 136 | def convert2readable(seconds): 137 | # 3h 30min 138 | hours = seconds // 3600 139 | minutes = (seconds % 3600) // 60 140 | return f"{hours}h {minutes}min" 141 | 142 | def plot_release_mip360_4k(expe_folder_base): 143 | 144 | expe_sets = [ 145 | "1g_1b_4k", 146 | "4g_1b_4k", 147 | ] 148 | for expe_set in expe_sets: 149 | extract_from_mip360_all9scene(f"{expe_folder_base}/{expe_set}/") 150 | 151 | all_scenes = [ 152 | "bicycle", 153 | "garden", 154 | ] 155 | 156 | df = pd.DataFrame(columns=["Configuration", "50k Training Time", "Memory Per GPU", "PSNR"]) 157 | 158 | for scene in all_scenes: 159 | for i, expe in enumerate(["1 GPU + Batch Size=1", "4 GPU + Batch Size=1"]): 160 | row = [scene + " + " + expe] 161 | 162 | expe_folder = os.path.join(expe_folder_base, expe_sets[i]) 163 | results = json.load(open(os.path.join(expe_folder, "mip360_all9scene.json"), "r")) 164 | 165 | row.append(convert2readable(results[scene]["30000"]["running_time"])) 166 | row.append(results[scene]["30000"]["max_memory"]) 167 | row.append(results[scene]["30000"]["psnr"]) 168 | df.loc[len(df)] = row 169 | df.to_markdown(os.path.join(expe_folder_base, f"mip360_4k_compare_{scene}.md"), index=False) 170 | 171 | if __name__ == "__main__": 172 | # read args from command line 173 | expe_folder_base = sys.argv[1] 174 | print("Expe base folder: ", expe_folder_base) 175 | 176 | plot_release_mip360_4k(expe_folder_base) 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /examples/mip360_4k/eval_mip360_4k.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | expe_folder_base=$1 # specify the folder to save these experiments log and checkpoints 4 | dataset_folder=$2 # specify the dataset folder 5 | 6 | # Train all scenes on single GPU with batch size 1 7 | bash examples/mip360_4k/1g_1b.sh $expe_folder_base $dataset_folder 8 | 9 | # Train all scenes on 4 GPU distributed with batch size 1 10 | bash examples/mip360_4k/4g_1b.sh $expe_folder_base $dataset_folder 11 | 12 | # analyze the results from logs and generate the result table 13 | python examples/mip360_4k/analyze_results.py $expe_folder_base -------------------------------------------------------------------------------- /examples/train_truck_1k/analyze_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import os 4 | import sys 5 | 6 | def get_suffix_in_folder(folder): 7 | 8 | if not os.path.exists(folder): 9 | return None 10 | 11 | if not folder.endswith("/"): 12 | folder += "/" 13 | 14 | suffix_list_candidates = [] 15 | for ws in [1,2,4,8,16,32]: 16 | for rk in range(ws): 17 | suffix_list_candidates.append(f"ws={ws}_rk={rk}") 18 | 19 | suffix_list = [] 20 | for suffix in suffix_list_candidates: 21 | if os.path.exists(folder + "python_" + suffix + ".log"): 22 | suffix_list.append(suffix) 23 | 24 | return suffix_list 25 | 26 | 27 | def get_running_time_at_iterations(expe_folder, iterations): 28 | a_suffix = get_suffix_in_folder(expe_folder)[0] 29 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 30 | lines = open(a_log_file, "r").readlines() 31 | results = [] 32 | bsz = 1 33 | for line in lines: 34 | # bsz: 1 35 | if "bsz: " in line: 36 | bsz = int(line.split("bsz: ")[1]) 37 | # end2end total_time: 443.026 s, iterations: 7001, throughput 15.80 it/s 38 | if "end2end total_time:" not in line: 39 | continue 40 | iteration = int(line.split("iterations: ")[1].split(",")[0]) 41 | running_time = int(float(line.split("end2end total_time: ")[1].split(" s")[0])) 42 | 43 | for r in range(iteration-bsz, iteration): 44 | if r in iterations: 45 | results.append(running_time) 46 | break 47 | return results 48 | 49 | def get_test_psnr_at_iterations(expe_folder, iterations): 50 | a_suffix = get_suffix_in_folder(expe_folder)[0] 51 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 52 | lines = open(a_log_file, "r").readlines() 53 | results = [] 54 | bsz = 1 55 | for line in lines: 56 | # bsz: 1 57 | if "bsz: " in line: 58 | bsz = int(line.split("bsz: ")[1]) 59 | # [ITER 50000] Evaluating test: L1 0.01809605024755001 PSNR 29.30947494506836 60 | if "Evaluating test:" not in line: 61 | continue 62 | iteration = int(line.split("[ITER ")[1].split("]")[0]) 63 | L1 = float(line.split("L1 ")[1].split(" PSNR")[0]) 64 | PSNR = float(line.split("PSNR ")[1]) 65 | for r in range(iteration, iteration+bsz): 66 | if r in iterations: 67 | results.append(round(PSNR, 2)) 68 | return results 69 | 70 | def get_test_psnr_list_from_logfile(expe_folder): 71 | a_suffix = get_suffix_in_folder(expe_folder)[0] 72 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 73 | lines = open(a_log_file, "r").readlines() 74 | results = [] 75 | for line in lines: 76 | # [ITER 50000] Evaluating test: L1 0.01809605024755001 PSNR 29.30947494506836 77 | if "Evaluating test:" not in line: 78 | continue 79 | iteration = int(line.split("[ITER ")[1].split("]")[0]) 80 | L1 = float(line.split("L1 ")[1].split(" PSNR")[0]) 81 | PSNR = float(line.split("PSNR ")[1]) 82 | results.append({ 83 | "iteration": iteration, 84 | "L1": round(L1, 2), 85 | "PSNR": round(PSNR, 2) 86 | }) 87 | return results 88 | 89 | def get_max_memory_at_iterations(expe_folder, check_iterations): 90 | a_suffix = get_suffix_in_folder(expe_folder)[0] 91 | a_log_file = os.path.join(expe_folder, f"python_{a_suffix}.log") 92 | lines = open(a_log_file, "r").readlines() 93 | results = [] 94 | for line in lines: 95 | # iteration[50000,50001) densify_and_prune. Now num of 3dgs: 1322628. Now Memory usage: 1.4657483100891113 GB. Max Memory usage: 4.159056186676025 GB. Max Reserved Memory: 8.357421875 GB. Now Reserved Memory: 2.44921875 GB. 96 | if "Max Reserved Memory: " not in line: 97 | continue 98 | iteration_start, iteration_end = line.split("iteration[")[1].split(")")[0].split(",") 99 | iteration_start, iteration_end = int(iteration_start), int(iteration_end) 100 | max_memory = float(line.split("Max Reserved Memory: ")[1].split(" GB")[0]) 101 | for r in range(iteration_start, iteration_end): 102 | if r in check_iterations: 103 | results.append(round(max_memory, 2)) 104 | return results 105 | 106 | def extract_from_tandb_scene(folder): 107 | 108 | scene_names = [ 109 | "train", 110 | "truck", 111 | ] 112 | check_iterations = [7000, 30000] 113 | results = {} 114 | for scene in scene_names: 115 | scene_folder = os.path.join(folder, "e_"+scene) 116 | if not os.path.exists(scene_folder): 117 | continue 118 | running_time_all = get_running_time_at_iterations(scene_folder, check_iterations) 119 | psnr_all = get_test_psnr_at_iterations(scene_folder, check_iterations) 120 | memory_all = get_max_memory_at_iterations(scene_folder, check_iterations) 121 | results[scene] = {} 122 | for iteration, running_time, psnr, max_memory in zip(check_iterations, running_time_all, psnr_all, memory_all): 123 | results[scene][iteration] = { 124 | "running_time": running_time, 125 | "psnr": psnr, 126 | "throughput": round(iteration/running_time, 2), 127 | "max_memory": max_memory 128 | } 129 | 130 | json.dump(results, open(os.path.join(folder, "tandb_scenes.json"), "w"), indent=4) 131 | print("Generated tandb_scenes.json for ", folder) 132 | 133 | def convert2readable(seconds): 134 | if seconds < 60: 135 | return f"{seconds}s" 136 | if seconds < 3600: 137 | return f"{seconds//60}min {seconds%60}s" 138 | # 3h 30min 139 | hours = seconds // 3600 140 | minutes = (seconds % 3600) // 60 141 | return f"{hours}h {minutes}min" 142 | 143 | def plot_release_tandb(expe_folder_base): 144 | 145 | expe_sets = [ 146 | # "1g_1b", 147 | "4g_8b", 148 | # "4g_16b", 149 | ] 150 | for expe_set in expe_sets: 151 | extract_from_tandb_scene(f"{expe_folder_base}/{expe_set}/") 152 | 153 | all_scenes = [ 154 | "train", 155 | "truck", 156 | ] 157 | 158 | df = pd.DataFrame(columns=["Configuration", "7k Training Time", "7k test PSNR", "30k Training Time", "30k test PSNR"]) 159 | 160 | iterations_to_check = ["7000", "30000"] 161 | 162 | for scene in all_scenes: 163 | for i, expe in enumerate(["4 GPU + Batch Size=8"]): 164 | # for i, expe in enumerate(["1 GPU + Batch Size=1", "4 GPU + Batch Size=8", "4 GPU + Batch Size=16"]): 165 | row = [scene + " + " + expe] 166 | expe_folder = os.path.join(expe_folder_base, expe_sets[i]) 167 | results = json.load(open(os.path.join(expe_folder, "tandb_scenes.json"), "r")) 168 | 169 | for iter in iterations_to_check: 170 | row.append(convert2readable(results[scene][iter]["running_time"])) 171 | row.append(results[scene][iter]["psnr"]) 172 | df.loc[len(df)] = row 173 | df.to_markdown(os.path.join(expe_folder_base, f"tandb_compare_{scene}.md"), index=False) 174 | 175 | if __name__ == "__main__": 176 | # read args from command line 177 | expe_folder_base = sys.argv[1] 178 | print("Expe base folder: ", expe_folder_base) 179 | 180 | plot_release_tandb(expe_folder_base) 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /examples/train_truck_1k/eval_train_truck_1k.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | expe_folder_base=$1 # specify the folder to save these experiments log and checkpoints 4 | dataset_folder=$2 # specify the dataset folder 5 | 6 | # Train all scenes on single GPU with batch size 1 7 | bash examples/train_truck_1k/train_truck_1k.sh $expe_folder_base ${dataset_folder}/tandt 8 | 9 | # analyze the results from logs and generate the result table 10 | python examples/train_truck_1k/analyze_results.py $expe_folder_base -------------------------------------------------------------------------------- /examples/train_truck_1k/train_truck_1k.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $# -ne 2 ]; then 3 | echo "Please specify exactly two arguments: the folder to save the experiments' log and checkpoints, and the folder of the dataset." 4 | exit 1 5 | fi 6 | 7 | expe_folder=$1 8 | echo "The experiments will be saved in $expe_folder" 9 | dataset_folder=$2 10 | echo "The dataset is in $dataset_folder" 11 | 12 | # scenes to be trained 13 | SCENE=(train truck) 14 | # batch size 15 | BSZ=8 16 | 17 | # Monitoring Settings 18 | monitor_opts="--enable_timer \ 19 | --end2end_time \ 20 | --check_gpu_memory \ 21 | --check_cpu_memory" 22 | 23 | for scene in ${SCENE[@]}; do 24 | expe_name="e_${scene}" 25 | 26 | torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py \ 27 | -s ${dataset_folder}/${scene} \ 28 | --images ${image_folder} \ 29 | --llffhold 8 \ 30 | --iterations 30000 \ 31 | --log_interval 250 \ 32 | --model_path ${expe_folder}/4g_8b/${expe_name} \ 33 | --bsz $BSZ \ 34 | $monitor_opts \ 35 | --test_iterations 7000 15000 30000 \ 36 | --save_iterations 7000 30000 \ 37 | --eval 38 | done 39 | -------------------------------------------------------------------------------- /gaussian_renderer/distribution_config.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import utils.general_utils as utils 3 | 4 | ImageDistributionConfig = namedtuple( 5 | "ImageDistributionConfig", 6 | [ 7 | "loss_distribution_mode", 8 | "workloads_division_mode", 9 | "avoid_pixels_all2all", 10 | "local_running_time_mode", 11 | ], 12 | ) 13 | 14 | 15 | def init_image_distribution_config(args): 16 | 17 | if args.image_distribution_mode == "0": 18 | args.image_distribution_config = ImageDistributionConfig( 19 | loss_distribution_mode="replicated_loss_computation", 20 | workloads_division_mode="DivisionStrategyUniform", 21 | avoid_pixels_all2all=False, 22 | local_running_time_mode=["backward_render_time"], 23 | ) 24 | 25 | elif args.image_distribution_mode == "1": 26 | args.image_distribution_config = ImageDistributionConfig( 27 | loss_distribution_mode="general_distributed_loss_computation", 28 | workloads_division_mode="DivisionStrategyUniform", 29 | avoid_pixels_all2all=False, 30 | local_running_time_mode=["backward_render_time"], 31 | ) 32 | 33 | elif args.image_distribution_mode == "2": 34 | args.image_distribution_config = ImageDistributionConfig( 35 | loss_distribution_mode="general_distributed_loss_computation", 36 | workloads_division_mode="DivisionStrategyDynamicAdjustment", 37 | avoid_pixels_all2all=False, 38 | local_running_time_mode=["backward_render_time"], 39 | ) 40 | 41 | elif args.image_distribution_mode == "3": 42 | args.image_distribution_config = ImageDistributionConfig( 43 | loss_distribution_mode="avoid_pixel_all2all_loss_computation", 44 | workloads_division_mode="DivisionStrategyDynamicAdjustment", 45 | avoid_pixels_all2all=True, 46 | local_running_time_mode=[ 47 | "backward_render_time", 48 | "forward_render_time", 49 | "forward_loss_time", 50 | "forward_loss_time", 51 | ], 52 | ) 53 | 54 | elif args.image_distribution_mode == "4": 55 | args.image_distribution_config = ImageDistributionConfig( 56 | loss_distribution_mode="avoid_pixel_all2all_loss_computation_adjust_mode6", 57 | workloads_division_mode="DivisionStrategyAsGrid", 58 | avoid_pixels_all2all=True, 59 | local_running_time_mode=[ 60 | "backward_render_time", 61 | "forward_render_time", 62 | "forward_loss_time", 63 | "forward_loss_time", 64 | ], 65 | ) 66 | 67 | else: 68 | raise ValueError( 69 | f"Unknown image_distribution_mode: {args.image_distribution_mode}" 70 | ) 71 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | 36 | def try_connect(): 37 | global conn, addr, listener 38 | try: 39 | conn, addr = listener.accept() 40 | print(f"\nConnected by {addr}") 41 | conn.settimeout(None) 42 | except Exception as inst: 43 | pass 44 | 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, "little") 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | 54 | def send(message_bytes, verify): 55 | global conn 56 | if message_bytes != None: 57 | conn.sendall(message_bytes) 58 | conn.sendall(len(verify).to_bytes(4, "little")) 59 | conn.sendall(bytes(verify, "ascii")) 60 | 61 | 62 | def receive(): 63 | message = read() 64 | 65 | width = message["resolution_x"] 66 | height = message["resolution_y"] 67 | 68 | if width != 0 and height != 0: 69 | try: 70 | do_training = bool(message["train"]) 71 | fovy = message["fov_y"] 72 | fovx = message["fov_x"] 73 | znear = message["z_near"] 74 | zfar = message["z_far"] 75 | do_shs_python = bool(message["shs_python"]) 76 | do_rot_scale_python = bool(message["rot_scale_python"]) 77 | keep_alive = bool(message["keep_alive"]) 78 | scaling_modifier = message["scaling_modifier"] 79 | world_view_transform = torch.reshape( 80 | torch.tensor(message["view_matrix"]), (4, 4) 81 | ).cuda() 82 | world_view_transform[:, 1] = -world_view_transform[:, 1] 83 | world_view_transform[:, 2] = -world_view_transform[:, 2] 84 | full_proj_transform = torch.reshape( 85 | torch.tensor(message["view_projection_matrix"]), (4, 4) 86 | ).cuda() 87 | full_proj_transform[:, 1] = -full_proj_transform[:, 1] 88 | custom_cam = MiniCam( 89 | width, 90 | height, 91 | fovy, 92 | fovx, 93 | znear, 94 | zfar, 95 | world_view_transform, 96 | full_proj_transform, 97 | ) 98 | except Exception as e: 99 | print("") 100 | traceback.print_exc() 101 | raise e 102 | return ( 103 | custom_cam, 104 | do_training, 105 | do_shs_python, 106 | do_rot_scale_python, 107 | keep_alive, 108 | scaling_modifier, 109 | ) 110 | else: 111 | return None, None, None, None, None, None 112 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips( 7 | x: torch.Tensor, y: torch.Tensor, net_type: str = "alex", version: str = "0.1" 8 | ): 9 | r"""Function that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | x, y (torch.Tensor): the input tensors to compare. 14 | net_type (str): the network type to compare the features: 15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 16 | version (str): the version of LPIPS. Default: 0.1. 17 | """ 18 | device = x.device 19 | criterion = LPIPS(net_type, version).to(device) 20 | return criterion(x, y) 21 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | 18 | def __init__(self, net_type: str = "alex", version: str = "0.1"): 19 | 20 | assert version in ["0.1"], "v0.1 is only supported now" 21 | 22 | super(LPIPS, self).__init__() 23 | 24 | # pretrained network 25 | self.net = get_network(net_type) 26 | 27 | # linear layers 28 | self.lin = LinLayers(self.net.n_channels_list) 29 | self.lin.load_state_dict(get_state_dict(net_type, version)) 30 | 31 | def forward(self, x: torch.Tensor, y: torch.Tensor): 32 | feat_x, feat_y = self.net(x), self.net(y) 33 | 34 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 35 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 36 | 37 | return torch.sum(torch.cat(res, 0), 0, True) 38 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == "alex": 14 | return AlexNet() 15 | elif net_type == "squeeze": 16 | return SqueezeNet() 17 | elif net_type == "vgg": 18 | return VGG16() 19 | else: 20 | raise NotImplementedError("choose net_type from [alex, squeeze, vgg].") 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__( 26 | [ 27 | nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False)) 28 | for nc in n_channels_list 29 | ] 30 | ) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 43 | ) 44 | self.register_buffer( 45 | "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 46 | ) 47 | 48 | def set_requires_grad(self, state: bool): 49 | for param in chain(self.parameters(), self.buffers()): 50 | param.requires_grad = state 51 | 52 | def z_score(self, x: torch.Tensor): 53 | return (x - self.mean) / self.std 54 | 55 | def forward(self, x: torch.Tensor): 56 | x = self.z_score(x) 57 | 58 | output = [] 59 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 60 | x = layer(x) 61 | if i in self.target_layers: 62 | output.append(normalize_activation(x)) 63 | if len(output) == len(self.target_layers): 64 | break 65 | return output 66 | 67 | 68 | class SqueezeNet(BaseNet): 69 | def __init__(self): 70 | super(SqueezeNet, self).__init__() 71 | 72 | self.layers = models.squeezenet1_1(True).features 73 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 74 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 75 | 76 | self.set_requires_grad(False) 77 | 78 | 79 | class AlexNet(BaseNet): 80 | def __init__(self): 81 | super(AlexNet, self).__init__() 82 | 83 | self.layers = models.alexnet(True).features 84 | self.target_layers = [2, 5, 8, 10, 12] 85 | self.n_channels_list = [64, 192, 384, 256, 256] 86 | 87 | self.set_requires_grad(False) 88 | 89 | 90 | class VGG16(BaseNet): 91 | def __init__(self): 92 | super(VGG16, self).__init__() 93 | 94 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 95 | self.target_layers = [4, 9, 16, 23, 30] 96 | self.n_channels_list = [64, 128, 256, 512, 512] 97 | 98 | self.set_requires_grad(False) 99 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = "alex", version: str = "0.1"): 12 | # build url 13 | url = ( 14 | "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/" 15 | + f"master/lpips/weights/v{version}/{net_type}.pth" 16 | ) 17 | 18 | # download 19 | old_state_dict = torch.hub.load_state_dict_from_url( 20 | url, 21 | progress=True, 22 | map_location=None if torch.cuda.is_available() else torch.device("cpu"), 23 | ) 24 | 25 | # rename keys 26 | new_state_dict = OrderedDict() 27 | for key, val in old_state_dict.items(): 28 | new_key = key 29 | new_key = new_key.replace("lin", "") 30 | new_key = new_key.replace("model.", "") 31 | new_state_dict[new_key] = val 32 | 33 | return new_state_dict 34 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from utils.general_utils import set_args 19 | from lpipsPyTorch import lpips 20 | import json 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser 24 | 25 | 26 | def readImages(renders_dir, gt_dir): 27 | print("Reading images from", renders_dir) 28 | renders = [] 29 | gts = [] 30 | image_names = [] 31 | for fname in os.listdir(renders_dir): 32 | render = Image.open(renders_dir / fname) 33 | gt = Image.open(gt_dir / fname) 34 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 35 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 36 | image_names.append(fname) 37 | return renders, gts, image_names 38 | 39 | 40 | def evaluate(model_paths, mode): 41 | 42 | full_dict = {} 43 | per_view_dict = {} 44 | full_dict_polytopeonly = {} 45 | per_view_dict_polytopeonly = {} 46 | print("") 47 | 48 | for scene_dir in model_paths: 49 | try: 50 | print("Scene:", scene_dir) 51 | full_dict[scene_dir] = {} 52 | per_view_dict[scene_dir] = {} 53 | full_dict_polytopeonly[scene_dir] = {} 54 | per_view_dict_polytopeonly[scene_dir] = {} 55 | 56 | test_dir = Path(scene_dir) / mode 57 | 58 | for method in os.listdir(test_dir): 59 | print("Method:", method) 60 | 61 | full_dict[scene_dir][method] = {} 62 | per_view_dict[scene_dir][method] = {} 63 | full_dict_polytopeonly[scene_dir][method] = {} 64 | per_view_dict_polytopeonly[scene_dir][method] = {} 65 | 66 | method_dir = test_dir / method 67 | gt_dir = method_dir / "gt" 68 | renders_dir = method_dir / "renders" 69 | renders, gts, image_names = readImages(renders_dir, gt_dir) 70 | 71 | print("Number of renders images:", len(renders)) 72 | print("Number of gt images:", len(gts)) 73 | ssims = [] 74 | psnrs = [] 75 | lpipss = [] 76 | 77 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 78 | ssims.append(ssim(renders[idx], gts[idx])) 79 | psnrs.append(psnr(renders[idx], gts[idx])) 80 | lpipss.append(lpips(renders[idx], gts[idx], net_type="vgg")) 81 | 82 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 83 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 84 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 85 | print("") 86 | 87 | full_dict[scene_dir][method].update( 88 | { 89 | "SSIM": torch.tensor(ssims).mean().item(), 90 | "PSNR": torch.tensor(psnrs).mean().item(), 91 | "LPIPS": torch.tensor(lpipss).mean().item(), 92 | } 93 | ) 94 | per_view_dict[scene_dir][method].update( 95 | { 96 | "SSIM": { 97 | name: ssim 98 | for ssim, name in zip( 99 | torch.tensor(ssims).tolist(), image_names 100 | ) 101 | }, 102 | "PSNR": { 103 | name: psnr 104 | for psnr, name in zip( 105 | torch.tensor(psnrs).tolist(), image_names 106 | ) 107 | }, 108 | "LPIPS": { 109 | name: lp 110 | for lp, name in zip( 111 | torch.tensor(lpipss).tolist(), image_names 112 | ) 113 | }, 114 | } 115 | ) 116 | 117 | with open(scene_dir + f"/results_{mode}.json", "w") as fp: 118 | json.dump(full_dict[scene_dir], fp, indent=True) 119 | with open(scene_dir + f"/per_view_{mode}.json", "w") as fp: 120 | json.dump(per_view_dict[scene_dir], fp, indent=True) 121 | except: 122 | print("Unable to compute metrics for model", scene_dir) 123 | 124 | 125 | if __name__ == "__main__": 126 | device = torch.device("cuda:0") 127 | torch.cuda.set_device(device) 128 | 129 | # Set up command line argument parser 130 | parser = ArgumentParser(description="Training script parameters") 131 | parser.add_argument( 132 | "--model_paths", "-m", required=True, nargs="+", type=str, default=[] 133 | ) 134 | parser.add_argument( 135 | "--mode", 136 | type=str, 137 | choices=["train", "test"], 138 | default="train", 139 | help="train or test", 140 | ) 141 | args = parser.parse_args() 142 | 143 | set_args(args) 144 | evaluate(args.model_paths, args.mode) 145 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.distributed as dist 14 | from scene import Scene, SceneDataset 15 | import os 16 | from tqdm import tqdm 17 | from os import makedirs 18 | from gaussian_renderer import ( 19 | # preprocess3dgs_and_all2all, 20 | # render 21 | distributed_preprocess3dgs_and_all2all_final, 22 | render_final, 23 | ) 24 | import torchvision 25 | from utils.general_utils import ( 26 | safe_state, 27 | set_args, 28 | init_distributed, 29 | set_log_file, 30 | set_cur_iter, 31 | ) 32 | from argparse import ArgumentParser 33 | from arguments import ModelParams, PipelineParams, get_combined_args 34 | from gaussian_renderer import GaussianModel 35 | from gaussian_renderer.loss_distribution import load_camera_from_cpu_to_all_gpu_for_eval 36 | from gaussian_renderer.workload_division import ( 37 | start_strategy_final, 38 | DivisionStrategyHistoryFinal, 39 | ) 40 | from arguments import ( 41 | AuxiliaryParams, 42 | ModelParams, 43 | PipelineParams, 44 | OptimizationParams, 45 | DistributionParams, 46 | BenchmarkParams, 47 | DebugParams, 48 | print_all_args, 49 | init_args, 50 | ) 51 | import utils.general_utils as utils 52 | 53 | 54 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 55 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 56 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 57 | 58 | makedirs(render_path, exist_ok=True) 59 | makedirs(gts_path, exist_ok=True) 60 | 61 | dataset = SceneDataset(views) 62 | 63 | set_cur_iter(iteration) 64 | generated_cnt = 0 65 | 66 | num_cameras = len(views) 67 | strategy_history = DivisionStrategyHistoryFinal( 68 | dataset, utils.DEFAULT_GROUP.size(), utils.DEFAULT_GROUP.rank() 69 | ) 70 | progress_bar = tqdm( 71 | range(1, num_cameras + 1), 72 | desc="Rendering progress", 73 | disable=(utils.LOCAL_RANK != 0), 74 | ) 75 | for idx in range(1, num_cameras + 1, args.bsz): 76 | progress_bar.update(args.bsz) 77 | 78 | num_camera_to_load = min(args.bsz, num_cameras - idx + 1) 79 | batched_cameras = dataset.get_batched_cameras(num_camera_to_load) 80 | batched_strategies, gpuid2tasks = start_strategy_final( 81 | batched_cameras, strategy_history 82 | ) 83 | load_camera_from_cpu_to_all_gpu_for_eval( 84 | batched_cameras, batched_strategies, gpuid2tasks 85 | ) 86 | 87 | batched_screenspace_pkg = distributed_preprocess3dgs_and_all2all_final( 88 | batched_cameras, 89 | gaussians, 90 | pipeline, 91 | background, 92 | batched_strategies=batched_strategies, 93 | mode="test", 94 | ) 95 | batched_image, _ = render_final(batched_screenspace_pkg, batched_strategies) 96 | 97 | for camera_id, (image, gt_camera) in enumerate( 98 | zip(batched_image, batched_cameras) 99 | ): 100 | actual_idx = idx + camera_id 101 | if args.sample_freq != -1 and actual_idx % args.sample_freq != 0: 102 | continue 103 | if generated_cnt == args.generate_num: 104 | break 105 | if os.path.exists( 106 | os.path.join(render_path, "{0:05d}".format(actual_idx) + ".png") 107 | ): 108 | continue 109 | if args.l != -1 and args.r != -1: 110 | if actual_idx < args.l or actual_idx >= args.r: 111 | continue 112 | 113 | generated_cnt += 1 114 | 115 | if ( 116 | image is None or len(image.shape) == 0 117 | ): # The image is not rendered locally. 118 | image = torch.zeros( 119 | gt_camera.original_image.shape, device="cuda", dtype=torch.float32 120 | ) 121 | 122 | if utils.DEFAULT_GROUP.size() > 1: 123 | torch.distributed.all_reduce( 124 | image, op=dist.ReduceOp.SUM, group=utils.DEFAULT_GROUP 125 | ) 126 | 127 | image = torch.clamp(image, 0.0, 1.0) 128 | gt_image = torch.clamp(gt_camera.original_image / 255.0, 0.0, 1.0) 129 | 130 | if utils.GLOBAL_RANK == 0: 131 | torchvision.utils.save_image( 132 | image, 133 | os.path.join(render_path, "{0:05d}".format(actual_idx) + ".png"), 134 | ) 135 | torchvision.utils.save_image( 136 | gt_image, 137 | os.path.join(gts_path, "{0:05d}".format(actual_idx) + ".png"), 138 | ) 139 | 140 | gt_camera.original_image = None 141 | 142 | if generated_cnt == args.generate_num: 143 | break 144 | 145 | 146 | def render_sets( 147 | dataset: ModelParams, 148 | iteration: int, 149 | pipeline: PipelineParams, 150 | skip_train: bool, 151 | skip_test: bool, 152 | ): 153 | with torch.no_grad(): 154 | args = utils.get_args() 155 | gaussians = GaussianModel(dataset.sh_degree) 156 | scene = Scene(args, gaussians, load_iteration=iteration, shuffle=False) 157 | 158 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 159 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 160 | 161 | if not skip_train: 162 | render_set( 163 | dataset.model_path, 164 | "train", 165 | scene.loaded_iter, 166 | scene.getTrainCameras(), 167 | gaussians, 168 | pipeline, 169 | background, 170 | ) 171 | 172 | if not skip_test: 173 | render_set( 174 | dataset.model_path, 175 | "test", 176 | scene.loaded_iter, 177 | scene.getTestCameras(), 178 | gaussians, 179 | pipeline, 180 | background, 181 | ) 182 | 183 | 184 | if __name__ == "__main__": 185 | # Set up command line argument parser 186 | parser = ArgumentParser(description="Training script parameters") 187 | ap = AuxiliaryParams(parser) 188 | lp = ModelParams(parser, sentinel=True) 189 | op = OptimizationParams(parser) 190 | pp = PipelineParams(parser) 191 | dist_p = DistributionParams(parser) 192 | bench_p = BenchmarkParams(parser) 193 | debug_p = DebugParams(parser) 194 | parser.add_argument("--iteration", default=-1, type=int) 195 | parser.add_argument("--skip_train", action="store_true") 196 | parser.add_argument("--skip_test", action="store_true") 197 | parser.add_argument("--generate_num", default=-1, type=int) 198 | parser.add_argument("--sample_freq", default=-1, type=int) 199 | parser.add_argument("--distributed_load", action="store_true") # TODO: delete this. 200 | parser.add_argument("--l", default=-1, type=int) 201 | parser.add_argument("--r", default=-1, type=int) 202 | args = get_combined_args(parser) 203 | print("Rendering " + args.model_path) 204 | init_distributed(args) 205 | # This script only supports single-gpu rendering. 206 | # I need to put the flags here because the render() function need it. 207 | # However, disable them during render.py because they are only needed during training. 208 | 209 | log_file = open( 210 | args.model_path 211 | + f"/render_ws={utils.DEFAULT_GROUP.size()}_rk_{utils.DEFAULT_GROUP.rank()}.log", 212 | "w", 213 | ) 214 | set_log_file(log_file) 215 | 216 | ## Prepare arguments. 217 | # Check arguments 218 | init_args(args) 219 | if args.skip_train: 220 | args.num_train_cameras = 0 221 | if args.skip_test: 222 | args.num_test_cameras = 0 223 | # Set up global args 224 | set_args(args) 225 | 226 | print_all_args(args, log_file) 227 | 228 | if utils.WORLD_SIZE > 1: 229 | torch.distributed.barrier(group=utils.DEFAULT_GROUP) 230 | # Initialize system state (RNG) 231 | safe_state(args.quiet) 232 | 233 | render_sets( 234 | lp.extract(args), 235 | args.iteration, 236 | pp.extract(args), 237 | args.skip_train, 238 | args.skip_test, 239 | ) 240 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from random import randint 16 | from utils.system_utils import searchForMaxIteration 17 | from scene.dataset_readers import sceneLoadTypeCallbacks 18 | from scene.gaussian_model import GaussianModel 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | import utils.general_utils as utils 21 | import torch 22 | 23 | 24 | class Scene: 25 | 26 | gaussians: GaussianModel 27 | 28 | def __init__( 29 | self, args, gaussians: GaussianModel, load_iteration=None, shuffle=True 30 | ): 31 | """b 32 | :param path: Path to colmap scene main folder. 33 | """ 34 | self.model_path = args.model_path 35 | self.loaded_iter = None 36 | self.gaussians = gaussians 37 | log_file = utils.get_log_file() 38 | 39 | if load_iteration: 40 | if load_iteration == -1: 41 | self.loaded_iter = searchForMaxIteration( 42 | os.path.join(self.model_path, "point_cloud") 43 | ) 44 | else: 45 | self.loaded_iter = load_iteration 46 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 47 | 48 | utils.log_cpu_memory_usage("before loading images meta data") 49 | 50 | if os.path.exists( 51 | os.path.join(args.source_path, "sparse") 52 | ): # This is the format from colmap. 53 | scene_info = sceneLoadTypeCallbacks["Colmap"]( 54 | args.source_path, args.images, args.eval, args.llffhold 55 | ) 56 | elif "matrixcity" in args.source_path: # This is for matrixcity 57 | scene_info = sceneLoadTypeCallbacks["City"]( 58 | args.source_path, 59 | args.random_background, 60 | args.white_background, 61 | llffhold=args.llffhold, 62 | ) 63 | else: 64 | raise ValueError("No valid dataset found in the source path") 65 | 66 | if not self.loaded_iter: 67 | with open(scene_info.ply_path, "rb") as src_file, open( 68 | os.path.join(self.model_path, "input.ply"), "wb" 69 | ) as dest_file: 70 | dest_file.write(src_file.read()) 71 | json_cams = [] 72 | camlist = [] 73 | if scene_info.test_cameras: 74 | camlist.extend(scene_info.test_cameras) 75 | if scene_info.train_cameras: 76 | camlist.extend(scene_info.train_cameras) 77 | for id, cam in enumerate(camlist): 78 | json_cams.append(camera_to_JSON(id, cam)) 79 | with open(os.path.join(self.model_path, "cameras.json"), "w") as file: 80 | json.dump(json_cams, file) 81 | 82 | if shuffle: 83 | random.shuffle( 84 | scene_info.train_cameras 85 | ) # Multi-res consistent random shuffling 86 | random.shuffle( 87 | scene_info.test_cameras 88 | ) # Multi-res consistent random shuffling 89 | 90 | utils.log_cpu_memory_usage("before decoding images") 91 | 92 | self.cameras_extent = scene_info.nerf_normalization["radius"] 93 | 94 | # Set image size to global variable 95 | orig_w, orig_h = ( 96 | scene_info.train_cameras[0].width, 97 | scene_info.train_cameras[0].height, 98 | ) 99 | utils.set_img_size(orig_h, orig_w) 100 | # Dataset size in GB 101 | dataset_size_in_GB = ( 102 | 1.0 103 | * (len(scene_info.train_cameras) + len(scene_info.test_cameras)) 104 | * orig_w 105 | * orig_h 106 | * 3 107 | / 1e9 108 | ) 109 | log_file.write(f"Dataset size: {dataset_size_in_GB} GB\n") 110 | if ( 111 | dataset_size_in_GB < args.preload_dataset_to_gpu_threshold 112 | ): # 10GB memory limit for dataset 113 | log_file.write( 114 | f"[NOTE]: Preloading dataset({dataset_size_in_GB}GB) to GPU. Disable local_sampling and distributed_dataset_storage.\n" 115 | ) 116 | print( 117 | f"[NOTE]: Preloading dataset({dataset_size_in_GB}GB) to GPU. Disable local_sampling and distributed_dataset_storage." 118 | ) 119 | args.preload_dataset_to_gpu = True 120 | args.local_sampling = False # TODO: Preloading dataset to GPU is not compatible with local_sampling and distributed_dataset_storage for now. Fix this. 121 | args.distributed_dataset_storage = False 122 | 123 | # Train on original resolution, no downsampling in our implementation. 124 | utils.print_rank_0("Decoding Training Cameras") 125 | self.train_cameras = None 126 | self.test_cameras = None 127 | if args.num_train_cameras >= 0: 128 | train_cameras = scene_info.train_cameras[: args.num_train_cameras] 129 | else: 130 | train_cameras = scene_info.train_cameras 131 | self.train_cameras = cameraList_from_camInfos(train_cameras, args) 132 | # output the number of cameras in the training set and image size to the log file 133 | log_file.write( 134 | "Number of local training cameras: {}\n".format(len(self.train_cameras)) 135 | ) 136 | if len(self.train_cameras) > 0: 137 | log_file.write( 138 | "Image size: {}x{}\n".format( 139 | self.train_cameras[0].image_height, 140 | self.train_cameras[0].image_width, 141 | ) 142 | ) 143 | 144 | if args.eval: 145 | utils.print_rank_0("Decoding Test Cameras") 146 | if args.num_test_cameras >= 0: 147 | test_cameras = scene_info.test_cameras[: args.num_test_cameras] 148 | else: 149 | test_cameras = scene_info.test_cameras 150 | self.test_cameras = cameraList_from_camInfos(test_cameras, args) 151 | # output the number of cameras in the training set and image size to the log file 152 | log_file.write( 153 | "Number of local test cameras: {}\n".format(len(self.test_cameras)) 154 | ) 155 | if len(self.test_cameras) > 0: 156 | log_file.write( 157 | "Image size: {}x{}\n".format( 158 | self.test_cameras[0].image_height, 159 | self.test_cameras[0].image_width, 160 | ) 161 | ) 162 | 163 | utils.check_initial_gpu_memory_usage("after Loading all images") 164 | utils.log_cpu_memory_usage("after decoding images") 165 | 166 | if self.loaded_iter: 167 | self.gaussians.load_ply( 168 | os.path.join( 169 | self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter) 170 | ) 171 | ) 172 | elif hasattr(args, "load_ply_path"): 173 | self.gaussians.load_ply(args.load_ply_path) 174 | else: 175 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 176 | 177 | utils.check_initial_gpu_memory_usage("after initializing point cloud") 178 | utils.log_cpu_memory_usage("after loading initial 3dgs points") 179 | 180 | def save(self, iteration): 181 | point_cloud_path = os.path.join( 182 | self.model_path, "point_cloud/iteration_{}".format(iteration) 183 | ) 184 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 185 | 186 | def getTrainCameras(self): 187 | return self.train_cameras 188 | 189 | def getTestCameras(self): 190 | return self.test_cameras 191 | 192 | def log_scene_info_to_file(self, log_file, prefix_str=""): 193 | 194 | # Print shape of gaussians parameters. 195 | log_file.write("xyz shape: {}\n".format(self.gaussians._xyz.shape)) 196 | log_file.write("f_dc shape: {}\n".format(self.gaussians._features_dc.shape)) 197 | log_file.write("f_rest shape: {}\n".format(self.gaussians._features_rest.shape)) 198 | log_file.write("opacity shape: {}\n".format(self.gaussians._opacity.shape)) 199 | log_file.write("scaling shape: {}\n".format(self.gaussians._scaling.shape)) 200 | log_file.write("rotation shape: {}\n".format(self.gaussians._rotation.shape)) 201 | 202 | 203 | class SceneDataset: 204 | def __init__(self, cameras): 205 | self.cameras = cameras 206 | self.camera_size = len(self.cameras) 207 | self.sample_camera_idx = [] 208 | for i in range(self.camera_size): 209 | if self.cameras[i].original_image_backup is not None: 210 | self.sample_camera_idx.append(i) 211 | # print("Number of cameras with sample images: ", len(self.sample_camera_idx)) 212 | 213 | self.cur_epoch_cameras = [] 214 | self.cur_iteration = 0 215 | 216 | self.iteration_loss = [] 217 | self.epoch_loss = [] 218 | 219 | self.log_file = utils.get_log_file() 220 | self.args = utils.get_args() 221 | 222 | self.last_time_point = None 223 | self.epoch_time = [] 224 | self.epoch_n_sample = [] 225 | 226 | @property 227 | def cur_epoch(self): 228 | return len(self.epoch_loss) 229 | 230 | @property 231 | def cur_iteration_in_epoch(self): 232 | return len(self.iteration_loss) 233 | 234 | def get_one_camera(self, batched_cameras_uid): 235 | args = utils.get_args() 236 | if len(self.cur_epoch_cameras) == 0: 237 | # start a new epoch 238 | if args.local_sampling: 239 | self.cur_epoch_cameras = self.sample_camera_idx.copy() 240 | else: 241 | self.cur_epoch_cameras = list(range(self.camera_size)) 242 | # random.shuffle(self.cur_epoch_cameras) 243 | indices = torch.randperm(len(self.cur_epoch_cameras)) 244 | self.cur_epoch_cameras = [self.cur_epoch_cameras[i] for i in indices] 245 | 246 | self.cur_iteration += 1 247 | 248 | idx = 0 249 | while self.cameras[self.cur_epoch_cameras[idx]].uid in batched_cameras_uid: 250 | idx += 1 251 | camera_idx = self.cur_epoch_cameras.pop(idx) 252 | viewpoint_cam = self.cameras[camera_idx] 253 | return camera_idx, viewpoint_cam 254 | 255 | def get_batched_cameras(self, batch_size): 256 | assert ( 257 | batch_size <= self.camera_size 258 | ), "Batch size is larger than the number of cameras in the scene." 259 | batched_cameras = [] 260 | batched_cameras_uid = [] 261 | for i in range(batch_size): 262 | _, camera = self.get_one_camera(batched_cameras_uid) 263 | batched_cameras.append(camera) 264 | batched_cameras_uid.append(camera.uid) 265 | 266 | return batched_cameras 267 | 268 | def get_batched_cameras_idx(self, batch_size): 269 | assert ( 270 | batch_size <= self.camera_size 271 | ), "Batch size is larger than the number of cameras in the scene." 272 | batched_cameras_idx = [] 273 | batched_cameras_uid = [] 274 | for i in range(batch_size): 275 | idx, camera = self.get_one_camera(batched_cameras_uid) 276 | batched_cameras_uid.append(camera.uid) 277 | batched_cameras_idx.append(idx) 278 | 279 | return batched_cameras_idx 280 | 281 | def get_batched_cameras_from_idx(self, idx_list): 282 | return [self.cameras[i] for i in idx_list] 283 | 284 | def update_losses(self, losses): 285 | for loss in losses: 286 | self.iteration_loss.append(loss) 287 | if len(self.iteration_loss) % self.camera_size == 0: 288 | self.epoch_loss.append( 289 | sum(self.iteration_loss[-self.camera_size :]) / self.camera_size 290 | ) 291 | self.log_file.write( 292 | "epoch {} loss: {}\n".format( 293 | len(self.epoch_loss), self.epoch_loss[-1] 294 | ) 295 | ) 296 | self.iteration_loss = [] 297 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | from utils.general_utils import get_args, get_log_file 17 | import utils.general_utils as utils 18 | import time 19 | 20 | 21 | class Camera(nn.Module): 22 | def __init__( 23 | self, 24 | colmap_id, 25 | R, 26 | T, 27 | FoVx, 28 | FoVy, 29 | image, 30 | gt_alpha_mask, 31 | image_name, 32 | uid, 33 | trans=np.array([0.0, 0.0, 0.0]), 34 | scale=1.0, 35 | ): 36 | super(Camera, self).__init__() 37 | 38 | self.uid = uid 39 | self.colmap_id = colmap_id 40 | self.R = R 41 | self.T = T 42 | self.FoVx = FoVx 43 | self.FoVy = FoVy 44 | self.image_name = image_name 45 | 46 | args = get_args() 47 | log_file = get_log_file() 48 | 49 | if args.time_image_loading: 50 | start_time = time.time() 51 | 52 | if ( 53 | ( 54 | args.local_sampling 55 | and args.distributed_dataset_storage 56 | and utils.GLOBAL_RANK == uid % utils.WORLD_SIZE 57 | ) 58 | or ( 59 | not args.local_sampling 60 | and args.distributed_dataset_storage 61 | and utils.LOCAL_RANK == 0 62 | ) 63 | or (not args.distributed_dataset_storage) 64 | ): 65 | # load to cpu 66 | self.original_image_backup = image.contiguous() 67 | if args.preload_dataset_to_gpu: 68 | self.original_image_backup = self.original_image_backup.to("cuda") 69 | self.image_width = self.original_image_backup.shape[2] 70 | self.image_height = self.original_image_backup.shape[1] 71 | else: 72 | self.original_image_backup = None 73 | self.image_height, self.image_width = utils.get_img_size() 74 | 75 | if args.time_image_loading: 76 | log_file.write(f"Image processing in {time.time() - start_time} seconds\n") 77 | 78 | self.zfar = 100.0 79 | self.znear = 0.01 80 | 81 | self.trans = trans 82 | self.scale = scale 83 | 84 | self.world_view_transform = ( 85 | torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 86 | ) 87 | self.world_view_transform_backup = self.world_view_transform.clone().detach() 88 | self.projection_matrix = ( 89 | getProjectionMatrix( 90 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 91 | ) 92 | .transpose(0, 1) 93 | .cuda() 94 | ) 95 | self.full_proj_transform = ( 96 | self.world_view_transform.unsqueeze(0).bmm( 97 | self.projection_matrix.unsqueeze(0) 98 | ) 99 | ).squeeze(0) 100 | self.camera_center = self.world_view_transform.inverse()[3, :3] 101 | 102 | def get_camera2world(self): 103 | return self.world_view_transform_backup.t().inverse() 104 | 105 | def update(self, dx, dy, dz): 106 | # Update the position of this camera pose. TODO: support updating rotation of camera pose. 107 | with torch.no_grad(): 108 | c2w = self.get_camera2world() 109 | c2w[0, 3] += dx 110 | c2w[1, 3] += dy 111 | c2w[2, 3] += dz 112 | 113 | t_prime = c2w[:3, 3] 114 | self.T = (-c2w[:3, :3].t() @ t_prime).cpu().numpy() 115 | # import pdb; pdb.set_trace() 116 | 117 | self.world_view_transform = ( 118 | torch.tensor(getWorld2View2(self.R, self.T, self.trans, self.scale)) 119 | .transpose(0, 1) 120 | .cuda() 121 | ) 122 | self.projection_matrix = ( 123 | getProjectionMatrix( 124 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 125 | ) 126 | .transpose(0, 1) 127 | .cuda() 128 | ) 129 | self.full_proj_transform = ( 130 | self.world_view_transform.unsqueeze(0).bmm( 131 | self.projection_matrix.unsqueeze(0) 132 | ) 133 | ).squeeze(0) 134 | self.camera_center = self.world_view_transform.inverse()[3, :3] 135 | 136 | 137 | class MiniCam: 138 | def __init__( 139 | self, 140 | width, 141 | height, 142 | fovy, 143 | fovx, 144 | znear, 145 | zfar, 146 | world_view_transform, 147 | full_proj_transform, 148 | ): 149 | self.image_width = width 150 | self.image_height = height 151 | self.FoVy = fovy 152 | self.FoVx = fovx 153 | self.znear = znear 154 | self.zfar = zfar 155 | self.world_view_transform = world_view_transform 156 | self.full_proj_transform = full_proj_transform 157 | view_inv = torch.inverse(self.world_view_transform) 158 | self.camera_center = view_inv[3][:3] 159 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"] 18 | ) 19 | Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] 22 | ) 23 | Point3D = collections.namedtuple( 24 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] 25 | ) 26 | CAMERA_MODELS = { 27 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 28 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 29 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 30 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 31 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 32 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 33 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 34 | CameraModel(model_id=7, model_name="FOV", num_params=5), 35 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 36 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 37 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 38 | } 39 | CAMERA_MODEL_IDS = dict( 40 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] 41 | ) 42 | CAMERA_MODEL_NAMES = dict( 43 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] 44 | ) 45 | 46 | 47 | def qvec2rotmat(qvec): 48 | return np.array( 49 | [ 50 | [ 51 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 52 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 53 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 54 | ], 55 | [ 56 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 57 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 58 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 59 | ], 60 | [ 61 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 62 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 63 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 64 | ], 65 | ] 66 | ) 67 | 68 | 69 | def rotmat2qvec(R): 70 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 71 | K = ( 72 | np.array( 73 | [ 74 | [Rxx - Ryy - Rzz, 0, 0, 0], 75 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 76 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 77 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], 78 | ] 79 | ) 80 | / 3.0 81 | ) 82 | eigvals, eigvecs = np.linalg.eigh(K) 83 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 84 | if qvec[0] < 0: 85 | qvec *= -1 86 | return qvec 87 | 88 | 89 | class Image(BaseImage): 90 | def qvec2rotmat(self): 91 | return qvec2rotmat(self.qvec) 92 | 93 | 94 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 95 | """Read and unpack the next bytes from a binary file. 96 | :param fid: 97 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 98 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 99 | :param endian_character: Any of {@, =, <, >, !} 100 | :return: Tuple of read and unpacked values. 101 | """ 102 | data = fid.read(num_bytes) 103 | return struct.unpack(endian_character + format_char_sequence, data) 104 | 105 | 106 | def read_points3D_text(path): 107 | """ 108 | see: src/base/reconstruction.cc 109 | void Reconstruction::ReadPoints3DText(const std::string& path) 110 | void Reconstruction::WritePoints3DText(const std::string& path) 111 | """ 112 | xyzs = None 113 | rgbs = None 114 | errors = None 115 | num_points = 0 116 | with open(path, "r") as fid: 117 | while True: 118 | line = fid.readline() 119 | if not line: 120 | break 121 | line = line.strip() 122 | if len(line) > 0 and line[0] != "#": 123 | num_points += 1 124 | 125 | xyzs = np.empty((num_points, 3)) 126 | rgbs = np.empty((num_points, 3)) 127 | errors = np.empty((num_points, 1)) 128 | count = 0 129 | with open(path, "r") as fid: 130 | while True: 131 | line = fid.readline() 132 | if not line: 133 | break 134 | line = line.strip() 135 | if len(line) > 0 and line[0] != "#": 136 | elems = line.split() 137 | xyz = np.array(tuple(map(float, elems[1:4]))) 138 | rgb = np.array(tuple(map(int, elems[4:7]))) 139 | error = np.array(float(elems[7])) 140 | xyzs[count] = xyz 141 | rgbs[count] = rgb 142 | errors[count] = error 143 | count += 1 144 | 145 | return xyzs, rgbs, errors 146 | 147 | 148 | def read_points3D_binary(path_to_model_file): 149 | """ 150 | see: src/base/reconstruction.cc 151 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 152 | void Reconstruction::WritePoints3DBinary(const std::string& path) 153 | """ 154 | 155 | with open(path_to_model_file, "rb") as fid: 156 | num_points = read_next_bytes(fid, 8, "Q")[0] 157 | 158 | xyzs = np.empty((num_points, 3)) 159 | rgbs = np.empty((num_points, 3)) 160 | errors = np.empty((num_points, 1)) 161 | 162 | for p_id in range(num_points): 163 | binary_point_line_properties = read_next_bytes( 164 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 165 | ) 166 | xyz = np.array(binary_point_line_properties[1:4]) 167 | rgb = np.array(binary_point_line_properties[4:7]) 168 | error = np.array(binary_point_line_properties[7]) 169 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 170 | 0 171 | ] 172 | track_elems = read_next_bytes( 173 | fid, 174 | num_bytes=8 * track_length, 175 | format_char_sequence="ii" * track_length, 176 | ) 177 | xyzs[p_id] = xyz 178 | rgbs[p_id] = rgb 179 | errors[p_id] = error 180 | return xyzs, rgbs, errors 181 | 182 | 183 | def read_intrinsics_text(path): 184 | """ 185 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 186 | """ 187 | cameras = {} 188 | with open(path, "r") as fid: 189 | while True: 190 | line = fid.readline() 191 | if not line: 192 | break 193 | line = line.strip() 194 | if len(line) > 0 and line[0] != "#": 195 | elems = line.split() 196 | camera_id = int(elems[0]) 197 | model = elems[1] 198 | assert ( 199 | model == "PINHOLE" 200 | ), "While the loader support other types, the rest of the code assumes PINHOLE" 201 | width = int(elems[2]) 202 | height = int(elems[3]) 203 | params = np.array(tuple(map(float, elems[4:]))) 204 | cameras[camera_id] = Camera( 205 | id=camera_id, model=model, width=width, height=height, params=params 206 | ) 207 | return cameras 208 | 209 | 210 | def read_extrinsics_binary(path_to_model_file): 211 | """ 212 | see: src/base/reconstruction.cc 213 | void Reconstruction::ReadImagesBinary(const std::string& path) 214 | void Reconstruction::WriteImagesBinary(const std::string& path) 215 | """ 216 | images = {} 217 | with open(path_to_model_file, "rb") as fid: 218 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 219 | for _ in range(num_reg_images): 220 | binary_image_properties = read_next_bytes( 221 | fid, num_bytes=64, format_char_sequence="idddddddi" 222 | ) 223 | image_id = binary_image_properties[0] 224 | qvec = np.array(binary_image_properties[1:5]) 225 | tvec = np.array(binary_image_properties[5:8]) 226 | camera_id = binary_image_properties[8] 227 | image_name = "" 228 | current_char = read_next_bytes(fid, 1, "c")[0] 229 | while current_char != b"\x00": # look for the ASCII 0 entry 230 | image_name += current_char.decode("utf-8") 231 | current_char = read_next_bytes(fid, 1, "c")[0] 232 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 233 | 0 234 | ] 235 | x_y_id_s = read_next_bytes( 236 | fid, 237 | num_bytes=24 * num_points2D, 238 | format_char_sequence="ddq" * num_points2D, 239 | ) 240 | xys = np.column_stack( 241 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] 242 | ) 243 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 244 | images[image_id] = Image( 245 | id=image_id, 246 | qvec=qvec, 247 | tvec=tvec, 248 | camera_id=camera_id, 249 | name=image_name, 250 | xys=xys, 251 | point3D_ids=point3D_ids, 252 | ) 253 | return images 254 | 255 | 256 | def read_intrinsics_binary(path_to_model_file): 257 | """ 258 | see: src/base/reconstruction.cc 259 | void Reconstruction::WriteCamerasBinary(const std::string& path) 260 | void Reconstruction::ReadCamerasBinary(const std::string& path) 261 | """ 262 | cameras = {} 263 | with open(path_to_model_file, "rb") as fid: 264 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 265 | for _ in range(num_cameras): 266 | camera_properties = read_next_bytes( 267 | fid, num_bytes=24, format_char_sequence="iiQQ" 268 | ) 269 | camera_id = camera_properties[0] 270 | model_id = camera_properties[1] 271 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 272 | width = camera_properties[2] 273 | height = camera_properties[3] 274 | num_params = CAMERA_MODEL_IDS[model_id].num_params 275 | params = read_next_bytes( 276 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params 277 | ) 278 | cameras[camera_id] = Camera( 279 | id=camera_id, 280 | model=model_name, 281 | width=width, 282 | height=height, 283 | params=np.array(params), 284 | ) 285 | assert len(cameras) == num_cameras 286 | return cameras 287 | 288 | 289 | def read_extrinsics_text(path): 290 | """ 291 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 292 | """ 293 | images = {} 294 | with open(path, "r") as fid: 295 | while True: 296 | line = fid.readline() 297 | if not line: 298 | break 299 | line = line.strip() 300 | if len(line) > 0 and line[0] != "#": 301 | elems = line.split() 302 | image_id = int(elems[0]) 303 | qvec = np.array(tuple(map(float, elems[1:5]))) 304 | tvec = np.array(tuple(map(float, elems[5:8]))) 305 | camera_id = int(elems[8]) 306 | image_name = elems[9] 307 | elems = fid.readline().split() 308 | xys = np.column_stack( 309 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] 310 | ) 311 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 312 | images[image_id] = Image( 313 | id=image_id, 314 | qvec=qvec, 315 | tvec=tvec, 316 | camera_id=camera_id, 317 | name=image_name, 318 | xys=xys, 319 | point3D_ids=point3D_ids, 320 | ) 321 | return images 322 | 323 | 324 | def read_colmap_bin_array(path): 325 | """ 326 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 327 | 328 | :param path: path to the colmap binary file. 329 | :return: nd array with the floating point values in the value 330 | """ 331 | with open(path, "rb") as fid: 332 | width, height, channels = np.genfromtxt( 333 | fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int 334 | ) 335 | fid.seek(0) 336 | num_delimiter = 0 337 | byte = fid.read(1) 338 | while True: 339 | if byte == b"&": 340 | num_delimiter += 1 341 | if num_delimiter >= 3: 342 | break 343 | byte = fid.read(1) 344 | array = np.fromfile(fid, np.float32) 345 | array = array.reshape((width, height, channels), order="F") 346 | return np.transpose(array, (1, 0, 2)).squeeze() 347 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | import glob 15 | from PIL import Image 16 | from typing import NamedTuple 17 | from scene.colmap_loader import ( 18 | read_extrinsics_text, 19 | read_intrinsics_text, 20 | qvec2rotmat, 21 | read_extrinsics_binary, 22 | read_intrinsics_binary, 23 | read_points3D_binary, 24 | read_points3D_text, 25 | ) 26 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 27 | import utils.general_utils as utils 28 | from tqdm import tqdm 29 | import numpy as np 30 | import json 31 | from pathlib import Path 32 | from plyfile import PlyData, PlyElement 33 | from utils.sh_utils import SH2RGB 34 | from scene.gaussian_model import BasicPointCloud 35 | import torch 36 | 37 | 38 | class CameraInfo(NamedTuple): 39 | uid: int 40 | R: np.array 41 | T: np.array 42 | FovY: np.array 43 | FovX: np.array 44 | image: np.array 45 | image_path: str 46 | image_name: str 47 | width: int 48 | height: int 49 | 50 | 51 | class SceneInfo(NamedTuple): 52 | point_cloud: BasicPointCloud 53 | train_cameras: list 54 | test_cameras: list 55 | nerf_normalization: dict 56 | ply_path: str 57 | 58 | 59 | def getNerfppNorm(cam_info): 60 | def get_center_and_diag(cam_centers): 61 | cam_centers = np.hstack(cam_centers) 62 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 63 | center = avg_cam_center 64 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 65 | diagonal = np.max(dist) 66 | return center.flatten(), diagonal 67 | 68 | cam_centers = [] 69 | 70 | for cam in cam_info: 71 | W2C = getWorld2View2(cam.R, cam.T) 72 | C2W = np.linalg.inv(W2C) 73 | cam_centers.append(C2W[:3, 3:4]) 74 | 75 | center, diagonal = get_center_and_diag(cam_centers) 76 | radius = diagonal * 1.1 77 | 78 | translate = -center 79 | 80 | return {"translate": translate, "radius": radius} 81 | 82 | 83 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 84 | args = utils.get_args() 85 | cam_infos = [] 86 | utils.print_rank_0("Loading cameras from disk...") 87 | for idx, key in tqdm( 88 | enumerate(cam_extrinsics), 89 | total=len(cam_extrinsics), 90 | disable=(utils.LOCAL_RANK != 0), 91 | ): 92 | 93 | extr = cam_extrinsics[key] 94 | intr = cam_intrinsics[extr.camera_id] 95 | height = intr.height 96 | width = intr.width 97 | 98 | uid = intr.id 99 | R = np.transpose(qvec2rotmat(extr.qvec)) 100 | T = np.array(extr.tvec) 101 | 102 | if intr.model == "SIMPLE_PINHOLE": 103 | focal_length_x = intr.params[0] 104 | FovY = focal2fov(focal_length_x, height) 105 | FovX = focal2fov(focal_length_x, width) 106 | elif intr.model == "PINHOLE": 107 | focal_length_x = intr.params[0] 108 | focal_length_y = intr.params[1] 109 | FovY = focal2fov(focal_length_y, height) 110 | FovX = focal2fov(focal_length_x, width) 111 | elif intr.model == "OPENCV": 112 | # we're ignoring the 4 distortion 113 | focal_length_x = intr.params[0] 114 | focal_length_y = intr.params[1] 115 | FovY = focal2fov(focal_length_y, height) 116 | FovX = focal2fov(focal_length_x, width) 117 | else: 118 | assert ( 119 | False 120 | ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 121 | 122 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 123 | image_name = os.path.basename(image_path).split(".")[0] 124 | image = Image.open( 125 | image_path 126 | ) # this is a lazy load, the image is not loaded yet 127 | width, height = image.size 128 | 129 | cam_info = CameraInfo( 130 | uid=uid, 131 | R=R, 132 | T=T, 133 | FovY=FovY, 134 | FovX=FovX, 135 | image=None, 136 | image_path=image_path, 137 | image_name=image_name, 138 | width=width, 139 | height=height, 140 | ) 141 | 142 | # release memory 143 | image.close() 144 | image = None 145 | 146 | cam_infos.append(cam_info) 147 | return cam_infos 148 | 149 | 150 | def fetchPly(path): 151 | plydata = PlyData.read(path) 152 | vertices = plydata["vertex"] 153 | positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T 154 | try: 155 | colors = ( 156 | np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0 157 | ) 158 | except: 159 | colors = np.random.rand(positions.shape[0], positions.shape[1]) 160 | try: 161 | normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T 162 | except: 163 | normals = np.random.rand(positions.shape[0], positions.shape[1]) 164 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 165 | 166 | 167 | def storePly(path, xyz, rgb): 168 | # Define the dtype for the structured array 169 | dtype = [ 170 | ("x", "f4"), 171 | ("y", "f4"), 172 | ("z", "f4"), 173 | ("nx", "f4"), 174 | ("ny", "f4"), 175 | ("nz", "f4"), 176 | ("red", "u1"), 177 | ("green", "u1"), 178 | ("blue", "u1"), 179 | ] 180 | 181 | normals = np.zeros_like(xyz) 182 | 183 | elements = np.empty(xyz.shape[0], dtype=dtype) 184 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 185 | elements[:] = list(map(tuple, attributes)) 186 | 187 | # Create the PlyData object and write to file 188 | vertex_element = PlyElement.describe(elements, "vertex") 189 | ply_data = PlyData([vertex_element]) 190 | ply_data.write(path) 191 | 192 | 193 | def readColmapSceneInfo(path, images, eval, llffhold=10): 194 | try: 195 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 196 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 197 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 198 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 199 | except: 200 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 201 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 202 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 203 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 204 | 205 | reading_dir = "images" if images == None else images 206 | cam_infos_unsorted = readColmapCameras( 207 | cam_extrinsics=cam_extrinsics, 208 | cam_intrinsics=cam_intrinsics, 209 | images_folder=os.path.join(path, reading_dir), 210 | ) 211 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) 212 | 213 | if eval: 214 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 215 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 216 | else: 217 | train_cam_infos = cam_infos 218 | test_cam_infos = [] 219 | 220 | nerf_normalization = getNerfppNorm(train_cam_infos) 221 | 222 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 223 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 224 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 225 | if not os.path.exists(ply_path): 226 | if utils.GLOBAL_RANK == 0: 227 | print( 228 | "Converting point3d.bin to .ply, will happen only the first time you open the scene." 229 | ) 230 | try: 231 | xyz, rgb, _ = read_points3D_binary(bin_path) 232 | except: 233 | xyz, rgb, _ = read_points3D_text(txt_path) 234 | storePly(ply_path, xyz, rgb) 235 | if utils.DEFAULT_GROUP.size() > 1: 236 | torch.distributed.barrier() 237 | else: 238 | if utils.DEFAULT_GROUP.size() > 1: 239 | torch.distributed.barrier() 240 | try: 241 | pcd = fetchPly(ply_path) 242 | except: 243 | pcd = None 244 | 245 | scene_info = SceneInfo( 246 | point_cloud=pcd, 247 | train_cameras=train_cam_infos, 248 | test_cameras=test_cam_infos, 249 | nerf_normalization=nerf_normalization, 250 | ply_path=ply_path, 251 | ) 252 | return scene_info 253 | 254 | 255 | def readCamerasFromTransformsCity( 256 | path, 257 | transformsfile, 258 | random_background, 259 | white_background, 260 | extension=".png", 261 | undistorted=False, 262 | is_debug=False, 263 | ): 264 | cam_infos = [] 265 | if undistorted: 266 | print("Undistortion the images!!!") 267 | # TODO: Support undistortion here. Please refer to octree-gs implementation. 268 | with open(os.path.join(path, transformsfile)) as json_file: 269 | contents = json.load(json_file) 270 | try: 271 | fovx = contents["camera_angle_x"] 272 | except: 273 | fovx = None 274 | 275 | frames = contents["frames"] 276 | # check if filename already contain postfix 277 | if frames[0]["file_path"].split(".")[-1] in ["jpg", "jpeg", "JPG", "png"]: 278 | extension = "" 279 | 280 | c2ws = np.array([frame["transform_matrix"] for frame in frames]) 281 | 282 | Ts = c2ws[:, :3, 3] 283 | 284 | ct = 0 285 | 286 | progress_bar = tqdm(frames, desc="Loading dataset") 287 | 288 | for idx, frame in enumerate(frames): 289 | # cam_name = os.path.join(path, frame["file_path"] + extension) 290 | cam_name = frame["file_path"] 291 | if not os.path.exists(cam_name): 292 | print(f"File {cam_name} not found, skipping...") 293 | continue 294 | # NeRF 'transform_matrix' is a camera-to-world transform 295 | c2w = np.array(frame["transform_matrix"]) 296 | 297 | if idx % 10 == 0: 298 | progress_bar.set_postfix({"num": f"{ct}/{len(frames)}"}) 299 | progress_bar.update(10) 300 | if idx == len(frames) - 1: 301 | progress_bar.close() 302 | 303 | ct += 1 304 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 305 | c2w[:3, 1:3] *= -1 306 | 307 | # get the world-to-camera transform and set R, T 308 | w2c = np.linalg.inv(c2w) 309 | 310 | R = np.transpose( 311 | w2c[:3, :3] 312 | ) # R is stored transposed due to 'glm' in CUDA code 313 | T = w2c[:3, 3] 314 | 315 | image_path = os.path.join(path, cam_name) 316 | image_name = cam_name[-17:] # Path(cam_name).stem 317 | image = Image.open(image_path) 318 | 319 | if fovx is not None: 320 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 321 | FovY = fovy 322 | FovX = fovx 323 | else: 324 | # given focal in pixel unit 325 | FovY = focal2fov(frame["fl_y"], image.size[1]) 326 | FovX = focal2fov(frame["fl_x"], image.size[0]) 327 | 328 | cam_infos.append( 329 | CameraInfo( 330 | uid=idx, 331 | R=R, 332 | T=T, 333 | FovY=FovY, 334 | FovX=FovX, 335 | image=None, 336 | image_path=image_path, 337 | image_name=image_name, 338 | width=image.size[0], 339 | height=image.size[1], 340 | ) 341 | ) 342 | 343 | # release memory 344 | image.close() 345 | image = None 346 | 347 | if is_debug and idx > 50: 348 | break 349 | return cam_infos 350 | 351 | 352 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 353 | cam_infos = [] 354 | 355 | with open(os.path.join(path, transformsfile)) as json_file: 356 | contents = json.load(json_file) 357 | fovx = contents["camera_angle_x"] 358 | 359 | frames = contents["frames"] 360 | for idx, frame in enumerate(frames): 361 | cam_name = os.path.join(path, frame["file_path"] + extension) 362 | 363 | # NeRF 'transform_matrix' is a camera-to-world transform 364 | c2w = np.array(frame["transform_matrix"]) 365 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 366 | c2w[:3, 1:3] *= -1 367 | 368 | # get the world-to-camera transform and set R, T 369 | w2c = np.linalg.inv(c2w) 370 | R = np.transpose( 371 | w2c[:3, :3] 372 | ) # R is stored transposed due to 'glm' in CUDA code 373 | T = w2c[:3, 3] 374 | 375 | image_path = os.path.join(path, cam_name) 376 | image_name = Path(cam_name).stem 377 | image = Image.open(image_path) 378 | 379 | im_data = np.array(image.convert("RGBA")) 380 | 381 | bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) 382 | 383 | norm_data = im_data / 255.0 384 | arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * ( 385 | 1 - norm_data[:, :, 3:4] 386 | ) 387 | image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") 388 | 389 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 390 | FovY = fovy 391 | FovX = fovx 392 | 393 | cam_infos.append( 394 | CameraInfo( 395 | uid=idx, 396 | R=R, 397 | T=T, 398 | FovY=FovY, 399 | FovX=FovX, 400 | image=image, 401 | image_path=image_path, 402 | image_name=image_name, 403 | width=image.size[0], 404 | height=image.size[1], 405 | ) 406 | ) 407 | 408 | return cam_infos 409 | 410 | 411 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 412 | print("Reading Training Transforms") 413 | train_cam_infos = readCamerasFromTransforms( 414 | path, "transforms_train.json", white_background, extension 415 | ) 416 | print("Reading Test Transforms") 417 | test_cam_infos = readCamerasFromTransforms( 418 | path, "transforms_test.json", white_background, extension 419 | ) 420 | 421 | if not eval: 422 | train_cam_infos.extend(test_cam_infos) 423 | test_cam_infos = [] 424 | 425 | nerf_normalization = getNerfppNorm(train_cam_infos) 426 | 427 | ply_path = os.path.join(path, "points3d.ply") 428 | if not os.path.exists(ply_path): 429 | # Since this data set has no colmap data, we start with random points 430 | num_pts = 100_000 431 | print(f"Generating random point cloud ({num_pts})...") 432 | 433 | # We create random points inside the bounds of the synthetic Blender scenes 434 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 435 | shs = np.random.random((num_pts, 3)) / 255.0 436 | pcd = BasicPointCloud( 437 | points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)) 438 | ) 439 | 440 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 441 | try: 442 | pcd = fetchPly(ply_path) 443 | except: 444 | pcd = None 445 | 446 | scene_info = SceneInfo( 447 | point_cloud=pcd, 448 | train_cameras=train_cam_infos, 449 | test_cameras=test_cam_infos, 450 | nerf_normalization=nerf_normalization, 451 | ply_path=ply_path, 452 | ) 453 | return scene_info 454 | 455 | 456 | def readCityInfo( 457 | path, 458 | random_background, 459 | white_background, 460 | extension=".tif", 461 | llffhold=8, 462 | undistorted=False, 463 | ): 464 | 465 | train_json_path = os.path.join(path, f"transforms_train.json") 466 | test_json_path = os.path.join(path, f"transforms_test.json") 467 | print( 468 | "Reading Training Transforms from {} {}".format(train_json_path, test_json_path) 469 | ) 470 | 471 | train_cam_infos = readCamerasFromTransformsCity( 472 | path, 473 | train_json_path, 474 | random_background, 475 | white_background, 476 | extension, 477 | undistorted, 478 | ) 479 | test_cam_infos = readCamerasFromTransformsCity( 480 | path, 481 | test_json_path, 482 | random_background, 483 | white_background, 484 | extension, 485 | undistorted, 486 | ) 487 | print("Load Cameras(train, test): ", len(train_cam_infos), len(test_cam_infos)) 488 | 489 | nerf_normalization = getNerfppNorm(train_cam_infos) 490 | 491 | ply_path = glob.glob(os.path.join(path, "*.ply"))[0] 492 | if os.path.exists(ply_path): 493 | try: 494 | pcd = fetchPly(ply_path) 495 | except: 496 | raise ValueError("must have tiepoints!") 497 | else: 498 | assert False, "No ply file found!" 499 | 500 | scene_info = SceneInfo( 501 | point_cloud=pcd, 502 | train_cameras=train_cam_infos, 503 | test_cameras=test_cam_infos, 504 | nerf_normalization=nerf_normalization, 505 | ply_path=ply_path, 506 | ) 507 | return scene_info 508 | 509 | 510 | sceneLoadTypeCallbacks = { 511 | "Colmap": readColmapSceneInfo, 512 | "Blender": readNerfSyntheticInfo, 513 | "City": readCityInfo, 514 | } 515 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import sys 15 | import json 16 | from utils.general_utils import safe_state, init_distributed 17 | import utils.general_utils as utils 18 | from argparse import ArgumentParser 19 | from arguments import ( 20 | AuxiliaryParams, 21 | ModelParams, 22 | PipelineParams, 23 | OptimizationParams, 24 | DistributionParams, 25 | BenchmarkParams, 26 | DebugParams, 27 | print_all_args, 28 | init_args, 29 | ) 30 | import train_internal 31 | 32 | if __name__ == "__main__": 33 | # Set up command line argument parser 34 | parser = ArgumentParser(description="Training script parameters") 35 | ap = AuxiliaryParams(parser) 36 | lp = ModelParams(parser) 37 | op = OptimizationParams(parser) 38 | pp = PipelineParams(parser) 39 | dist_p = DistributionParams(parser) 40 | bench_p = BenchmarkParams(parser) 41 | debug_p = DebugParams(parser) 42 | args = parser.parse_args(sys.argv[1:]) 43 | 44 | # Set up distributed training 45 | init_distributed(args) 46 | 47 | ## Prepare arguments. 48 | # Check arguments 49 | init_args(args) 50 | 51 | args = utils.get_args() 52 | 53 | # create log folder 54 | if utils.GLOBAL_RANK == 0: 55 | os.makedirs(args.log_folder, exist_ok=True) 56 | os.makedirs(args.model_path, exist_ok=True) 57 | if utils.WORLD_SIZE > 1: 58 | torch.distributed.barrier( 59 | group=utils.DEFAULT_GROUP 60 | ) # log_folder is created before other ranks start writing log. 61 | if utils.GLOBAL_RANK == 0: 62 | with open(args.log_folder + "/args.json", "w") as f: 63 | json.dump(vars(args), f) 64 | 65 | # Initialize system state (RNG) 66 | safe_state(args.quiet) 67 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 68 | 69 | # Initialize log file and print all args 70 | log_file = open( 71 | args.log_folder 72 | + "/python_ws=" 73 | + str(utils.WORLD_SIZE) 74 | + "_rk=" 75 | + str(utils.GLOBAL_RANK) 76 | + ".log", 77 | "a" if args.auto_start_checkpoint else "w", 78 | ) 79 | utils.set_log_file(log_file) 80 | print_all_args(args, log_file) 81 | 82 | train_internal.training( 83 | lp.extract(args), op.extract(args), pp.extract(args), args, log_file 84 | ) 85 | 86 | # All done 87 | if utils.WORLD_SIZE > 1: 88 | torch.distributed.barrier(group=utils.DEFAULT_GROUP) 89 | utils.print_rank_0("\nTraining complete.") 90 | -------------------------------------------------------------------------------- /train_internal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from utils.loss_utils import l1_loss 5 | from gaussian_renderer import ( 6 | distributed_preprocess3dgs_and_all2all_final, 7 | render_final, 8 | gsplat_distributed_preprocess3dgs_and_all2all_final, 9 | gsplat_render_final, 10 | ) 11 | from torch.cuda import nvtx 12 | from scene import Scene, GaussianModel, SceneDataset 13 | from gaussian_renderer.workload_division import ( 14 | start_strategy_final, 15 | finish_strategy_final, 16 | DivisionStrategyHistoryFinal, 17 | ) 18 | from gaussian_renderer.loss_distribution import ( 19 | load_camera_from_cpu_to_all_gpu, 20 | load_camera_from_cpu_to_all_gpu_for_eval, 21 | batched_loss_computation, 22 | ) 23 | from utils.general_utils import prepare_output_and_logger, globally_sync_for_timer 24 | import utils.general_utils as utils 25 | from utils.timer import Timer, End2endTimer 26 | from tqdm import tqdm 27 | from utils.image_utils import psnr 28 | import torch.distributed as dist 29 | from densification import densification, gsplat_densification 30 | 31 | 32 | def training(dataset_args, opt_args, pipe_args, args, log_file): 33 | 34 | # Init auxiliary tools 35 | 36 | timers = Timer(args) 37 | utils.set_timers(timers) 38 | prepare_output_and_logger(dataset_args) 39 | utils.log_cpu_memory_usage("at the beginning of training") 40 | start_from_this_iteration = 1 41 | 42 | # Init parameterized scene 43 | gaussians = GaussianModel(dataset_args.sh_degree) 44 | 45 | with torch.no_grad(): 46 | scene = Scene(args, gaussians) 47 | gaussians.training_setup(opt_args) 48 | 49 | if args.start_checkpoint != "": 50 | model_params, start_from_this_iteration = utils.load_checkpoint(args) 51 | gaussians.restore(model_params, opt_args) 52 | utils.print_rank_0( 53 | "Restored from checkpoint: {}".format(args.start_checkpoint) 54 | ) 55 | log_file.write( 56 | "Restored from checkpoint: {}\n".format(args.start_checkpoint) 57 | ) 58 | 59 | scene.log_scene_info_to_file(log_file, "Scene Info Before Training") 60 | utils.check_initial_gpu_memory_usage("after init and before training loop") 61 | 62 | # Init dataset 63 | train_dataset = SceneDataset(scene.getTrainCameras()) 64 | if args.adjust_strategy_warmp_iterations == -1: 65 | args.adjust_strategy_warmp_iterations = len(train_dataset.cameras) 66 | # use one epoch to warm up. do not use the first epoch's running time for adjustment of strategy. 67 | 68 | # Init distribution strategy history 69 | strategy_history = DivisionStrategyHistoryFinal( 70 | train_dataset, utils.DEFAULT_GROUP.size(), utils.DEFAULT_GROUP.rank() 71 | ) 72 | 73 | # Init background 74 | background = None 75 | if args.backend == "gsplat": 76 | bg_color = [1, 1, 1] if dataset_args.white_background else None 77 | else: 78 | bg_color = [1, 1, 1] if dataset_args.white_background else [0, 0, 0] 79 | 80 | if bg_color is not None: 81 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 82 | 83 | # Training Loop 84 | end2end_timers = End2endTimer(args) 85 | end2end_timers.start() 86 | progress_bar = tqdm( 87 | range(1, opt_args.iterations + 1), 88 | desc="Training progress", 89 | disable=(utils.LOCAL_RANK != 0), 90 | ) 91 | progress_bar.update(start_from_this_iteration - 1) 92 | num_trained_batches = 0 93 | 94 | ema_loss_for_log = 0 95 | for iteration in range( 96 | start_from_this_iteration, opt_args.iterations + 1, args.bsz 97 | ): 98 | # Step Initialization 99 | if iteration // args.bsz % 30 == 0: 100 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 101 | progress_bar.update(args.bsz) 102 | utils.set_cur_iter(iteration) 103 | gaussians.update_learning_rate(iteration) 104 | num_trained_batches += 1 105 | timers.clear() 106 | if args.nsys_profile: 107 | nvtx.range_push(f"iteration[{iteration},{iteration+args.bsz})") 108 | # Every 1000 its we increase the levels of SH up to a maximum degree 109 | if utils.check_update_at_this_iter(iteration, args.bsz, 1000, 0): 110 | gaussians.oneupSHdegree() 111 | 112 | # Prepare data: Pick random Cameras for training 113 | if args.local_sampling: 114 | assert ( 115 | args.bsz % utils.WORLD_SIZE == 0 116 | ), "Batch size should be divisible by the number of GPUs." 117 | batched_cameras_idx = train_dataset.get_batched_cameras_idx( 118 | args.bsz // utils.WORLD_SIZE 119 | ) 120 | batched_all_cameras_idx = torch.zeros( 121 | (utils.WORLD_SIZE, len(batched_cameras_idx)), device="cuda", dtype=int 122 | ) 123 | batched_cameras_idx = torch.tensor( 124 | batched_cameras_idx, device="cuda", dtype=int 125 | ) 126 | torch.distributed.all_gather_into_tensor( 127 | batched_all_cameras_idx, batched_cameras_idx, group=utils.DEFAULT_GROUP 128 | ) 129 | batched_all_cameras_idx = batched_all_cameras_idx.cpu().numpy().squeeze() 130 | batched_cameras = train_dataset.get_batched_cameras_from_idx( 131 | batched_all_cameras_idx 132 | ) 133 | else: 134 | batched_cameras = train_dataset.get_batched_cameras(args.bsz) 135 | 136 | with torch.no_grad(): 137 | # Prepare Workload division strategy 138 | timers.start("prepare_strategies") 139 | batched_strategies, gpuid2tasks = start_strategy_final( 140 | batched_cameras, strategy_history 141 | ) 142 | timers.stop("prepare_strategies") 143 | 144 | # Load ground-truth images to GPU 145 | timers.start("load_cameras") 146 | load_camera_from_cpu_to_all_gpu( 147 | batched_cameras, batched_strategies, gpuid2tasks 148 | ) 149 | timers.stop("load_cameras") 150 | 151 | if args.backend == "gsplat": 152 | batched_screenspace_pkg = ( 153 | gsplat_distributed_preprocess3dgs_and_all2all_final( 154 | batched_cameras, 155 | gaussians, 156 | pipe_args, 157 | background, 158 | batched_strategies=batched_strategies, 159 | mode="train", 160 | ) 161 | ) 162 | batched_image, batched_compute_locally = gsplat_render_final( 163 | batched_screenspace_pkg, batched_strategies 164 | ) 165 | batch_statistic_collector = [ 166 | cuda_args["stats_collector"] 167 | for cuda_args in batched_screenspace_pkg["batched_cuda_args"] 168 | ] 169 | else: 170 | batched_screenspace_pkg = distributed_preprocess3dgs_and_all2all_final( 171 | batched_cameras, 172 | gaussians, 173 | pipe_args, 174 | background, 175 | batched_strategies=batched_strategies, 176 | mode="train", 177 | ) 178 | batched_image, batched_compute_locally = render_final( 179 | batched_screenspace_pkg, batched_strategies 180 | ) 181 | batch_statistic_collector = [ 182 | cuda_args["stats_collector"] 183 | for cuda_args in batched_screenspace_pkg["batched_cuda_args"] 184 | ] 185 | 186 | loss_sum, batched_losses = batched_loss_computation( 187 | batched_image, 188 | batched_cameras, 189 | batched_compute_locally, 190 | batched_strategies, 191 | batch_statistic_collector, 192 | ) 193 | 194 | timers.start("backward") 195 | loss_sum.backward() 196 | timers.stop("backward") 197 | utils.check_initial_gpu_memory_usage("after backward") 198 | 199 | with torch.no_grad(): 200 | # Adjust workload division strategy. 201 | globally_sync_for_timer() 202 | timers.start("finish_strategy_final") 203 | finish_strategy_final( 204 | batched_cameras, 205 | strategy_history, 206 | batched_strategies, 207 | batch_statistic_collector, 208 | ) 209 | timers.stop("finish_strategy_final") 210 | 211 | # Sync losses in the batch 212 | timers.start("sync_loss_and_log") 213 | batched_losses = torch.tensor(batched_losses, device="cuda") 214 | if utils.DEFAULT_GROUP.size() > 1: 215 | dist.all_reduce( 216 | batched_losses, op=dist.ReduceOp.SUM, group=utils.DEFAULT_GROUP 217 | ) 218 | batched_loss = (1.0 - args.lambda_dssim) * batched_losses[ 219 | :, 0 220 | ] + args.lambda_dssim * (1.0 - batched_losses[:, 1]) 221 | batched_loss_cpu = batched_loss.cpu().numpy() 222 | ema_loss_for_log = ( 223 | batched_loss_cpu.mean() 224 | if ema_loss_for_log is None 225 | else 0.6 * ema_loss_for_log + 0.4 * batched_loss_cpu.mean() 226 | ) 227 | # Update Epoch Statistics 228 | train_dataset.update_losses(batched_loss_cpu) 229 | # Logging 230 | batched_loss_cpu = [round(loss, 6) for loss in batched_loss_cpu] 231 | log_string = "iteration[{},{}) loss: {} image: {}\n".format( 232 | iteration, 233 | iteration + args.bsz, 234 | batched_loss_cpu, 235 | [viewpoint_cam.image_name for viewpoint_cam in batched_cameras], 236 | ) 237 | log_file.write(log_string) 238 | timers.stop("sync_loss_and_log") 239 | 240 | # Evaluation 241 | end2end_timers.stop() 242 | training_report( 243 | iteration, 244 | l1_loss, 245 | args.test_iterations, 246 | scene, 247 | pipe_args, 248 | background, 249 | args.backend, 250 | ) 251 | end2end_timers.start() 252 | 253 | # Densification 254 | if args.backend == "gsplat": 255 | gsplat_densification( 256 | iteration, scene, gaussians, batched_screenspace_pkg 257 | ) 258 | else: 259 | densification(iteration, scene, gaussians, batched_screenspace_pkg) 260 | 261 | # Save Gaussians 262 | if any( 263 | [ 264 | iteration <= save_iteration < iteration + args.bsz 265 | for save_iteration in args.save_iterations 266 | ] 267 | ): 268 | end2end_timers.stop() 269 | end2end_timers.print_time(log_file, iteration + args.bsz) 270 | utils.print_rank_0("\n[ITER {}] Saving Gaussians".format(iteration)) 271 | log_file.write("[ITER {}] Saving Gaussians\n".format(iteration)) 272 | scene.save(iteration) 273 | 274 | if args.save_strategy_history: 275 | with open( 276 | args.log_folder 277 | + "/strategy_history_ws=" 278 | + str(utils.WORLD_SIZE) 279 | + "_rk=" 280 | + str(utils.GLOBAL_RANK) 281 | + ".json", 282 | "w", 283 | ) as f: 284 | json.dump(strategy_history.to_json(), f) 285 | end2end_timers.start() 286 | 287 | # Save Checkpoints 288 | if any( 289 | [ 290 | iteration <= checkpoint_iteration < iteration + args.bsz 291 | for checkpoint_iteration in args.checkpoint_iterations 292 | ] 293 | ): 294 | end2end_timers.stop() 295 | utils.print_rank_0("\n[ITER {}] Saving Checkpoint".format(iteration)) 296 | log_file.write("[ITER {}] Saving Checkpoint\n".format(iteration)) 297 | save_folder = scene.model_path + "/checkpoints/" + str(iteration) + "/" 298 | if utils.DEFAULT_GROUP.rank() == 0: 299 | os.makedirs(save_folder, exist_ok=True) 300 | if utils.DEFAULT_GROUP.size() > 1: 301 | torch.distributed.barrier(group=utils.DEFAULT_GROUP) 302 | elif utils.DEFAULT_GROUP.size() > 1: 303 | torch.distributed.barrier(group=utils.DEFAULT_GROUP) 304 | torch.save( 305 | (gaussians.capture(), iteration + args.bsz), 306 | save_folder 307 | + "/chkpnt_ws=" 308 | + str(utils.WORLD_SIZE) 309 | + "_rk=" 310 | + str(utils.GLOBAL_RANK) 311 | + ".pth", 312 | ) 313 | end2end_timers.start() 314 | 315 | # Optimizer step 316 | if iteration < opt_args.iterations: 317 | timers.start("optimizer_step") 318 | 319 | if ( 320 | args.lr_scale_mode != "accumu" 321 | ): # we scale the learning rate rather than accumulate the gradients. 322 | for param in gaussians.all_parameters(): 323 | if param.grad is not None: 324 | param.grad /= args.bsz 325 | 326 | if not args.stop_update_param: 327 | gaussians.optimizer.step() 328 | gaussians.optimizer.zero_grad(set_to_none=True) 329 | timers.stop("optimizer_step") 330 | utils.check_initial_gpu_memory_usage("after optimizer step") 331 | 332 | # Finish a iteration and clean up 333 | torch.cuda.synchronize() 334 | for ( 335 | viewpoint_cam 336 | ) in batched_cameras: # Release memory of locally rendered original_image 337 | viewpoint_cam.original_image = None 338 | if args.nsys_profile: 339 | nvtx.range_pop() 340 | if utils.check_enable_python_timer(): 341 | timers.printTimers(iteration, mode="sum") 342 | log_file.flush() 343 | 344 | # Finish training 345 | if opt_args.iterations not in args.save_iterations: 346 | end2end_timers.print_time(log_file, opt_args.iterations) 347 | log_file.write( 348 | "Max Memory usage: {} GB.\n".format( 349 | torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 350 | ) 351 | ) 352 | progress_bar.close() 353 | 354 | 355 | def training_report( 356 | iteration, l1_loss, testing_iterations, scene: Scene, pipe_args, background, backend 357 | ): 358 | args = utils.get_args() 359 | log_file = utils.get_log_file() 360 | # Report test and samples of training set 361 | while len(testing_iterations) > 0 and iteration > testing_iterations[0]: 362 | testing_iterations.pop(0) 363 | if len(testing_iterations) > 0 and utils.check_update_at_this_iter( 364 | iteration, utils.get_args().bsz, testing_iterations[0], 0 365 | ): 366 | testing_iterations.pop(0) 367 | utils.print_rank_0("\n[ITER {}] Start Testing".format(iteration)) 368 | 369 | validation_configs = ( 370 | {"name": "test", "cameras": scene.getTestCameras(), "num_cameras": len(scene.getTestCameras())}, 371 | { 372 | "name": "train", 373 | "cameras": scene.getTrainCameras(), 374 | "num_cameras": max(len(scene.getTrainCameras()) // args.llffhold, args.bsz), 375 | }, 376 | ) 377 | 378 | # init workload division strategy 379 | for config in validation_configs: 380 | if config["cameras"] and len(config["cameras"]) > 0: 381 | l1_test = torch.scalar_tensor(0.0, device="cuda") 382 | psnr_test = torch.scalar_tensor(0.0, device="cuda") 383 | 384 | # TODO: if not divisible by world size 385 | num_cameras = config["num_cameras"] // args.bsz * args.bsz 386 | eval_dataset = SceneDataset(config["cameras"]) 387 | strategy_history = DivisionStrategyHistoryFinal( 388 | eval_dataset, utils.DEFAULT_GROUP.size(), utils.DEFAULT_GROUP.rank() 389 | ) 390 | for idx in range(1, num_cameras + 1, args.bsz): 391 | num_camera_to_load = min(args.bsz, num_cameras - idx + 1) 392 | if args.local_sampling: 393 | # TODO: if not divisible by world size 394 | batched_cameras_idx = eval_dataset.get_batched_cameras_idx( 395 | args.bsz // utils.WORLD_SIZE 396 | ) 397 | batched_all_cameras_idx = torch.zeros( 398 | (utils.WORLD_SIZE, len(batched_cameras_idx)), 399 | device="cuda", 400 | dtype=int, 401 | ) 402 | batched_cameras_idx = torch.tensor( 403 | batched_cameras_idx, device="cuda", dtype=int 404 | ) 405 | torch.distributed.all_gather_into_tensor( 406 | batched_all_cameras_idx, 407 | batched_cameras_idx, 408 | group=utils.DEFAULT_GROUP, 409 | ) 410 | batched_all_cameras_idx = ( 411 | batched_all_cameras_idx.cpu().numpy().squeeze() 412 | ) 413 | batched_cameras = eval_dataset.get_batched_cameras_from_idx( 414 | batched_all_cameras_idx 415 | ) 416 | else: 417 | batched_cameras = eval_dataset.get_batched_cameras( 418 | num_camera_to_load 419 | ) 420 | batched_strategies, gpuid2tasks = start_strategy_final( 421 | batched_cameras, strategy_history 422 | ) 423 | load_camera_from_cpu_to_all_gpu_for_eval( 424 | batched_cameras, batched_strategies, gpuid2tasks 425 | ) 426 | if backend == "gsplat": 427 | batched_screenspace_pkg = ( 428 | gsplat_distributed_preprocess3dgs_and_all2all_final( 429 | batched_cameras, 430 | scene.gaussians, 431 | pipe_args, 432 | background, 433 | batched_strategies=batched_strategies, 434 | mode="test", 435 | ) 436 | ) 437 | batched_image, _ = gsplat_render_final( 438 | batched_screenspace_pkg, batched_strategies 439 | ) 440 | else: 441 | batched_screenspace_pkg = ( 442 | distributed_preprocess3dgs_and_all2all_final( 443 | batched_cameras, 444 | scene.gaussians, 445 | pipe_args, 446 | background, 447 | batched_strategies=batched_strategies, 448 | mode="test", 449 | ) 450 | ) 451 | batched_image, _ = render_final( 452 | batched_screenspace_pkg, batched_strategies 453 | ) 454 | for camera_id, (image, gt_camera) in enumerate( 455 | zip(batched_image, batched_cameras) 456 | ): 457 | if ( 458 | image is None or len(image.shape) == 0 459 | ): # The image is not rendered locally. 460 | image = torch.zeros( 461 | gt_camera.original_image.shape, 462 | device="cuda", 463 | dtype=torch.float32, 464 | ) 465 | 466 | if utils.DEFAULT_GROUP.size() > 1: 467 | torch.distributed.all_reduce( 468 | image, op=dist.ReduceOp.SUM, group=utils.DEFAULT_GROUP 469 | ) 470 | 471 | image = torch.clamp(image, 0.0, 1.0) 472 | gt_image = torch.clamp( 473 | gt_camera.original_image / 255.0, 0.0, 1.0 474 | ) 475 | 476 | if idx + camera_id < num_cameras + 1: 477 | l1_test += l1_loss(image, gt_image).mean().double() 478 | psnr_test += psnr(image, gt_image).mean().double() 479 | gt_camera.original_image = None 480 | psnr_test /= num_cameras 481 | l1_test /= num_cameras 482 | utils.print_rank_0( 483 | "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( 484 | iteration, config["name"], l1_test, psnr_test 485 | ) 486 | ) 487 | log_file.write( 488 | "[ITER {}] Evaluating {}: L1 {} PSNR {}\n".format( 489 | iteration, config["name"], l1_test, psnr_test 490 | ) 491 | ) 492 | 493 | torch.cuda.empty_cache() 494 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch, get_args, get_log_file 15 | import utils.general_utils as utils 16 | from tqdm import tqdm 17 | from utils.graphics_utils import fov2focal 18 | import time 19 | import multiprocessing 20 | from multiprocessing import shared_memory 21 | import torch 22 | from PIL import Image 23 | 24 | 25 | def loadCam(args, id, cam_info, decompressed_image=None, return_image=False): 26 | orig_w, orig_h = cam_info.width, cam_info.height 27 | assert ( 28 | orig_w == utils.get_img_width() and orig_h == utils.get_img_height() 29 | ), "All images should have the same size. " 30 | 31 | args = get_args() 32 | log_file = get_log_file() 33 | resolution = orig_w, orig_h 34 | # NOTE: we do not support downsampling here. 35 | 36 | # may use cam_info.uid 37 | if ( 38 | ( 39 | args.local_sampling 40 | and args.distributed_dataset_storage 41 | and utils.GLOBAL_RANK == id % utils.WORLD_SIZE 42 | ) 43 | or ( 44 | not args.local_sampling 45 | and args.distributed_dataset_storage 46 | and utils.LOCAL_RANK == 0 47 | ) 48 | or (not args.distributed_dataset_storage) 49 | ): 50 | if args.time_image_loading: 51 | start_time = time.time() 52 | image = Image.open(cam_info.image_path) 53 | resized_image_rgb = PILtoTorch( 54 | image, resolution, args, log_file, decompressed_image=decompressed_image 55 | ) 56 | if args.time_image_loading: 57 | log_file.write(f"PILtoTorch image in {time.time() - start_time} seconds\n") 58 | 59 | # assert resized_image_rgb.shape[0] == 3, "Image should have exactly 3 channels!" 60 | gt_image = resized_image_rgb[:3, ...].contiguous() 61 | loaded_mask = None 62 | 63 | # Free the memory: because the PIL image has been converted to torch tensor, we don't need it anymore. And it takes up lots of cpu memory. 64 | image.close() 65 | image = None 66 | else: 67 | gt_image = None 68 | loaded_mask = None 69 | 70 | if return_image: 71 | return gt_image 72 | 73 | return Camera( 74 | colmap_id=cam_info.uid, 75 | R=cam_info.R, 76 | T=cam_info.T, 77 | FoVx=cam_info.FovX, 78 | FoVy=cam_info.FovY, 79 | image=gt_image, 80 | gt_alpha_mask=loaded_mask, 81 | image_name=cam_info.image_name, 82 | uid=id, 83 | ) 84 | 85 | 86 | def load_decompressed_image(params): 87 | args, id, cam_info = params 88 | return loadCam(args, id, cam_info, decompressed_image=None, return_image=True) 89 | 90 | 91 | # Modify this code to support shared_memory.SharedMemory to make inter-process communication faster 92 | def decompressed_images_from_camInfos_multiprocess(cam_infos, args): 93 | args = get_args() 94 | decompressed_images = [] 95 | total_cameras = len(cam_infos) 96 | 97 | # Create a pool of processes 98 | with multiprocessing.Pool(processes=2) as pool: 99 | # Prepare data for processing 100 | tasks = [(args, id, cam_info) for id, cam_info in enumerate(cam_infos)] 101 | 102 | # Map load_camera_data to the tasks 103 | # results = pool.map(load_decompressed_image, tasks) 104 | results = list( 105 | tqdm( 106 | pool.imap(load_decompressed_image, tasks), 107 | total=total_cameras, 108 | disable=(utils.LOCAL_RANK != 0), 109 | ) 110 | ) 111 | 112 | for id, result in enumerate(results): 113 | decompressed_images.append(result) 114 | 115 | return decompressed_images 116 | 117 | 118 | def decompress_and_scale_image(cam_info): 119 | pil_image = cam_info.image 120 | resolution = cam_info.image.size # (w, h) 121 | # print("cam_info.image.size: ", cam_info.image.size) 122 | pil_image.load() 123 | resized_image_PIL = pil_image.resize(resolution) 124 | resized_image = np.array(resized_image_PIL) # (h, w, 3) 125 | # print("resized_image.shape: ", resized_image.shape) 126 | if len(resized_image.shape) == 3: 127 | return resized_image.transpose(2, 0, 1) 128 | else: 129 | return resized_image[..., np.newaxis].transpose(2, 0, 1) 130 | 131 | 132 | def load_decompressed_image_shared(params): 133 | shared_mem_name, args, id, cam_info, resolution_scale = params 134 | # Retrieve the shared memory block 135 | existing_shm = shared_memory.SharedMemory(name=shared_mem_name) 136 | 137 | # Assume each image will be stored as a flat array in shared memory 138 | # Example: using numpy for manipulation; adjust size and dtype as needed 139 | resolution_width, resolution_height = cam_info.image.size 140 | image_shape = (3, resolution_height, resolution_width) # Set appropriate values 141 | dtype = np.uint8 # Adjust as per your image data type 142 | 143 | # Calculate the offset for this particular image 144 | offset = id * np.prod(image_shape) 145 | np_image_array = np.ndarray( 146 | image_shape, dtype=dtype, buffer=existing_shm.buf, offset=offset 147 | ) 148 | 149 | # Decompress image into the numpy array directly 150 | decompressed_image = decompress_and_scale_image(cam_info) # Implement this 151 | np_image_array[:] = decompressed_image 152 | 153 | # Clean up 154 | existing_shm.close() 155 | 156 | 157 | def decompressed_images_from_camInfos_multiprocess_sharedmem( 158 | cam_infos, resolution_scale, args 159 | ): 160 | args = get_args() 161 | decompressed_images = [] 162 | total_cameras = len(cam_infos) 163 | 164 | # Assume each image shape and dtype 165 | resolution_width, resolution_height = cam_infos[0].image.size 166 | image_shape = ( 167 | 3, 168 | resolution_height, 169 | resolution_width, 170 | ) # Define these as per your data 171 | dtype = np.uint8 172 | image_size = np.prod(image_shape) * np.dtype(dtype).itemsize 173 | 174 | # Create shared memory 175 | total_size = image_size * total_cameras 176 | shm = shared_memory.SharedMemory(create=True, size=total_size) 177 | 178 | # Create a pool of processes 179 | with multiprocessing.Pool(16) as pool: 180 | # Prepare data for processing 181 | tasks = [ 182 | (shm.name, args, id, cam_info, resolution_scale) 183 | for id, cam_info in enumerate(cam_infos) 184 | ] 185 | 186 | # print("Start Parallel loading...") 187 | # Map load_camera_data to the tasks 188 | list( 189 | tqdm(pool.imap(load_decompressed_image_shared, tasks), total=total_cameras) 190 | ) 191 | 192 | # Read images from shared memory 193 | decompressed_images = [] 194 | for id in range(total_cameras): 195 | offset = id * np.prod(image_shape) 196 | np_image_array = np.ndarray( 197 | image_shape, dtype=dtype, buffer=shm.buf, offset=offset 198 | ) 199 | decompressed_images.append( 200 | torch.from_numpy(np_image_array) 201 | ) # Make a copy if necessary 202 | 203 | # Clean up shared memory 204 | shm.close() 205 | shm.unlink() 206 | 207 | return decompressed_images 208 | 209 | 210 | def cameraList_from_camInfos(cam_infos, args): 211 | args = get_args() 212 | 213 | if args.multiprocesses_image_loading: 214 | decompressed_images = decompressed_images_from_camInfos_multiprocess( 215 | cam_infos, args 216 | ) 217 | # decompressed_images = decompressed_images_from_camInfos_multiprocess_sharedmem(cam_infos, resolution_scale, args) 218 | else: 219 | decompressed_images = [None for _ in cam_infos] 220 | 221 | camera_list = [] 222 | for id, c in tqdm( 223 | enumerate(cam_infos), total=len(cam_infos), disable=(utils.LOCAL_RANK != 0) 224 | ): 225 | camera_list.append( 226 | loadCam( 227 | args, 228 | id, 229 | c, 230 | decompressed_image=decompressed_images[id], 231 | return_image=False, 232 | ) 233 | ) 234 | 235 | if utils.DEFAULT_GROUP.size() > 1: 236 | torch.distributed.barrier(group=utils.DEFAULT_GROUP) 237 | 238 | return camera_list 239 | 240 | 241 | def camera_to_JSON(id, camera: Camera): 242 | Rt = np.zeros((4, 4)) 243 | Rt[:3, :3] = camera.R.transpose() 244 | Rt[:3, 3] = camera.T 245 | Rt[3, 3] = 1.0 246 | 247 | W2C = np.linalg.inv(Rt) 248 | pos = W2C[:3, 3] 249 | rot = W2C[:3, :3] 250 | serializable_array_2d = [x.tolist() for x in rot] 251 | camera_entry = { 252 | "id": id, 253 | "img_name": camera.image_name, 254 | "width": camera.width, 255 | "height": camera.height, 256 | "position": pos.tolist(), 257 | "rotation": serializable_array_2d, 258 | "fy": fov2focal(camera.FovY, camera.height), 259 | "fx": fov2focal(camera.FovX, camera.width), 260 | } 261 | return camera_entry 262 | -------------------------------------------------------------------------------- /utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import utils.general_utils as utils 4 | 5 | 6 | def save_image_for_debug(image, file_name, keep_digits=8): 7 | # save image for visualization 8 | 9 | file = open(file_name, "w") 10 | 11 | image_cpu = image.detach().cpu().numpy() 12 | 13 | for c in range(3): 14 | for i in range(image.shape[1]): 15 | for j in range(image.shape[2]): 16 | float_value = image_cpu[c, i, j] 17 | file.write(f"{float_value:.{keep_digits}f} ") 18 | file.write("\n") 19 | file.write("\n") 20 | 21 | 22 | def save_image_tiles_for_debug(image_tiles, file_name, keep_digits=3): 23 | # save image for visualization 24 | 25 | file = open(file_name, "w") 26 | 27 | image_tiles_cpu = image_tiles.detach().cpu().numpy() 28 | 29 | for tile_idx in range(image_tiles.shape[0]): 30 | file.write(f"tile_idx " + str(tile_idx) + "\n") 31 | for c in range(3): 32 | file.write(f"channel " + str(c) + "\n") 33 | for i in range(utils.BLOCK_X): 34 | for j in range(utils.BLOCK_Y): 35 | float_value = image_tiles_cpu[tile_idx, c, i, j] 36 | file.write(f"{float_value:.{keep_digits}f} ") 37 | file.write("\n") 38 | 39 | 40 | def save_all_pos_for_debug(all_pos, file_name): 41 | 42 | file = open(file_name, "w") 43 | 44 | all_pos_cpu = all_pos.detach().cpu().numpy() 45 | 46 | for i in range(all_pos.shape[0]): 47 | for j in range(all_pos.shape[1]): 48 | int_value = all_pos_cpu[i, j] 49 | file.write(f"{int_value} ") 50 | file.write("\n") 51 | 52 | 53 | def save_compute_locally_for_debug(compute_locally, file_name): 54 | file = open(file_name, "w") 55 | 56 | compute_locally_cpu = compute_locally.detach().cpu().numpy() 57 | 58 | for i in range(compute_locally.shape[0]): 59 | for j in range(compute_locally.shape[1]): 60 | int_value = int(compute_locally_cpu[i, j]) 61 | file.write(f"{int_value}") 62 | file.write("\n") 63 | 64 | 65 | def save_pixels_compute_locally_for_debug(pixels_compute_locally, file_name): 66 | file = open(file_name, "w") 67 | 68 | pixels_compute_locally_cpu = pixels_compute_locally.detach().cpu().numpy() 69 | 70 | for i in range(pixels_compute_locally.shape[0]): 71 | for j in range(pixels_compute_locally.shape[1]): 72 | int_value = int(pixels_compute_locally_cpu[i, j]) 73 | file.write(f"{int_value}") 74 | file.write("\n") 75 | 76 | 77 | def save_pixel_loss_for_debug(pixel_loss, file_name, keep_digits=3): 78 | file = open(file_name, "w") 79 | 80 | pixel_loss_cpu = pixel_loss.detach().cpu().numpy() 81 | 82 | for i in range(pixel_loss.shape[0]): 83 | for j in range(pixel_loss.shape[1]): 84 | float_value = pixel_loss_cpu[i, j] 85 | file.write(f"{float_value:.{keep_digits}f} ") 86 | file.write("\n") 87 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | 15 | def mse(img1, img2): 16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 17 | 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def l2_loss(network_output, gt): 23 | return ((network_output - gt) ** 2).mean() 24 | 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor( 28 | [ 29 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 30 | for x in range(window_size) 31 | ] 32 | ) 33 | return gauss / gauss.sum() 34 | 35 | 36 | def create_window(window_size, channel): 37 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 38 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 39 | window = Variable( 40 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 41 | ) 42 | return window 43 | 44 | 45 | def ssim(img1, img2, window_size=11, size_average=True): 46 | channel = img1.size(-3) 47 | window = create_window(window_size, channel) 48 | 49 | if img1.is_cuda: 50 | window = window.cuda(img1.get_device()) 51 | window = window.type_as(img1) 52 | 53 | return _ssim(img1, img2, window, window_size, channel, size_average) 54 | 55 | 56 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 59 | 60 | mu1_sq = mu1.pow(2) 61 | mu2_sq = mu2.pow(2) 62 | mu1_mu2 = mu1 * mu2 63 | 64 | sigma1_sq = ( 65 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 66 | ) 67 | sigma2_sq = ( 68 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 69 | ) 70 | sigma12 = ( 71 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 72 | - mu1_mu2 73 | ) 74 | 75 | C1 = 0.01**2 76 | C2 = 0.03**2 77 | 78 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 79 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 80 | ) 81 | 82 | if size_average: 83 | return ssim_map.mean() 84 | else: 85 | return ssim_map.mean(1).mean(1).mean(1) 86 | 87 | 88 | def pixelwise_l1_with_mask(img1, img2, pixel_mask): 89 | # img1, img2: (3, H, W) 90 | # pixel_mask: (H, W) bool torch tensor as mask. 91 | # only compute l1 loss for the pixels that are touched 92 | 93 | pixelwise_l1_loss = torch.abs((img1 - img2)) * pixel_mask.unsqueeze(0) 94 | return pixelwise_l1_loss 95 | 96 | 97 | def pixelwise_ssim_with_mask(img1, img2, pixel_mask): 98 | window_size = 11 99 | 100 | channel = img1.size(-3) 101 | window = create_window(window_size, channel) 102 | if img1.is_cuda: 103 | window = window.cuda(img1.get_device()) 104 | window = window.type_as(img1) 105 | 106 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 107 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 108 | 109 | mu1_sq = mu1.pow(2) 110 | mu2_sq = mu2.pow(2) 111 | mu1_mu2 = mu1 * mu2 112 | 113 | sigma1_sq = ( 114 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 115 | ) 116 | sigma2_sq = ( 117 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 118 | ) 119 | sigma12 = ( 120 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 121 | - mu1_mu2 122 | ) 123 | 124 | C1 = 0.01**2 125 | C2 = 0.03**2 126 | 127 | pixelwise_ssim_loss = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 128 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 129 | ) 130 | pixelwise_ssim_loss = pixelwise_ssim_loss * pixel_mask.unsqueeze(0) 131 | 132 | return pixelwise_ssim_loss 133 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396, 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435, 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = ( 78 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 79 | ) 80 | 81 | if deg > 1: 82 | xx, yy, zz = x * x, y * y, z * z 83 | xy, yz, xz = x * y, y * z, x * z 84 | result = ( 85 | result 86 | + C2[0] * xy * sh[..., 4] 87 | + C2[1] * yz * sh[..., 5] 88 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 89 | + C2[3] * xz * sh[..., 7] 90 | + C2[4] * (xx - yy) * sh[..., 8] 91 | ) 92 | 93 | if deg > 2: 94 | result = ( 95 | result 96 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 97 | + C3[1] * xy * z * sh[..., 10] 98 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 99 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 100 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 101 | + C3[5] * z * (xx - yy) * sh[..., 14] 102 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 103 | ) 104 | 105 | if deg > 3: 106 | result = ( 107 | result 108 | + C4[0] * xy * (xx - yy) * sh[..., 16] 109 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 110 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 111 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 112 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 113 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 114 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 115 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 116 | + C4[8] 117 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 118 | * sh[..., 24] 119 | ) 120 | return result 121 | 122 | 123 | def RGB2SH(rgb): 124 | return (rgb - 0.5) / C0 125 | 126 | 127 | def SH2RGB(sh): 128 | return sh * C0 + 0.5 129 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 30 | return max(saved_iters) 31 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import utils.general_utils as utils 3 | import torch 4 | 5 | 6 | class Timer: 7 | def __init__(self, args, file=None): 8 | self.timers = {} 9 | self.args = args 10 | if args.enable_timer: 11 | # Enable time measure evaluated on python side. 12 | self.file = open( 13 | args.log_folder 14 | + "/python_time_ws=" 15 | + str(utils.WORLD_SIZE) 16 | + "_rk=" 17 | + str(utils.GLOBAL_RANK) 18 | + ".log", 19 | "w", 20 | ) 21 | else: 22 | self.file = None 23 | 24 | def start(self, key): 25 | if not utils.check_enable_python_timer(): 26 | return 27 | """Start timer for the given key""" 28 | if key not in self.timers: 29 | self.timers[key] = {"start_time": None, "cnt": 0, "all_time": []} 30 | 31 | torch.cuda.synchronize() 32 | 33 | self.timers[key]["start_time"] = time.time() 34 | 35 | def stop(self, key, print_elapsed=False): 36 | if not utils.check_enable_python_timer(): 37 | return 38 | 39 | """Stop the timer for the given key, and report the elapsed time""" 40 | if key not in self.timers or self.timers[key]["start_time"] is None: 41 | raise ValueError(f"Timer with key '{key}' is not running.") 42 | 43 | torch.cuda.synchronize() 44 | 45 | cur_time = time.time() 46 | duration = cur_time - self.timers[key]["start_time"] 47 | self.timers[key]["cnt"] += 1 48 | self.timers[key]["all_time"].append(duration) 49 | self.timers[key]["start_time"] = None 50 | if print_elapsed: 51 | print(f"Time for '{key}': {duration:.6f} seconds") 52 | return duration 53 | 54 | def printTimers( 55 | self, iteration, mode="this_iteration" 56 | ): # this_iteration, average, sum 57 | """Get the elapsed time for the given key without stopping the timer""" 58 | if not utils.check_enable_python_timer(): 59 | return 60 | 61 | for x in range(self.args.bsz): 62 | if (iteration + x) % self.args.log_interval == 1: 63 | iteration += x 64 | break 65 | 66 | for key in self.timers: 67 | if mode == "this_iteration": 68 | # print(f"iter {iteration}, TimeFor '{key}': {self.timers[key]['all_time'][-1]*1000:.6f} ms") 69 | self.file.write( 70 | f"iter {iteration}, TimeFor '{key}': {self.timers[key]['all_time'][-1]*1000:.6f} ms\n" 71 | ) 72 | elif mode == "average": 73 | average_time = ( 74 | sum(self.timers[key]["all_time"]) / self.timers[key]["cnt"] 75 | ) 76 | # print(f"iter {iteration}, AverageTimeFor '{key}': {average_time*1000:.6f} ms") 77 | self.file.write( 78 | f"iter {iteration}, AverageTimeFor '{key}': {average_time*1000:.6f} ms\n" 79 | ) 80 | elif mode == "sum": 81 | sum_time = sum(self.timers[key]["all_time"]) 82 | self.file.write( 83 | f"iter {iteration}, TimeFor '{key}': {sum_time*1000:.6f} ms\n" 84 | ) 85 | self.file.write("\n") 86 | self.file.flush() 87 | 88 | def clear(self): 89 | self.timers = {} 90 | 91 | 92 | class End2endTimer: 93 | def __init__(self, args, file=None): 94 | self.total_time = 0 95 | self.last_time_point = None 96 | self.args = args 97 | 98 | def start(self): 99 | torch.cuda.synchronize() 100 | self.last_time_point = time.time() 101 | 102 | def stop(self): 103 | torch.cuda.synchronize() 104 | new_time_point = time.time() 105 | duration = new_time_point - self.last_time_point 106 | self.total_time += duration 107 | self.last_time_point = None 108 | 109 | def print_time(self, log_file, n_iterations): 110 | if self.last_time_point is not None: 111 | self.stop() 112 | log_file.write( 113 | "end2end total_time: {:.3f} s, iterations: {}, throughput {:.2f} it/s\n".format( 114 | self.total_time, n_iterations, n_iterations / self.total_time 115 | ) 116 | ) 117 | --------------------------------------------------------------------------------